(fix)(chat):fix typo in s2sql and add prompt to output. (#581)

* 1.refactor the retrieval module. 2.refactor the http service module. 3.upgrade text2sql output format the parse for absolute time related expression in query.

* fix bug.

* upgrade the config module, now support config llm suppoted by langchain.

* fix conflicts.

* update text2sql config reload to be compitable with new config format.

* modify default config.

* 1.add self-consistency feature for text2sql. 2.upgrade llm api call from sync to async. 3.refactor text2sql module. 4. refactor semantical retriever modules.

* merege with upstream master

* add general retrieve service.

* add api service for sql_agent for crud opereations of few-shots examples.

* modify requirements

* add auto-cot feature

* 1. output log to a fixed log file.  2.allow few-shots examples tied to data model, and add strategy that extend examples when retrieved examples tied to a data model is not enough. 3. fix misformat in s2ql args.

* add prior_ext to output.

* fix type in in s2sql

* add prompt to output.

---------

Co-authored-by: shaoweigong <shaoweigong@tencent.com>
This commit is contained in:
codescracker
2024-01-02 16:35:33 +08:00
committed by GitHub
parent 49f0a4dc1d
commit af1c560cc4
11 changed files with 27 additions and 19 deletions

View File

@@ -25,4 +25,4 @@ LLM_PROVIDER_NAME = openai
[LLMModel]
MODEL_NAME = gpt-3.5-turbo-16k
OPENAI_API_KEY = YOUR_API_KEY
TEMPERATURE = 0.0
TEMPERATURE = 0.0

View File

@@ -1,4 +1,4 @@
examplars= [
exemplars= [
{ "currentDate":"2020-12-01",
"tableName":"内容库产品",
"fieldsList":"""["部门", "模块", "用户名", "访问次数", "访问人数", "访问时长", "数据日期"]""",

View File

@@ -15,7 +15,7 @@ from instances.logging_instance import logger
from config.config_parse import (
TEXT2DSL_EXAMPLE_NUM, TEXT2DSL_FEWSHOTS_NUM, TEXT2DSL_SELF_CONSISTENCY_NUM,
LLMPARSER_HOST, LLMPARSER_PORT,)
from few_shot_example.s2ql_examplar import examplars as sql_examplars
from few_shot_example.s2sql_exemplar import exemplars as sql_exemplars
def text2dsl_agent_wrapper_setting_update(llm_host:str, llm_port:str,
@@ -35,6 +35,6 @@ def text2dsl_agent_wrapper_setting_update(llm_host:str, llm_port:str,
if __name__ == "__main__":
text2dsl_agent_wrapper_setting_update(LLMPARSER_HOST,LLMPARSER_PORT,
sql_examplars, TEXT2DSL_EXAMPLE_NUM, TEXT2DSL_FEWSHOTS_NUM, TEXT2DSL_SELF_CONSISTENCY_NUM)
sql_exemplars, TEXT2DSL_EXAMPLE_NUM, TEXT2DSL_FEWSHOTS_NUM, TEXT2DSL_SELF_CONSISTENCY_NUM)

View File

@@ -11,15 +11,15 @@ sys.path.append(os.path.dirname(os.path.abspath(__file__)))
import json
from s2ql.constructor import FewShotPromptTemplate2
from s2ql.sql_agent import Text2DSLAgent, Text2DSLAgentAutoCoT, Text2DSLAgentWrapper
from s2sql.constructor import FewShotPromptTemplate2
from s2sql.sql_agent import Text2DSLAgent, Text2DSLAgentAutoCoT, Text2DSLAgentWrapper
from instances.llm_instance import llm
from instances.chromadb_instance import client as chromadb_client
from instances.logging_instance import logger
from instances.text2vec_instance import emb_func
from few_shot_example.s2ql_examplar import examplars as sql_examplars
from few_shot_example.s2sql_exemplar import exemplars as sql_exemplars
from config.config_parse import (TEXT2DSLAGENT_COLLECTION_NAME, TEXT2DSLAGENTACT_COLLECTION_NAME,
TEXT2DSL_EXAMPLE_NUM, TEXT2DSL_FEWSHOTS_NUM, TEXT2DSL_SELF_CONSISTENCY_NUM,
ACT_MIN_WINDOWN_SIZE, ACT_MAX_WINDOWN_SIZE)
@@ -45,8 +45,8 @@ text2sql_agent_autoCoT = Text2DSLAgentAutoCoT(num_fewshots=TEXT2DSL_FEWSHOTS_NUM
sql_example_prompter=text2dsl_agent_act_example_prompter, llm=llm,
auto_cot_min_window_size=ACT_MIN_WINDOWN_SIZE, auto_cot_max_window_size=ACT_MAX_WINDOWN_SIZE)
sql_ids = [str(i) for i in range(0, len(sql_examplars))]
text2sql_agent.reload_setting(sql_ids, sql_examplars, TEXT2DSL_EXAMPLE_NUM, TEXT2DSL_FEWSHOTS_NUM, TEXT2DSL_SELF_CONSISTENCY_NUM)
sql_ids = [str(i) for i in range(0, len(sql_exemplars))]
text2sql_agent.reload_setting(sql_ids, sql_exemplars, TEXT2DSL_EXAMPLE_NUM, TEXT2DSL_FEWSHOTS_NUM, TEXT2DSL_SELF_CONSISTENCY_NUM)
if text2sql_agent_autoCoT.count_examples()==0:
source_dir_path = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))

View File

@@ -14,9 +14,9 @@ sys.path.append(os.path.dirname(os.path.abspath(__file__)))
from instances.logging_instance import logger
from s2ql.constructor import FewShotPromptTemplate2
from s2ql.output_parser import schema_link_parse, combo_schema_link_parse, combo_sql_parse
from s2ql.auto_cot_run import transform_sql_example, transform_sql_example_autoCoT_run
from s2sql.constructor import FewShotPromptTemplate2
from s2sql.output_parser import schema_link_parse, combo_schema_link_parse, combo_sql_parse
from s2sql.auto_cot_run import transform_sql_example, transform_sql_example_autoCoT_run
class Text2DSLAgentBase(object):
@@ -248,6 +248,8 @@ class Text2DSLAgentAutoCoT(Text2DSLAgentBase):
resp['priorExts'] = prior_exts
resp['currentDate'] = current_date
resp['prompt'] = [schema_linking_prompt+'\n\n'+sql_prompt]
resp['schemaLinkingOutput'] = schema_link_output
resp['schemaLinkStr'] = schema_link_str
@@ -285,6 +287,8 @@ class Text2DSLAgentAutoCoT(Text2DSLAgentBase):
resp['priorExts'] = prior_exts
resp['currentDate'] = current_date
resp['prompt'] = [schema_linking_sql_shortcut_prompt]
resp['schemaLinkingComboOutput'] = schema_linking_sql_shortcut_output
resp['schemaLinkStr'] = schema_linking_str
resp['sqlOutput'] = sql_str
@@ -303,7 +307,7 @@ class Text2DSLAgentAutoCoT(Text2DSLAgentBase):
schema_linking_str_pool = [schema_link_parse(schema_linking_output) for schema_linking_output in schema_linking_output_pool]
return schema_linking_str_pool
return schema_linking_str_pool, schema_linking_output_pool, schema_linking_prompt_pool
async def generate_sql_tasks(self, question: str, model_name: str, fields_list: List[str], schema_link_str_pool: List[str],
current_date: str, prior_schema_links: Mapping[str,str], prior_exts: str, fewshot_example_list_combo:List[List[Mapping[str, str]]]):
@@ -313,7 +317,7 @@ class Text2DSLAgentAutoCoT(Text2DSLAgentBase):
sql_output_pool = await asyncio.gather(*[self.llm._call_async(sql_prompt) for sql_prompt in sql_prompt_pool])
logger.debug("sql_output_pool->{}".format(sql_output_pool))
return sql_output_pool
return sql_output_pool, sql_prompt_pool
async def generate_schema_linking_sql_tasks(self, question: str, model_name: str, fields_list: List[str],
current_date: str, prior_schema_links: Mapping[str,str], prior_exts: str, fewshot_example_list_combo:List[List[Mapping[str, str]]]):
@@ -322,7 +326,7 @@ class Text2DSLAgentAutoCoT(Text2DSLAgentBase):
schema_linking_sql_output_res_pool = await asyncio.gather(*schema_linking_sql_output_task_pool)
logger.debug("schema_linking_sql_output_res_pool->{}".format(schema_linking_sql_output_res_pool))
return schema_linking_sql_output_res_pool
return schema_linking_sql_output_res_pool, schema_linking_sql_prompt_pool, schema_linking_sql_output_task_pool
async def tasks_run(self, question: str, filter_condition: Mapping[str,str],
model_name: str, fields_list: List[str],
@@ -338,14 +342,14 @@ class Text2DSLAgentAutoCoT(Text2DSLAgentBase):
fewshot_example_meta_list = self.get_examples_candidates(question, filter_condition, self.num_examples)
fewshot_example_list_combo = self.get_fewshot_example_combos(fewshot_example_meta_list, self.num_fewshots)
schema_linking_candidate_list = await self.generate_schema_linking_tasks(question, model_name, fields_list, current_date, prior_schema_links, prior_exts, fewshot_example_list_combo)
schema_linking_candidate_list, _, schema_linking_prompt_list = await self.generate_schema_linking_tasks(question, model_name, fields_list, current_date, prior_schema_links, prior_exts, fewshot_example_list_combo)
logger.debug(f'schema_linking_candidate_list:{schema_linking_candidate_list}')
schema_linking_candidate_sorted_list = self.schema_linking_list_str_unify(schema_linking_candidate_list)
logger.debug(f'schema_linking_candidate_sorted_list:{schema_linking_candidate_sorted_list}')
schema_linking_output_max, schema_linking_output_vote_percentage = self.self_consistency_vote(schema_linking_candidate_sorted_list)
sql_output_candicates = await self.generate_sql_tasks(question, model_name, fields_list, schema_linking_candidate_list, current_date, prior_schema_links, prior_exts, fewshot_example_list_combo)
sql_output_candicates, sql_output_prompt_list = await self.generate_sql_tasks(question, model_name, fields_list, schema_linking_candidate_list, current_date, prior_schema_links, prior_exts, fewshot_example_list_combo)
logger.debug(f'sql_output_candicates:{sql_output_candicates}')
sql_output_max, sql_output_vote_percentage = self.self_consistency_vote(sql_output_candicates)
@@ -357,6 +361,8 @@ class Text2DSLAgentAutoCoT(Text2DSLAgentBase):
resp['priorExts'] = prior_exts
resp['currentDate'] = current_date
resp['prompt'] = [schema_linking_prompt+'\n\n'+sql_prompt for schema_linking_prompt, sql_prompt in zip(schema_linking_prompt_list, sql_output_prompt_list)]
resp['schemaLinkStr'] = schema_linking_output_max
resp['schemaLinkingWeight'] = schema_linking_output_vote_percentage
@@ -380,7 +386,7 @@ class Text2DSLAgentAutoCoT(Text2DSLAgentBase):
fewshot_example_meta_list = self.get_examples_candidates(question, filter_condition, self.num_examples)
fewshot_example_list_combo = self.get_fewshot_example_combos(fewshot_example_meta_list, self.num_fewshots)
schema_linking_sql_output_candidates = await self.generate_schema_linking_sql_tasks(question, model_name, fields_list, current_date, prior_schema_links, prior_exts, fewshot_example_list_combo)
schema_linking_sql_output_candidates, schema_linking_sql_prompt_list, _ = await self.generate_schema_linking_sql_tasks(question, model_name, fields_list, current_date, prior_schema_links, prior_exts, fewshot_example_list_combo)
logger.debug(f'schema_linking_sql_output_candidates:{schema_linking_sql_output_candidates}')
schema_linking_output_candidate_list = [combo_schema_link_parse(schema_linking_sql_output_candidate) for schema_linking_sql_output_candidate in schema_linking_sql_output_candidates]
logger.debug(f'schema_linking_sql_output_candidate_list:{schema_linking_output_candidate_list}')
@@ -400,6 +406,8 @@ class Text2DSLAgentAutoCoT(Text2DSLAgentBase):
resp['priorExts'] = prior_exts
resp['currentDate'] = current_date
resp['prompt'] = schema_linking_sql_prompt_list
resp['schemaLinkStr'] = schema_linking_output_max
resp['schemaLinkingWeight'] = schema_linking_output_vote_percentage

View File

@@ -8,7 +8,7 @@ sys.path.append(os.path.dirname(os.path.abspath(__file__)))
from fastapi import APIRouter, Depends, HTTPException
from services.s2ql.run import text2sql_agent_router
from services.s2sql.run import text2sql_agent_router
router = APIRouter()