feat. Terms for llm prompt (#1042)

Co-authored-by: guixuanwen
This commit is contained in:
GuiXuan Wen
2024-05-29 14:45:11 +08:00
committed by GitHub
parent bd9cc8f88a
commit 9c3509fc1f
3 changed files with 61 additions and 31 deletions

View File

@@ -1,6 +1,6 @@
import os
import sys
from typing import List, Union, Mapping, Any
from typing import Dict, List, Optional, Union, Mapping, Any
from collections import Counter
import random
import asyncio
@@ -158,7 +158,7 @@ class Text2DSLAgentAutoCoT(Text2DSLAgentBase):
def generate_sql_prompt(self, question: str, domain_name: str,fields_list: List[str],
schema_link_str: str, current_date: str, prior_schema_links: Mapping[str,str], prior_exts:str,
fewshot_example_list:List[Mapping[str, str]])-> str:
fewshot_example_list:List[Mapping[str, str]], terms_list: Optional[List[Dict]] = [])-> str:
instruction = "# Use the the schema links to generate the SQL queries for each of the questions."
sql_example_keys = ["questionAugmented", "dbSchema", "generatedSchemaLinkings", "sql"]
@@ -168,7 +168,7 @@ class Text2DSLAgentAutoCoT(Text2DSLAgentBase):
example_keys=sql_example_keys,
few_shot_example_meta_list=fewshot_example_list)
question_augmented, db_schema, _ = transform_sql_example(question, current_date, domain_name, fields_list, prior_schema_links, prior_exts)
question_augmented, db_schema, _ = transform_sql_example(question, current_date, domain_name, fields_list, prior_schema_links, prior_exts, terms_list=terms_list)
new_case_template = "{dbSchema}\nQ: {questionAugmented}\nSchema_links: {schemaLinkings}\nSQL: "
new_case_prompt = new_case_template.format(dbSchema=db_schema, questionAugmented=question_augmented, schemaLinkings=schema_link_str)
@@ -179,26 +179,27 @@ class Text2DSLAgentAutoCoT(Text2DSLAgentBase):
def generate_sql_prompt_pool(self, question: str, domain_name: str,fields_list: List[str],
schema_link_str_pool: List[str], current_date: str, prior_schema_links: Mapping[str,str], prior_exts:str,
fewshot_example_list_pool:List[List[Mapping[str, str]]])-> List[str]:
fewshot_example_list_pool:List[List[Mapping[str, str]]], terms_list: Optional[List[Dict]] = [])-> List[str]:
sql_prompt_pool = []
for schema_link_str, fewshot_example_list in zip(schema_link_str_pool, fewshot_example_list_pool):
sql_prompt = self.generate_sql_prompt(question, domain_name, fields_list, schema_link_str, current_date, prior_schema_links, prior_exts, fewshot_example_list)
sql_prompt = self.generate_sql_prompt(question, domain_name, fields_list, schema_link_str, current_date, prior_schema_links, prior_exts, fewshot_example_list, terms_list=terms_list)
sql_prompt_pool.append(sql_prompt)
return sql_prompt_pool
def generate_schema_linking_sql_prompt(self, question: str, current_date:str, domain_name: str, fields_list: List[str],
prior_schema_links: Mapping[str,str], prior_exts:str, fewshot_example_list:List[Mapping[str, str]]):
prior_schema_links: Mapping[str,str], prior_exts:str, fewshot_example_list:List[Mapping[str, str]], terms_list: Optional[List[Dict]] = []):
instruction = "# Find the schema_links for generating SQL queries for each question based on the database schema and Foreign keys. Then use the the schema links to generate the SQL queries for each of the questions."
example_keys = ["questionAugmented", "dbSchema", "generatedSchemaLinkingCoT","sql"]
example_template = "{dbSchema}\nQ: {questionAugmented}\nA: {generatedSchemaLinkingCoT}\nSQL: {sql}"
example_template = "{dbSchema}\nQ: {questionAugmented}\nA: {generatedSchemaLinkingCoT}\nSQL: {sql}\n"
fewshot_prompt = self.sql_example_prompter.make_few_shot_example_prompt(few_shot_template=example_template,
example_keys=example_keys,
few_shot_example_meta_list=fewshot_example_list)
question_augmented, db_schema, _ = transform_sql_example(question, current_date, domain_name, fields_list, prior_schema_links, prior_exts)
question_augmented, db_schema, _ = transform_sql_example(question, current_date, domain_name, fields_list, prior_schema_links, prior_exts, terms_list)
new_case_template = """{dbSchema}\nQ: {questionAugmented1}\nA: Lets think step by step. In the question "{questionAugmented2}", we are asked:"""
new_case_prompt = new_case_template.format(dbSchema=db_schema, questionAugmented1=question_augmented, questionAugmented2=question_augmented)
@@ -208,10 +209,10 @@ class Text2DSLAgentAutoCoT(Text2DSLAgentBase):
return prompt
def generate_schema_linking_sql_prompt_pool(self, question: str, current_date:str, domain_name: str, fields_list: List[str],
prior_schema_links: Mapping[str,str], prior_exts:str, fewshot_example_list_pool:List[List[Mapping[str, str]]])-> List[str]:
prior_schema_links: Mapping[str,str], prior_exts:str, fewshot_example_list_pool:List[List[Mapping[str, str]]], terms_list: Optional[List[Dict]] = [])-> List[str]:
schema_linking_sql_prompt_pool = []
for fewshot_example_list in fewshot_example_list_pool:
schema_linking_sql_prompt = self.generate_schema_linking_sql_prompt(question, current_date, domain_name, fields_list, prior_schema_links, prior_exts, fewshot_example_list)
schema_linking_sql_prompt = self.generate_schema_linking_sql_prompt(question, current_date, domain_name, fields_list, prior_schema_links, prior_exts, fewshot_example_list, terms_list=terms_list)
schema_linking_sql_prompt_pool.append(schema_linking_sql_prompt)
return schema_linking_sql_prompt_pool
@@ -219,7 +220,7 @@ class Text2DSLAgentAutoCoT(Text2DSLAgentBase):
async def async_query2sql(self, question: str, filter_condition: Mapping[str,str],
model_name: str, fields_list: List[str],
current_date: str, prior_schema_links: Mapping[str,str], prior_exts: str,
llm_config:dict):
llm_config:dict, terms_list: Optional[List[Dict]] = []):
logger.info("question: {}".format(question))
logger.info("filter_condition: {}".format(filter_condition))
logger.info("model_name: {}".format(model_name))
@@ -227,6 +228,8 @@ class Text2DSLAgentAutoCoT(Text2DSLAgentBase):
logger.info("current_date: {}".format(current_date))
logger.info("prior_schema_links: {}".format(prior_schema_links))
logger.info("prior_exts: {}".format(prior_exts))
logger.info("terms_list: {}".format(terms_list))
fewshot_example_meta_list = self.get_examples_candidates(question, filter_condition, self.num_examples)
schema_linking_prompt = self.generate_schema_linking_prompt(question, current_date, model_name, fields_list, prior_schema_links, prior_exts, fewshot_example_meta_list)
@@ -238,7 +241,7 @@ class Text2DSLAgentAutoCoT(Text2DSLAgentBase):
schema_link_str = schema_link_parse(schema_link_output)
logger.debug("schema_link_str->{}".format(schema_link_str))
sql_prompt = self.generate_sql_prompt(question, model_name, fields_list, schema_link_str, current_date, prior_schema_links, prior_exts, fewshot_example_meta_list)
sql_prompt = self.generate_sql_prompt(question, model_name, fields_list, schema_link_str, current_date, prior_schema_links, prior_exts, fewshot_example_meta_list, terms_list=terms_list)
logger.debug("sql_prompt->{}".format(sql_prompt))
sql_output = await llm._call_async(sql_prompt)
@@ -264,7 +267,7 @@ class Text2DSLAgentAutoCoT(Text2DSLAgentBase):
async def async_query2sql_shortcut(self, question: str, filter_condition: Mapping[str,str],
model_name: str, fields_list: List[str],
current_date: str, prior_schema_links: Mapping[str,str], prior_exts: str,
llm_config:dict):
llm_config:dict, terms_list: Optional[List[Dict]] = []):
logger.info("question: {}".format(question))
logger.info("filter_condition: {}".format(filter_condition))
logger.info("model_name: {}".format(model_name))
@@ -272,9 +275,10 @@ class Text2DSLAgentAutoCoT(Text2DSLAgentBase):
logger.info("current_date: {}".format(current_date))
logger.info("prior_schema_links: {}".format(prior_schema_links))
logger.info("prior_exts: {}".format(prior_exts))
logger.info("terms_list: {}".format(terms_list))
fewshot_example_meta_list = self.get_examples_candidates(question, filter_condition, self.num_examples)
schema_linking_sql_shortcut_prompt = self.generate_schema_linking_sql_prompt(question, current_date, model_name, fields_list, prior_schema_links, prior_exts, fewshot_example_meta_list)
schema_linking_sql_shortcut_prompt = self.generate_schema_linking_sql_prompt(question, current_date, model_name, fields_list, prior_schema_links, prior_exts, fewshot_example_meta_list, terms_list)
logger.debug("schema_linking_sql_shortcut_prompt->{}".format(schema_linking_sql_shortcut_prompt))
llm = get_llm(llm_config)
schema_linking_sql_shortcut_output = await llm._call_async(schema_linking_sql_shortcut_prompt)
@@ -316,9 +320,9 @@ class Text2DSLAgentAutoCoT(Text2DSLAgentBase):
return schema_linking_str_pool, schema_linking_output_pool, schema_linking_prompt_pool
async def generate_sql_tasks(self, question: str, model_name: str, fields_list: List[str], schema_link_str_pool: List[str],
current_date: str, prior_schema_links: Mapping[str,str], prior_exts: str, fewshot_example_list_combo:List[List[Mapping[str, str]]], llm_config: dict):
current_date: str, prior_schema_links: Mapping[str,str], prior_exts: str, fewshot_example_list_combo:List[List[Mapping[str, str]]], llm_config: dict, terms_list: Optional[List[Dict]] = []):
sql_prompt_pool = self.generate_sql_prompt_pool(question, model_name, fields_list, schema_link_str_pool, current_date, prior_schema_links, prior_exts, fewshot_example_list_combo)
sql_prompt_pool = self.generate_sql_prompt_pool(question, model_name, fields_list, schema_link_str_pool, current_date, prior_schema_links, prior_exts, fewshot_example_list_combo, terms_list=terms_list)
logger.debug("sql_prompt_pool->{}".format(sql_prompt_pool))
llm = get_llm(llm_config)
sql_output_pool = await asyncio.gather(*[llm._call_async(sql_prompt) for sql_prompt in sql_prompt_pool])
@@ -328,8 +332,8 @@ class Text2DSLAgentAutoCoT(Text2DSLAgentBase):
async def generate_schema_linking_sql_tasks(self, question: str, model_name: str, fields_list: List[str],
current_date: str, prior_schema_links: Mapping[str,str], prior_exts: str,
fewshot_example_list_combo:List[List[Mapping[str, str]]],llm_config: dict):
schema_linking_sql_prompt_pool = self.generate_schema_linking_sql_prompt_pool(question, current_date, model_name, fields_list, prior_schema_links, prior_exts, fewshot_example_list_combo)
fewshot_example_list_combo:List[List[Mapping[str, str]]],llm_config: dict, terms_list: Optional[List[Dict]] = []):
schema_linking_sql_prompt_pool = self.generate_schema_linking_sql_prompt_pool(question, current_date, model_name, fields_list, prior_schema_links, prior_exts, fewshot_example_list_combo, terms_list=terms_list)
llm = get_llm(llm_config)
schema_linking_sql_output_task_pool = [llm._call_async(schema_linking_sql_prompt) for schema_linking_sql_prompt in schema_linking_sql_prompt_pool]
schema_linking_sql_output_res_pool = await asyncio.gather(*schema_linking_sql_output_task_pool)
@@ -339,7 +343,7 @@ class Text2DSLAgentAutoCoT(Text2DSLAgentBase):
async def tasks_run(self, question: str, filter_condition: Mapping[str,str],
model_name: str, fields_list: List[str],
current_date: str, prior_schema_links: Mapping[str,str], prior_exts: str, llm_config: dict):
current_date: str, prior_schema_links: Mapping[str,str], prior_exts: str, llm_config: dict, terms_list: Optional[List[Dict]] = []):
logger.info("question: {}".format(question))
logger.info("filter_condition: {}".format(filter_condition))
logger.info("model_name: {}".format(model_name))
@@ -347,18 +351,20 @@ class Text2DSLAgentAutoCoT(Text2DSLAgentBase):
logger.info("current_date: {}".format(current_date))
logger.info("prior_schema_links: {}".format(prior_schema_links))
logger.info("prior_exts: {}".format(prior_exts))
logger.info("terms_list: {}".format(terms_list))
fewshot_example_meta_list = self.get_examples_candidates(question, filter_condition, self.num_examples)
fewshot_example_list_combo = self.get_fewshot_example_combos(fewshot_example_meta_list, self.num_fewshots)
schema_linking_candidate_list, _, schema_linking_prompt_list = await self.generate_schema_linking_tasks(question, model_name, fields_list, current_date, prior_schema_links, prior_exts, fewshot_example_list_combo, llm_config)
schema_linking_candidate_list, _, schema_linking_prompt_list = await self.generate_schema_linking_tasks(question, model_name, fields_list, current_date, prior_schema_links, prior_exts, fewshot_example_list_combo, llm_config,)
logger.debug(f'schema_linking_candidate_list:{schema_linking_candidate_list}')
schema_linking_candidate_sorted_list = self.schema_linking_list_str_unify(schema_linking_candidate_list)
logger.debug(f'schema_linking_candidate_sorted_list:{schema_linking_candidate_sorted_list}')
schema_linking_output_max, schema_linking_output_vote_percentage = self.self_consistency_vote(schema_linking_candidate_sorted_list)
sql_output_candicates, sql_output_prompt_list = await self.generate_sql_tasks(question, model_name, fields_list, schema_linking_candidate_list, current_date, prior_schema_links, prior_exts, fewshot_example_list_combo, llm_config)
sql_output_candicates, sql_output_prompt_list = await self.generate_sql_tasks(question, model_name, fields_list, schema_linking_candidate_list, current_date, prior_schema_links, prior_exts, fewshot_example_list_combo, llm_config, terms_list=terms_list)
logger.debug(f'sql_output_candicates:{sql_output_candicates}')
sql_output_max, sql_output_vote_percentage = self.self_consistency_vote(sql_output_candicates)
@@ -383,7 +389,7 @@ class Text2DSLAgentAutoCoT(Text2DSLAgentBase):
return resp
async def tasks_run_shortcut(self, question: str, filter_condition: Mapping[str,str], model_name: str, fields_list: List[str],
current_date: str, prior_schema_links: Mapping[str,str], prior_exts: str, llm_config: dict):
current_date: str, prior_schema_links: Mapping[str,str], prior_exts: str, llm_config: dict, terms_list: Optional[List[Dict]] = []):
logger.info("question: {}".format(question))
logger.info("filter_condition: {}".format(filter_condition))
logger.info("model_name: {}".format(model_name))
@@ -391,11 +397,12 @@ class Text2DSLAgentAutoCoT(Text2DSLAgentBase):
logger.info("current_date: {}".format(current_date))
logger.info("prior_schema_links: {}".format(prior_schema_links))
logger.info("prior_exts: {}".format(prior_exts))
logger.info("terms_list: {}".format(terms_list))
fewshot_example_meta_list = self.get_examples_candidates(question, filter_condition, self.num_examples)
fewshot_example_list_combo = self.get_fewshot_example_combos(fewshot_example_meta_list, self.num_fewshots)
schema_linking_sql_output_candidates, schema_linking_sql_prompt_list, _ = await self.generate_schema_linking_sql_tasks(question, model_name, fields_list, current_date, prior_schema_links, prior_exts, fewshot_example_list_combo)
schema_linking_sql_output_candidates, schema_linking_sql_prompt_list, _ = await self.generate_schema_linking_sql_tasks(question, model_name, fields_list, current_date, prior_schema_links, prior_exts, fewshot_example_list_combo, terms_list=terms_list)
logger.debug(f'schema_linking_sql_output_candidates:{schema_linking_sql_output_candidates}')
schema_linking_output_candidate_list = [combo_schema_link_parse(schema_linking_sql_output_candidate) for schema_linking_sql_output_candidate in schema_linking_sql_output_candidates]
logger.debug(f'schema_linking_sql_output_candidate_list:{schema_linking_output_candidate_list}')
@@ -781,26 +788,26 @@ class Text2DSLAgentWrapper(object):
async def async_query2sql(self, question: str, filter_condition: Mapping[str,str],
model_name: str, fields_list: List[str],
data_date: str, prior_schema_links: Mapping[str,str], prior_exts: str, sql_generation_mode: str, llm_config: dict):
data_date: str, prior_schema_links: Mapping[str,str], prior_exts: str, sql_generation_mode: str, llm_config: dict, terms_list: Optional[List[Dict]] = []):
if sql_generation_mode not in (sql_mode.value for sql_mode in SqlModeEnum):
raise ValueError(f"sql_generation_mode: {sql_generation_mode} is not in SqlModeEnum")
if sql_generation_mode == '1_pass_auto_cot':
logger.info(f"sql wrapper: {sql_generation_mode}")
resp = await self.sql_agent_act.async_query2sql_shortcut(question=question, filter_condition=filter_condition, model_name=model_name, fields_list=fields_list, current_date=data_date, prior_schema_links=prior_schema_links, prior_exts=prior_exts, llm_config=llm_config)
resp = await self.sql_agent_act.async_query2sql_shortcut(question=question, filter_condition=filter_condition, model_name=model_name, fields_list=fields_list, current_date=data_date, prior_schema_links=prior_schema_links, prior_exts=prior_exts, llm_config=llm_config, terms_list=terms_list)
return resp
elif sql_generation_mode == '1_pass_auto_cot_self_consistency':
logger.info(f"sql wrapper: {sql_generation_mode}")
resp = await self.sql_agent_act.tasks_run_shortcut(question=question, filter_condition=filter_condition, model_name=model_name, fields_list=fields_list, current_date=data_date, prior_schema_links=prior_schema_links, prior_exts=prior_exts, llm_config=llm_config)
resp = await self.sql_agent_act.tasks_run_shortcut(question=question, filter_condition=filter_condition, model_name=model_name, fields_list=fields_list, current_date=data_date, prior_schema_links=prior_schema_links, prior_exts=prior_exts, llm_config=llm_config, terms_list=terms_list)
return resp
elif sql_generation_mode == '2_pass_auto_cot':
logger.info(f"sql wrapper: {sql_generation_mode}")
resp = await self.sql_agent_act.async_query2sql(question=question, filter_condition=filter_condition, model_name=model_name, fields_list=fields_list, current_date=data_date, prior_schema_links=prior_schema_links, prior_exts=prior_exts, llm_config=llm_config)
resp = await self.sql_agent_act.async_query2sql(question=question, filter_condition=filter_condition, model_name=model_name, fields_list=fields_list, current_date=data_date, prior_schema_links=prior_schema_links, prior_exts=prior_exts, llm_config=llm_config, terms_list=terms_list)
return resp
elif sql_generation_mode == '2_pass_auto_cot_self_consistency':
logger.info(f"sql wrapper: {sql_generation_mode}")
resp = await self.sql_agent_act.tasks_run(question=question, filter_condition=filter_condition, model_name=model_name, fields_list=fields_list, current_date=data_date, prior_schema_links=prior_schema_links, prior_exts=prior_exts, llm_config=llm_config)
resp = await self.sql_agent_act.tasks_run(question=question, filter_condition=filter_condition, model_name=model_name, fields_list=fields_list, current_date=data_date, prior_schema_links=prior_schema_links, prior_exts=prior_exts, llm_config=llm_config, terms_list=terms_list)
return resp
else:
raise ValueError(f'sql_generation_mode:{sql_generation_mode} is not in SqlModeEnum')