diff --git a/chat/python/config/run_config.ini b/chat/python/config/run_config.ini index d1c0c143f..48f5ac32e 100644 --- a/chat/python/config/run_config.ini +++ b/chat/python/config/run_config.ini @@ -25,4 +25,4 @@ LLM_PROVIDER_NAME = openai [LLMModel] MODEL_NAME = gpt-3.5-turbo-16k OPENAI_API_KEY = YOUR_API_KEY -TEMPERATURE = 0.0 +TEMPERATURE = 0.0 \ No newline at end of file diff --git a/chat/python/few_shot_example/s2ql_examplar.py b/chat/python/few_shot_example/s2sql_exemplar.py similarity index 99% rename from chat/python/few_shot_example/s2ql_examplar.py rename to chat/python/few_shot_example/s2sql_exemplar.py index f48e7a634..92cceb150 100644 --- a/chat/python/few_shot_example/s2ql_examplar.py +++ b/chat/python/few_shot_example/s2sql_exemplar.py @@ -1,4 +1,4 @@ -examplars= [ +exemplars= [ { "currentDate":"2020-12-01", "tableName":"内容库产品", "fieldsList":"""["部门", "模块", "用户名", "访问次数", "访问人数", "访问时长", "数据日期"]""", diff --git a/chat/python/few_shot_example/s2ql_examplar3_transformed.json b/chat/python/few_shot_example/s2sql_exemplar3_transformed.json similarity index 100% rename from chat/python/few_shot_example/s2ql_examplar3_transformed.json rename to chat/python/few_shot_example/s2sql_exemplar3_transformed.json diff --git a/chat/python/services/s2ql/auto_cot.py b/chat/python/services/s2sql/auto_cot.py similarity index 100% rename from chat/python/services/s2ql/auto_cot.py rename to chat/python/services/s2sql/auto_cot.py diff --git a/chat/python/services/s2ql/auto_cot_run.py b/chat/python/services/s2sql/auto_cot_run.py similarity index 100% rename from chat/python/services/s2ql/auto_cot_run.py rename to chat/python/services/s2sql/auto_cot_run.py diff --git a/chat/python/services/s2ql/constructor.py b/chat/python/services/s2sql/constructor.py similarity index 100% rename from chat/python/services/s2ql/constructor.py rename to chat/python/services/s2sql/constructor.py diff --git a/chat/python/services/s2ql/examples_reload_run.py b/chat/python/services/s2sql/examples_reload_run.py similarity index 91% rename from chat/python/services/s2ql/examples_reload_run.py rename to chat/python/services/s2sql/examples_reload_run.py index 9f4b80e69..97d13252d 100644 --- a/chat/python/services/s2ql/examples_reload_run.py +++ b/chat/python/services/s2sql/examples_reload_run.py @@ -15,7 +15,7 @@ from instances.logging_instance import logger from config.config_parse import ( TEXT2DSL_EXAMPLE_NUM, TEXT2DSL_FEWSHOTS_NUM, TEXT2DSL_SELF_CONSISTENCY_NUM, LLMPARSER_HOST, LLMPARSER_PORT,) -from few_shot_example.s2ql_examplar import examplars as sql_examplars +from few_shot_example.s2sql_exemplar import exemplars as sql_exemplars def text2dsl_agent_wrapper_setting_update(llm_host:str, llm_port:str, @@ -35,6 +35,6 @@ def text2dsl_agent_wrapper_setting_update(llm_host:str, llm_port:str, if __name__ == "__main__": text2dsl_agent_wrapper_setting_update(LLMPARSER_HOST,LLMPARSER_PORT, - sql_examplars, TEXT2DSL_EXAMPLE_NUM, TEXT2DSL_FEWSHOTS_NUM, TEXT2DSL_SELF_CONSISTENCY_NUM) + sql_exemplars, TEXT2DSL_EXAMPLE_NUM, TEXT2DSL_FEWSHOTS_NUM, TEXT2DSL_SELF_CONSISTENCY_NUM) diff --git a/chat/python/services/s2ql/output_parser.py b/chat/python/services/s2sql/output_parser.py similarity index 100% rename from chat/python/services/s2ql/output_parser.py rename to chat/python/services/s2sql/output_parser.py diff --git a/chat/python/services/s2ql/run.py b/chat/python/services/s2sql/run.py similarity index 90% rename from chat/python/services/s2ql/run.py rename to chat/python/services/s2sql/run.py index 89bc548cc..03165dade 100644 --- a/chat/python/services/s2ql/run.py +++ b/chat/python/services/s2sql/run.py @@ -11,15 +11,15 @@ sys.path.append(os.path.dirname(os.path.abspath(__file__))) import json -from s2ql.constructor import FewShotPromptTemplate2 -from s2ql.sql_agent import Text2DSLAgent, Text2DSLAgentAutoCoT, Text2DSLAgentWrapper +from s2sql.constructor import FewShotPromptTemplate2 +from s2sql.sql_agent import Text2DSLAgent, Text2DSLAgentAutoCoT, Text2DSLAgentWrapper from instances.llm_instance import llm from instances.chromadb_instance import client as chromadb_client from instances.logging_instance import logger from instances.text2vec_instance import emb_func -from few_shot_example.s2ql_examplar import examplars as sql_examplars +from few_shot_example.s2sql_exemplar import exemplars as sql_exemplars from config.config_parse import (TEXT2DSLAGENT_COLLECTION_NAME, TEXT2DSLAGENTACT_COLLECTION_NAME, TEXT2DSL_EXAMPLE_NUM, TEXT2DSL_FEWSHOTS_NUM, TEXT2DSL_SELF_CONSISTENCY_NUM, ACT_MIN_WINDOWN_SIZE, ACT_MAX_WINDOWN_SIZE) @@ -45,8 +45,8 @@ text2sql_agent_autoCoT = Text2DSLAgentAutoCoT(num_fewshots=TEXT2DSL_FEWSHOTS_NUM sql_example_prompter=text2dsl_agent_act_example_prompter, llm=llm, auto_cot_min_window_size=ACT_MIN_WINDOWN_SIZE, auto_cot_max_window_size=ACT_MAX_WINDOWN_SIZE) -sql_ids = [str(i) for i in range(0, len(sql_examplars))] -text2sql_agent.reload_setting(sql_ids, sql_examplars, TEXT2DSL_EXAMPLE_NUM, TEXT2DSL_FEWSHOTS_NUM, TEXT2DSL_SELF_CONSISTENCY_NUM) +sql_ids = [str(i) for i in range(0, len(sql_exemplars))] +text2sql_agent.reload_setting(sql_ids, sql_exemplars, TEXT2DSL_EXAMPLE_NUM, TEXT2DSL_FEWSHOTS_NUM, TEXT2DSL_SELF_CONSISTENCY_NUM) if text2sql_agent_autoCoT.count_examples()==0: source_dir_path = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) diff --git a/chat/python/services/s2ql/sql_agent.py b/chat/python/services/s2sql/sql_agent.py similarity index 96% rename from chat/python/services/s2ql/sql_agent.py rename to chat/python/services/s2sql/sql_agent.py index f3cd7f370..0284de098 100644 --- a/chat/python/services/s2ql/sql_agent.py +++ b/chat/python/services/s2sql/sql_agent.py @@ -14,9 +14,9 @@ sys.path.append(os.path.dirname(os.path.abspath(__file__))) from instances.logging_instance import logger -from s2ql.constructor import FewShotPromptTemplate2 -from s2ql.output_parser import schema_link_parse, combo_schema_link_parse, combo_sql_parse -from s2ql.auto_cot_run import transform_sql_example, transform_sql_example_autoCoT_run +from s2sql.constructor import FewShotPromptTemplate2 +from s2sql.output_parser import schema_link_parse, combo_schema_link_parse, combo_sql_parse +from s2sql.auto_cot_run import transform_sql_example, transform_sql_example_autoCoT_run class Text2DSLAgentBase(object): @@ -248,6 +248,8 @@ class Text2DSLAgentAutoCoT(Text2DSLAgentBase): resp['priorExts'] = prior_exts resp['currentDate'] = current_date + resp['prompt'] = [schema_linking_prompt+'\n\n'+sql_prompt] + resp['schemaLinkingOutput'] = schema_link_output resp['schemaLinkStr'] = schema_link_str @@ -285,6 +287,8 @@ class Text2DSLAgentAutoCoT(Text2DSLAgentBase): resp['priorExts'] = prior_exts resp['currentDate'] = current_date + resp['prompt'] = [schema_linking_sql_shortcut_prompt] + resp['schemaLinkingComboOutput'] = schema_linking_sql_shortcut_output resp['schemaLinkStr'] = schema_linking_str resp['sqlOutput'] = sql_str @@ -303,7 +307,7 @@ class Text2DSLAgentAutoCoT(Text2DSLAgentBase): schema_linking_str_pool = [schema_link_parse(schema_linking_output) for schema_linking_output in schema_linking_output_pool] - return schema_linking_str_pool + return schema_linking_str_pool, schema_linking_output_pool, schema_linking_prompt_pool async def generate_sql_tasks(self, question: str, model_name: str, fields_list: List[str], schema_link_str_pool: List[str], current_date: str, prior_schema_links: Mapping[str,str], prior_exts: str, fewshot_example_list_combo:List[List[Mapping[str, str]]]): @@ -313,7 +317,7 @@ class Text2DSLAgentAutoCoT(Text2DSLAgentBase): sql_output_pool = await asyncio.gather(*[self.llm._call_async(sql_prompt) for sql_prompt in sql_prompt_pool]) logger.debug("sql_output_pool->{}".format(sql_output_pool)) - return sql_output_pool + return sql_output_pool, sql_prompt_pool async def generate_schema_linking_sql_tasks(self, question: str, model_name: str, fields_list: List[str], current_date: str, prior_schema_links: Mapping[str,str], prior_exts: str, fewshot_example_list_combo:List[List[Mapping[str, str]]]): @@ -322,7 +326,7 @@ class Text2DSLAgentAutoCoT(Text2DSLAgentBase): schema_linking_sql_output_res_pool = await asyncio.gather(*schema_linking_sql_output_task_pool) logger.debug("schema_linking_sql_output_res_pool->{}".format(schema_linking_sql_output_res_pool)) - return schema_linking_sql_output_res_pool + return schema_linking_sql_output_res_pool, schema_linking_sql_prompt_pool, schema_linking_sql_output_task_pool async def tasks_run(self, question: str, filter_condition: Mapping[str,str], model_name: str, fields_list: List[str], @@ -338,14 +342,14 @@ class Text2DSLAgentAutoCoT(Text2DSLAgentBase): fewshot_example_meta_list = self.get_examples_candidates(question, filter_condition, self.num_examples) fewshot_example_list_combo = self.get_fewshot_example_combos(fewshot_example_meta_list, self.num_fewshots) - schema_linking_candidate_list = await self.generate_schema_linking_tasks(question, model_name, fields_list, current_date, prior_schema_links, prior_exts, fewshot_example_list_combo) + schema_linking_candidate_list, _, schema_linking_prompt_list = await self.generate_schema_linking_tasks(question, model_name, fields_list, current_date, prior_schema_links, prior_exts, fewshot_example_list_combo) logger.debug(f'schema_linking_candidate_list:{schema_linking_candidate_list}') schema_linking_candidate_sorted_list = self.schema_linking_list_str_unify(schema_linking_candidate_list) logger.debug(f'schema_linking_candidate_sorted_list:{schema_linking_candidate_sorted_list}') schema_linking_output_max, schema_linking_output_vote_percentage = self.self_consistency_vote(schema_linking_candidate_sorted_list) - sql_output_candicates = await self.generate_sql_tasks(question, model_name, fields_list, schema_linking_candidate_list, current_date, prior_schema_links, prior_exts, fewshot_example_list_combo) + sql_output_candicates, sql_output_prompt_list = await self.generate_sql_tasks(question, model_name, fields_list, schema_linking_candidate_list, current_date, prior_schema_links, prior_exts, fewshot_example_list_combo) logger.debug(f'sql_output_candicates:{sql_output_candicates}') sql_output_max, sql_output_vote_percentage = self.self_consistency_vote(sql_output_candicates) @@ -357,6 +361,8 @@ class Text2DSLAgentAutoCoT(Text2DSLAgentBase): resp['priorExts'] = prior_exts resp['currentDate'] = current_date + resp['prompt'] = [schema_linking_prompt+'\n\n'+sql_prompt for schema_linking_prompt, sql_prompt in zip(schema_linking_prompt_list, sql_output_prompt_list)] + resp['schemaLinkStr'] = schema_linking_output_max resp['schemaLinkingWeight'] = schema_linking_output_vote_percentage @@ -380,7 +386,7 @@ class Text2DSLAgentAutoCoT(Text2DSLAgentBase): fewshot_example_meta_list = self.get_examples_candidates(question, filter_condition, self.num_examples) fewshot_example_list_combo = self.get_fewshot_example_combos(fewshot_example_meta_list, self.num_fewshots) - schema_linking_sql_output_candidates = await self.generate_schema_linking_sql_tasks(question, model_name, fields_list, current_date, prior_schema_links, prior_exts, fewshot_example_list_combo) + schema_linking_sql_output_candidates, schema_linking_sql_prompt_list, _ = await self.generate_schema_linking_sql_tasks(question, model_name, fields_list, current_date, prior_schema_links, prior_exts, fewshot_example_list_combo) logger.debug(f'schema_linking_sql_output_candidates:{schema_linking_sql_output_candidates}') schema_linking_output_candidate_list = [combo_schema_link_parse(schema_linking_sql_output_candidate) for schema_linking_sql_output_candidate in schema_linking_sql_output_candidates] logger.debug(f'schema_linking_sql_output_candidate_list:{schema_linking_output_candidate_list}') @@ -400,6 +406,8 @@ class Text2DSLAgentAutoCoT(Text2DSLAgentBase): resp['priorExts'] = prior_exts resp['currentDate'] = current_date + resp['prompt'] = schema_linking_sql_prompt_list + resp['schemaLinkStr'] = schema_linking_output_max resp['schemaLinkingWeight'] = schema_linking_output_vote_percentage diff --git a/chat/python/services_router/query2sql_service.py b/chat/python/services_router/query2sql_service.py index 76bada405..cef37c215 100644 --- a/chat/python/services_router/query2sql_service.py +++ b/chat/python/services_router/query2sql_service.py @@ -8,7 +8,7 @@ sys.path.append(os.path.dirname(os.path.abspath(__file__))) from fastapi import APIRouter, Depends, HTTPException -from services.s2ql.run import text2sql_agent_router +from services.s2sql.run import text2sql_agent_router router = APIRouter()