mirror of
https://github.com/tencentmusic/supersonic.git
synced 2025-12-13 04:57:28 +00:00
Add feature to s2sql that allow few-shots example tied to data model. (#571)
* 1.refactor the retrieval module. 2.refactor the http service module. 3.upgrade text2sql output format the parse for absolute time related expression in query. * fix bug. * upgrade the config module, now support config llm suppoted by langchain. * fix conflicts. * update text2sql config reload to be compitable with new config format. * modify default config. * 1.add self-consistency feature for text2sql. 2.upgrade llm api call from sync to async. 3.refactor text2sql module. 4. refactor semantical retriever modules. * merege with upstream master * add general retrieve service. * add api service for sql_agent for crud opereations of few-shots examples. * modify requirements * add auto-cot feature * 1. output log to a fixed log file. 2.allow few-shots examples tied to data model, and add strategy that extend examples when retrieved examples tied to a data model is not enough. 3. fix misformat in s2ql args. * add prior_ext to output. --------- Co-authored-by: shaoweigong <shaoweigong@tencent.com>
This commit is contained in:
@@ -70,6 +70,7 @@ def transform_sql_example_autoCoT_run(examplar_list, min_window_size, max_window
|
||||
transformed_sql_examplar = dict()
|
||||
transformed_sql_examplar['question'] = question
|
||||
transformed_sql_examplar['questionAugmented'] = question_augmented
|
||||
transformed_sql_examplar['modelName'] = table_name
|
||||
transformed_sql_examplar['dbSchema'] = db_schema
|
||||
transformed_sql_examplar['sql'] = sql
|
||||
transformed_sql_examplar['generatedSchemaLinkingCoT'] = generated_schema_linking_cot
|
||||
|
||||
@@ -33,8 +33,27 @@ class Text2DSLAgentBase(object):
|
||||
def get_examples_candidates(self, question: str, filter_condition: Mapping[str, str], num_examples: int)->List[Mapping[str, str]]:
|
||||
few_shot_example_meta_list = self.sql_example_prompter.retrieve_few_shot_example(question, num_examples, filter_condition)
|
||||
|
||||
return few_shot_example_meta_list
|
||||
|
||||
if len(few_shot_example_meta_list) == num_examples:
|
||||
return few_shot_example_meta_list
|
||||
elif len(few_shot_example_meta_list) < num_examples:
|
||||
logger.info(f"few_shot_example_meta_list size: {len(few_shot_example_meta_list)} < num_examples: {num_examples}")
|
||||
existed_id_set = set([item['id'] for item in few_shot_example_meta_list])
|
||||
extra_few_shot_example_meta_list = self.sql_example_prompter.retrieve_few_shot_example(query_text=question, retrieval_num=num_examples, filter_condition=None)
|
||||
|
||||
for item in extra_few_shot_example_meta_list:
|
||||
if item['id'] not in existed_id_set:
|
||||
few_shot_example_meta_list.append(item)
|
||||
existed_id_set.add(item['id'])
|
||||
if len(few_shot_example_meta_list) == num_examples:
|
||||
break
|
||||
|
||||
logger.info(f"few_shot_example_meta_list size: {len(few_shot_example_meta_list)} = num_examples: {num_examples}")
|
||||
return few_shot_example_meta_list
|
||||
else:
|
||||
logger.info(f"few_shot_example_meta_list size: {len(few_shot_example_meta_list)} > num_examples: {num_examples}")
|
||||
few_shot_example_meta_list = few_shot_example_meta_list[:num_examples]
|
||||
return few_shot_example_meta_list
|
||||
|
||||
def get_fewshot_example_combos(self, example_meta_list:List[Mapping[str, str]], num_fewshots:int)-> List[List[Mapping[str, str]]]:
|
||||
fewshot_example_list = []
|
||||
for i in range(0, self.num_self_consistency):
|
||||
@@ -124,6 +143,7 @@ class Text2DSLAgentAutoCoT(Text2DSLAgentBase):
|
||||
|
||||
schema_linking_prompt = instruction + '\n\n' + schema_linking_fewshot_prompt + '\n\n' + new_case_prompt
|
||||
|
||||
logger.info(f'schema_linking_prompt: {schema_linking_prompt}')
|
||||
return schema_linking_prompt
|
||||
|
||||
|
||||
@@ -153,7 +173,8 @@ class Text2DSLAgentAutoCoT(Text2DSLAgentBase):
|
||||
new_case_prompt = new_case_template.format(dbSchema=db_schema, questionAugmented=question_augmented, schemaLinkings=schema_link_str)
|
||||
|
||||
sql_example_prompt = instruction + '\n\n' + sql_example_fewshot_prompt + '\n\n' + new_case_prompt
|
||||
|
||||
|
||||
logger.info(f'sql_example_prompt: {sql_example_prompt}')
|
||||
return sql_example_prompt
|
||||
|
||||
def generate_sql_prompt_pool(self, question: str, domain_name: str,fields_list: List[str],
|
||||
@@ -183,6 +204,7 @@ class Text2DSLAgentAutoCoT(Text2DSLAgentBase):
|
||||
|
||||
prompt = instruction + '\n\n' + fewshot_prompt + '\n\n' + new_case_prompt
|
||||
|
||||
logger.info(f'schema_linking_sql_prompt: {prompt}')
|
||||
return prompt
|
||||
|
||||
def generate_schema_linking_sql_prompt_pool(self, question: str, current_date:str, domain_name: str, fields_list: List[str],
|
||||
@@ -223,6 +245,7 @@ class Text2DSLAgentAutoCoT(Text2DSLAgentBase):
|
||||
resp['model'] = model_name
|
||||
resp['fields'] = fields_list
|
||||
resp['priorSchemaLinking'] = prior_schema_links
|
||||
resp['priorExts'] = prior_exts
|
||||
resp['currentDate'] = current_date
|
||||
|
||||
resp['schemaLinkingOutput'] = schema_link_output
|
||||
@@ -259,6 +282,7 @@ class Text2DSLAgentAutoCoT(Text2DSLAgentBase):
|
||||
resp['model'] = model_name
|
||||
resp['fields'] = fields_list
|
||||
resp['priorSchemaLinking'] = prior_schema_links
|
||||
resp['priorExts'] = prior_exts
|
||||
resp['currentDate'] = current_date
|
||||
|
||||
resp['schemaLinkingComboOutput'] = schema_linking_sql_shortcut_output
|
||||
@@ -330,6 +354,7 @@ class Text2DSLAgentAutoCoT(Text2DSLAgentBase):
|
||||
resp['model'] = model_name
|
||||
resp['fields'] = fields_list
|
||||
resp['priorSchemaLinking'] = prior_schema_links
|
||||
resp['priorExts'] = prior_exts
|
||||
resp['currentDate'] = current_date
|
||||
|
||||
resp['schemaLinkStr'] = schema_linking_output_max
|
||||
@@ -372,6 +397,7 @@ class Text2DSLAgentAutoCoT(Text2DSLAgentBase):
|
||||
resp['model'] = model_name
|
||||
resp['fields'] = fields_list
|
||||
resp['priorSchemaLinking'] = prior_schema_links
|
||||
resp['priorExts'] = prior_exts
|
||||
resp['currentDate'] = current_date
|
||||
|
||||
resp['schemaLinkStr'] = schema_linking_output_max
|
||||
@@ -775,4 +801,4 @@ class Text2DSLAgentWrapper(object):
|
||||
def count_examples(self):
|
||||
sql_agent_examples_act_cnt = self.sql_agent_act.count_examples()
|
||||
|
||||
return sql_agent_examples_act_cnt
|
||||
return sql_agent_examples_act_cnt
|
||||
Reference in New Issue
Block a user