mirror of
https://github.com/tencentmusic/supersonic.git
synced 2025-12-11 03:58:14 +00:00
@@ -2,7 +2,9 @@
|
||||
|
||||
import os
|
||||
import sys
|
||||
from typing import Any, List, Union, Mapping
|
||||
from typing import Any, Dict, List, Union, Mapping
|
||||
|
||||
from git import Optional
|
||||
|
||||
sys.path.append(os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))))
|
||||
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||||
@@ -14,7 +16,7 @@ from auto_cot import auto_cot_run
|
||||
|
||||
|
||||
|
||||
def transform_sql_example(question:str, current_date:str, table_name:str, field_list: Union[str, List[str]], prior_linkings: Union[str, Mapping[str,str]], prior_exts:str, sql:str=None):
|
||||
def transform_sql_example(question:str, current_date:str, table_name:str, field_list: Union[str, List[str]], prior_linkings: Union[str, Mapping[str,str]], prior_exts:str, sql:str=None, terms_list: Optional[List[Dict]] = []):
|
||||
db_schema = f"Table: {table_name}, Columns = {field_list}\nForeign_keys: []"
|
||||
|
||||
prior_linkings_pairs = []
|
||||
@@ -40,7 +42,27 @@ def transform_sql_example(question:str, current_date:str, table_name:str, field_
|
||||
|
||||
current_data_str = """当前的日期是{}""".format(current_date)
|
||||
|
||||
question_augmented = """{question} (补充信息:{prior_linking}。{current_date}) (备注: {prior_exts})""".format(question=question, prior_linking=prior_linkings_str, prior_exts=prior_exts, current_date=current_data_str)
|
||||
terms_desc = ''
|
||||
|
||||
if len(terms_list) > 0:
|
||||
terms_desc += "相关业务术语:"
|
||||
for idx, term in enumerate(terms_list):
|
||||
|
||||
if (term['description'] is not None and len(term['description']) > 0) and (term['alias'] is not None and len(term['alias']) > 0):
|
||||
terms_desc += f"""{idx+1}.<{term['name']}>是业务术语,它通常是指<{term['description']}>,类似的表达还有{term['alias']};"""
|
||||
elif (term['description'] is None or len(term['description']) == 0) and (term['alias'] is not None and len(term['alias']) > 0):
|
||||
terms_desc += f"""{idx+1}.<{term['name']}>是业务术语,类似的表达还有{term['alias']};"""
|
||||
elif (term['description'] is not None and len(term['description']) > 0) and (term['alias'] is None or len(term['alias']) == 0):
|
||||
terms_desc += f"""{idx+1}.<{term['name']}>是业务术语,它通常是指<{term['description']}>;"""
|
||||
else:
|
||||
terms_desc += f"""{idx+1}.<{term['name']}>是业务术语;"""
|
||||
|
||||
if len(terms_desc) > 0:
|
||||
terms_desc = terms_desc[:-1]
|
||||
|
||||
|
||||
|
||||
question_augmented = """{question} (补充信息:{prior_linking}。{current_date}。{terms_desc}) (备注: {prior_exts})""".format(question=question, prior_linking=prior_linkings_str, prior_exts=prior_exts, current_date=current_data_str)
|
||||
|
||||
return question_augmented, db_schema, sql
|
||||
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
import os
|
||||
import sys
|
||||
from typing import List, Union, Mapping, Any
|
||||
from typing import Dict, List, Optional, Union, Mapping, Any
|
||||
from collections import Counter
|
||||
import random
|
||||
import asyncio
|
||||
@@ -158,7 +158,7 @@ class Text2DSLAgentAutoCoT(Text2DSLAgentBase):
|
||||
|
||||
def generate_sql_prompt(self, question: str, domain_name: str,fields_list: List[str],
|
||||
schema_link_str: str, current_date: str, prior_schema_links: Mapping[str,str], prior_exts:str,
|
||||
fewshot_example_list:List[Mapping[str, str]])-> str:
|
||||
fewshot_example_list:List[Mapping[str, str]], terms_list: Optional[List[Dict]] = [])-> str:
|
||||
|
||||
instruction = "# Use the the schema links to generate the SQL queries for each of the questions."
|
||||
sql_example_keys = ["questionAugmented", "dbSchema", "generatedSchemaLinkings", "sql"]
|
||||
@@ -168,7 +168,7 @@ class Text2DSLAgentAutoCoT(Text2DSLAgentBase):
|
||||
example_keys=sql_example_keys,
|
||||
few_shot_example_meta_list=fewshot_example_list)
|
||||
|
||||
question_augmented, db_schema, _ = transform_sql_example(question, current_date, domain_name, fields_list, prior_schema_links, prior_exts)
|
||||
question_augmented, db_schema, _ = transform_sql_example(question, current_date, domain_name, fields_list, prior_schema_links, prior_exts, terms_list=terms_list)
|
||||
new_case_template = "{dbSchema}\nQ: {questionAugmented}\nSchema_links: {schemaLinkings}\nSQL: "
|
||||
new_case_prompt = new_case_template.format(dbSchema=db_schema, questionAugmented=question_augmented, schemaLinkings=schema_link_str)
|
||||
|
||||
@@ -179,26 +179,27 @@ class Text2DSLAgentAutoCoT(Text2DSLAgentBase):
|
||||
|
||||
def generate_sql_prompt_pool(self, question: str, domain_name: str,fields_list: List[str],
|
||||
schema_link_str_pool: List[str], current_date: str, prior_schema_links: Mapping[str,str], prior_exts:str,
|
||||
fewshot_example_list_pool:List[List[Mapping[str, str]]])-> List[str]:
|
||||
fewshot_example_list_pool:List[List[Mapping[str, str]]], terms_list: Optional[List[Dict]] = [])-> List[str]:
|
||||
sql_prompt_pool = []
|
||||
for schema_link_str, fewshot_example_list in zip(schema_link_str_pool, fewshot_example_list_pool):
|
||||
sql_prompt = self.generate_sql_prompt(question, domain_name, fields_list, schema_link_str, current_date, prior_schema_links, prior_exts, fewshot_example_list)
|
||||
sql_prompt = self.generate_sql_prompt(question, domain_name, fields_list, schema_link_str, current_date, prior_schema_links, prior_exts, fewshot_example_list, terms_list=terms_list)
|
||||
sql_prompt_pool.append(sql_prompt)
|
||||
|
||||
return sql_prompt_pool
|
||||
|
||||
def generate_schema_linking_sql_prompt(self, question: str, current_date:str, domain_name: str, fields_list: List[str],
|
||||
prior_schema_links: Mapping[str,str], prior_exts:str, fewshot_example_list:List[Mapping[str, str]]):
|
||||
prior_schema_links: Mapping[str,str], prior_exts:str, fewshot_example_list:List[Mapping[str, str]], terms_list: Optional[List[Dict]] = []):
|
||||
|
||||
instruction = "# Find the schema_links for generating SQL queries for each question based on the database schema and Foreign keys. Then use the the schema links to generate the SQL queries for each of the questions."
|
||||
|
||||
example_keys = ["questionAugmented", "dbSchema", "generatedSchemaLinkingCoT","sql"]
|
||||
example_template = "{dbSchema}\nQ: {questionAugmented}\nA: {generatedSchemaLinkingCoT}\nSQL: {sql}"
|
||||
example_template = "{dbSchema}\nQ: {questionAugmented}\nA: {generatedSchemaLinkingCoT}\nSQL: {sql}\n"
|
||||
fewshot_prompt = self.sql_example_prompter.make_few_shot_example_prompt(few_shot_template=example_template,
|
||||
example_keys=example_keys,
|
||||
few_shot_example_meta_list=fewshot_example_list)
|
||||
|
||||
|
||||
question_augmented, db_schema, _ = transform_sql_example(question, current_date, domain_name, fields_list, prior_schema_links, prior_exts)
|
||||
question_augmented, db_schema, _ = transform_sql_example(question, current_date, domain_name, fields_list, prior_schema_links, prior_exts, terms_list)
|
||||
new_case_template = """{dbSchema}\nQ: {questionAugmented1}\nA: Let’s think step by step. In the question "{questionAugmented2}", we are asked:"""
|
||||
new_case_prompt = new_case_template.format(dbSchema=db_schema, questionAugmented1=question_augmented, questionAugmented2=question_augmented)
|
||||
|
||||
@@ -208,10 +209,10 @@ class Text2DSLAgentAutoCoT(Text2DSLAgentBase):
|
||||
return prompt
|
||||
|
||||
def generate_schema_linking_sql_prompt_pool(self, question: str, current_date:str, domain_name: str, fields_list: List[str],
|
||||
prior_schema_links: Mapping[str,str], prior_exts:str, fewshot_example_list_pool:List[List[Mapping[str, str]]])-> List[str]:
|
||||
prior_schema_links: Mapping[str,str], prior_exts:str, fewshot_example_list_pool:List[List[Mapping[str, str]]], terms_list: Optional[List[Dict]] = [])-> List[str]:
|
||||
schema_linking_sql_prompt_pool = []
|
||||
for fewshot_example_list in fewshot_example_list_pool:
|
||||
schema_linking_sql_prompt = self.generate_schema_linking_sql_prompt(question, current_date, domain_name, fields_list, prior_schema_links, prior_exts, fewshot_example_list)
|
||||
schema_linking_sql_prompt = self.generate_schema_linking_sql_prompt(question, current_date, domain_name, fields_list, prior_schema_links, prior_exts, fewshot_example_list, terms_list=terms_list)
|
||||
schema_linking_sql_prompt_pool.append(schema_linking_sql_prompt)
|
||||
|
||||
return schema_linking_sql_prompt_pool
|
||||
@@ -219,7 +220,7 @@ class Text2DSLAgentAutoCoT(Text2DSLAgentBase):
|
||||
async def async_query2sql(self, question: str, filter_condition: Mapping[str,str],
|
||||
model_name: str, fields_list: List[str],
|
||||
current_date: str, prior_schema_links: Mapping[str,str], prior_exts: str,
|
||||
llm_config:dict):
|
||||
llm_config:dict, terms_list: Optional[List[Dict]] = []):
|
||||
logger.info("question: {}".format(question))
|
||||
logger.info("filter_condition: {}".format(filter_condition))
|
||||
logger.info("model_name: {}".format(model_name))
|
||||
@@ -227,6 +228,8 @@ class Text2DSLAgentAutoCoT(Text2DSLAgentBase):
|
||||
logger.info("current_date: {}".format(current_date))
|
||||
logger.info("prior_schema_links: {}".format(prior_schema_links))
|
||||
logger.info("prior_exts: {}".format(prior_exts))
|
||||
logger.info("terms_list: {}".format(terms_list))
|
||||
|
||||
|
||||
fewshot_example_meta_list = self.get_examples_candidates(question, filter_condition, self.num_examples)
|
||||
schema_linking_prompt = self.generate_schema_linking_prompt(question, current_date, model_name, fields_list, prior_schema_links, prior_exts, fewshot_example_meta_list)
|
||||
@@ -238,7 +241,7 @@ class Text2DSLAgentAutoCoT(Text2DSLAgentBase):
|
||||
schema_link_str = schema_link_parse(schema_link_output)
|
||||
logger.debug("schema_link_str->{}".format(schema_link_str))
|
||||
|
||||
sql_prompt = self.generate_sql_prompt(question, model_name, fields_list, schema_link_str, current_date, prior_schema_links, prior_exts, fewshot_example_meta_list)
|
||||
sql_prompt = self.generate_sql_prompt(question, model_name, fields_list, schema_link_str, current_date, prior_schema_links, prior_exts, fewshot_example_meta_list, terms_list=terms_list)
|
||||
logger.debug("sql_prompt->{}".format(sql_prompt))
|
||||
sql_output = await llm._call_async(sql_prompt)
|
||||
|
||||
@@ -264,7 +267,7 @@ class Text2DSLAgentAutoCoT(Text2DSLAgentBase):
|
||||
async def async_query2sql_shortcut(self, question: str, filter_condition: Mapping[str,str],
|
||||
model_name: str, fields_list: List[str],
|
||||
current_date: str, prior_schema_links: Mapping[str,str], prior_exts: str,
|
||||
llm_config:dict):
|
||||
llm_config:dict, terms_list: Optional[List[Dict]] = []):
|
||||
logger.info("question: {}".format(question))
|
||||
logger.info("filter_condition: {}".format(filter_condition))
|
||||
logger.info("model_name: {}".format(model_name))
|
||||
@@ -272,9 +275,10 @@ class Text2DSLAgentAutoCoT(Text2DSLAgentBase):
|
||||
logger.info("current_date: {}".format(current_date))
|
||||
logger.info("prior_schema_links: {}".format(prior_schema_links))
|
||||
logger.info("prior_exts: {}".format(prior_exts))
|
||||
logger.info("terms_list: {}".format(terms_list))
|
||||
|
||||
fewshot_example_meta_list = self.get_examples_candidates(question, filter_condition, self.num_examples)
|
||||
schema_linking_sql_shortcut_prompt = self.generate_schema_linking_sql_prompt(question, current_date, model_name, fields_list, prior_schema_links, prior_exts, fewshot_example_meta_list)
|
||||
schema_linking_sql_shortcut_prompt = self.generate_schema_linking_sql_prompt(question, current_date, model_name, fields_list, prior_schema_links, prior_exts, fewshot_example_meta_list, terms_list)
|
||||
logger.debug("schema_linking_sql_shortcut_prompt->{}".format(schema_linking_sql_shortcut_prompt))
|
||||
llm = get_llm(llm_config)
|
||||
schema_linking_sql_shortcut_output = await llm._call_async(schema_linking_sql_shortcut_prompt)
|
||||
@@ -316,9 +320,9 @@ class Text2DSLAgentAutoCoT(Text2DSLAgentBase):
|
||||
return schema_linking_str_pool, schema_linking_output_pool, schema_linking_prompt_pool
|
||||
|
||||
async def generate_sql_tasks(self, question: str, model_name: str, fields_list: List[str], schema_link_str_pool: List[str],
|
||||
current_date: str, prior_schema_links: Mapping[str,str], prior_exts: str, fewshot_example_list_combo:List[List[Mapping[str, str]]], llm_config: dict):
|
||||
current_date: str, prior_schema_links: Mapping[str,str], prior_exts: str, fewshot_example_list_combo:List[List[Mapping[str, str]]], llm_config: dict, terms_list: Optional[List[Dict]] = []):
|
||||
|
||||
sql_prompt_pool = self.generate_sql_prompt_pool(question, model_name, fields_list, schema_link_str_pool, current_date, prior_schema_links, prior_exts, fewshot_example_list_combo)
|
||||
sql_prompt_pool = self.generate_sql_prompt_pool(question, model_name, fields_list, schema_link_str_pool, current_date, prior_schema_links, prior_exts, fewshot_example_list_combo, terms_list=terms_list)
|
||||
logger.debug("sql_prompt_pool->{}".format(sql_prompt_pool))
|
||||
llm = get_llm(llm_config)
|
||||
sql_output_pool = await asyncio.gather(*[llm._call_async(sql_prompt) for sql_prompt in sql_prompt_pool])
|
||||
@@ -328,8 +332,8 @@ class Text2DSLAgentAutoCoT(Text2DSLAgentBase):
|
||||
|
||||
async def generate_schema_linking_sql_tasks(self, question: str, model_name: str, fields_list: List[str],
|
||||
current_date: str, prior_schema_links: Mapping[str,str], prior_exts: str,
|
||||
fewshot_example_list_combo:List[List[Mapping[str, str]]],llm_config: dict):
|
||||
schema_linking_sql_prompt_pool = self.generate_schema_linking_sql_prompt_pool(question, current_date, model_name, fields_list, prior_schema_links, prior_exts, fewshot_example_list_combo)
|
||||
fewshot_example_list_combo:List[List[Mapping[str, str]]],llm_config: dict, terms_list: Optional[List[Dict]] = []):
|
||||
schema_linking_sql_prompt_pool = self.generate_schema_linking_sql_prompt_pool(question, current_date, model_name, fields_list, prior_schema_links, prior_exts, fewshot_example_list_combo, terms_list=terms_list)
|
||||
llm = get_llm(llm_config)
|
||||
schema_linking_sql_output_task_pool = [llm._call_async(schema_linking_sql_prompt) for schema_linking_sql_prompt in schema_linking_sql_prompt_pool]
|
||||
schema_linking_sql_output_res_pool = await asyncio.gather(*schema_linking_sql_output_task_pool)
|
||||
@@ -339,7 +343,7 @@ class Text2DSLAgentAutoCoT(Text2DSLAgentBase):
|
||||
|
||||
async def tasks_run(self, question: str, filter_condition: Mapping[str,str],
|
||||
model_name: str, fields_list: List[str],
|
||||
current_date: str, prior_schema_links: Mapping[str,str], prior_exts: str, llm_config: dict):
|
||||
current_date: str, prior_schema_links: Mapping[str,str], prior_exts: str, llm_config: dict, terms_list: Optional[List[Dict]] = []):
|
||||
logger.info("question: {}".format(question))
|
||||
logger.info("filter_condition: {}".format(filter_condition))
|
||||
logger.info("model_name: {}".format(model_name))
|
||||
@@ -347,18 +351,20 @@ class Text2DSLAgentAutoCoT(Text2DSLAgentBase):
|
||||
logger.info("current_date: {}".format(current_date))
|
||||
logger.info("prior_schema_links: {}".format(prior_schema_links))
|
||||
logger.info("prior_exts: {}".format(prior_exts))
|
||||
logger.info("terms_list: {}".format(terms_list))
|
||||
|
||||
|
||||
fewshot_example_meta_list = self.get_examples_candidates(question, filter_condition, self.num_examples)
|
||||
fewshot_example_list_combo = self.get_fewshot_example_combos(fewshot_example_meta_list, self.num_fewshots)
|
||||
|
||||
schema_linking_candidate_list, _, schema_linking_prompt_list = await self.generate_schema_linking_tasks(question, model_name, fields_list, current_date, prior_schema_links, prior_exts, fewshot_example_list_combo, llm_config)
|
||||
schema_linking_candidate_list, _, schema_linking_prompt_list = await self.generate_schema_linking_tasks(question, model_name, fields_list, current_date, prior_schema_links, prior_exts, fewshot_example_list_combo, llm_config,)
|
||||
logger.debug(f'schema_linking_candidate_list:{schema_linking_candidate_list}')
|
||||
schema_linking_candidate_sorted_list = self.schema_linking_list_str_unify(schema_linking_candidate_list)
|
||||
logger.debug(f'schema_linking_candidate_sorted_list:{schema_linking_candidate_sorted_list}')
|
||||
|
||||
schema_linking_output_max, schema_linking_output_vote_percentage = self.self_consistency_vote(schema_linking_candidate_sorted_list)
|
||||
|
||||
sql_output_candicates, sql_output_prompt_list = await self.generate_sql_tasks(question, model_name, fields_list, schema_linking_candidate_list, current_date, prior_schema_links, prior_exts, fewshot_example_list_combo, llm_config)
|
||||
sql_output_candicates, sql_output_prompt_list = await self.generate_sql_tasks(question, model_name, fields_list, schema_linking_candidate_list, current_date, prior_schema_links, prior_exts, fewshot_example_list_combo, llm_config, terms_list=terms_list)
|
||||
logger.debug(f'sql_output_candicates:{sql_output_candicates}')
|
||||
sql_output_max, sql_output_vote_percentage = self.self_consistency_vote(sql_output_candicates)
|
||||
|
||||
@@ -383,7 +389,7 @@ class Text2DSLAgentAutoCoT(Text2DSLAgentBase):
|
||||
return resp
|
||||
|
||||
async def tasks_run_shortcut(self, question: str, filter_condition: Mapping[str,str], model_name: str, fields_list: List[str],
|
||||
current_date: str, prior_schema_links: Mapping[str,str], prior_exts: str, llm_config: dict):
|
||||
current_date: str, prior_schema_links: Mapping[str,str], prior_exts: str, llm_config: dict, terms_list: Optional[List[Dict]] = []):
|
||||
logger.info("question: {}".format(question))
|
||||
logger.info("filter_condition: {}".format(filter_condition))
|
||||
logger.info("model_name: {}".format(model_name))
|
||||
@@ -391,11 +397,12 @@ class Text2DSLAgentAutoCoT(Text2DSLAgentBase):
|
||||
logger.info("current_date: {}".format(current_date))
|
||||
logger.info("prior_schema_links: {}".format(prior_schema_links))
|
||||
logger.info("prior_exts: {}".format(prior_exts))
|
||||
logger.info("terms_list: {}".format(terms_list))
|
||||
|
||||
fewshot_example_meta_list = self.get_examples_candidates(question, filter_condition, self.num_examples)
|
||||
fewshot_example_list_combo = self.get_fewshot_example_combos(fewshot_example_meta_list, self.num_fewshots)
|
||||
|
||||
schema_linking_sql_output_candidates, schema_linking_sql_prompt_list, _ = await self.generate_schema_linking_sql_tasks(question, model_name, fields_list, current_date, prior_schema_links, prior_exts, fewshot_example_list_combo)
|
||||
schema_linking_sql_output_candidates, schema_linking_sql_prompt_list, _ = await self.generate_schema_linking_sql_tasks(question, model_name, fields_list, current_date, prior_schema_links, prior_exts, fewshot_example_list_combo, terms_list=terms_list)
|
||||
logger.debug(f'schema_linking_sql_output_candidates:{schema_linking_sql_output_candidates}')
|
||||
schema_linking_output_candidate_list = [combo_schema_link_parse(schema_linking_sql_output_candidate) for schema_linking_sql_output_candidate in schema_linking_sql_output_candidates]
|
||||
logger.debug(f'schema_linking_sql_output_candidate_list:{schema_linking_output_candidate_list}')
|
||||
@@ -781,26 +788,26 @@ class Text2DSLAgentWrapper(object):
|
||||
|
||||
async def async_query2sql(self, question: str, filter_condition: Mapping[str,str],
|
||||
model_name: str, fields_list: List[str],
|
||||
data_date: str, prior_schema_links: Mapping[str,str], prior_exts: str, sql_generation_mode: str, llm_config: dict):
|
||||
data_date: str, prior_schema_links: Mapping[str,str], prior_exts: str, sql_generation_mode: str, llm_config: dict, terms_list: Optional[List[Dict]] = []):
|
||||
|
||||
if sql_generation_mode not in (sql_mode.value for sql_mode in SqlModeEnum):
|
||||
raise ValueError(f"sql_generation_mode: {sql_generation_mode} is not in SqlModeEnum")
|
||||
|
||||
if sql_generation_mode == '1_pass_auto_cot':
|
||||
logger.info(f"sql wrapper: {sql_generation_mode}")
|
||||
resp = await self.sql_agent_act.async_query2sql_shortcut(question=question, filter_condition=filter_condition, model_name=model_name, fields_list=fields_list, current_date=data_date, prior_schema_links=prior_schema_links, prior_exts=prior_exts, llm_config=llm_config)
|
||||
resp = await self.sql_agent_act.async_query2sql_shortcut(question=question, filter_condition=filter_condition, model_name=model_name, fields_list=fields_list, current_date=data_date, prior_schema_links=prior_schema_links, prior_exts=prior_exts, llm_config=llm_config, terms_list=terms_list)
|
||||
return resp
|
||||
elif sql_generation_mode == '1_pass_auto_cot_self_consistency':
|
||||
logger.info(f"sql wrapper: {sql_generation_mode}")
|
||||
resp = await self.sql_agent_act.tasks_run_shortcut(question=question, filter_condition=filter_condition, model_name=model_name, fields_list=fields_list, current_date=data_date, prior_schema_links=prior_schema_links, prior_exts=prior_exts, llm_config=llm_config)
|
||||
resp = await self.sql_agent_act.tasks_run_shortcut(question=question, filter_condition=filter_condition, model_name=model_name, fields_list=fields_list, current_date=data_date, prior_schema_links=prior_schema_links, prior_exts=prior_exts, llm_config=llm_config, terms_list=terms_list)
|
||||
return resp
|
||||
elif sql_generation_mode == '2_pass_auto_cot':
|
||||
logger.info(f"sql wrapper: {sql_generation_mode}")
|
||||
resp = await self.sql_agent_act.async_query2sql(question=question, filter_condition=filter_condition, model_name=model_name, fields_list=fields_list, current_date=data_date, prior_schema_links=prior_schema_links, prior_exts=prior_exts, llm_config=llm_config)
|
||||
resp = await self.sql_agent_act.async_query2sql(question=question, filter_condition=filter_condition, model_name=model_name, fields_list=fields_list, current_date=data_date, prior_schema_links=prior_schema_links, prior_exts=prior_exts, llm_config=llm_config, terms_list=terms_list)
|
||||
return resp
|
||||
elif sql_generation_mode == '2_pass_auto_cot_self_consistency':
|
||||
logger.info(f"sql wrapper: {sql_generation_mode}")
|
||||
resp = await self.sql_agent_act.tasks_run(question=question, filter_condition=filter_condition, model_name=model_name, fields_list=fields_list, current_date=data_date, prior_schema_links=prior_schema_links, prior_exts=prior_exts, llm_config=llm_config)
|
||||
resp = await self.sql_agent_act.tasks_run(question=question, filter_condition=filter_condition, model_name=model_name, fields_list=fields_list, current_date=data_date, prior_schema_links=prior_schema_links, prior_exts=prior_exts, llm_config=llm_config, terms_list=terms_list)
|
||||
return resp
|
||||
else:
|
||||
raise ValueError(f'sql_generation_mode:{sql_generation_mode} is not in SqlModeEnum')
|
||||
|
||||
@@ -59,6 +59,7 @@ async def query2sql(query_body: Mapping[str, Any]):
|
||||
dataset_name = schema['dataSetName']
|
||||
fields_list = schema['fieldNameList']
|
||||
prior_schema_links = {item['fieldValue']:item['fieldName'] for item in linking}
|
||||
terms_list = schema['terms']
|
||||
|
||||
resp = await text2sql_agent_router.async_query2sql(question=query_text, filter_condition=filter_condition,
|
||||
model_name=dataset_name, fields_list=fields_list,
|
||||
|
||||
Reference in New Issue
Block a user