feat. Terms for llm prompt (#1042)

Co-authored-by: guixuanwen
This commit is contained in:
GuiXuan Wen
2024-05-29 14:45:11 +08:00
committed by GitHub
parent bd9cc8f88a
commit 9c3509fc1f
3 changed files with 61 additions and 31 deletions

View File

@@ -2,7 +2,9 @@
import os import os
import sys import sys
from typing import Any, List, Union, Mapping from typing import Any, Dict, List, Union, Mapping
from git import Optional
sys.path.append(os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))) sys.path.append(os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))))
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
@@ -14,7 +16,7 @@ from auto_cot import auto_cot_run
def transform_sql_example(question:str, current_date:str, table_name:str, field_list: Union[str, List[str]], prior_linkings: Union[str, Mapping[str,str]], prior_exts:str, sql:str=None): def transform_sql_example(question:str, current_date:str, table_name:str, field_list: Union[str, List[str]], prior_linkings: Union[str, Mapping[str,str]], prior_exts:str, sql:str=None, terms_list: Optional[List[Dict]] = []):
db_schema = f"Table: {table_name}, Columns = {field_list}\nForeign_keys: []" db_schema = f"Table: {table_name}, Columns = {field_list}\nForeign_keys: []"
prior_linkings_pairs = [] prior_linkings_pairs = []
@@ -40,7 +42,27 @@ def transform_sql_example(question:str, current_date:str, table_name:str, field_
current_data_str = """当前的日期是{}""".format(current_date) current_data_str = """当前的日期是{}""".format(current_date)
question_augmented = """{question} (补充信息:{prior_linking}{current_date}) (备注: {prior_exts})""".format(question=question, prior_linking=prior_linkings_str, prior_exts=prior_exts, current_date=current_data_str) terms_desc = ''
if len(terms_list) > 0:
terms_desc += "相关业务术语:"
for idx, term in enumerate(terms_list):
if (term['description'] is not None and len(term['description']) > 0) and (term['alias'] is not None and len(term['alias']) > 0):
terms_desc += f"""{idx+1}.<{term['name']}>是业务术语,它通常是指<{term['description']}>,类似的表达还有{term['alias']}"""
elif (term['description'] is None or len(term['description']) == 0) and (term['alias'] is not None and len(term['alias']) > 0):
terms_desc += f"""{idx+1}.<{term['name']}>是业务术语,类似的表达还有{term['alias']}"""
elif (term['description'] is not None and len(term['description']) > 0) and (term['alias'] is None or len(term['alias']) == 0):
terms_desc += f"""{idx+1}.<{term['name']}>是业务术语,它通常是指<{term['description']}>"""
else:
terms_desc += f"""{idx+1}.<{term['name']}>是业务术语;"""
if len(terms_desc) > 0:
terms_desc = terms_desc[:-1]
question_augmented = """{question} (补充信息:{prior_linking}{current_date}{terms_desc}) (备注: {prior_exts})""".format(question=question, prior_linking=prior_linkings_str, prior_exts=prior_exts, current_date=current_data_str)
return question_augmented, db_schema, sql return question_augmented, db_schema, sql

View File

@@ -1,6 +1,6 @@
import os import os
import sys import sys
from typing import List, Union, Mapping, Any from typing import Dict, List, Optional, Union, Mapping, Any
from collections import Counter from collections import Counter
import random import random
import asyncio import asyncio
@@ -158,7 +158,7 @@ class Text2DSLAgentAutoCoT(Text2DSLAgentBase):
def generate_sql_prompt(self, question: str, domain_name: str,fields_list: List[str], def generate_sql_prompt(self, question: str, domain_name: str,fields_list: List[str],
schema_link_str: str, current_date: str, prior_schema_links: Mapping[str,str], prior_exts:str, schema_link_str: str, current_date: str, prior_schema_links: Mapping[str,str], prior_exts:str,
fewshot_example_list:List[Mapping[str, str]])-> str: fewshot_example_list:List[Mapping[str, str]], terms_list: Optional[List[Dict]] = [])-> str:
instruction = "# Use the the schema links to generate the SQL queries for each of the questions." instruction = "# Use the the schema links to generate the SQL queries for each of the questions."
sql_example_keys = ["questionAugmented", "dbSchema", "generatedSchemaLinkings", "sql"] sql_example_keys = ["questionAugmented", "dbSchema", "generatedSchemaLinkings", "sql"]
@@ -168,7 +168,7 @@ class Text2DSLAgentAutoCoT(Text2DSLAgentBase):
example_keys=sql_example_keys, example_keys=sql_example_keys,
few_shot_example_meta_list=fewshot_example_list) few_shot_example_meta_list=fewshot_example_list)
question_augmented, db_schema, _ = transform_sql_example(question, current_date, domain_name, fields_list, prior_schema_links, prior_exts) question_augmented, db_schema, _ = transform_sql_example(question, current_date, domain_name, fields_list, prior_schema_links, prior_exts, terms_list=terms_list)
new_case_template = "{dbSchema}\nQ: {questionAugmented}\nSchema_links: {schemaLinkings}\nSQL: " new_case_template = "{dbSchema}\nQ: {questionAugmented}\nSchema_links: {schemaLinkings}\nSQL: "
new_case_prompt = new_case_template.format(dbSchema=db_schema, questionAugmented=question_augmented, schemaLinkings=schema_link_str) new_case_prompt = new_case_template.format(dbSchema=db_schema, questionAugmented=question_augmented, schemaLinkings=schema_link_str)
@@ -179,26 +179,27 @@ class Text2DSLAgentAutoCoT(Text2DSLAgentBase):
def generate_sql_prompt_pool(self, question: str, domain_name: str,fields_list: List[str], def generate_sql_prompt_pool(self, question: str, domain_name: str,fields_list: List[str],
schema_link_str_pool: List[str], current_date: str, prior_schema_links: Mapping[str,str], prior_exts:str, schema_link_str_pool: List[str], current_date: str, prior_schema_links: Mapping[str,str], prior_exts:str,
fewshot_example_list_pool:List[List[Mapping[str, str]]])-> List[str]: fewshot_example_list_pool:List[List[Mapping[str, str]]], terms_list: Optional[List[Dict]] = [])-> List[str]:
sql_prompt_pool = [] sql_prompt_pool = []
for schema_link_str, fewshot_example_list in zip(schema_link_str_pool, fewshot_example_list_pool): for schema_link_str, fewshot_example_list in zip(schema_link_str_pool, fewshot_example_list_pool):
sql_prompt = self.generate_sql_prompt(question, domain_name, fields_list, schema_link_str, current_date, prior_schema_links, prior_exts, fewshot_example_list) sql_prompt = self.generate_sql_prompt(question, domain_name, fields_list, schema_link_str, current_date, prior_schema_links, prior_exts, fewshot_example_list, terms_list=terms_list)
sql_prompt_pool.append(sql_prompt) sql_prompt_pool.append(sql_prompt)
return sql_prompt_pool return sql_prompt_pool
def generate_schema_linking_sql_prompt(self, question: str, current_date:str, domain_name: str, fields_list: List[str], def generate_schema_linking_sql_prompt(self, question: str, current_date:str, domain_name: str, fields_list: List[str],
prior_schema_links: Mapping[str,str], prior_exts:str, fewshot_example_list:List[Mapping[str, str]]): prior_schema_links: Mapping[str,str], prior_exts:str, fewshot_example_list:List[Mapping[str, str]], terms_list: Optional[List[Dict]] = []):
instruction = "# Find the schema_links for generating SQL queries for each question based on the database schema and Foreign keys. Then use the the schema links to generate the SQL queries for each of the questions." instruction = "# Find the schema_links for generating SQL queries for each question based on the database schema and Foreign keys. Then use the the schema links to generate the SQL queries for each of the questions."
example_keys = ["questionAugmented", "dbSchema", "generatedSchemaLinkingCoT","sql"] example_keys = ["questionAugmented", "dbSchema", "generatedSchemaLinkingCoT","sql"]
example_template = "{dbSchema}\nQ: {questionAugmented}\nA: {generatedSchemaLinkingCoT}\nSQL: {sql}" example_template = "{dbSchema}\nQ: {questionAugmented}\nA: {generatedSchemaLinkingCoT}\nSQL: {sql}\n"
fewshot_prompt = self.sql_example_prompter.make_few_shot_example_prompt(few_shot_template=example_template, fewshot_prompt = self.sql_example_prompter.make_few_shot_example_prompt(few_shot_template=example_template,
example_keys=example_keys, example_keys=example_keys,
few_shot_example_meta_list=fewshot_example_list) few_shot_example_meta_list=fewshot_example_list)
question_augmented, db_schema, _ = transform_sql_example(question, current_date, domain_name, fields_list, prior_schema_links, prior_exts) question_augmented, db_schema, _ = transform_sql_example(question, current_date, domain_name, fields_list, prior_schema_links, prior_exts, terms_list)
new_case_template = """{dbSchema}\nQ: {questionAugmented1}\nA: Lets think step by step. In the question "{questionAugmented2}", we are asked:""" new_case_template = """{dbSchema}\nQ: {questionAugmented1}\nA: Lets think step by step. In the question "{questionAugmented2}", we are asked:"""
new_case_prompt = new_case_template.format(dbSchema=db_schema, questionAugmented1=question_augmented, questionAugmented2=question_augmented) new_case_prompt = new_case_template.format(dbSchema=db_schema, questionAugmented1=question_augmented, questionAugmented2=question_augmented)
@@ -208,10 +209,10 @@ class Text2DSLAgentAutoCoT(Text2DSLAgentBase):
return prompt return prompt
def generate_schema_linking_sql_prompt_pool(self, question: str, current_date:str, domain_name: str, fields_list: List[str], def generate_schema_linking_sql_prompt_pool(self, question: str, current_date:str, domain_name: str, fields_list: List[str],
prior_schema_links: Mapping[str,str], prior_exts:str, fewshot_example_list_pool:List[List[Mapping[str, str]]])-> List[str]: prior_schema_links: Mapping[str,str], prior_exts:str, fewshot_example_list_pool:List[List[Mapping[str, str]]], terms_list: Optional[List[Dict]] = [])-> List[str]:
schema_linking_sql_prompt_pool = [] schema_linking_sql_prompt_pool = []
for fewshot_example_list in fewshot_example_list_pool: for fewshot_example_list in fewshot_example_list_pool:
schema_linking_sql_prompt = self.generate_schema_linking_sql_prompt(question, current_date, domain_name, fields_list, prior_schema_links, prior_exts, fewshot_example_list) schema_linking_sql_prompt = self.generate_schema_linking_sql_prompt(question, current_date, domain_name, fields_list, prior_schema_links, prior_exts, fewshot_example_list, terms_list=terms_list)
schema_linking_sql_prompt_pool.append(schema_linking_sql_prompt) schema_linking_sql_prompt_pool.append(schema_linking_sql_prompt)
return schema_linking_sql_prompt_pool return schema_linking_sql_prompt_pool
@@ -219,7 +220,7 @@ class Text2DSLAgentAutoCoT(Text2DSLAgentBase):
async def async_query2sql(self, question: str, filter_condition: Mapping[str,str], async def async_query2sql(self, question: str, filter_condition: Mapping[str,str],
model_name: str, fields_list: List[str], model_name: str, fields_list: List[str],
current_date: str, prior_schema_links: Mapping[str,str], prior_exts: str, current_date: str, prior_schema_links: Mapping[str,str], prior_exts: str,
llm_config:dict): llm_config:dict, terms_list: Optional[List[Dict]] = []):
logger.info("question: {}".format(question)) logger.info("question: {}".format(question))
logger.info("filter_condition: {}".format(filter_condition)) logger.info("filter_condition: {}".format(filter_condition))
logger.info("model_name: {}".format(model_name)) logger.info("model_name: {}".format(model_name))
@@ -227,6 +228,8 @@ class Text2DSLAgentAutoCoT(Text2DSLAgentBase):
logger.info("current_date: {}".format(current_date)) logger.info("current_date: {}".format(current_date))
logger.info("prior_schema_links: {}".format(prior_schema_links)) logger.info("prior_schema_links: {}".format(prior_schema_links))
logger.info("prior_exts: {}".format(prior_exts)) logger.info("prior_exts: {}".format(prior_exts))
logger.info("terms_list: {}".format(terms_list))
fewshot_example_meta_list = self.get_examples_candidates(question, filter_condition, self.num_examples) fewshot_example_meta_list = self.get_examples_candidates(question, filter_condition, self.num_examples)
schema_linking_prompt = self.generate_schema_linking_prompt(question, current_date, model_name, fields_list, prior_schema_links, prior_exts, fewshot_example_meta_list) schema_linking_prompt = self.generate_schema_linking_prompt(question, current_date, model_name, fields_list, prior_schema_links, prior_exts, fewshot_example_meta_list)
@@ -238,7 +241,7 @@ class Text2DSLAgentAutoCoT(Text2DSLAgentBase):
schema_link_str = schema_link_parse(schema_link_output) schema_link_str = schema_link_parse(schema_link_output)
logger.debug("schema_link_str->{}".format(schema_link_str)) logger.debug("schema_link_str->{}".format(schema_link_str))
sql_prompt = self.generate_sql_prompt(question, model_name, fields_list, schema_link_str, current_date, prior_schema_links, prior_exts, fewshot_example_meta_list) sql_prompt = self.generate_sql_prompt(question, model_name, fields_list, schema_link_str, current_date, prior_schema_links, prior_exts, fewshot_example_meta_list, terms_list=terms_list)
logger.debug("sql_prompt->{}".format(sql_prompt)) logger.debug("sql_prompt->{}".format(sql_prompt))
sql_output = await llm._call_async(sql_prompt) sql_output = await llm._call_async(sql_prompt)
@@ -264,7 +267,7 @@ class Text2DSLAgentAutoCoT(Text2DSLAgentBase):
async def async_query2sql_shortcut(self, question: str, filter_condition: Mapping[str,str], async def async_query2sql_shortcut(self, question: str, filter_condition: Mapping[str,str],
model_name: str, fields_list: List[str], model_name: str, fields_list: List[str],
current_date: str, prior_schema_links: Mapping[str,str], prior_exts: str, current_date: str, prior_schema_links: Mapping[str,str], prior_exts: str,
llm_config:dict): llm_config:dict, terms_list: Optional[List[Dict]] = []):
logger.info("question: {}".format(question)) logger.info("question: {}".format(question))
logger.info("filter_condition: {}".format(filter_condition)) logger.info("filter_condition: {}".format(filter_condition))
logger.info("model_name: {}".format(model_name)) logger.info("model_name: {}".format(model_name))
@@ -272,9 +275,10 @@ class Text2DSLAgentAutoCoT(Text2DSLAgentBase):
logger.info("current_date: {}".format(current_date)) logger.info("current_date: {}".format(current_date))
logger.info("prior_schema_links: {}".format(prior_schema_links)) logger.info("prior_schema_links: {}".format(prior_schema_links))
logger.info("prior_exts: {}".format(prior_exts)) logger.info("prior_exts: {}".format(prior_exts))
logger.info("terms_list: {}".format(terms_list))
fewshot_example_meta_list = self.get_examples_candidates(question, filter_condition, self.num_examples) fewshot_example_meta_list = self.get_examples_candidates(question, filter_condition, self.num_examples)
schema_linking_sql_shortcut_prompt = self.generate_schema_linking_sql_prompt(question, current_date, model_name, fields_list, prior_schema_links, prior_exts, fewshot_example_meta_list) schema_linking_sql_shortcut_prompt = self.generate_schema_linking_sql_prompt(question, current_date, model_name, fields_list, prior_schema_links, prior_exts, fewshot_example_meta_list, terms_list)
logger.debug("schema_linking_sql_shortcut_prompt->{}".format(schema_linking_sql_shortcut_prompt)) logger.debug("schema_linking_sql_shortcut_prompt->{}".format(schema_linking_sql_shortcut_prompt))
llm = get_llm(llm_config) llm = get_llm(llm_config)
schema_linking_sql_shortcut_output = await llm._call_async(schema_linking_sql_shortcut_prompt) schema_linking_sql_shortcut_output = await llm._call_async(schema_linking_sql_shortcut_prompt)
@@ -316,9 +320,9 @@ class Text2DSLAgentAutoCoT(Text2DSLAgentBase):
return schema_linking_str_pool, schema_linking_output_pool, schema_linking_prompt_pool return schema_linking_str_pool, schema_linking_output_pool, schema_linking_prompt_pool
async def generate_sql_tasks(self, question: str, model_name: str, fields_list: List[str], schema_link_str_pool: List[str], async def generate_sql_tasks(self, question: str, model_name: str, fields_list: List[str], schema_link_str_pool: List[str],
current_date: str, prior_schema_links: Mapping[str,str], prior_exts: str, fewshot_example_list_combo:List[List[Mapping[str, str]]], llm_config: dict): current_date: str, prior_schema_links: Mapping[str,str], prior_exts: str, fewshot_example_list_combo:List[List[Mapping[str, str]]], llm_config: dict, terms_list: Optional[List[Dict]] = []):
sql_prompt_pool = self.generate_sql_prompt_pool(question, model_name, fields_list, schema_link_str_pool, current_date, prior_schema_links, prior_exts, fewshot_example_list_combo) sql_prompt_pool = self.generate_sql_prompt_pool(question, model_name, fields_list, schema_link_str_pool, current_date, prior_schema_links, prior_exts, fewshot_example_list_combo, terms_list=terms_list)
logger.debug("sql_prompt_pool->{}".format(sql_prompt_pool)) logger.debug("sql_prompt_pool->{}".format(sql_prompt_pool))
llm = get_llm(llm_config) llm = get_llm(llm_config)
sql_output_pool = await asyncio.gather(*[llm._call_async(sql_prompt) for sql_prompt in sql_prompt_pool]) sql_output_pool = await asyncio.gather(*[llm._call_async(sql_prompt) for sql_prompt in sql_prompt_pool])
@@ -328,8 +332,8 @@ class Text2DSLAgentAutoCoT(Text2DSLAgentBase):
async def generate_schema_linking_sql_tasks(self, question: str, model_name: str, fields_list: List[str], async def generate_schema_linking_sql_tasks(self, question: str, model_name: str, fields_list: List[str],
current_date: str, prior_schema_links: Mapping[str,str], prior_exts: str, current_date: str, prior_schema_links: Mapping[str,str], prior_exts: str,
fewshot_example_list_combo:List[List[Mapping[str, str]]],llm_config: dict): fewshot_example_list_combo:List[List[Mapping[str, str]]],llm_config: dict, terms_list: Optional[List[Dict]] = []):
schema_linking_sql_prompt_pool = self.generate_schema_linking_sql_prompt_pool(question, current_date, model_name, fields_list, prior_schema_links, prior_exts, fewshot_example_list_combo) schema_linking_sql_prompt_pool = self.generate_schema_linking_sql_prompt_pool(question, current_date, model_name, fields_list, prior_schema_links, prior_exts, fewshot_example_list_combo, terms_list=terms_list)
llm = get_llm(llm_config) llm = get_llm(llm_config)
schema_linking_sql_output_task_pool = [llm._call_async(schema_linking_sql_prompt) for schema_linking_sql_prompt in schema_linking_sql_prompt_pool] schema_linking_sql_output_task_pool = [llm._call_async(schema_linking_sql_prompt) for schema_linking_sql_prompt in schema_linking_sql_prompt_pool]
schema_linking_sql_output_res_pool = await asyncio.gather(*schema_linking_sql_output_task_pool) schema_linking_sql_output_res_pool = await asyncio.gather(*schema_linking_sql_output_task_pool)
@@ -339,7 +343,7 @@ class Text2DSLAgentAutoCoT(Text2DSLAgentBase):
async def tasks_run(self, question: str, filter_condition: Mapping[str,str], async def tasks_run(self, question: str, filter_condition: Mapping[str,str],
model_name: str, fields_list: List[str], model_name: str, fields_list: List[str],
current_date: str, prior_schema_links: Mapping[str,str], prior_exts: str, llm_config: dict): current_date: str, prior_schema_links: Mapping[str,str], prior_exts: str, llm_config: dict, terms_list: Optional[List[Dict]] = []):
logger.info("question: {}".format(question)) logger.info("question: {}".format(question))
logger.info("filter_condition: {}".format(filter_condition)) logger.info("filter_condition: {}".format(filter_condition))
logger.info("model_name: {}".format(model_name)) logger.info("model_name: {}".format(model_name))
@@ -347,18 +351,20 @@ class Text2DSLAgentAutoCoT(Text2DSLAgentBase):
logger.info("current_date: {}".format(current_date)) logger.info("current_date: {}".format(current_date))
logger.info("prior_schema_links: {}".format(prior_schema_links)) logger.info("prior_schema_links: {}".format(prior_schema_links))
logger.info("prior_exts: {}".format(prior_exts)) logger.info("prior_exts: {}".format(prior_exts))
logger.info("terms_list: {}".format(terms_list))
fewshot_example_meta_list = self.get_examples_candidates(question, filter_condition, self.num_examples) fewshot_example_meta_list = self.get_examples_candidates(question, filter_condition, self.num_examples)
fewshot_example_list_combo = self.get_fewshot_example_combos(fewshot_example_meta_list, self.num_fewshots) fewshot_example_list_combo = self.get_fewshot_example_combos(fewshot_example_meta_list, self.num_fewshots)
schema_linking_candidate_list, _, schema_linking_prompt_list = await self.generate_schema_linking_tasks(question, model_name, fields_list, current_date, prior_schema_links, prior_exts, fewshot_example_list_combo, llm_config) schema_linking_candidate_list, _, schema_linking_prompt_list = await self.generate_schema_linking_tasks(question, model_name, fields_list, current_date, prior_schema_links, prior_exts, fewshot_example_list_combo, llm_config,)
logger.debug(f'schema_linking_candidate_list:{schema_linking_candidate_list}') logger.debug(f'schema_linking_candidate_list:{schema_linking_candidate_list}')
schema_linking_candidate_sorted_list = self.schema_linking_list_str_unify(schema_linking_candidate_list) schema_linking_candidate_sorted_list = self.schema_linking_list_str_unify(schema_linking_candidate_list)
logger.debug(f'schema_linking_candidate_sorted_list:{schema_linking_candidate_sorted_list}') logger.debug(f'schema_linking_candidate_sorted_list:{schema_linking_candidate_sorted_list}')
schema_linking_output_max, schema_linking_output_vote_percentage = self.self_consistency_vote(schema_linking_candidate_sorted_list) schema_linking_output_max, schema_linking_output_vote_percentage = self.self_consistency_vote(schema_linking_candidate_sorted_list)
sql_output_candicates, sql_output_prompt_list = await self.generate_sql_tasks(question, model_name, fields_list, schema_linking_candidate_list, current_date, prior_schema_links, prior_exts, fewshot_example_list_combo, llm_config) sql_output_candicates, sql_output_prompt_list = await self.generate_sql_tasks(question, model_name, fields_list, schema_linking_candidate_list, current_date, prior_schema_links, prior_exts, fewshot_example_list_combo, llm_config, terms_list=terms_list)
logger.debug(f'sql_output_candicates:{sql_output_candicates}') logger.debug(f'sql_output_candicates:{sql_output_candicates}')
sql_output_max, sql_output_vote_percentage = self.self_consistency_vote(sql_output_candicates) sql_output_max, sql_output_vote_percentage = self.self_consistency_vote(sql_output_candicates)
@@ -383,7 +389,7 @@ class Text2DSLAgentAutoCoT(Text2DSLAgentBase):
return resp return resp
async def tasks_run_shortcut(self, question: str, filter_condition: Mapping[str,str], model_name: str, fields_list: List[str], async def tasks_run_shortcut(self, question: str, filter_condition: Mapping[str,str], model_name: str, fields_list: List[str],
current_date: str, prior_schema_links: Mapping[str,str], prior_exts: str, llm_config: dict): current_date: str, prior_schema_links: Mapping[str,str], prior_exts: str, llm_config: dict, terms_list: Optional[List[Dict]] = []):
logger.info("question: {}".format(question)) logger.info("question: {}".format(question))
logger.info("filter_condition: {}".format(filter_condition)) logger.info("filter_condition: {}".format(filter_condition))
logger.info("model_name: {}".format(model_name)) logger.info("model_name: {}".format(model_name))
@@ -391,11 +397,12 @@ class Text2DSLAgentAutoCoT(Text2DSLAgentBase):
logger.info("current_date: {}".format(current_date)) logger.info("current_date: {}".format(current_date))
logger.info("prior_schema_links: {}".format(prior_schema_links)) logger.info("prior_schema_links: {}".format(prior_schema_links))
logger.info("prior_exts: {}".format(prior_exts)) logger.info("prior_exts: {}".format(prior_exts))
logger.info("terms_list: {}".format(terms_list))
fewshot_example_meta_list = self.get_examples_candidates(question, filter_condition, self.num_examples) fewshot_example_meta_list = self.get_examples_candidates(question, filter_condition, self.num_examples)
fewshot_example_list_combo = self.get_fewshot_example_combos(fewshot_example_meta_list, self.num_fewshots) fewshot_example_list_combo = self.get_fewshot_example_combos(fewshot_example_meta_list, self.num_fewshots)
schema_linking_sql_output_candidates, schema_linking_sql_prompt_list, _ = await self.generate_schema_linking_sql_tasks(question, model_name, fields_list, current_date, prior_schema_links, prior_exts, fewshot_example_list_combo) schema_linking_sql_output_candidates, schema_linking_sql_prompt_list, _ = await self.generate_schema_linking_sql_tasks(question, model_name, fields_list, current_date, prior_schema_links, prior_exts, fewshot_example_list_combo, terms_list=terms_list)
logger.debug(f'schema_linking_sql_output_candidates:{schema_linking_sql_output_candidates}') logger.debug(f'schema_linking_sql_output_candidates:{schema_linking_sql_output_candidates}')
schema_linking_output_candidate_list = [combo_schema_link_parse(schema_linking_sql_output_candidate) for schema_linking_sql_output_candidate in schema_linking_sql_output_candidates] schema_linking_output_candidate_list = [combo_schema_link_parse(schema_linking_sql_output_candidate) for schema_linking_sql_output_candidate in schema_linking_sql_output_candidates]
logger.debug(f'schema_linking_sql_output_candidate_list:{schema_linking_output_candidate_list}') logger.debug(f'schema_linking_sql_output_candidate_list:{schema_linking_output_candidate_list}')
@@ -781,26 +788,26 @@ class Text2DSLAgentWrapper(object):
async def async_query2sql(self, question: str, filter_condition: Mapping[str,str], async def async_query2sql(self, question: str, filter_condition: Mapping[str,str],
model_name: str, fields_list: List[str], model_name: str, fields_list: List[str],
data_date: str, prior_schema_links: Mapping[str,str], prior_exts: str, sql_generation_mode: str, llm_config: dict): data_date: str, prior_schema_links: Mapping[str,str], prior_exts: str, sql_generation_mode: str, llm_config: dict, terms_list: Optional[List[Dict]] = []):
if sql_generation_mode not in (sql_mode.value for sql_mode in SqlModeEnum): if sql_generation_mode not in (sql_mode.value for sql_mode in SqlModeEnum):
raise ValueError(f"sql_generation_mode: {sql_generation_mode} is not in SqlModeEnum") raise ValueError(f"sql_generation_mode: {sql_generation_mode} is not in SqlModeEnum")
if sql_generation_mode == '1_pass_auto_cot': if sql_generation_mode == '1_pass_auto_cot':
logger.info(f"sql wrapper: {sql_generation_mode}") logger.info(f"sql wrapper: {sql_generation_mode}")
resp = await self.sql_agent_act.async_query2sql_shortcut(question=question, filter_condition=filter_condition, model_name=model_name, fields_list=fields_list, current_date=data_date, prior_schema_links=prior_schema_links, prior_exts=prior_exts, llm_config=llm_config) resp = await self.sql_agent_act.async_query2sql_shortcut(question=question, filter_condition=filter_condition, model_name=model_name, fields_list=fields_list, current_date=data_date, prior_schema_links=prior_schema_links, prior_exts=prior_exts, llm_config=llm_config, terms_list=terms_list)
return resp return resp
elif sql_generation_mode == '1_pass_auto_cot_self_consistency': elif sql_generation_mode == '1_pass_auto_cot_self_consistency':
logger.info(f"sql wrapper: {sql_generation_mode}") logger.info(f"sql wrapper: {sql_generation_mode}")
resp = await self.sql_agent_act.tasks_run_shortcut(question=question, filter_condition=filter_condition, model_name=model_name, fields_list=fields_list, current_date=data_date, prior_schema_links=prior_schema_links, prior_exts=prior_exts, llm_config=llm_config) resp = await self.sql_agent_act.tasks_run_shortcut(question=question, filter_condition=filter_condition, model_name=model_name, fields_list=fields_list, current_date=data_date, prior_schema_links=prior_schema_links, prior_exts=prior_exts, llm_config=llm_config, terms_list=terms_list)
return resp return resp
elif sql_generation_mode == '2_pass_auto_cot': elif sql_generation_mode == '2_pass_auto_cot':
logger.info(f"sql wrapper: {sql_generation_mode}") logger.info(f"sql wrapper: {sql_generation_mode}")
resp = await self.sql_agent_act.async_query2sql(question=question, filter_condition=filter_condition, model_name=model_name, fields_list=fields_list, current_date=data_date, prior_schema_links=prior_schema_links, prior_exts=prior_exts, llm_config=llm_config) resp = await self.sql_agent_act.async_query2sql(question=question, filter_condition=filter_condition, model_name=model_name, fields_list=fields_list, current_date=data_date, prior_schema_links=prior_schema_links, prior_exts=prior_exts, llm_config=llm_config, terms_list=terms_list)
return resp return resp
elif sql_generation_mode == '2_pass_auto_cot_self_consistency': elif sql_generation_mode == '2_pass_auto_cot_self_consistency':
logger.info(f"sql wrapper: {sql_generation_mode}") logger.info(f"sql wrapper: {sql_generation_mode}")
resp = await self.sql_agent_act.tasks_run(question=question, filter_condition=filter_condition, model_name=model_name, fields_list=fields_list, current_date=data_date, prior_schema_links=prior_schema_links, prior_exts=prior_exts, llm_config=llm_config) resp = await self.sql_agent_act.tasks_run(question=question, filter_condition=filter_condition, model_name=model_name, fields_list=fields_list, current_date=data_date, prior_schema_links=prior_schema_links, prior_exts=prior_exts, llm_config=llm_config, terms_list=terms_list)
return resp return resp
else: else:
raise ValueError(f'sql_generation_mode:{sql_generation_mode} is not in SqlModeEnum') raise ValueError(f'sql_generation_mode:{sql_generation_mode} is not in SqlModeEnum')

View File

@@ -59,6 +59,7 @@ async def query2sql(query_body: Mapping[str, Any]):
dataset_name = schema['dataSetName'] dataset_name = schema['dataSetName']
fields_list = schema['fieldNameList'] fields_list = schema['fieldNameList']
prior_schema_links = {item['fieldValue']:item['fieldName'] for item in linking} prior_schema_links = {item['fieldValue']:item['fieldName'] for item in linking}
terms_list = schema['terms']
resp = await text2sql_agent_router.async_query2sql(question=query_text, filter_condition=filter_condition, resp = await text2sql_agent_router.async_query2sql(question=query_text, filter_condition=filter_condition,
model_name=dataset_name, fields_list=fields_list, model_name=dataset_name, fields_list=fields_list,