feat. Terms for llm prompt (#1042)

Co-authored-by: guixuanwen
This commit is contained in:
GuiXuan Wen
2024-05-29 14:45:11 +08:00
committed by GitHub
parent bd9cc8f88a
commit 9c3509fc1f
3 changed files with 61 additions and 31 deletions

View File

@@ -2,7 +2,9 @@
import os
import sys
from typing import Any, List, Union, Mapping
from typing import Any, Dict, List, Union, Mapping
from git import Optional
sys.path.append(os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))))
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
@@ -14,7 +16,7 @@ from auto_cot import auto_cot_run
def transform_sql_example(question:str, current_date:str, table_name:str, field_list: Union[str, List[str]], prior_linkings: Union[str, Mapping[str,str]], prior_exts:str, sql:str=None):
def transform_sql_example(question:str, current_date:str, table_name:str, field_list: Union[str, List[str]], prior_linkings: Union[str, Mapping[str,str]], prior_exts:str, sql:str=None, terms_list: Optional[List[Dict]] = []):
db_schema = f"Table: {table_name}, Columns = {field_list}\nForeign_keys: []"
prior_linkings_pairs = []
@@ -40,7 +42,27 @@ def transform_sql_example(question:str, current_date:str, table_name:str, field_
current_data_str = """当前的日期是{}""".format(current_date)
question_augmented = """{question} (补充信息:{prior_linking}{current_date}) (备注: {prior_exts})""".format(question=question, prior_linking=prior_linkings_str, prior_exts=prior_exts, current_date=current_data_str)
terms_desc = ''
if len(terms_list) > 0:
terms_desc += "相关业务术语:"
for idx, term in enumerate(terms_list):
if (term['description'] is not None and len(term['description']) > 0) and (term['alias'] is not None and len(term['alias']) > 0):
terms_desc += f"""{idx+1}.<{term['name']}>是业务术语,它通常是指<{term['description']}>,类似的表达还有{term['alias']}"""
elif (term['description'] is None or len(term['description']) == 0) and (term['alias'] is not None and len(term['alias']) > 0):
terms_desc += f"""{idx+1}.<{term['name']}>是业务术语,类似的表达还有{term['alias']}"""
elif (term['description'] is not None and len(term['description']) > 0) and (term['alias'] is None or len(term['alias']) == 0):
terms_desc += f"""{idx+1}.<{term['name']}>是业务术语,它通常是指<{term['description']}>"""
else:
terms_desc += f"""{idx+1}.<{term['name']}>是业务术语;"""
if len(terms_desc) > 0:
terms_desc = terms_desc[:-1]
question_augmented = """{question} (补充信息:{prior_linking}{current_date}{terms_desc}) (备注: {prior_exts})""".format(question=question, prior_linking=prior_linkings_str, prior_exts=prior_exts, current_date=current_data_str)
return question_augmented, db_schema, sql

View File

@@ -1,6 +1,6 @@
import os
import sys
from typing import List, Union, Mapping, Any
from typing import Dict, List, Optional, Union, Mapping, Any
from collections import Counter
import random
import asyncio
@@ -158,7 +158,7 @@ class Text2DSLAgentAutoCoT(Text2DSLAgentBase):
def generate_sql_prompt(self, question: str, domain_name: str,fields_list: List[str],
schema_link_str: str, current_date: str, prior_schema_links: Mapping[str,str], prior_exts:str,
fewshot_example_list:List[Mapping[str, str]])-> str:
fewshot_example_list:List[Mapping[str, str]], terms_list: Optional[List[Dict]] = [])-> str:
instruction = "# Use the the schema links to generate the SQL queries for each of the questions."
sql_example_keys = ["questionAugmented", "dbSchema", "generatedSchemaLinkings", "sql"]
@@ -168,7 +168,7 @@ class Text2DSLAgentAutoCoT(Text2DSLAgentBase):
example_keys=sql_example_keys,
few_shot_example_meta_list=fewshot_example_list)
question_augmented, db_schema, _ = transform_sql_example(question, current_date, domain_name, fields_list, prior_schema_links, prior_exts)
question_augmented, db_schema, _ = transform_sql_example(question, current_date, domain_name, fields_list, prior_schema_links, prior_exts, terms_list=terms_list)
new_case_template = "{dbSchema}\nQ: {questionAugmented}\nSchema_links: {schemaLinkings}\nSQL: "
new_case_prompt = new_case_template.format(dbSchema=db_schema, questionAugmented=question_augmented, schemaLinkings=schema_link_str)
@@ -179,26 +179,27 @@ class Text2DSLAgentAutoCoT(Text2DSLAgentBase):
def generate_sql_prompt_pool(self, question: str, domain_name: str,fields_list: List[str],
schema_link_str_pool: List[str], current_date: str, prior_schema_links: Mapping[str,str], prior_exts:str,
fewshot_example_list_pool:List[List[Mapping[str, str]]])-> List[str]:
fewshot_example_list_pool:List[List[Mapping[str, str]]], terms_list: Optional[List[Dict]] = [])-> List[str]:
sql_prompt_pool = []
for schema_link_str, fewshot_example_list in zip(schema_link_str_pool, fewshot_example_list_pool):
sql_prompt = self.generate_sql_prompt(question, domain_name, fields_list, schema_link_str, current_date, prior_schema_links, prior_exts, fewshot_example_list)
sql_prompt = self.generate_sql_prompt(question, domain_name, fields_list, schema_link_str, current_date, prior_schema_links, prior_exts, fewshot_example_list, terms_list=terms_list)
sql_prompt_pool.append(sql_prompt)
return sql_prompt_pool
def generate_schema_linking_sql_prompt(self, question: str, current_date:str, domain_name: str, fields_list: List[str],
prior_schema_links: Mapping[str,str], prior_exts:str, fewshot_example_list:List[Mapping[str, str]]):
prior_schema_links: Mapping[str,str], prior_exts:str, fewshot_example_list:List[Mapping[str, str]], terms_list: Optional[List[Dict]] = []):
instruction = "# Find the schema_links for generating SQL queries for each question based on the database schema and Foreign keys. Then use the the schema links to generate the SQL queries for each of the questions."
example_keys = ["questionAugmented", "dbSchema", "generatedSchemaLinkingCoT","sql"]
example_template = "{dbSchema}\nQ: {questionAugmented}\nA: {generatedSchemaLinkingCoT}\nSQL: {sql}"
example_template = "{dbSchema}\nQ: {questionAugmented}\nA: {generatedSchemaLinkingCoT}\nSQL: {sql}\n"
fewshot_prompt = self.sql_example_prompter.make_few_shot_example_prompt(few_shot_template=example_template,
example_keys=example_keys,
few_shot_example_meta_list=fewshot_example_list)
question_augmented, db_schema, _ = transform_sql_example(question, current_date, domain_name, fields_list, prior_schema_links, prior_exts)
question_augmented, db_schema, _ = transform_sql_example(question, current_date, domain_name, fields_list, prior_schema_links, prior_exts, terms_list)
new_case_template = """{dbSchema}\nQ: {questionAugmented1}\nA: Lets think step by step. In the question "{questionAugmented2}", we are asked:"""
new_case_prompt = new_case_template.format(dbSchema=db_schema, questionAugmented1=question_augmented, questionAugmented2=question_augmented)
@@ -208,10 +209,10 @@ class Text2DSLAgentAutoCoT(Text2DSLAgentBase):
return prompt
def generate_schema_linking_sql_prompt_pool(self, question: str, current_date:str, domain_name: str, fields_list: List[str],
prior_schema_links: Mapping[str,str], prior_exts:str, fewshot_example_list_pool:List[List[Mapping[str, str]]])-> List[str]:
prior_schema_links: Mapping[str,str], prior_exts:str, fewshot_example_list_pool:List[List[Mapping[str, str]]], terms_list: Optional[List[Dict]] = [])-> List[str]:
schema_linking_sql_prompt_pool = []
for fewshot_example_list in fewshot_example_list_pool:
schema_linking_sql_prompt = self.generate_schema_linking_sql_prompt(question, current_date, domain_name, fields_list, prior_schema_links, prior_exts, fewshot_example_list)
schema_linking_sql_prompt = self.generate_schema_linking_sql_prompt(question, current_date, domain_name, fields_list, prior_schema_links, prior_exts, fewshot_example_list, terms_list=terms_list)
schema_linking_sql_prompt_pool.append(schema_linking_sql_prompt)
return schema_linking_sql_prompt_pool
@@ -219,7 +220,7 @@ class Text2DSLAgentAutoCoT(Text2DSLAgentBase):
async def async_query2sql(self, question: str, filter_condition: Mapping[str,str],
model_name: str, fields_list: List[str],
current_date: str, prior_schema_links: Mapping[str,str], prior_exts: str,
llm_config:dict):
llm_config:dict, terms_list: Optional[List[Dict]] = []):
logger.info("question: {}".format(question))
logger.info("filter_condition: {}".format(filter_condition))
logger.info("model_name: {}".format(model_name))
@@ -227,6 +228,8 @@ class Text2DSLAgentAutoCoT(Text2DSLAgentBase):
logger.info("current_date: {}".format(current_date))
logger.info("prior_schema_links: {}".format(prior_schema_links))
logger.info("prior_exts: {}".format(prior_exts))
logger.info("terms_list: {}".format(terms_list))
fewshot_example_meta_list = self.get_examples_candidates(question, filter_condition, self.num_examples)
schema_linking_prompt = self.generate_schema_linking_prompt(question, current_date, model_name, fields_list, prior_schema_links, prior_exts, fewshot_example_meta_list)
@@ -238,7 +241,7 @@ class Text2DSLAgentAutoCoT(Text2DSLAgentBase):
schema_link_str = schema_link_parse(schema_link_output)
logger.debug("schema_link_str->{}".format(schema_link_str))
sql_prompt = self.generate_sql_prompt(question, model_name, fields_list, schema_link_str, current_date, prior_schema_links, prior_exts, fewshot_example_meta_list)
sql_prompt = self.generate_sql_prompt(question, model_name, fields_list, schema_link_str, current_date, prior_schema_links, prior_exts, fewshot_example_meta_list, terms_list=terms_list)
logger.debug("sql_prompt->{}".format(sql_prompt))
sql_output = await llm._call_async(sql_prompt)
@@ -264,7 +267,7 @@ class Text2DSLAgentAutoCoT(Text2DSLAgentBase):
async def async_query2sql_shortcut(self, question: str, filter_condition: Mapping[str,str],
model_name: str, fields_list: List[str],
current_date: str, prior_schema_links: Mapping[str,str], prior_exts: str,
llm_config:dict):
llm_config:dict, terms_list: Optional[List[Dict]] = []):
logger.info("question: {}".format(question))
logger.info("filter_condition: {}".format(filter_condition))
logger.info("model_name: {}".format(model_name))
@@ -272,9 +275,10 @@ class Text2DSLAgentAutoCoT(Text2DSLAgentBase):
logger.info("current_date: {}".format(current_date))
logger.info("prior_schema_links: {}".format(prior_schema_links))
logger.info("prior_exts: {}".format(prior_exts))
logger.info("terms_list: {}".format(terms_list))
fewshot_example_meta_list = self.get_examples_candidates(question, filter_condition, self.num_examples)
schema_linking_sql_shortcut_prompt = self.generate_schema_linking_sql_prompt(question, current_date, model_name, fields_list, prior_schema_links, prior_exts, fewshot_example_meta_list)
schema_linking_sql_shortcut_prompt = self.generate_schema_linking_sql_prompt(question, current_date, model_name, fields_list, prior_schema_links, prior_exts, fewshot_example_meta_list, terms_list)
logger.debug("schema_linking_sql_shortcut_prompt->{}".format(schema_linking_sql_shortcut_prompt))
llm = get_llm(llm_config)
schema_linking_sql_shortcut_output = await llm._call_async(schema_linking_sql_shortcut_prompt)
@@ -316,9 +320,9 @@ class Text2DSLAgentAutoCoT(Text2DSLAgentBase):
return schema_linking_str_pool, schema_linking_output_pool, schema_linking_prompt_pool
async def generate_sql_tasks(self, question: str, model_name: str, fields_list: List[str], schema_link_str_pool: List[str],
current_date: str, prior_schema_links: Mapping[str,str], prior_exts: str, fewshot_example_list_combo:List[List[Mapping[str, str]]], llm_config: dict):
current_date: str, prior_schema_links: Mapping[str,str], prior_exts: str, fewshot_example_list_combo:List[List[Mapping[str, str]]], llm_config: dict, terms_list: Optional[List[Dict]] = []):
sql_prompt_pool = self.generate_sql_prompt_pool(question, model_name, fields_list, schema_link_str_pool, current_date, prior_schema_links, prior_exts, fewshot_example_list_combo)
sql_prompt_pool = self.generate_sql_prompt_pool(question, model_name, fields_list, schema_link_str_pool, current_date, prior_schema_links, prior_exts, fewshot_example_list_combo, terms_list=terms_list)
logger.debug("sql_prompt_pool->{}".format(sql_prompt_pool))
llm = get_llm(llm_config)
sql_output_pool = await asyncio.gather(*[llm._call_async(sql_prompt) for sql_prompt in sql_prompt_pool])
@@ -328,8 +332,8 @@ class Text2DSLAgentAutoCoT(Text2DSLAgentBase):
async def generate_schema_linking_sql_tasks(self, question: str, model_name: str, fields_list: List[str],
current_date: str, prior_schema_links: Mapping[str,str], prior_exts: str,
fewshot_example_list_combo:List[List[Mapping[str, str]]],llm_config: dict):
schema_linking_sql_prompt_pool = self.generate_schema_linking_sql_prompt_pool(question, current_date, model_name, fields_list, prior_schema_links, prior_exts, fewshot_example_list_combo)
fewshot_example_list_combo:List[List[Mapping[str, str]]],llm_config: dict, terms_list: Optional[List[Dict]] = []):
schema_linking_sql_prompt_pool = self.generate_schema_linking_sql_prompt_pool(question, current_date, model_name, fields_list, prior_schema_links, prior_exts, fewshot_example_list_combo, terms_list=terms_list)
llm = get_llm(llm_config)
schema_linking_sql_output_task_pool = [llm._call_async(schema_linking_sql_prompt) for schema_linking_sql_prompt in schema_linking_sql_prompt_pool]
schema_linking_sql_output_res_pool = await asyncio.gather(*schema_linking_sql_output_task_pool)
@@ -339,7 +343,7 @@ class Text2DSLAgentAutoCoT(Text2DSLAgentBase):
async def tasks_run(self, question: str, filter_condition: Mapping[str,str],
model_name: str, fields_list: List[str],
current_date: str, prior_schema_links: Mapping[str,str], prior_exts: str, llm_config: dict):
current_date: str, prior_schema_links: Mapping[str,str], prior_exts: str, llm_config: dict, terms_list: Optional[List[Dict]] = []):
logger.info("question: {}".format(question))
logger.info("filter_condition: {}".format(filter_condition))
logger.info("model_name: {}".format(model_name))
@@ -347,18 +351,20 @@ class Text2DSLAgentAutoCoT(Text2DSLAgentBase):
logger.info("current_date: {}".format(current_date))
logger.info("prior_schema_links: {}".format(prior_schema_links))
logger.info("prior_exts: {}".format(prior_exts))
logger.info("terms_list: {}".format(terms_list))
fewshot_example_meta_list = self.get_examples_candidates(question, filter_condition, self.num_examples)
fewshot_example_list_combo = self.get_fewshot_example_combos(fewshot_example_meta_list, self.num_fewshots)
schema_linking_candidate_list, _, schema_linking_prompt_list = await self.generate_schema_linking_tasks(question, model_name, fields_list, current_date, prior_schema_links, prior_exts, fewshot_example_list_combo, llm_config)
schema_linking_candidate_list, _, schema_linking_prompt_list = await self.generate_schema_linking_tasks(question, model_name, fields_list, current_date, prior_schema_links, prior_exts, fewshot_example_list_combo, llm_config,)
logger.debug(f'schema_linking_candidate_list:{schema_linking_candidate_list}')
schema_linking_candidate_sorted_list = self.schema_linking_list_str_unify(schema_linking_candidate_list)
logger.debug(f'schema_linking_candidate_sorted_list:{schema_linking_candidate_sorted_list}')
schema_linking_output_max, schema_linking_output_vote_percentage = self.self_consistency_vote(schema_linking_candidate_sorted_list)
sql_output_candicates, sql_output_prompt_list = await self.generate_sql_tasks(question, model_name, fields_list, schema_linking_candidate_list, current_date, prior_schema_links, prior_exts, fewshot_example_list_combo, llm_config)
sql_output_candicates, sql_output_prompt_list = await self.generate_sql_tasks(question, model_name, fields_list, schema_linking_candidate_list, current_date, prior_schema_links, prior_exts, fewshot_example_list_combo, llm_config, terms_list=terms_list)
logger.debug(f'sql_output_candicates:{sql_output_candicates}')
sql_output_max, sql_output_vote_percentage = self.self_consistency_vote(sql_output_candicates)
@@ -383,7 +389,7 @@ class Text2DSLAgentAutoCoT(Text2DSLAgentBase):
return resp
async def tasks_run_shortcut(self, question: str, filter_condition: Mapping[str,str], model_name: str, fields_list: List[str],
current_date: str, prior_schema_links: Mapping[str,str], prior_exts: str, llm_config: dict):
current_date: str, prior_schema_links: Mapping[str,str], prior_exts: str, llm_config: dict, terms_list: Optional[List[Dict]] = []):
logger.info("question: {}".format(question))
logger.info("filter_condition: {}".format(filter_condition))
logger.info("model_name: {}".format(model_name))
@@ -391,11 +397,12 @@ class Text2DSLAgentAutoCoT(Text2DSLAgentBase):
logger.info("current_date: {}".format(current_date))
logger.info("prior_schema_links: {}".format(prior_schema_links))
logger.info("prior_exts: {}".format(prior_exts))
logger.info("terms_list: {}".format(terms_list))
fewshot_example_meta_list = self.get_examples_candidates(question, filter_condition, self.num_examples)
fewshot_example_list_combo = self.get_fewshot_example_combos(fewshot_example_meta_list, self.num_fewshots)
schema_linking_sql_output_candidates, schema_linking_sql_prompt_list, _ = await self.generate_schema_linking_sql_tasks(question, model_name, fields_list, current_date, prior_schema_links, prior_exts, fewshot_example_list_combo)
schema_linking_sql_output_candidates, schema_linking_sql_prompt_list, _ = await self.generate_schema_linking_sql_tasks(question, model_name, fields_list, current_date, prior_schema_links, prior_exts, fewshot_example_list_combo, terms_list=terms_list)
logger.debug(f'schema_linking_sql_output_candidates:{schema_linking_sql_output_candidates}')
schema_linking_output_candidate_list = [combo_schema_link_parse(schema_linking_sql_output_candidate) for schema_linking_sql_output_candidate in schema_linking_sql_output_candidates]
logger.debug(f'schema_linking_sql_output_candidate_list:{schema_linking_output_candidate_list}')
@@ -781,26 +788,26 @@ class Text2DSLAgentWrapper(object):
async def async_query2sql(self, question: str, filter_condition: Mapping[str,str],
model_name: str, fields_list: List[str],
data_date: str, prior_schema_links: Mapping[str,str], prior_exts: str, sql_generation_mode: str, llm_config: dict):
data_date: str, prior_schema_links: Mapping[str,str], prior_exts: str, sql_generation_mode: str, llm_config: dict, terms_list: Optional[List[Dict]] = []):
if sql_generation_mode not in (sql_mode.value for sql_mode in SqlModeEnum):
raise ValueError(f"sql_generation_mode: {sql_generation_mode} is not in SqlModeEnum")
if sql_generation_mode == '1_pass_auto_cot':
logger.info(f"sql wrapper: {sql_generation_mode}")
resp = await self.sql_agent_act.async_query2sql_shortcut(question=question, filter_condition=filter_condition, model_name=model_name, fields_list=fields_list, current_date=data_date, prior_schema_links=prior_schema_links, prior_exts=prior_exts, llm_config=llm_config)
resp = await self.sql_agent_act.async_query2sql_shortcut(question=question, filter_condition=filter_condition, model_name=model_name, fields_list=fields_list, current_date=data_date, prior_schema_links=prior_schema_links, prior_exts=prior_exts, llm_config=llm_config, terms_list=terms_list)
return resp
elif sql_generation_mode == '1_pass_auto_cot_self_consistency':
logger.info(f"sql wrapper: {sql_generation_mode}")
resp = await self.sql_agent_act.tasks_run_shortcut(question=question, filter_condition=filter_condition, model_name=model_name, fields_list=fields_list, current_date=data_date, prior_schema_links=prior_schema_links, prior_exts=prior_exts, llm_config=llm_config)
resp = await self.sql_agent_act.tasks_run_shortcut(question=question, filter_condition=filter_condition, model_name=model_name, fields_list=fields_list, current_date=data_date, prior_schema_links=prior_schema_links, prior_exts=prior_exts, llm_config=llm_config, terms_list=terms_list)
return resp
elif sql_generation_mode == '2_pass_auto_cot':
logger.info(f"sql wrapper: {sql_generation_mode}")
resp = await self.sql_agent_act.async_query2sql(question=question, filter_condition=filter_condition, model_name=model_name, fields_list=fields_list, current_date=data_date, prior_schema_links=prior_schema_links, prior_exts=prior_exts, llm_config=llm_config)
resp = await self.sql_agent_act.async_query2sql(question=question, filter_condition=filter_condition, model_name=model_name, fields_list=fields_list, current_date=data_date, prior_schema_links=prior_schema_links, prior_exts=prior_exts, llm_config=llm_config, terms_list=terms_list)
return resp
elif sql_generation_mode == '2_pass_auto_cot_self_consistency':
logger.info(f"sql wrapper: {sql_generation_mode}")
resp = await self.sql_agent_act.tasks_run(question=question, filter_condition=filter_condition, model_name=model_name, fields_list=fields_list, current_date=data_date, prior_schema_links=prior_schema_links, prior_exts=prior_exts, llm_config=llm_config)
resp = await self.sql_agent_act.tasks_run(question=question, filter_condition=filter_condition, model_name=model_name, fields_list=fields_list, current_date=data_date, prior_schema_links=prior_schema_links, prior_exts=prior_exts, llm_config=llm_config, terms_list=terms_list)
return resp
else:
raise ValueError(f'sql_generation_mode:{sql_generation_mode} is not in SqlModeEnum')

View File

@@ -59,6 +59,7 @@ async def query2sql(query_body: Mapping[str, Any]):
dataset_name = schema['dataSetName']
fields_list = schema['fieldNameList']
prior_schema_links = {item['fieldValue']:item['fieldName'] for item in linking}
terms_list = schema['terms']
resp = await text2sql_agent_router.async_query2sql(question=query_text, filter_condition=filter_condition,
model_name=dataset_name, fields_list=fields_list,