mirror of
https://github.com/tencentmusic/supersonic.git
synced 2025-12-11 03:58:14 +00:00
(fix)(chat):fix typo in s2sql and add prompt to output. (#581)
* 1.refactor the retrieval module. 2.refactor the http service module. 3.upgrade text2sql output format the parse for absolute time related expression in query. * fix bug. * upgrade the config module, now support config llm suppoted by langchain. * fix conflicts. * update text2sql config reload to be compitable with new config format. * modify default config. * 1.add self-consistency feature for text2sql. 2.upgrade llm api call from sync to async. 3.refactor text2sql module. 4. refactor semantical retriever modules. * merege with upstream master * add general retrieve service. * add api service for sql_agent for crud opereations of few-shots examples. * modify requirements * add auto-cot feature * 1. output log to a fixed log file. 2.allow few-shots examples tied to data model, and add strategy that extend examples when retrieved examples tied to a data model is not enough. 3. fix misformat in s2ql args. * add prior_ext to output. * fix type in in s2sql * add prompt to output. --------- Co-authored-by: shaoweigong <shaoweigong@tencent.com>
This commit is contained in:
@@ -25,4 +25,4 @@ LLM_PROVIDER_NAME = openai
|
||||
[LLMModel]
|
||||
MODEL_NAME = gpt-3.5-turbo-16k
|
||||
OPENAI_API_KEY = YOUR_API_KEY
|
||||
TEMPERATURE = 0.0
|
||||
TEMPERATURE = 0.0
|
||||
@@ -1,4 +1,4 @@
|
||||
examplars= [
|
||||
exemplars= [
|
||||
{ "currentDate":"2020-12-01",
|
||||
"tableName":"内容库产品",
|
||||
"fieldsList":"""["部门", "模块", "用户名", "访问次数", "访问人数", "访问时长", "数据日期"]""",
|
||||
@@ -15,7 +15,7 @@ from instances.logging_instance import logger
|
||||
from config.config_parse import (
|
||||
TEXT2DSL_EXAMPLE_NUM, TEXT2DSL_FEWSHOTS_NUM, TEXT2DSL_SELF_CONSISTENCY_NUM,
|
||||
LLMPARSER_HOST, LLMPARSER_PORT,)
|
||||
from few_shot_example.s2ql_examplar import examplars as sql_examplars
|
||||
from few_shot_example.s2sql_exemplar import exemplars as sql_exemplars
|
||||
|
||||
|
||||
def text2dsl_agent_wrapper_setting_update(llm_host:str, llm_port:str,
|
||||
@@ -35,6 +35,6 @@ def text2dsl_agent_wrapper_setting_update(llm_host:str, llm_port:str,
|
||||
|
||||
if __name__ == "__main__":
|
||||
text2dsl_agent_wrapper_setting_update(LLMPARSER_HOST,LLMPARSER_PORT,
|
||||
sql_examplars, TEXT2DSL_EXAMPLE_NUM, TEXT2DSL_FEWSHOTS_NUM, TEXT2DSL_SELF_CONSISTENCY_NUM)
|
||||
sql_exemplars, TEXT2DSL_EXAMPLE_NUM, TEXT2DSL_FEWSHOTS_NUM, TEXT2DSL_SELF_CONSISTENCY_NUM)
|
||||
|
||||
|
||||
@@ -11,15 +11,15 @@ sys.path.append(os.path.dirname(os.path.abspath(__file__)))
|
||||
|
||||
import json
|
||||
|
||||
from s2ql.constructor import FewShotPromptTemplate2
|
||||
from s2ql.sql_agent import Text2DSLAgent, Text2DSLAgentAutoCoT, Text2DSLAgentWrapper
|
||||
from s2sql.constructor import FewShotPromptTemplate2
|
||||
from s2sql.sql_agent import Text2DSLAgent, Text2DSLAgentAutoCoT, Text2DSLAgentWrapper
|
||||
|
||||
from instances.llm_instance import llm
|
||||
from instances.chromadb_instance import client as chromadb_client
|
||||
from instances.logging_instance import logger
|
||||
from instances.text2vec_instance import emb_func
|
||||
|
||||
from few_shot_example.s2ql_examplar import examplars as sql_examplars
|
||||
from few_shot_example.s2sql_exemplar import exemplars as sql_exemplars
|
||||
from config.config_parse import (TEXT2DSLAGENT_COLLECTION_NAME, TEXT2DSLAGENTACT_COLLECTION_NAME,
|
||||
TEXT2DSL_EXAMPLE_NUM, TEXT2DSL_FEWSHOTS_NUM, TEXT2DSL_SELF_CONSISTENCY_NUM,
|
||||
ACT_MIN_WINDOWN_SIZE, ACT_MAX_WINDOWN_SIZE)
|
||||
@@ -45,8 +45,8 @@ text2sql_agent_autoCoT = Text2DSLAgentAutoCoT(num_fewshots=TEXT2DSL_FEWSHOTS_NUM
|
||||
sql_example_prompter=text2dsl_agent_act_example_prompter, llm=llm,
|
||||
auto_cot_min_window_size=ACT_MIN_WINDOWN_SIZE, auto_cot_max_window_size=ACT_MAX_WINDOWN_SIZE)
|
||||
|
||||
sql_ids = [str(i) for i in range(0, len(sql_examplars))]
|
||||
text2sql_agent.reload_setting(sql_ids, sql_examplars, TEXT2DSL_EXAMPLE_NUM, TEXT2DSL_FEWSHOTS_NUM, TEXT2DSL_SELF_CONSISTENCY_NUM)
|
||||
sql_ids = [str(i) for i in range(0, len(sql_exemplars))]
|
||||
text2sql_agent.reload_setting(sql_ids, sql_exemplars, TEXT2DSL_EXAMPLE_NUM, TEXT2DSL_FEWSHOTS_NUM, TEXT2DSL_SELF_CONSISTENCY_NUM)
|
||||
|
||||
if text2sql_agent_autoCoT.count_examples()==0:
|
||||
source_dir_path = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||||
@@ -14,9 +14,9 @@ sys.path.append(os.path.dirname(os.path.abspath(__file__)))
|
||||
|
||||
from instances.logging_instance import logger
|
||||
|
||||
from s2ql.constructor import FewShotPromptTemplate2
|
||||
from s2ql.output_parser import schema_link_parse, combo_schema_link_parse, combo_sql_parse
|
||||
from s2ql.auto_cot_run import transform_sql_example, transform_sql_example_autoCoT_run
|
||||
from s2sql.constructor import FewShotPromptTemplate2
|
||||
from s2sql.output_parser import schema_link_parse, combo_schema_link_parse, combo_sql_parse
|
||||
from s2sql.auto_cot_run import transform_sql_example, transform_sql_example_autoCoT_run
|
||||
|
||||
|
||||
class Text2DSLAgentBase(object):
|
||||
@@ -248,6 +248,8 @@ class Text2DSLAgentAutoCoT(Text2DSLAgentBase):
|
||||
resp['priorExts'] = prior_exts
|
||||
resp['currentDate'] = current_date
|
||||
|
||||
resp['prompt'] = [schema_linking_prompt+'\n\n'+sql_prompt]
|
||||
|
||||
resp['schemaLinkingOutput'] = schema_link_output
|
||||
resp['schemaLinkStr'] = schema_link_str
|
||||
|
||||
@@ -285,6 +287,8 @@ class Text2DSLAgentAutoCoT(Text2DSLAgentBase):
|
||||
resp['priorExts'] = prior_exts
|
||||
resp['currentDate'] = current_date
|
||||
|
||||
resp['prompt'] = [schema_linking_sql_shortcut_prompt]
|
||||
|
||||
resp['schemaLinkingComboOutput'] = schema_linking_sql_shortcut_output
|
||||
resp['schemaLinkStr'] = schema_linking_str
|
||||
resp['sqlOutput'] = sql_str
|
||||
@@ -303,7 +307,7 @@ class Text2DSLAgentAutoCoT(Text2DSLAgentBase):
|
||||
|
||||
schema_linking_str_pool = [schema_link_parse(schema_linking_output) for schema_linking_output in schema_linking_output_pool]
|
||||
|
||||
return schema_linking_str_pool
|
||||
return schema_linking_str_pool, schema_linking_output_pool, schema_linking_prompt_pool
|
||||
|
||||
async def generate_sql_tasks(self, question: str, model_name: str, fields_list: List[str], schema_link_str_pool: List[str],
|
||||
current_date: str, prior_schema_links: Mapping[str,str], prior_exts: str, fewshot_example_list_combo:List[List[Mapping[str, str]]]):
|
||||
@@ -313,7 +317,7 @@ class Text2DSLAgentAutoCoT(Text2DSLAgentBase):
|
||||
sql_output_pool = await asyncio.gather(*[self.llm._call_async(sql_prompt) for sql_prompt in sql_prompt_pool])
|
||||
logger.debug("sql_output_pool->{}".format(sql_output_pool))
|
||||
|
||||
return sql_output_pool
|
||||
return sql_output_pool, sql_prompt_pool
|
||||
|
||||
async def generate_schema_linking_sql_tasks(self, question: str, model_name: str, fields_list: List[str],
|
||||
current_date: str, prior_schema_links: Mapping[str,str], prior_exts: str, fewshot_example_list_combo:List[List[Mapping[str, str]]]):
|
||||
@@ -322,7 +326,7 @@ class Text2DSLAgentAutoCoT(Text2DSLAgentBase):
|
||||
schema_linking_sql_output_res_pool = await asyncio.gather(*schema_linking_sql_output_task_pool)
|
||||
logger.debug("schema_linking_sql_output_res_pool->{}".format(schema_linking_sql_output_res_pool))
|
||||
|
||||
return schema_linking_sql_output_res_pool
|
||||
return schema_linking_sql_output_res_pool, schema_linking_sql_prompt_pool, schema_linking_sql_output_task_pool
|
||||
|
||||
async def tasks_run(self, question: str, filter_condition: Mapping[str,str],
|
||||
model_name: str, fields_list: List[str],
|
||||
@@ -338,14 +342,14 @@ class Text2DSLAgentAutoCoT(Text2DSLAgentBase):
|
||||
fewshot_example_meta_list = self.get_examples_candidates(question, filter_condition, self.num_examples)
|
||||
fewshot_example_list_combo = self.get_fewshot_example_combos(fewshot_example_meta_list, self.num_fewshots)
|
||||
|
||||
schema_linking_candidate_list = await self.generate_schema_linking_tasks(question, model_name, fields_list, current_date, prior_schema_links, prior_exts, fewshot_example_list_combo)
|
||||
schema_linking_candidate_list, _, schema_linking_prompt_list = await self.generate_schema_linking_tasks(question, model_name, fields_list, current_date, prior_schema_links, prior_exts, fewshot_example_list_combo)
|
||||
logger.debug(f'schema_linking_candidate_list:{schema_linking_candidate_list}')
|
||||
schema_linking_candidate_sorted_list = self.schema_linking_list_str_unify(schema_linking_candidate_list)
|
||||
logger.debug(f'schema_linking_candidate_sorted_list:{schema_linking_candidate_sorted_list}')
|
||||
|
||||
schema_linking_output_max, schema_linking_output_vote_percentage = self.self_consistency_vote(schema_linking_candidate_sorted_list)
|
||||
|
||||
sql_output_candicates = await self.generate_sql_tasks(question, model_name, fields_list, schema_linking_candidate_list, current_date, prior_schema_links, prior_exts, fewshot_example_list_combo)
|
||||
sql_output_candicates, sql_output_prompt_list = await self.generate_sql_tasks(question, model_name, fields_list, schema_linking_candidate_list, current_date, prior_schema_links, prior_exts, fewshot_example_list_combo)
|
||||
logger.debug(f'sql_output_candicates:{sql_output_candicates}')
|
||||
sql_output_max, sql_output_vote_percentage = self.self_consistency_vote(sql_output_candicates)
|
||||
|
||||
@@ -357,6 +361,8 @@ class Text2DSLAgentAutoCoT(Text2DSLAgentBase):
|
||||
resp['priorExts'] = prior_exts
|
||||
resp['currentDate'] = current_date
|
||||
|
||||
resp['prompt'] = [schema_linking_prompt+'\n\n'+sql_prompt for schema_linking_prompt, sql_prompt in zip(schema_linking_prompt_list, sql_output_prompt_list)]
|
||||
|
||||
resp['schemaLinkStr'] = schema_linking_output_max
|
||||
resp['schemaLinkingWeight'] = schema_linking_output_vote_percentage
|
||||
|
||||
@@ -380,7 +386,7 @@ class Text2DSLAgentAutoCoT(Text2DSLAgentBase):
|
||||
fewshot_example_meta_list = self.get_examples_candidates(question, filter_condition, self.num_examples)
|
||||
fewshot_example_list_combo = self.get_fewshot_example_combos(fewshot_example_meta_list, self.num_fewshots)
|
||||
|
||||
schema_linking_sql_output_candidates = await self.generate_schema_linking_sql_tasks(question, model_name, fields_list, current_date, prior_schema_links, prior_exts, fewshot_example_list_combo)
|
||||
schema_linking_sql_output_candidates, schema_linking_sql_prompt_list, _ = await self.generate_schema_linking_sql_tasks(question, model_name, fields_list, current_date, prior_schema_links, prior_exts, fewshot_example_list_combo)
|
||||
logger.debug(f'schema_linking_sql_output_candidates:{schema_linking_sql_output_candidates}')
|
||||
schema_linking_output_candidate_list = [combo_schema_link_parse(schema_linking_sql_output_candidate) for schema_linking_sql_output_candidate in schema_linking_sql_output_candidates]
|
||||
logger.debug(f'schema_linking_sql_output_candidate_list:{schema_linking_output_candidate_list}')
|
||||
@@ -400,6 +406,8 @@ class Text2DSLAgentAutoCoT(Text2DSLAgentBase):
|
||||
resp['priorExts'] = prior_exts
|
||||
resp['currentDate'] = current_date
|
||||
|
||||
resp['prompt'] = schema_linking_sql_prompt_list
|
||||
|
||||
resp['schemaLinkStr'] = schema_linking_output_max
|
||||
resp['schemaLinkingWeight'] = schema_linking_output_vote_percentage
|
||||
|
||||
@@ -8,7 +8,7 @@ sys.path.append(os.path.dirname(os.path.abspath(__file__)))
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException
|
||||
|
||||
from services.s2ql.run import text2sql_agent_router
|
||||
from services.s2sql.run import text2sql_agent_router
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
Reference in New Issue
Block a user