|
|
|
|
@@ -17,17 +17,17 @@ from instances.logging_instance import logger
|
|
|
|
|
from s2sql.constructor import FewShotPromptTemplate2
|
|
|
|
|
from s2sql.output_parser import schema_link_parse, combo_schema_link_parse, combo_sql_parse
|
|
|
|
|
from s2sql.auto_cot_run import transform_sql_example, transform_sql_example_autoCoT_run
|
|
|
|
|
from instances.llm_instance import get_llm
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class Text2DSLAgentBase(object):
|
|
|
|
|
def __init__(self, num_fewshots:int, num_examples:int, num_self_consistency:int,
|
|
|
|
|
sql_example_prompter:FewShotPromptTemplate2, llm: BaseLLM) -> None:
|
|
|
|
|
sql_example_prompter:FewShotPromptTemplate2) -> None:
|
|
|
|
|
self.num_fewshots = num_fewshots
|
|
|
|
|
self.num_examples = num_examples
|
|
|
|
|
assert self.num_fewshots <= self.num_examples
|
|
|
|
|
self.num_self_consistency = num_self_consistency
|
|
|
|
|
|
|
|
|
|
self.llm = llm
|
|
|
|
|
self.sql_example_prompter = sql_example_prompter
|
|
|
|
|
|
|
|
|
|
def get_examples_candidates(self, question: str, filter_condition: Mapping[str, str], num_examples: int)->List[Mapping[str, str]]:
|
|
|
|
|
@@ -82,9 +82,9 @@ class Text2DSLAgentBase(object):
|
|
|
|
|
|
|
|
|
|
class Text2DSLAgentAutoCoT(Text2DSLAgentBase):
|
|
|
|
|
def __init__(self, num_fewshots:int, num_examples:int, num_self_consistency:int,
|
|
|
|
|
sql_example_prompter:FewShotPromptTemplate2, llm: BaseLLM,
|
|
|
|
|
sql_example_prompter:FewShotPromptTemplate2,
|
|
|
|
|
auto_cot_min_window_size: int, auto_cot_max_window_size: int):
|
|
|
|
|
super().__init__(num_fewshots, num_examples, num_self_consistency, sql_example_prompter, llm)
|
|
|
|
|
super().__init__(num_fewshots, num_examples, num_self_consistency, sql_example_prompter)
|
|
|
|
|
|
|
|
|
|
assert auto_cot_min_window_size <= auto_cot_max_window_size
|
|
|
|
|
self.auto_cot_min_window_size = auto_cot_min_window_size
|
|
|
|
|
@@ -218,7 +218,8 @@ class Text2DSLAgentAutoCoT(Text2DSLAgentBase):
|
|
|
|
|
|
|
|
|
|
async def async_query2sql(self, question: str, filter_condition: Mapping[str,str],
|
|
|
|
|
model_name: str, fields_list: List[str],
|
|
|
|
|
current_date: str, prior_schema_links: Mapping[str,str], prior_exts: str):
|
|
|
|
|
current_date: str, prior_schema_links: Mapping[str,str], prior_exts: str,
|
|
|
|
|
llm_config:dict):
|
|
|
|
|
logger.info("question: {}".format(question))
|
|
|
|
|
logger.info("filter_condition: {}".format(filter_condition))
|
|
|
|
|
logger.info("model_name: {}".format(model_name))
|
|
|
|
|
@@ -230,7 +231,8 @@ class Text2DSLAgentAutoCoT(Text2DSLAgentBase):
|
|
|
|
|
fewshot_example_meta_list = self.get_examples_candidates(question, filter_condition, self.num_examples)
|
|
|
|
|
schema_linking_prompt = self.generate_schema_linking_prompt(question, current_date, model_name, fields_list, prior_schema_links, prior_exts, fewshot_example_meta_list)
|
|
|
|
|
logger.debug("schema_linking_prompt->{}".format(schema_linking_prompt))
|
|
|
|
|
schema_link_output = await self.llm._call_async(schema_linking_prompt)
|
|
|
|
|
llm = get_llm(llm_config)
|
|
|
|
|
schema_link_output = await llm._call_async(schema_linking_prompt)
|
|
|
|
|
logger.debug("schema_link_output->{}".format(schema_link_output))
|
|
|
|
|
|
|
|
|
|
schema_link_str = schema_link_parse(schema_link_output)
|
|
|
|
|
@@ -238,7 +240,7 @@ class Text2DSLAgentAutoCoT(Text2DSLAgentBase):
|
|
|
|
|
|
|
|
|
|
sql_prompt = self.generate_sql_prompt(question, model_name, fields_list, schema_link_str, current_date, prior_schema_links, prior_exts, fewshot_example_meta_list)
|
|
|
|
|
logger.debug("sql_prompt->{}".format(sql_prompt))
|
|
|
|
|
sql_output = await self.llm._call_async(sql_prompt)
|
|
|
|
|
sql_output = await llm._call_async(sql_prompt)
|
|
|
|
|
|
|
|
|
|
resp = dict()
|
|
|
|
|
resp['question'] = question
|
|
|
|
|
@@ -261,7 +263,8 @@ class Text2DSLAgentAutoCoT(Text2DSLAgentBase):
|
|
|
|
|
|
|
|
|
|
async def async_query2sql_shortcut(self, question: str, filter_condition: Mapping[str,str],
|
|
|
|
|
model_name: str, fields_list: List[str],
|
|
|
|
|
current_date: str, prior_schema_links: Mapping[str,str], prior_exts: str):
|
|
|
|
|
current_date: str, prior_schema_links: Mapping[str,str], prior_exts: str,
|
|
|
|
|
llm_config:dict):
|
|
|
|
|
logger.info("question: {}".format(question))
|
|
|
|
|
logger.info("filter_condition: {}".format(filter_condition))
|
|
|
|
|
logger.info("model_name: {}".format(model_name))
|
|
|
|
|
@@ -273,7 +276,8 @@ class Text2DSLAgentAutoCoT(Text2DSLAgentBase):
|
|
|
|
|
fewshot_example_meta_list = self.get_examples_candidates(question, filter_condition, self.num_examples)
|
|
|
|
|
schema_linking_sql_shortcut_prompt = self.generate_schema_linking_sql_prompt(question, current_date, model_name, fields_list, prior_schema_links, prior_exts, fewshot_example_meta_list)
|
|
|
|
|
logger.debug("schema_linking_sql_shortcut_prompt->{}".format(schema_linking_sql_shortcut_prompt))
|
|
|
|
|
schema_linking_sql_shortcut_output = await self.llm._call_async(schema_linking_sql_shortcut_prompt)
|
|
|
|
|
llm = get_llm(llm_config)
|
|
|
|
|
schema_linking_sql_shortcut_output = await llm._call_async(schema_linking_sql_shortcut_prompt)
|
|
|
|
|
logger.debug("schema_linking_sql_shortcut_output->{}".format(schema_linking_sql_shortcut_output))
|
|
|
|
|
|
|
|
|
|
schema_linking_str = combo_schema_link_parse(schema_linking_sql_shortcut_output)
|
|
|
|
|
@@ -298,11 +302,13 @@ class Text2DSLAgentAutoCoT(Text2DSLAgentBase):
|
|
|
|
|
return resp
|
|
|
|
|
|
|
|
|
|
async def generate_schema_linking_tasks(self, question: str, model_name: str, fields_list: List[str],
|
|
|
|
|
current_date: str, prior_schema_links: Mapping[str,str], prior_exts: str, fewshot_example_list_combo:List[List[Mapping[str, str]]]):
|
|
|
|
|
current_date: str, prior_schema_links: Mapping[str,str], prior_exts: str,
|
|
|
|
|
fewshot_example_list_combo:List[List[Mapping[str, str]]], llm_config: dict):
|
|
|
|
|
|
|
|
|
|
schema_linking_prompt_pool = self.generate_schema_linking_prompt_pool(question, current_date, model_name, fields_list, prior_schema_links, prior_exts, fewshot_example_list_combo)
|
|
|
|
|
logger.debug("schema_linking_prompt_pool->{}".format(schema_linking_prompt_pool))
|
|
|
|
|
schema_linking_output_pool = await asyncio.gather(*[self.llm._call_async(schema_linking_prompt) for schema_linking_prompt in schema_linking_prompt_pool])
|
|
|
|
|
llm = get_llm(llm_config)
|
|
|
|
|
schema_linking_output_pool = await asyncio.gather(*[llm._call_async(schema_linking_prompt) for schema_linking_prompt in schema_linking_prompt_pool])
|
|
|
|
|
logger.debug("schema_linking_output_pool->{}".format(schema_linking_output_pool))
|
|
|
|
|
|
|
|
|
|
schema_linking_str_pool = [schema_link_parse(schema_linking_output) for schema_linking_output in schema_linking_output_pool]
|
|
|
|
|
@@ -310,19 +316,22 @@ class Text2DSLAgentAutoCoT(Text2DSLAgentBase):
|
|
|
|
|
return schema_linking_str_pool, schema_linking_output_pool, schema_linking_prompt_pool
|
|
|
|
|
|
|
|
|
|
async def generate_sql_tasks(self, question: str, model_name: str, fields_list: List[str], schema_link_str_pool: List[str],
|
|
|
|
|
current_date: str, prior_schema_links: Mapping[str,str], prior_exts: str, fewshot_example_list_combo:List[List[Mapping[str, str]]]):
|
|
|
|
|
current_date: str, prior_schema_links: Mapping[str,str], prior_exts: str, fewshot_example_list_combo:List[List[Mapping[str, str]]], llm_config: dict):
|
|
|
|
|
|
|
|
|
|
sql_prompt_pool = self.generate_sql_prompt_pool(question, model_name, fields_list, schema_link_str_pool, current_date, prior_schema_links, prior_exts, fewshot_example_list_combo)
|
|
|
|
|
logger.debug("sql_prompt_pool->{}".format(sql_prompt_pool))
|
|
|
|
|
sql_output_pool = await asyncio.gather(*[self.llm._call_async(sql_prompt) for sql_prompt in sql_prompt_pool])
|
|
|
|
|
llm = get_llm(llm_config)
|
|
|
|
|
sql_output_pool = await asyncio.gather(*[llm._call_async(sql_prompt) for sql_prompt in sql_prompt_pool])
|
|
|
|
|
logger.debug("sql_output_pool->{}".format(sql_output_pool))
|
|
|
|
|
|
|
|
|
|
return sql_output_pool, sql_prompt_pool
|
|
|
|
|
|
|
|
|
|
async def generate_schema_linking_sql_tasks(self, question: str, model_name: str, fields_list: List[str],
|
|
|
|
|
current_date: str, prior_schema_links: Mapping[str,str], prior_exts: str, fewshot_example_list_combo:List[List[Mapping[str, str]]]):
|
|
|
|
|
current_date: str, prior_schema_links: Mapping[str,str], prior_exts: str,
|
|
|
|
|
fewshot_example_list_combo:List[List[Mapping[str, str]]],llm_config: dict):
|
|
|
|
|
schema_linking_sql_prompt_pool = self.generate_schema_linking_sql_prompt_pool(question, current_date, model_name, fields_list, prior_schema_links, prior_exts, fewshot_example_list_combo)
|
|
|
|
|
schema_linking_sql_output_task_pool = [self.llm._call_async(schema_linking_sql_prompt) for schema_linking_sql_prompt in schema_linking_sql_prompt_pool]
|
|
|
|
|
llm = get_llm(llm_config)
|
|
|
|
|
schema_linking_sql_output_task_pool = [llm._call_async(schema_linking_sql_prompt) for schema_linking_sql_prompt in schema_linking_sql_prompt_pool]
|
|
|
|
|
schema_linking_sql_output_res_pool = await asyncio.gather(*schema_linking_sql_output_task_pool)
|
|
|
|
|
logger.debug("schema_linking_sql_output_res_pool->{}".format(schema_linking_sql_output_res_pool))
|
|
|
|
|
|
|
|
|
|
@@ -330,7 +339,7 @@ class Text2DSLAgentAutoCoT(Text2DSLAgentBase):
|
|
|
|
|
|
|
|
|
|
async def tasks_run(self, question: str, filter_condition: Mapping[str,str],
|
|
|
|
|
model_name: str, fields_list: List[str],
|
|
|
|
|
current_date: str, prior_schema_links: Mapping[str,str], prior_exts: str):
|
|
|
|
|
current_date: str, prior_schema_links: Mapping[str,str], prior_exts: str, llm_config: dict):
|
|
|
|
|
logger.info("question: {}".format(question))
|
|
|
|
|
logger.info("filter_condition: {}".format(filter_condition))
|
|
|
|
|
logger.info("model_name: {}".format(model_name))
|
|
|
|
|
@@ -342,14 +351,14 @@ class Text2DSLAgentAutoCoT(Text2DSLAgentBase):
|
|
|
|
|
fewshot_example_meta_list = self.get_examples_candidates(question, filter_condition, self.num_examples)
|
|
|
|
|
fewshot_example_list_combo = self.get_fewshot_example_combos(fewshot_example_meta_list, self.num_fewshots)
|
|
|
|
|
|
|
|
|
|
schema_linking_candidate_list, _, schema_linking_prompt_list = await self.generate_schema_linking_tasks(question, model_name, fields_list, current_date, prior_schema_links, prior_exts, fewshot_example_list_combo)
|
|
|
|
|
schema_linking_candidate_list, _, schema_linking_prompt_list = await self.generate_schema_linking_tasks(question, model_name, fields_list, current_date, prior_schema_links, prior_exts, fewshot_example_list_combo, llm_config)
|
|
|
|
|
logger.debug(f'schema_linking_candidate_list:{schema_linking_candidate_list}')
|
|
|
|
|
schema_linking_candidate_sorted_list = self.schema_linking_list_str_unify(schema_linking_candidate_list)
|
|
|
|
|
logger.debug(f'schema_linking_candidate_sorted_list:{schema_linking_candidate_sorted_list}')
|
|
|
|
|
|
|
|
|
|
schema_linking_output_max, schema_linking_output_vote_percentage = self.self_consistency_vote(schema_linking_candidate_sorted_list)
|
|
|
|
|
|
|
|
|
|
sql_output_candicates, sql_output_prompt_list = await self.generate_sql_tasks(question, model_name, fields_list, schema_linking_candidate_list, current_date, prior_schema_links, prior_exts, fewshot_example_list_combo)
|
|
|
|
|
sql_output_candicates, sql_output_prompt_list = await self.generate_sql_tasks(question, model_name, fields_list, schema_linking_candidate_list, current_date, prior_schema_links, prior_exts, fewshot_example_list_combo, llm_config)
|
|
|
|
|
logger.debug(f'sql_output_candicates:{sql_output_candicates}')
|
|
|
|
|
sql_output_max, sql_output_vote_percentage = self.self_consistency_vote(sql_output_candicates)
|
|
|
|
|
|
|
|
|
|
@@ -374,7 +383,7 @@ class Text2DSLAgentAutoCoT(Text2DSLAgentBase):
|
|
|
|
|
return resp
|
|
|
|
|
|
|
|
|
|
async def tasks_run_shortcut(self, question: str, filter_condition: Mapping[str,str], model_name: str, fields_list: List[str],
|
|
|
|
|
current_date: str, prior_schema_links: Mapping[str,str], prior_exts: str):
|
|
|
|
|
current_date: str, prior_schema_links: Mapping[str,str], prior_exts: str, llm_config: dict):
|
|
|
|
|
logger.info("question: {}".format(question))
|
|
|
|
|
logger.info("filter_condition: {}".format(filter_condition))
|
|
|
|
|
logger.info("model_name: {}".format(model_name))
|
|
|
|
|
@@ -420,8 +429,8 @@ class Text2DSLAgentAutoCoT(Text2DSLAgentBase):
|
|
|
|
|
|
|
|
|
|
class Text2DSLAgent(Text2DSLAgentBase):
|
|
|
|
|
def __init__(self, num_fewshots:int, num_examples:int, num_self_consistency:int,
|
|
|
|
|
sql_example_prompter:FewShotPromptTemplate2, llm: BaseLLM,) -> None:
|
|
|
|
|
super().__init__(num_fewshots, num_examples, num_self_consistency, sql_example_prompter, llm)
|
|
|
|
|
sql_example_prompter:FewShotPromptTemplate2) -> None:
|
|
|
|
|
super().__init__(num_fewshots, num_examples, num_self_consistency, sql_example_prompter)
|
|
|
|
|
|
|
|
|
|
def reload_setting(self, sql_example_ids:List[str], sql_example_units: List[Mapping[str, str]], num_examples:int, num_fewshots:int, num_self_consistency:int):
|
|
|
|
|
self.num_fewshots = num_fewshots
|
|
|
|
|
@@ -554,12 +563,13 @@ class Text2DSLAgent(Text2DSLAgentBase):
|
|
|
|
|
|
|
|
|
|
async def generate_schema_linking_tasks(self, question: str, domain_name: str,
|
|
|
|
|
fields_list: List[str], prior_schema_links: Mapping[str,str],
|
|
|
|
|
fewshot_example_list_combo:List[List[Mapping[str, str]]]):
|
|
|
|
|
fewshot_example_list_combo:List[List[Mapping[str, str]]], llm_config: dict):
|
|
|
|
|
|
|
|
|
|
schema_linking_prompt_pool = self.generate_schema_linking_prompt_pool(question, domain_name,
|
|
|
|
|
fields_list, prior_schema_links,
|
|
|
|
|
fewshot_example_list_combo)
|
|
|
|
|
schema_linking_output_task_pool = [self.llm._call_async(schema_linking_prompt) for schema_linking_prompt in schema_linking_prompt_pool]
|
|
|
|
|
llm = get_llm(llm_config)
|
|
|
|
|
schema_linking_output_task_pool = [llm._call_async(schema_linking_prompt) for schema_linking_prompt in schema_linking_prompt_pool]
|
|
|
|
|
schema_linking_output_pool = await asyncio.gather(*schema_linking_output_task_pool)
|
|
|
|
|
logger.debug(f'schema_linking_output_pool:{schema_linking_output_pool}')
|
|
|
|
|
|
|
|
|
|
@@ -568,25 +578,29 @@ class Text2DSLAgent(Text2DSLAgentBase):
|
|
|
|
|
return schema_linking_str_pool
|
|
|
|
|
|
|
|
|
|
async def generate_sql_tasks(self, question: str, domain_name: str, data_date: str,
|
|
|
|
|
schema_link_str_pool: List[str], fewshot_example_list_combo:List[List[Mapping[str, str]]]):
|
|
|
|
|
schema_link_str_pool: List[str], fewshot_example_list_combo:List[List[Mapping[str, str]]],
|
|
|
|
|
llm_config: dict):
|
|
|
|
|
|
|
|
|
|
sql_prompt_pool = self.generate_sql_prompt_pool(question, domain_name, schema_link_str_pool, data_date, fewshot_example_list_combo)
|
|
|
|
|
sql_output_task_pool = [self.llm._call_async(sql_prompt) for sql_prompt in sql_prompt_pool]
|
|
|
|
|
llm = get_llm(llm_config)
|
|
|
|
|
sql_output_task_pool = [llm._call_async(sql_prompt) for sql_prompt in sql_prompt_pool]
|
|
|
|
|
sql_output_res_pool = await asyncio.gather(*sql_output_task_pool)
|
|
|
|
|
logger.debug(f'sql_output_res_pool:{sql_output_res_pool}')
|
|
|
|
|
|
|
|
|
|
return sql_output_res_pool
|
|
|
|
|
|
|
|
|
|
async def generate_schema_linking_sql_tasks(self, question: str, domain_name: str, fields_list: List[str], data_date: str,
|
|
|
|
|
prior_schema_links: Mapping[str,str], fewshot_example_list_combo:List[List[Mapping[str, str]]]):
|
|
|
|
|
prior_schema_links: Mapping[str,str], fewshot_example_list_combo:List[List[Mapping[str, str]]],
|
|
|
|
|
llm_config: dict):
|
|
|
|
|
schema_linking_sql_prompt_pool = self.generate_schema_linking_sql_prompt_pool(question, domain_name, fields_list, data_date, prior_schema_links, fewshot_example_list_combo)
|
|
|
|
|
schema_linking_sql_output_task_pool = [self.llm._call_async(schema_linking_sql_prompt) for schema_linking_sql_prompt in schema_linking_sql_prompt_pool]
|
|
|
|
|
llm = get_llm(llm_config)
|
|
|
|
|
schema_linking_sql_output_task_pool = [llm._call_async(schema_linking_sql_prompt) for schema_linking_sql_prompt in schema_linking_sql_prompt_pool]
|
|
|
|
|
schema_linking_sql_output_res_pool = await asyncio.gather(*schema_linking_sql_output_task_pool)
|
|
|
|
|
logger.debug(f'schema_linking_sql_output_res_pool:{schema_linking_sql_output_res_pool}')
|
|
|
|
|
|
|
|
|
|
return schema_linking_sql_output_res_pool
|
|
|
|
|
|
|
|
|
|
async def tasks_run(self, question: str, filter_condition: Mapping[str, str], domain_name: str, fields_list: List[str], prior_schema_links: Mapping[str,str], data_date: str, prior_exts: str):
|
|
|
|
|
async def tasks_run(self, question: str, filter_condition: Mapping[str, str], domain_name: str, fields_list: List[str], prior_schema_links: Mapping[str,str], data_date: str, prior_exts: str, llm_config: dict):
|
|
|
|
|
logger.info("question: {}".format(question))
|
|
|
|
|
logger.info("domain_name: {}".format(domain_name))
|
|
|
|
|
logger.info("fields_list: {}".format(fields_list))
|
|
|
|
|
@@ -601,7 +615,7 @@ class Text2DSLAgent(Text2DSLAgentBase):
|
|
|
|
|
fewshot_example_meta_list = self.get_examples_candidates(question, filter_condition, self.num_examples)
|
|
|
|
|
fewshot_example_list_combo = self.get_fewshot_example_combos(fewshot_example_meta_list, self.num_fewshots)
|
|
|
|
|
|
|
|
|
|
schema_linking_candidate_list = await self.generate_schema_linking_tasks(question, domain_name, fields_list, prior_schema_links, fewshot_example_list_combo)
|
|
|
|
|
schema_linking_candidate_list = await self.generate_schema_linking_tasks(question, domain_name, fields_list, prior_schema_links, fewshot_example_list_combo, llm_config)
|
|
|
|
|
logger.debug(f'schema_linking_candidate_list:{schema_linking_candidate_list}')
|
|
|
|
|
schema_linking_candidate_sorted_list = self.schema_linking_list_str_unify(schema_linking_candidate_list)
|
|
|
|
|
logger.debug(f'schema_linking_candidate_sorted_list:{schema_linking_candidate_sorted_list}')
|
|
|
|
|
@@ -675,7 +689,7 @@ class Text2DSLAgent(Text2DSLAgentBase):
|
|
|
|
|
|
|
|
|
|
async def async_query2sql(self, question: str, filter_condition: Mapping[str,str],
|
|
|
|
|
model_name: str, fields_list: List[str],
|
|
|
|
|
data_date: str, prior_schema_links: Mapping[str,str], prior_exts: str):
|
|
|
|
|
data_date: str, prior_schema_links: Mapping[str,str], prior_exts: str, llm_config: dict):
|
|
|
|
|
logger.info("question: {}".format(question))
|
|
|
|
|
logger.info("model_name: {}".format(model_name))
|
|
|
|
|
logger.info("fields_list: {}".format(fields_list))
|
|
|
|
|
@@ -690,13 +704,14 @@ class Text2DSLAgent(Text2DSLAgentBase):
|
|
|
|
|
fewshot_example_meta_list = self.get_examples_candidates(question, filter_condition, self.num_examples)
|
|
|
|
|
schema_linking_prompt = self.generate_schema_linking_prompt(question, model_name, fields_list, prior_schema_links, fewshot_example_meta_list)
|
|
|
|
|
logger.debug("schema_linking_prompt->{}".format(schema_linking_prompt))
|
|
|
|
|
schema_link_output = await self.llm._call_async(schema_linking_prompt)
|
|
|
|
|
llm = get_llm(llm_config)
|
|
|
|
|
schema_link_output = await llm._call_async(schema_linking_prompt)
|
|
|
|
|
|
|
|
|
|
schema_link_str = schema_link_parse(schema_link_output)
|
|
|
|
|
|
|
|
|
|
sql_prompt = self.generate_sql_prompt(question, model_name, schema_link_str, data_date, fewshot_example_meta_list)
|
|
|
|
|
logger.debug("sql_prompt->{}".format(sql_prompt))
|
|
|
|
|
sql_output = await self.llm._call_async(sql_prompt)
|
|
|
|
|
sql_output = await llm._call_async(sql_prompt)
|
|
|
|
|
|
|
|
|
|
resp = dict()
|
|
|
|
|
resp['question'] = question
|
|
|
|
|
@@ -716,7 +731,8 @@ class Text2DSLAgent(Text2DSLAgentBase):
|
|
|
|
|
|
|
|
|
|
async def async_query2sql_shortcut(self, question: str, filter_condition: Mapping[str,str],
|
|
|
|
|
model_name: str, fields_list: List[str],
|
|
|
|
|
data_date: str, prior_schema_links: Mapping[str,str], prior_exts: str):
|
|
|
|
|
data_date: str, prior_schema_links: Mapping[str,str], prior_exts: str,
|
|
|
|
|
llm_config: dict):
|
|
|
|
|
|
|
|
|
|
logger.info("question: {}".format(question))
|
|
|
|
|
logger.info("model_name: {}".format(model_name))
|
|
|
|
|
@@ -732,7 +748,8 @@ class Text2DSLAgent(Text2DSLAgentBase):
|
|
|
|
|
fewshot_example_meta_list = self.get_examples_candidates(question, filter_condition, self.num_examples)
|
|
|
|
|
schema_linking_sql_shortcut_prompt = self.generate_schema_linking_sql_prompt(question, model_name, data_date, fields_list, prior_schema_links, fewshot_example_meta_list)
|
|
|
|
|
logger.debug("schema_linking_sql_shortcut_prompt->{}".format(schema_linking_sql_shortcut_prompt))
|
|
|
|
|
schema_linking_sql_shortcut_output = await self.llm._call_async(schema_linking_sql_shortcut_prompt)
|
|
|
|
|
llm = get_llm(llm_config)
|
|
|
|
|
schema_linking_sql_shortcut_output = await llm._call_async(schema_linking_sql_shortcut_prompt)
|
|
|
|
|
|
|
|
|
|
schema_linking_str = combo_schema_link_parse(schema_linking_sql_shortcut_output)
|
|
|
|
|
sql_str = combo_sql_parse(schema_linking_sql_shortcut_output)
|
|
|
|
|
@@ -764,26 +781,26 @@ class Text2DSLAgentWrapper(object):
|
|
|
|
|
|
|
|
|
|
async def async_query2sql(self, question: str, filter_condition: Mapping[str,str],
|
|
|
|
|
model_name: str, fields_list: List[str],
|
|
|
|
|
data_date: str, prior_schema_links: Mapping[str,str], prior_exts: str, sql_generation_mode: str):
|
|
|
|
|
data_date: str, prior_schema_links: Mapping[str,str], prior_exts: str, sql_generation_mode: str, llm_config: dict):
|
|
|
|
|
|
|
|
|
|
if sql_generation_mode not in (sql_mode.value for sql_mode in SqlModeEnum):
|
|
|
|
|
raise ValueError(f"sql_generation_mode: {sql_generation_mode} is not in SqlModeEnum")
|
|
|
|
|
|
|
|
|
|
if sql_generation_mode == '1_pass_auto_cot':
|
|
|
|
|
logger.info(f"sql wrapper: {sql_generation_mode}")
|
|
|
|
|
resp = await self.sql_agent_act.async_query2sql_shortcut(question=question, filter_condition=filter_condition, model_name=model_name, fields_list=fields_list, current_date=data_date, prior_schema_links=prior_schema_links, prior_exts=prior_exts)
|
|
|
|
|
resp = await self.sql_agent_act.async_query2sql_shortcut(question=question, filter_condition=filter_condition, model_name=model_name, fields_list=fields_list, current_date=data_date, prior_schema_links=prior_schema_links, prior_exts=prior_exts, llm_config=llm_config)
|
|
|
|
|
return resp
|
|
|
|
|
elif sql_generation_mode == '1_pass_auto_cot_self_consistency':
|
|
|
|
|
logger.info(f"sql wrapper: {sql_generation_mode}")
|
|
|
|
|
resp = await self.sql_agent_act.tasks_run_shortcut(question=question, filter_condition=filter_condition, model_name=model_name, fields_list=fields_list, current_date=data_date, prior_schema_links=prior_schema_links, prior_exts=prior_exts)
|
|
|
|
|
resp = await self.sql_agent_act.tasks_run_shortcut(question=question, filter_condition=filter_condition, model_name=model_name, fields_list=fields_list, current_date=data_date, prior_schema_links=prior_schema_links, prior_exts=prior_exts, llm_config=llm_config)
|
|
|
|
|
return resp
|
|
|
|
|
elif sql_generation_mode == '2_pass_auto_cot':
|
|
|
|
|
logger.info(f"sql wrapper: {sql_generation_mode}")
|
|
|
|
|
resp = await self.sql_agent_act.async_query2sql(question=question, filter_condition=filter_condition, model_name=model_name, fields_list=fields_list, current_date=data_date, prior_schema_links=prior_schema_links, prior_exts=prior_exts)
|
|
|
|
|
resp = await self.sql_agent_act.async_query2sql(question=question, filter_condition=filter_condition, model_name=model_name, fields_list=fields_list, current_date=data_date, prior_schema_links=prior_schema_links, prior_exts=prior_exts, llm_config=llm_config)
|
|
|
|
|
return resp
|
|
|
|
|
elif sql_generation_mode == '2_pass_auto_cot_self_consistency':
|
|
|
|
|
logger.info(f"sql wrapper: {sql_generation_mode}")
|
|
|
|
|
resp = await self.sql_agent_act.tasks_run(question=question, filter_condition=filter_condition, model_name=model_name, fields_list=fields_list, current_date=data_date, prior_schema_links=prior_schema_links, prior_exts=prior_exts)
|
|
|
|
|
resp = await self.sql_agent_act.tasks_run(question=question, filter_condition=filter_condition, model_name=model_name, fields_list=fields_list, current_date=data_date, prior_schema_links=prior_schema_links, prior_exts=prior_exts, llm_config=llm_config)
|
|
|
|
|
return resp
|
|
|
|
|
else:
|
|
|
|
|
raise ValueError(f'sql_generation_mode:{sql_generation_mode} is not in SqlModeEnum')
|
|
|
|
|
|