diff --git a/chat/core/src/main/python/run_config.py b/chat/core/src/main/python/run_config.py index e0fe5f1aa..2d4cbaf53 100644 --- a/chat/core/src/main/python/run_config.py +++ b/chat/core/src/main/python/run_config.py @@ -15,6 +15,7 @@ CHROMA_DB_PERSIST_DIR = 'chm_db' PRESET_QUERY_COLLECTION_NAME = "preset_query_collection" TEXT2DSL_COLLECTION_NAME = "text2dsl_collection" TEXT2DSL_FEW_SHOTS_EXAMPLE_NUM = 15 +TEXT2DSL_IS_SHORTCUT = False CHROMA_DB_PERSIST_PATH = os.path.join(PROJECT_DIR_PATH, CHROMA_DB_PERSIST_DIR) diff --git a/chat/core/src/main/python/sql/constructor.py b/chat/core/src/main/python/sql/constructor.py index b844a84fe..2553e4eca 100644 --- a/chat/core/src/main/python/sql/constructor.py +++ b/chat/core/src/main/python/sql/constructor.py @@ -22,10 +22,8 @@ from util.text2vec import Text2VecEmbeddingFunction, hg_embedding from util.chromadb_instance import client as chromadb_client, empty_chroma_collection_2 from run_config import TEXT2DSL_COLLECTION_NAME, TEXT2DSL_FEW_SHOTS_EXAMPLE_NUM - def reload_sql_example_collection(vectorstore:Chroma, sql_examplars:List[Mapping[str, str]], - schema_linking_example_selector:SemanticSimilarityExampleSelector, sql_example_selector:SemanticSimilarityExampleSelector, example_nums:int ): @@ -35,20 +33,16 @@ def reload_sql_example_collection(vectorstore:Chroma, print("emptied sql_examples_collection size:", vectorstore._collection.count()) - schema_linking_example_selector = SemanticSimilarityExampleSelector(vectorstore=sql_examples_vectorstore, k=example_nums, - input_keys=["question"], - example_keys=["table_name", "fields_list", "prior_schema_links", "question", "analysis", "schema_links"]) - sql_example_selector = SemanticSimilarityExampleSelector(vectorstore=sql_examples_vectorstore, k=example_nums, - input_keys=["question"], - example_keys=["question", "current_date", "table_name", "schema_links", "sql"]) + input_keys=["question"], + example_keys=["table_name", "fields_list", "prior_schema_links", "question", "analysis", "schema_links", "current_date", "sql"]) for example in sql_examplars: - schema_linking_example_selector.add_example(example) + sql_example_selector.add_example(example) print("reloaded sql_examples_collection size:", vectorstore._collection.count()) - return vectorstore, schema_linking_example_selector, sql_example_selector + return vectorstore, sql_example_selector sql_examples_vectorstore = Chroma(collection_name=TEXT2DSL_COLLECTION_NAME, @@ -57,22 +51,14 @@ sql_examples_vectorstore = Chroma(collection_name=TEXT2DSL_COLLECTION_NAME, example_nums = TEXT2DSL_FEW_SHOTS_EXAMPLE_NUM -schema_linking_example_selector = SemanticSimilarityExampleSelector(vectorstore=sql_examples_vectorstore, k=example_nums, - input_keys=["question"], - example_keys=["table_name", "fields_list", "prior_schema_links", "question", "analysis", "schema_links"]) - sql_example_selector = SemanticSimilarityExampleSelector(vectorstore=sql_examples_vectorstore, k=example_nums, - input_keys=["question"], - example_keys=["question", "current_date", "table_name", "schema_links", "sql"]) + input_keys=["question"], + example_keys=["table_name", "fields_list", "prior_schema_links", "question", "analysis", "schema_links", "current_date", "sql"]) if sql_examples_vectorstore._collection.count() > 0: print("examples already in sql_vectorstore") print("init sql_vectorstore size:", sql_examples_vectorstore._collection.count()) - if sql_examples_vectorstore._collection.count() < len(sql_examplars): - print("sql_examplars size:", len(sql_examplars)) - sql_examples_vectorstore, schema_linking_example_selector, sql_example_selector = reload_sql_example_collection(sql_examples_vectorstore, sql_examplars, schema_linking_example_selector, sql_example_selector, example_nums) - print("added sql_vectorstore size:", sql_examples_vectorstore._collection.count()) -else: - sql_examples_vectorstore, schema_linking_example_selector, sql_example_selector = reload_sql_example_collection(sql_examples_vectorstore, sql_examplars, schema_linking_example_selector, sql_example_selector, example_nums) - print("added sql_vectorstore size:", sql_examples_vectorstore._collection.count()) +print("sql_examplars size:", len(sql_examplars)) +sql_examples_vectorstore, sql_example_selector = reload_sql_example_collection(sql_examples_vectorstore, sql_examplars, sql_example_selector, example_nums) +print("added sql_vectorstore size:", sql_examples_vectorstore._collection.count()) diff --git a/chat/core/src/main/python/sql/examples_reload_run.py b/chat/core/src/main/python/sql/examples_reload_run.py index 65f1e3bed..65df9087d 100644 --- a/chat/core/src/main/python/sql/examples_reload_run.py +++ b/chat/core/src/main/python/sql/examples_reload_run.py @@ -8,24 +8,22 @@ import json sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) sys.path.append(os.path.dirname(os.path.abspath(__file__))) -from run_config import TEXT2DSL_FEW_SHOTS_EXAMPLE_NUM +from run_config import TEXT2DSL_FEW_SHOTS_EXAMPLE_NUM, TEXT2DSL_IS_SHORTCUT from few_shot_example.sql_exampler import examplars as sql_examplars -from run_config import LLMPARSER_HOST -from run_config import LLMPARSER_PORT +from run_config import LLMPARSER_HOST, LLMPARSER_PORT def text2dsl_setting_update(llm_parser_host:str, llm_parser_port:str, - sql_examplars:List[Mapping[str, str]], example_nums:int): + sql_examplars:List[Mapping[str, str]], example_nums:int, is_shortcut:bool): url = f"http://{llm_parser_host}:{llm_parser_port}/query2sql_setting_update/" print("url: ", url) - payload = {"sqlExamplars":sql_examplars, "exampleNums":example_nums} + payload = {"sqlExamplars":sql_examplars, "exampleNums":example_nums, "isShortcut":is_shortcut} headers = {'content-type': 'application/json'} response = requests.post(url, data=json.dumps(payload), headers=headers) print(response.text) if __name__ == "__main__": - arguments = sys.argv text2dsl_setting_update(LLMPARSER_HOST, LLMPARSER_PORT, - sql_examplars, TEXT2DSL_FEW_SHOTS_EXAMPLE_NUM) + sql_examplars, TEXT2DSL_FEW_SHOTS_EXAMPLE_NUM, TEXT2DSL_IS_SHORTCUT) diff --git a/chat/core/src/main/python/sql/output_parser.py b/chat/core/src/main/python/sql/output_parser.py index c90388850..aa0ff317f 100644 --- a/chat/core/src/main/python/sql/output_parser.py +++ b/chat/core/src/main/python/sql/output_parser.py @@ -10,4 +10,36 @@ def schema_link_parse(schema_link_output): print(e) schema_link_output = None - return schema_link_output \ No newline at end of file + return schema_link_output + +def combo_schema_link_parse(schema_linking_sql_combo_output: str): + try: + schema_linking_sql_combo_output = schema_linking_sql_combo_output.strip() + pattern = r'Schema_links:(\[.*?\])' + schema_links_match = re.search(pattern, schema_linking_sql_combo_output) + + if schema_links_match: + schema_links = schema_links_match.group(1) + else: + schema_links = None + except Exception as e: + print(e) + schema_links = None + + return schema_links + +def combo_sql_parse(schema_linking_sql_combo_output: str): + try: + schema_linking_sql_combo_output = schema_linking_sql_combo_output.strip() + pattern = r'SQL:(.*)' + sql_match = re.search(pattern, schema_linking_sql_combo_output) + + if sql_match: + sql = sql_match.group(1) + else: + sql = None + except Exception as e: + print(e) + sql = None + + return sql diff --git a/chat/core/src/main/python/sql/prompt_maker.py b/chat/core/src/main/python/sql/prompt_maker.py index 0cfed83b1..7c4f5fccc 100644 --- a/chat/core/src/main/python/sql/prompt_maker.py +++ b/chat/core/src/main/python/sql/prompt_maker.py @@ -73,3 +73,38 @@ def sql_exampler(user_query: str, schema_links=schema_link_str) return sql_example_prompt + + +def schema_linking_sql_combo_examplar(user_query: str, + domain_name: str, + data_date : str, + fields_list: List[str], + prior_schema_links: Mapping[str,str], + example_selector: SemanticSimilarityExampleSelector) -> str: + + prior_schema_links_str = '['+ ','.join(["""'{}'->{}""".format(k,v) for k,v in prior_schema_links.items()]) + ']' + + example_prompt_template = PromptTemplate(input_variables=["table_name", "fields_list", "prior_schema_links", "current_date", "question", "analysis", "schema_links", "sql"], + template="Table {table_name}, columns = {fields_list}, prior_schema_links = {prior_schema_links}\nCurrent_date:{current_date}\n问题:{question}\n分析:{analysis} 所以Schema_links是:\nSchema_links:{schema_links}\nSQL:{sql}") + + instruction = "# 根据数据库的表结构,参考先验信息,找出为每个问题生成SQL查询语句的schema_links,再根据schema_links为每个问题生成SQL查询语句" + + schema_linking_sql_combo_prompt = "Table {table_name}, columns = {fields_list}, prior_schema_links = {prior_schema_links}\nCurrent_date:{current_date}\n问题:{question}\n分析: 让我们一步一步地思考。" + + schema_linking_sql_combo_example_prompt_template = FewShotPromptTemplate( + example_selector=example_selector, + example_prompt=example_prompt_template, + example_separator="\n\n", + prefix=instruction, + input_variables=["table_name", "fields_list", "prior_schema_links", "current_date", "question"], + suffix=schema_linking_sql_combo_prompt + ) + + schema_linking_sql_combo_example_prompt = schema_linking_sql_combo_example_prompt_template.format(table_name=domain_name, + fields_list=fields_list, + prior_schema_links=prior_schema_links_str, + current_date=data_date, + question=user_query) + return schema_linking_sql_combo_example_prompt + + diff --git a/chat/core/src/main/python/sql/run.py b/chat/core/src/main/python/sql/run.py index a7ece82d8..02931b5c8 100644 --- a/chat/core/src/main/python/sql/run.py +++ b/chat/core/src/main/python/sql/run.py @@ -7,32 +7,37 @@ import sys sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) sys.path.append(os.path.dirname(os.path.abspath(__file__))) -from sql.prompt_maker import schema_linking_exampler, sql_exampler -from sql.constructor import schema_linking_example_selector, sql_example_selector,sql_examples_vectorstore, reload_sql_example_collection -from sql.output_parser import schema_link_parse +from sql.prompt_maker import schema_linking_exampler, sql_exampler, schema_linking_sql_combo_examplar +from sql.constructor import sql_examples_vectorstore, sql_example_selector, reload_sql_example_collection +from sql.output_parser import schema_link_parse, combo_schema_link_parse, combo_sql_parse from util.llm_instance import llm - +from run_config import TEXT2DSL_IS_SHORTCUT class Text2DSLAgent(object): def __init__(self): self.schema_linking_exampler = schema_linking_exampler self.sql_exampler = sql_exampler + self.schema_linking_sql_combo_exampler = schema_linking_sql_combo_examplar + self.sql_examples_vectorstore = sql_examples_vectorstore - self.schema_linking_example_selector = schema_linking_example_selector self.sql_example_selector = sql_example_selector self.schema_link_parse = schema_link_parse + self.combo_schema_link_parse = combo_schema_link_parse + self.combo_sql_parse = combo_sql_parse self.llm = llm - def update_examples(self, sql_examplars, example_nums): - self.sql_examples_vectorstore, self.schema_linking_example_selector, self.sql_example_selector = reload_sql_example_collection(self.sql_examples_vectorstore, - sql_examplars, - self.schema_linking_example_selector, - self.sql_example_selector, - example_nums) + self.is_shortcut = TEXT2DSL_IS_SHORTCUT + + def update_examples(self, sql_examples, example_nums, is_shortcut): + self.sql_examples_vectorstore, self.sql_example_selector = reload_sql_example_collection(self.sql_examples_vectorstore, + sql_examples, + self.sql_example_selector, + example_nums) + self.is_shortcut = is_shortcut def query2sql(self, query_text: str, schema : Union[dict, None] = None, @@ -53,14 +58,14 @@ class Text2DSLAgent(object): model_name = schema['modelName'] fields_list = schema['fieldNameList'] - schema_linking_prompt = self.schema_linking_exampler(query_text, model_name, fields_list, prior_schema_links, self.schema_linking_example_selector) + schema_linking_prompt = self.schema_linking_exampler(query_text, model_name, fields_list, prior_schema_links, self.sql_example_selector) print("schema_linking_prompt->", schema_linking_prompt) schema_link_output = self.llm(schema_linking_prompt) schema_link_str = self.schema_link_parse(schema_link_output) sql_prompt = self.sql_exampler(query_text, model_name, schema_link_str, current_date, self.sql_example_selector) print("sql_prompt->", sql_prompt) - sql_output = llm(sql_prompt) + sql_output = self.llm(sql_prompt) resp = dict() resp['query'] = query_text @@ -69,7 +74,7 @@ class Text2DSLAgent(object): resp['priorSchemaLinking'] = linking resp['dataDate'] = current_date - resp['schemaLinkingOutput'] = schema_link_output + resp['analysisOutput'] = schema_link_output resp['schemaLinkStr'] = schema_link_str resp['sqlOutput'] = sql_output @@ -78,5 +83,57 @@ class Text2DSLAgent(object): return resp + def query2sqlcombo(self, query_text: str, + schema : Union[dict, None] = None, + current_date: str = None, + linking: Union[List[Mapping[str, str]], None] = None + ): + + print("query_text: ", query_text) + print("schema: ", schema) + print("current_date: ", current_date) + print("prior_schema_links: ", linking) + + if linking is not None: + prior_schema_links = {item['fieldValue']:item['fieldName'] for item in linking} + else: + prior_schema_links = {} + + model_name = schema['modelName'] + fields_list = schema['fieldNameList'] + + schema_linking_sql_combo_prompt = self.schema_linking_sql_combo_exampler(query_text, model_name, current_date, fields_list, + prior_schema_links, self.sql_example_selector) + print("schema_linking_sql_combo_prompt->", schema_linking_sql_combo_prompt) + schema_linking_sql_combo_output = self.llm(schema_linking_sql_combo_prompt) + + schema_linking_str = self.combo_schema_link_parse(schema_linking_sql_combo_output) + sql_str = self.combo_sql_parse(schema_linking_sql_combo_output) + + resp = dict() + resp['query'] = query_text + resp['model'] = model_name + resp['fields'] = fields_list + resp['priorSchemaLinking'] = prior_schema_links + resp['dataDate'] = current_date + + resp['analysisOutput'] = schema_linking_sql_combo_output + resp['schemaLinkStr'] = schema_linking_str + resp['sqlOutput'] = sql_str + + print("resp: ", resp) + + return resp + + def query2sql_run(self, query_text: str, + schema : Union[dict, None] = None, + current_date: str = None, + linking: Union[List[Mapping[str, str]], None] = None): + + if self.is_shortcut: + return self.query2sqlcombo(query_text, schema, current_date, linking) + else: + return self.query2sql(query_text, schema, current_date, linking) + text2sql_agent = Text2DSLAgent() diff --git a/chat/core/src/main/python/supersonic_llmparser.py b/chat/core/src/main/python/supersonic_llmparser.py index 963328a27..40ebfe613 100644 --- a/chat/core/src/main/python/supersonic_llmparser.py +++ b/chat/core/src/main/python/supersonic_llmparser.py @@ -51,7 +51,7 @@ async def din_query2sql(query_body: Mapping[str, Any]): else: linking = query_body['linking'] - resp = text2sql_agent.query2sql(query_text=query_text, + resp = text2sql_agent.query2sql_run(query_text=query_text, schema=schema, current_date=current_date, linking=linking) return resp @@ -70,7 +70,12 @@ async def query2sql_setting_update(query_body: Mapping[str, Any]): else: example_nums = query_body['exampleNums'] - text2sql_agent.update_examples(sql_examplars=sql_examplars, example_nums=example_nums) + if 'isShortcut' not in query_body: + raise HTTPException(status_code=400, detail="isShortcut is not in query_body") + else: + is_shortcut = query_body['isShortcut'] + + text2sql_agent.update_examples(sql_examples=sql_examplars, example_nums=example_nums, is_shortcut=is_shortcut) return "success" diff --git a/docs/images/text2sql_config.png b/docs/images/text2sql_config.png index af552aca3..d9f641438 100644 Binary files a/docs/images/text2sql_config.png and b/docs/images/text2sql_config.png differ diff --git a/docs/userguides/text2sql_cn.md b/docs/userguides/text2sql_cn.md index 71c1efacc..eb207271d 100644 --- a/docs/userguides/text2sql_cn.md +++ b/docs/userguides/text2sql_cn.md @@ -5,21 +5,25 @@ text2sql的功能实现,高度依赖对LLM的应用。通过LLM生成SQL的过 ### **配置方式** 1. 样本池的配置。 - - supersonic/chat/core/src/main/python/llm/few_shot_example/sql_exampler.py为样本池配置文件。用户可以以已有的样本作为参考,配置更贴近自身业务需求的样本,用于更好的引导LLM生成SQL。 + - supersonic/chat/core/src/main/python/few_shot_example/sql_exampler.py 为样本池配置文件。用户可以以已有的样本作为参考,配置更贴近自身业务需求的样本,用于更好的引导LLM生成SQL。 2. 样本数量的配置。 - - 在supersonic/chat/core/src/main/python/llm/run_config.py 中通过 TEXT2DSL_FEW_SHOTS_EXAMPLE_NUM 变量进行配置。 + - 在 supersonic/chat/core/src/main/python/run_config.py 中通过 TEXT2DSL_FEW_SHOTS_EXAMPLE_NUM 变量进行配置。 - 默认值为15,为项目在内部实践后较优的经验值。样本少太少,对导致LLM在生成SQL的过程中缺少引导和示范,生成的SQL会更不稳定;样本太多,会增加生成SQL需要的时间和LLM的token消耗(或超过LLM的token上限)。 - -
+3. SQL生成方式的配置 + - 在 supersonic/chat/core/src/main/python/run_config.py 中通过 TEXT2DSL_IS_SHORTCUT 变量进行配置。 + - 默认值为False;当为False时,会调用2次LLM生成SQL;当为True时,会只调用1次LLM生成SQL。相较于2次LLM调用生成的SQL,耗时会减少30-40%,token的消耗量会减少30%左右,但生成的SQL正确率会有所下降。 +
-

图1-1 样本数量的配置文件

+

图1-1 配置文件

-3. 运行中更新配置的脚本。 - - 如果在启动项目后,用户需要对text2sql功能的相关配置进行调试,可以在修改相关配置文件后,通过脚本 supersonic/chat/core/src/main/python/bin/text2sql_resetting.sh 在项目运行中让配置生效。 - +### **运行中更新配置的脚本** +1. 如果在启动项目后,用户需要对text2sql功能的相关配置进行调试,可以在修改相关配置文件后,通过以下2种方式让配置在项目运行中让配置生效。 + - 执行 supersonic-daemon.sh reload llmparser + - 执行 python examples_reload_run.py ### **FAQ** 1. 生成一个SQL需要消耗的的LLM token数量太多了,按照openAI对token的收费标准,生成一个SQL太贵了,可以少用一些token吗? - - 可以。 用户可以根据自身需求,如配置方式1.中所示,修改样本池中的样本,选用一些更加简短的样本。如配置方式2.中所示,减少使用的样本数量。 + - 可以。 用户可以根据自身需求,如配置方式1.中所示,修改样本池中的样本,选用一些更加简短的样本。如配置方式2.中所示,减少使用的样本数量。配置方式3.中所示,只调用1次LLM生成SQL。 - 需要注意,样本和样本数量的选择对生成SQL的质量有很大的影响。过于激进的降低输入的token数量可能会降低生成SQL的质量。需要用户根据自身业务特点实测后进行平衡。