Add feature to s2sql that allow few-shots example tied to data model. (#571)

* 1.refactor the retrieval module. 2.refactor the http service module. 3.upgrade text2sql output format the parse for absolute time related expression in query.

* fix bug.

* upgrade the config module, now support config llm suppoted by langchain.

* fix conflicts.

* update text2sql config reload to be compitable with new config format.

* modify default config.

* 1.add self-consistency feature for text2sql. 2.upgrade llm api call from sync to async. 3.refactor text2sql module. 4. refactor semantical retriever modules.

* merege with upstream master

* add general retrieve service.

* add api service for sql_agent for crud opereations of few-shots examples.

* modify requirements

* add auto-cot feature

* 1. output log to a fixed log file.  2.allow few-shots examples tied to data model, and add strategy that extend examples when retrieved examples tied to a data model is not enough. 3. fix misformat in s2ql args.

* add prior_ext to output.

---------

Co-authored-by: shaoweigong <shaoweigong@tencent.com>
This commit is contained in:
codescracker
2023-12-27 19:39:50 +08:00
committed by GitHub
parent cf2b4bfb5c
commit b706c4efb4
6 changed files with 512 additions and 68 deletions

View File

@@ -70,6 +70,7 @@ def transform_sql_example_autoCoT_run(examplar_list, min_window_size, max_window
transformed_sql_examplar = dict()
transformed_sql_examplar['question'] = question
transformed_sql_examplar['questionAugmented'] = question_augmented
transformed_sql_examplar['modelName'] = table_name
transformed_sql_examplar['dbSchema'] = db_schema
transformed_sql_examplar['sql'] = sql
transformed_sql_examplar['generatedSchemaLinkingCoT'] = generated_schema_linking_cot

View File

@@ -33,8 +33,27 @@ class Text2DSLAgentBase(object):
def get_examples_candidates(self, question: str, filter_condition: Mapping[str, str], num_examples: int)->List[Mapping[str, str]]:
few_shot_example_meta_list = self.sql_example_prompter.retrieve_few_shot_example(question, num_examples, filter_condition)
return few_shot_example_meta_list
if len(few_shot_example_meta_list) == num_examples:
return few_shot_example_meta_list
elif len(few_shot_example_meta_list) < num_examples:
logger.info(f"few_shot_example_meta_list size: {len(few_shot_example_meta_list)} < num_examples: {num_examples}")
existed_id_set = set([item['id'] for item in few_shot_example_meta_list])
extra_few_shot_example_meta_list = self.sql_example_prompter.retrieve_few_shot_example(query_text=question, retrieval_num=num_examples, filter_condition=None)
for item in extra_few_shot_example_meta_list:
if item['id'] not in existed_id_set:
few_shot_example_meta_list.append(item)
existed_id_set.add(item['id'])
if len(few_shot_example_meta_list) == num_examples:
break
logger.info(f"few_shot_example_meta_list size: {len(few_shot_example_meta_list)} = num_examples: {num_examples}")
return few_shot_example_meta_list
else:
logger.info(f"few_shot_example_meta_list size: {len(few_shot_example_meta_list)} > num_examples: {num_examples}")
few_shot_example_meta_list = few_shot_example_meta_list[:num_examples]
return few_shot_example_meta_list
def get_fewshot_example_combos(self, example_meta_list:List[Mapping[str, str]], num_fewshots:int)-> List[List[Mapping[str, str]]]:
fewshot_example_list = []
for i in range(0, self.num_self_consistency):
@@ -124,6 +143,7 @@ class Text2DSLAgentAutoCoT(Text2DSLAgentBase):
schema_linking_prompt = instruction + '\n\n' + schema_linking_fewshot_prompt + '\n\n' + new_case_prompt
logger.info(f'schema_linking_prompt: {schema_linking_prompt}')
return schema_linking_prompt
@@ -153,7 +173,8 @@ class Text2DSLAgentAutoCoT(Text2DSLAgentBase):
new_case_prompt = new_case_template.format(dbSchema=db_schema, questionAugmented=question_augmented, schemaLinkings=schema_link_str)
sql_example_prompt = instruction + '\n\n' + sql_example_fewshot_prompt + '\n\n' + new_case_prompt
logger.info(f'sql_example_prompt: {sql_example_prompt}')
return sql_example_prompt
def generate_sql_prompt_pool(self, question: str, domain_name: str,fields_list: List[str],
@@ -183,6 +204,7 @@ class Text2DSLAgentAutoCoT(Text2DSLAgentBase):
prompt = instruction + '\n\n' + fewshot_prompt + '\n\n' + new_case_prompt
logger.info(f'schema_linking_sql_prompt: {prompt}')
return prompt
def generate_schema_linking_sql_prompt_pool(self, question: str, current_date:str, domain_name: str, fields_list: List[str],
@@ -223,6 +245,7 @@ class Text2DSLAgentAutoCoT(Text2DSLAgentBase):
resp['model'] = model_name
resp['fields'] = fields_list
resp['priorSchemaLinking'] = prior_schema_links
resp['priorExts'] = prior_exts
resp['currentDate'] = current_date
resp['schemaLinkingOutput'] = schema_link_output
@@ -259,6 +282,7 @@ class Text2DSLAgentAutoCoT(Text2DSLAgentBase):
resp['model'] = model_name
resp['fields'] = fields_list
resp['priorSchemaLinking'] = prior_schema_links
resp['priorExts'] = prior_exts
resp['currentDate'] = current_date
resp['schemaLinkingComboOutput'] = schema_linking_sql_shortcut_output
@@ -330,6 +354,7 @@ class Text2DSLAgentAutoCoT(Text2DSLAgentBase):
resp['model'] = model_name
resp['fields'] = fields_list
resp['priorSchemaLinking'] = prior_schema_links
resp['priorExts'] = prior_exts
resp['currentDate'] = current_date
resp['schemaLinkStr'] = schema_linking_output_max
@@ -372,6 +397,7 @@ class Text2DSLAgentAutoCoT(Text2DSLAgentBase):
resp['model'] = model_name
resp['fields'] = fields_list
resp['priorSchemaLinking'] = prior_schema_links
resp['priorExts'] = prior_exts
resp['currentDate'] = current_date
resp['schemaLinkStr'] = schema_linking_output_max
@@ -775,4 +801,4 @@ class Text2DSLAgentWrapper(object):
def count_examples(self):
sql_agent_examples_act_cnt = self.sql_agent_act.count_examples()
return sql_agent_examples_act_cnt
return sql_agent_examples_act_cnt