From 438e8463f5f3138ba6df28831f725bc3d528e934 Mon Sep 17 00:00:00 2001 From: codescracker Date: Tue, 31 Oct 2023 20:02:20 +0800 Subject: [PATCH] add self-consistency feature for text2sql (#303) --- .../src/main/python/config/config_parse.py | 43 +- .../src/main/python/config/run_config.ini | 8 +- .../python/few_shot_example/sql_exampler.py | 3 +- .../python/instances/chromadb_instance.py | 21 + .../{util => instances}/llm_instance.py | 0 .../main/python/instances/logging_instance.py | 6 + .../python/{util => instances}/text2vec.py | 0 .../services/plugin_call/prompt_construct.py | 2 +- .../main/python/services/plugin_call/run.py | 2 +- .../preset_retrieval/preset_query_db.py | 111 ----- .../python/services/preset_retrieval/run.py | 51 --- .../services/query_retrieval/retriever.py | 98 +++++ .../python/services/query_retrieval/run.py | 74 +--- .../main/python/services/sql/constructor.py | 107 ++--- .../services/sql/examples_reload_run.py | 63 +-- .../main/python/services/sql/output_parser.py | 6 +- .../main/python/services/sql/prompt_maker.py | 166 -------- chat/core/src/main/python/services/sql/run.py | 208 ++-------- .../src/main/python/services/sql/sql_agent.py | 380 ++++++++++++++++++ .../services_router/query2sql_service.py | 74 ++-- .../src/main/python/util/logging_utils.py | 4 - .../chromadb_utils.py} | 64 +-- 22 files changed, 764 insertions(+), 727 deletions(-) create mode 100644 chat/core/src/main/python/instances/chromadb_instance.py rename chat/core/src/main/python/{util => instances}/llm_instance.py (100%) create mode 100644 chat/core/src/main/python/instances/logging_instance.py rename chat/core/src/main/python/{util => instances}/text2vec.py (100%) delete mode 100644 chat/core/src/main/python/services/preset_retrieval/preset_query_db.py delete mode 100644 chat/core/src/main/python/services/preset_retrieval/run.py create mode 100644 chat/core/src/main/python/services/query_retrieval/retriever.py delete mode 100644 chat/core/src/main/python/services/sql/prompt_maker.py create mode 100644 chat/core/src/main/python/services/sql/sql_agent.py delete mode 100644 chat/core/src/main/python/util/logging_utils.py rename chat/core/src/main/python/{util/chromadb_instance.py => utils/chromadb_utils.py} (70%) diff --git a/chat/core/src/main/python/config/config_parse.py b/chat/core/src/main/python/config/config_parse.py index f8425347d..5ddba66c9 100644 --- a/chat/core/src/main/python/config/config_parse.py +++ b/chat/core/src/main/python/config/config_parse.py @@ -1,7 +1,14 @@ # -*- coding:utf-8 -*- import os import configparser -from util.logging_utils import logger + +import os +import sys +sys.path.append(os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))) +sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) +sys.path.append(os.path.dirname(os.path.abspath(__file__))) + +from instances.logging_instance import logger def type_convert(input_str: str): @@ -11,10 +18,12 @@ def type_convert(input_str: str): return input_str -PROJECT_DIR_PATH = os.path.dirname(os.path.abspath(__file__)) - +PROJECT_DIR_PATH = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) +config_dir = "config" +CONFIG_DIR_PATH = os.path.join(PROJECT_DIR_PATH, config_dir) config_file = "run_config.ini" -config_path = os.path.join(PROJECT_DIR_PATH, config_file) +config_path = os.path.join(CONFIG_DIR_PATH, config_file) + config = configparser.ConfigParser() config.read(config_path) @@ -26,9 +35,13 @@ chroma_db_section_name = "ChromaDB" CHROMA_DB_PERSIST_DIR = config.get(chroma_db_section_name, 'CHROMA_DB_PERSIST_DIR') PRESET_QUERY_COLLECTION_NAME = config.get(chroma_db_section_name, 'PRESET_QUERY_COLLECTION_NAME') SOLVED_QUERY_COLLECTION_NAME = config.get(chroma_db_section_name, 'SOLVED_QUERY_COLLECTION_NAME') -TEXT2DSL_COLLECTION_NAME = config.get(chroma_db_section_name, 'TEXT2DSL_COLLECTION_NAME') -TEXT2DSL_FEW_SHOTS_EXAMPLE_NUM = int(config.get(chroma_db_section_name, 'TEXT2DSL_FEW_SHOTS_EXAMPLE_NUM')) +TEXT2DSLAGENT_COLLECTION_NAME = config.get(chroma_db_section_name, 'TEXT2DSLAGENT_COLLECTION_NAME') +TEXT2DSLAGENTCS_COLLECTION_NAME = config.get(chroma_db_section_name, 'TEXT2DSLAGENTCS_COLLECTION_NAME') +TEXT2DSL_EXAMPLE_NUM = int(config.get(chroma_db_section_name, 'TEXT2DSL_EXAMPLE_NUM')) +TEXT2DSL_FEWSHOTS_NUM = int(config.get(chroma_db_section_name, 'TEXT2DSL_FEWSHOTS_NUM')) +TEXT2DSL_SELF_CONSISTENCY_NUM = int(config.get(chroma_db_section_name, 'TEXT2DSL_SELF_CONSISTENCY_NUM')) TEXT2DSL_IS_SHORTCUT = eval(config.get(chroma_db_section_name, 'TEXT2DSL_IS_SHORTCUT')) +TEXT2DSL_IS_SELF_CONSISTENCY = eval(config.get(chroma_db_section_name, 'TEXT2DSL_IS_SELF_CONSISTENCY')) CHROMA_DB_PERSIST_PATH = os.path.join(PROJECT_DIR_PATH, CHROMA_DB_PERSIST_DIR) text2vec_section_name = "Text2Vec" @@ -44,10 +57,14 @@ for option in config.options(llm_model_section_name): if __name__ == "__main__": - logger.info("PROJECT_DIR_PATH: ", PROJECT_DIR_PATH) - logger.info("EMB_MODEL_PATH: ", HF_TEXT2VEC_MODEL_NAME) - logger.info("CHROMA_DB_PERSIST_PATH: ", CHROMA_DB_PERSIST_PATH) - logger.info("LLMPARSER_HOST: ", LLMPARSER_HOST) - logger.info("LLMPARSER_PORT: ", LLMPARSER_PORT) - logger.info("llm_config_dict: ", llm_config_dict) - logger.info("is_shortcut: ", TEXT2DSL_IS_SHORTCUT) + logger.info(f"PROJECT_DIR_PATH: {PROJECT_DIR_PATH}") + logger.info(f"EMB_MODEL_PATH: {HF_TEXT2VEC_MODEL_NAME}") + logger.info(f"CHROMA_DB_PERSIST_PATH: {CHROMA_DB_PERSIST_PATH}") + logger.info(f"LLMPARSER_HOST: {LLMPARSER_HOST}") + logger.info(f"LLMPARSER_PORT: {LLMPARSER_PORT}") + logger.info(f"llm_config_dict: {llm_config_dict}") + logger.info(f"TEXT2DSL_EXAMPLE_NUM: {TEXT2DSL_EXAMPLE_NUM}") + logger.info(f"TEXT2DSL_FEWSHOTS_NUM: {TEXT2DSL_FEWSHOTS_NUM}") + logger.info(f"TEXT2DSL_SELF_CONSISTENCY_NUM: {TEXT2DSL_SELF_CONSISTENCY_NUM}") + logger.info(f"TEXT2DSL_IS_SHORTCUT: {TEXT2DSL_IS_SHORTCUT}") + logger.info(f"TEXT2DSL_IS_SELF_CONSISTENCY: {TEXT2DSL_IS_SELF_CONSISTENCY}") diff --git a/chat/core/src/main/python/config/run_config.ini b/chat/core/src/main/python/config/run_config.ini index 35b0b5375..5b596bcb1 100644 --- a/chat/core/src/main/python/config/run_config.ini +++ b/chat/core/src/main/python/config/run_config.ini @@ -6,9 +6,13 @@ LLMPARSER_PORT = 9092 CHROMA_DB_PERSIST_DIR = chm_db PRESET_QUERY_COLLECTION_NAME = preset_query_collection SOLVED_QUERY_COLLECTION_NAME = solved_query_collection -TEXT2DSL_COLLECTION_NAME = text2dsl_collection -TEXT2DSL_FEW_SHOTS_EXAMPLE_NUM = 15 +TEXT2DSLAGENT_COLLECTION_NAME = text2dsl_agent_collection +TEXT2DSLAGENTCS_COLLECTION_NAME = text2dsl_agent_cs_collection +TEXT2DSL_EXAMPLE_NUM = 15 +TEXT2DSL_FEWSHOTS_NUM = 10 +TEXT2DSL_SELF_CONSISTENCY_NUM = 5 TEXT2DSL_IS_SHORTCUT = False +TEXT2DSL_IS_SELF_CONSISTENCY = False [Text2Vec] HF_TEXT2VEC_MODEL_NAME = GanymedeNil/text2vec-large-chinese diff --git a/chat/core/src/main/python/few_shot_example/sql_exampler.py b/chat/core/src/main/python/few_shot_example/sql_exampler.py index aeeb4c79f..1c2fd669f 100644 --- a/chat/core/src/main/python/few_shot_example/sql_exampler.py +++ b/chat/core/src/main/python/few_shot_example/sql_exampler.py @@ -228,10 +228,11 @@ examplars= [ "question":"邓梓琦在2023年1月5日之后发布的歌曲中,有哪些播放量大于500W的?", "prior_schema_links":"""['2312311'->MPPM歌手ID]""", "analysis": """让我们一步一步地思考。在问题“邓梓琦在2023年1月5日之后发布的歌曲中,有哪些播放量大于500W的?“中,我们被问: +“歌曲中,有哪些”,所以我们需要column=[歌曲名] “播放量大于500W的”,所以我们需要column=[结算播放量], cell values = [5000000],所以有[结算播放量:(5000000)] ”邓梓琦在2023年1月5日之后发布的“,所以我们需要column=[发布时间], cell values = ['2023-01-05'],所以有[发布时间:('2023-01-05')] ”邓梓琦“,所以我们需要column=[歌手名], cell values = ['邓梓琦'],所以有[歌手名:('邓梓琦')]""", - "schema_links":"""["结算播放量":(5000000), "发布时间":("'2023-01-05'"), "歌手名":("'邓梓琦'")]""", + "schema_links":"""["歌曲名", "结算播放量":(5000000), "发布时间":("'2023-01-05'"), "歌手名":("'邓梓琦'")]""", "sql":"""select 歌曲名 from 歌曲库 where 发布时间 >= '2023-01-05' and 歌手名 = '邓梓琦' and 结算播放量 > 5000000""" }, { "current_date":"2023-09-17", diff --git a/chat/core/src/main/python/instances/chromadb_instance.py b/chat/core/src/main/python/instances/chromadb_instance.py new file mode 100644 index 000000000..d7a80e6e2 --- /dev/null +++ b/chat/core/src/main/python/instances/chromadb_instance.py @@ -0,0 +1,21 @@ +# -*- coding:utf-8 -*- +from typing import Any, List, Mapping, Optional, Union + +import chromadb +from chromadb.api import Collection +from chromadb.config import Settings + +import os +import sys +sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) +sys.path.append(os.path.dirname(os.path.abspath(__file__))) + +from config.config_parse import CHROMA_DB_PERSIST_PATH + + +client = chromadb.Client( + Settings( + chroma_db_impl="duckdb+parquet", + persist_directory=CHROMA_DB_PERSIST_PATH, # Optional, defaults to .chromadb/ in the current directory + ) +) \ No newline at end of file diff --git a/chat/core/src/main/python/util/llm_instance.py b/chat/core/src/main/python/instances/llm_instance.py similarity index 100% rename from chat/core/src/main/python/util/llm_instance.py rename to chat/core/src/main/python/instances/llm_instance.py diff --git a/chat/core/src/main/python/instances/logging_instance.py b/chat/core/src/main/python/instances/logging_instance.py new file mode 100644 index 000000000..5fd4fdfe8 --- /dev/null +++ b/chat/core/src/main/python/instances/logging_instance.py @@ -0,0 +1,6 @@ +from loguru import logger +import sys + +logger.remove() #remove the old handler. Else, the old one will work along with the new one you've added below' +logger.add(sys.stdout, format="{time:YYYY-MM-DD at HH:mm:ss} | {level} | {message}", level="INFO") + diff --git a/chat/core/src/main/python/util/text2vec.py b/chat/core/src/main/python/instances/text2vec.py similarity index 100% rename from chat/core/src/main/python/util/text2vec.py rename to chat/core/src/main/python/instances/text2vec.py diff --git a/chat/core/src/main/python/services/plugin_call/prompt_construct.py b/chat/core/src/main/python/services/plugin_call/prompt_construct.py index 64b243578..9c3832e9f 100644 --- a/chat/core/src/main/python/services/plugin_call/prompt_construct.py +++ b/chat/core/src/main/python/services/plugin_call/prompt_construct.py @@ -5,7 +5,7 @@ import re import sys from typing import Any, List, Mapping, Union -from util.logging_utils import logger +from instances.logging_instance import logger sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) sys.path.append(os.path.dirname(os.path.abspath(__file__))) diff --git a/chat/core/src/main/python/services/plugin_call/run.py b/chat/core/src/main/python/services/plugin_call/run.py index 7021e7bb5..c58fb5653 100644 --- a/chat/core/src/main/python/services/plugin_call/run.py +++ b/chat/core/src/main/python/services/plugin_call/run.py @@ -12,7 +12,7 @@ from plugin_call.prompt_construct import ( construct_task_prompt, plugin_selection_output_parse, ) -from util.llm_instance import llm +from instances.llm_instance import llm def plugin_selection_run( diff --git a/chat/core/src/main/python/services/preset_retrieval/preset_query_db.py b/chat/core/src/main/python/services/preset_retrieval/preset_query_db.py deleted file mode 100644 index 70abcb5c1..000000000 --- a/chat/core/src/main/python/services/preset_retrieval/preset_query_db.py +++ /dev/null @@ -1,111 +0,0 @@ -# -*- coding:utf-8 -*- -import os -import sys -import uuid -from typing import Any, List, Mapping - -from chromadb.api import Collection - -sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) -sys.path.append(os.path.dirname(os.path.abspath(__file__))) - - -def get_ids(documents: List[str]) -> List[str]: - ids = [] - for doc in documents: - ids.append(str(uuid.uuid5(uuid.NAMESPACE_URL, doc))) - - return ids - - -def add2preset_query_collection( - collection: Collection, preset_queries: List[str], preset_query_ids: List[str] -) -> None: - - collection.add(documents=preset_queries, ids=preset_query_ids) - - -def update_preset_query_collection( - collection: Collection, preset_queries: List[str], preset_query_ids: List[str] -) -> None: - - collection.update(documents=preset_queries, ids=preset_query_ids) - - -def query2preset_query_collection( - collection: Collection, query_texts: List[str], n_results: int = 10 -): - collection_cnt = collection.count() - min_n_results = 10 - min_n_results = min(collection_cnt, min_n_results) - - if n_results > min_n_results: - res = collection.query(query_texts=query_texts, n_results=n_results) - return res - else: - res = collection.query(query_texts=query_texts, n_results=min_n_results) - - for _key in res.keys(): - if res[_key] is None: - continue - for _idx in range(0, len(query_texts)): - res[_key][_idx] = res[_key][_idx][:n_results] - - return res - - -def parse_retrieval_preset_query(res: List[Mapping[str, Any]]): - parsed_res = [[] for _ in range(0, len(res["ids"]))] - - retrieval_ids = res["ids"] - retrieval_distances = res["distances"] - retrieval_sentences = res["documents"] - - for query_idx in range(0, len(retrieval_ids)): - id_ls = retrieval_ids[query_idx] - distance_ls = retrieval_distances[query_idx] - sentence_ls = retrieval_sentences[query_idx] - - for idx in range(0, len(id_ls)): - id = id_ls[idx] - distance = distance_ls[idx] - sentence = sentence_ls[idx] - - parsed_res[query_idx].append( - {"id": id, "distance": distance, "presetQuery": sentence} - ) - - return parsed_res - - -def preset_query_retrieval_format( - query_list: List[str], retrieval_list: List[Mapping[str, Any]] -): - res = [] - for query_idx in range(0, len(query_list)): - query = query_list[query_idx] - retrieval = retrieval_list[query_idx] - - res.append({"query": query, "retrieval": retrieval}) - - return res - - -def empty_preset_query_collection(collection: Collection) -> None: - collection.delete() - - -def delete_preset_query_by_ids( - collection: Collection, preset_query_ids: List[str] -) -> None: - collection.delete(ids=preset_query_ids) - - -def get_preset_query_by_ids(collection: Collection, preset_query_ids: List[str]): - res = collection.get(ids=preset_query_ids) - - return res - - -def preset_query_collection_size(collection: Collection) -> int: - return collection.count() diff --git a/chat/core/src/main/python/services/preset_retrieval/run.py b/chat/core/src/main/python/services/preset_retrieval/run.py deleted file mode 100644 index 5c3e9adae..000000000 --- a/chat/core/src/main/python/services/preset_retrieval/run.py +++ /dev/null @@ -1,51 +0,0 @@ -# -*- coding:utf-8 -*- - -import os -import sys -from typing import List - -sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) -sys.path.append(os.path.dirname(os.path.abspath(__file__))) - -from util.logging_utils import logger -from chromadb.api import Collection - -from preset_query_db import ( - query2preset_query_collection, - parse_retrieval_preset_query, - preset_query_retrieval_format, - preset_query_collection_size, -) - -from util.text2vec import Text2VecEmbeddingFunction - -from config.config_parse import PRESET_QUERY_COLLECTION_NAME -from util.chromadb_instance import client - - -emb_func = Text2VecEmbeddingFunction() - -collection = client.get_or_create_collection( - name=PRESET_QUERY_COLLECTION_NAME, - embedding_function=emb_func, - metadata={"hnsw:space": "cosine"}, -) # Get a collection object from an existing collection, by name. If it doesn't exist, create it. - -logger.info("init_preset_query_collection_size: {}", preset_query_collection_size(collection)) - - -def preset_query_retrieval_run( - collection: Collection, query_texts_list: List[str], n_results: int = 5 -): - retrieval_res = query2preset_query_collection( - collection=collection, query_texts=query_texts_list, n_results=n_results - ) - - parsed_retrieval_res = parse_retrieval_preset_query(retrieval_res) - parsed_retrieval_res_format = preset_query_retrieval_format( - query_texts_list, parsed_retrieval_res - ) - - logger.info("parsed_retrieval_res_format: {}", parsed_retrieval_res_format) - - return parsed_retrieval_res_format diff --git a/chat/core/src/main/python/services/query_retrieval/retriever.py b/chat/core/src/main/python/services/query_retrieval/retriever.py new file mode 100644 index 000000000..b51b54e7f --- /dev/null +++ b/chat/core/src/main/python/services/query_retrieval/retriever.py @@ -0,0 +1,98 @@ +# -*- coding:utf-8 -*- + +import os +import sys +import uuid +from typing import Any, List, Mapping, Optional, Union + +import chromadb +from chromadb import Client +from chromadb.config import Settings +from chromadb.api import Collection, Documents, Embeddings +from chromadb.api.types import CollectionMetadata + +sys.path.append(os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))) +sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) +sys.path.append(os.path.dirname(os.path.abspath(__file__))) + +from instances.logging_instance import logger +from utils.chromadb_utils import (get_chroma_collection_size, query_chroma_collection, + parse_retrieval_chroma_collection_query, chroma_collection_query_retrieval_format, + get_chroma_collection_by_ids, get_chroma_collection_size, + add_chroma_collection, update_chroma_collection, delete_chroma_collection_by_ids, + empty_chroma_collection_2) + +from instances.text2vec import Text2VecEmbeddingFunction + +class ChromaCollectionRetriever(object): + def __init__(self, collection:Collection): + self.collection = collection + + def retrieval_query_run(self, query_texts_list:List[str]=None, query_embeddings:Embeddings=None, + filter_condition:Mapping[str,str]=None, n_results:int=5): + + retrieval_res = query_chroma_collection(self.collection, query_texts_list, query_embeddings, + filter_condition, n_results) + + parsed_retrieval_res = parse_retrieval_chroma_collection_query(retrieval_res) + logger.debug('parsed_retrieval_res: {}', parsed_retrieval_res) + parsed_retrieval_res_format = chroma_collection_query_retrieval_format(query_texts_list, query_embeddings, parsed_retrieval_res) + logger.debug('parsed_retrieval_res_format: {}', parsed_retrieval_res_format) + + return parsed_retrieval_res_format + + def get_query_by_ids(self, query_ids:List[str]): + queries = get_chroma_collection_by_ids(self.collection, query_ids) + return queries + + def get_query_size(self): + return get_chroma_collection_size(self.collection) + + def add_queries(self, query_text_list:List[str], + query_id_list:List[str], + metadatas:List[Mapping[str, str]]=None, + embeddings:Embeddings=None): + add_chroma_collection(self.collection, query_text_list, query_id_list, metadatas, embeddings) + return True + + def update_queries(self, query_text_list:List[str], + query_id_list:List[str], + metadatas:List[Mapping[str, str]]=None, + embeddings:Embeddings=None): + update_chroma_collection(self.collection, query_text_list, query_id_list, metadatas, embeddings) + return True + + def delete_queries_by_ids(self, query_ids:List[str]): + delete_chroma_collection_by_ids(self.collection, query_ids) + return True + + def empty_query_collection(self): + self.collection = empty_chroma_collection_2(self.collection) + + return True + +class CollectionManager(object): + def __init__(self, chroma_client:Client, embedding_func: Text2VecEmbeddingFunction, collection_meta: Optional[CollectionMetadata] = None): + self.chroma_client = chroma_client + self.embedding_func = embedding_func + self.collection_meta = collection_meta + + def list_collections(self): + collection_list = self.chroma_client.list_collections() + return collection_list + + def get_collection(self, collection_name:str): + collection = self.chroma_client.get_collection(name=collection_name, embedding_function=self.embedding_func) + return collection + + def create_collection(self, collection_name:str): + collection = self.chroma_client.create_collection(name=collection_name, embedding_function=self.embedding_func, metadata=self.collection_meta) + return collection + + def get_or_create_collection(self, collection_name:str): + collection = self.chroma_client.get_or_create_collection(name=collection_name, embedding_function=self.embedding_func, metadata=self.collection_meta) + return collection + + def delete_collection(self, collection_name:str): + self.chroma_client.delete_collection(collection_name) + return True \ No newline at end of file diff --git a/chat/core/src/main/python/services/query_retrieval/run.py b/chat/core/src/main/python/services/query_retrieval/run.py index 010130232..58df1b44f 100644 --- a/chat/core/src/main/python/services/query_retrieval/run.py +++ b/chat/core/src/main/python/services/query_retrieval/run.py @@ -8,80 +8,30 @@ from typing import Any, List, Mapping, Optional, Union sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) sys.path.append(os.path.dirname(os.path.abspath(__file__))) -from util.logging_utils import logger +from instances.logging_instance import logger import chromadb from chromadb.config import Settings from chromadb.api import Collection, Documents, Embeddings -from util.text2vec import Text2VecEmbeddingFunction +from instances.text2vec import Text2VecEmbeddingFunction +from instances.chromadb_instance import client from config.config_parse import SOLVED_QUERY_COLLECTION_NAME, PRESET_QUERY_COLLECTION_NAME -from util.chromadb_instance import (client, - get_chroma_collection_size, query_chroma_collection, - parse_retrieval_chroma_collection_query, chroma_collection_query_retrieval_format, - get_chroma_collection_by_ids, get_chroma_collection_size, - add_chroma_collection, update_chroma_collection, delete_chroma_collection_by_ids, - empty_chroma_collection_2) +from retriever import ChromaCollectionRetriever, CollectionManager + emb_func = Text2VecEmbeddingFunction() -solved_query_collection = client.get_or_create_collection(name=SOLVED_QUERY_COLLECTION_NAME, - embedding_function=emb_func, - metadata={"hnsw:space": "cosine"} - ) # Get a collection object from an existing collection, by name. If it doesn't exist, create it. -logger.info("init_solved_query_collection_size: {}", get_chroma_collection_size(solved_query_collection)) +collection_manager = CollectionManager(chroma_client=client, embedding_func=emb_func + ,collection_meta={"hnsw:space": "cosine"}) +solved_query_collection = collection_manager.get_or_create_collection(collection_name=SOLVED_QUERY_COLLECTION_NAME) +preset_query_collection = collection_manager.get_or_create_collection(collection_name=PRESET_QUERY_COLLECTION_NAME) -preset_query_collection = client.get_or_create_collection(name=PRESET_QUERY_COLLECTION_NAME, - embedding_function=emb_func, - metadata={"hnsw:space": "cosine"} - ) -logger.info("init_preset_query_collection_size: {}", get_chroma_collection_size(preset_query_collection)) - -class ChromaCollectionRetriever(object): - def __init__(self, collection:Collection): - self.collection = collection - - def retrieval_query_run(self, query_texts_list:List[str], - filter_condition:Mapping[str,str]=None, n_results:int=5): - - retrieval_res = query_chroma_collection(self.collection, query_texts_list, - filter_condition, n_results) - - parsed_retrieval_res = parse_retrieval_chroma_collection_query(retrieval_res) - parsed_retrieval_res_format = chroma_collection_query_retrieval_format(query_texts_list, parsed_retrieval_res) - - logger.info('parsed_retrieval_res_format: {}', parsed_retrieval_res_format) - - return parsed_retrieval_res_format - - def get_query_by_ids(self, query_ids:List[str]): - queries = get_chroma_collection_by_ids(self.collection, query_ids) - return queries - - def get_query_size(self): - return get_chroma_collection_size(self.collection) - - def add_queries(self, query_text_list:List[str], - query_id_list:List[str], metadatas:List[Mapping[str, str]]=None): - add_chroma_collection(self.collection, query_text_list, query_id_list, metadatas) - return True - - def update_queries(self, query_text_list:List[str], - query_id_list:List[str], metadatas:List[Mapping[str, str]]=None): - update_chroma_collection(self.collection, query_text_list, query_id_list, metadatas) - return True - - def delete_queries_by_ids(self, query_ids:List[str]): - delete_chroma_collection_by_ids(self.collection, query_ids) - return True - - def empty_query_collection(self): - self.collection = empty_chroma_collection_2(self.collection) - - return True - solved_query_retriever = ChromaCollectionRetriever(solved_query_collection) preset_query_retriever = ChromaCollectionRetriever(preset_query_collection) + +logger.info("init_solved_query_collection_size: {}".format(solved_query_retriever.get_query_size())) +logger.info("init_preset_query_collection_size: {}".format(preset_query_retriever.get_query_size())) diff --git a/chat/core/src/main/python/services/sql/constructor.py b/chat/core/src/main/python/services/sql/constructor.py index e313e765f..3d4ea47e8 100644 --- a/chat/core/src/main/python/services/sql/constructor.py +++ b/chat/core/src/main/python/services/sql/constructor.py @@ -2,87 +2,64 @@ import os import sys from typing import List, Mapping +from chromadb.api import Collection sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) sys.path.append(os.path.dirname(os.path.abspath(__file__))) -from util.logging_utils import logger +from instances.logging_instance import logger +from services.query_retrieval.retriever import ChromaCollectionRetriever -from langchain.vectorstores import Chroma -from langchain.prompts.example_selector import SemanticSimilarityExampleSelector +class FewShotPromptTemplate2(object): + def __init__(self, collection:Collection, few_shot_examples:List[Mapping[str, str]], + retrieval_key:str, few_shot_seperator:str = "\n\n") -> None: + self.collection = collection + self.few_shot_retriever = ChromaCollectionRetriever(self.collection) -from few_shot_example.sql_exampler import examplars as sql_examplars -from util.text2vec import hg_embedding -from util.chromadb_instance import client as chromadb_client, empty_chroma_collection_2 -from config.config_parse import TEXT2DSL_COLLECTION_NAME, TEXT2DSL_FEW_SHOTS_EXAMPLE_NUM + self.few_shot_examples = few_shot_examples + self.retrieval_key = retrieval_key + self.few_shot_seperator = few_shot_seperator -def reload_sql_example_collection( - vectorstore: Chroma, - sql_examplars: List[Mapping[str, str]], - sql_example_selector: SemanticSimilarityExampleSelector, - example_nums: int, -): - logger.info("original sql_examples_collection size: {}", vectorstore._collection.count()) - new_collection = empty_chroma_collection_2(collection=vectorstore._collection) - vectorstore._collection = new_collection + def add_few_shot_example(self, examples: List[Mapping[str, str]])-> None: + query_text_list = [] + query_id_list = [] + for idx, example in enumerate(examples): + query_text_list.append(example[self.retrieval_key]) + query_id_list.append(str(idx)) - logger.info("emptied sql_examples_collection size: {}", vectorstore._collection.count()) + self.few_shot_retriever.add_queries(query_text_list=query_text_list, query_id_list=query_id_list, metadatas=examples) - sql_example_selector = SemanticSimilarityExampleSelector( - vectorstore=sql_examples_vectorstore, - k=example_nums, - input_keys=["question"], - example_keys=[ - "table_name", - "fields_list", - "prior_schema_links", - "question", - "analysis", - "schema_links", - "current_date", - "sql", - ], - ) + def reload_few_shot_example(self, examples: List[Mapping[str, str]])-> None: + logger.info(f"original sql_examples_collection size: {self.few_shot_retriever.get_query_size()}") - for example in sql_examplars: - sql_example_selector.add_example(example) + self.few_shot_retriever.empty_query_collection() + logger.info(f"emptied sql_examples_collection size: {self.few_shot_retriever.get_query_size()}") - logger.info("reloaded sql_examples_collection size: {}", vectorstore._collection.count()) + self.add_few_shot_example(examples=examples) + logger.info(f"reloaded sql_examples_collection size: {self.few_shot_retriever.get_query_size()}") - return vectorstore, sql_example_selector + def _sub_dict(self, d:Mapping[str, str], keys:List[str])-> Mapping[str, str]: + return {k:d[k] for k in keys if k in d} + def retrieve_few_shot_example(self, query_text: str, retrieval_num: int)-> List[Mapping[str, str]]: + query_text_list = [query_text] + retrieval_res_list = self.few_shot_retriever.retrieval_query_run(query_texts_list=query_text_list, + filter_condition=None, n_results=retrieval_num) + retrieval_res_unit_list = retrieval_res_list[0]['retrieval'] -sql_examples_vectorstore = Chroma( - collection_name=TEXT2DSL_COLLECTION_NAME, - embedding_function=hg_embedding, - client=chromadb_client, -) + return retrieval_res_unit_list -example_nums = TEXT2DSL_FEW_SHOTS_EXAMPLE_NUM + def make_few_shot_example_prompt(self, few_shot_template: str, example_keys: List[str], + few_shot_example_meta_list: List[Mapping[str, str]])-> str: + few_shot_example_str_unit_list = [] -sql_example_selector = SemanticSimilarityExampleSelector( - vectorstore=sql_examples_vectorstore, - k=example_nums, - input_keys=["question"], - example_keys=[ - "table_name", - "fields_list", - "prior_schema_links", - "question", - "analysis", - "schema_links", - "current_date", - "sql", - ], -) + retrieval_metas_list = [self._sub_dict(few_shot_example_meta['metadata'], example_keys) for few_shot_example_meta in few_shot_example_meta_list] -if sql_examples_vectorstore._collection.count() > 0: - logger.info("examples already in sql_vectorstore") - logger.info("init sql_vectorstore size: {}", sql_examples_vectorstore._collection.count()) + for meta in retrieval_metas_list: + few_shot_example_str_unit_list.append(few_shot_template.format(**meta)) + + few_shot_example_str = self.few_shot_seperator.join(few_shot_example_str_unit_list) + + return few_shot_example_str -logger.info("sql_examplars size: {}", len(sql_examplars)) -sql_examples_vectorstore, sql_example_selector = reload_sql_example_collection( - sql_examples_vectorstore, sql_examplars, sql_example_selector, example_nums -) -logger.info("added sql_vectorstore size: {}", sql_examples_vectorstore._collection.count()) diff --git a/chat/core/src/main/python/services/sql/examples_reload_run.py b/chat/core/src/main/python/services/sql/examples_reload_run.py index cf5634c8c..1d93d3bc9 100644 --- a/chat/core/src/main/python/services/sql/examples_reload_run.py +++ b/chat/core/src/main/python/services/sql/examples_reload_run.py @@ -6,41 +6,54 @@ from typing import List, Mapping import requests +sys.path.append(os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))) sys.path.append(os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))) sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) sys.path.append(os.path.dirname(os.path.abspath(__file__))) -from config.config_parse import (TEXT2DSL_FEW_SHOTS_EXAMPLE_NUM, TEXT2DSL_IS_SHORTCUT, - LLMPARSER_HOST, LLMPARSER_PORT) +from instances.logging_instance import logger + +from config.config_parse import (TEXT2DSL_EXAMPLE_NUM, TEXT2DSL_FEWSHOTS_NUM, TEXT2DSL_SELF_CONSISTENCY_NUM, + LLMPARSER_HOST, LLMPARSER_PORT, TEXT2DSL_IS_SHORTCUT, TEXT2DSL_IS_SELF_CONSISTENCY) from few_shot_example.sql_exampler import examplars as sql_examplars -from util.logging_utils import logger -def text2dsl_setting_update( - llm_parser_host: str, - llm_parser_port: str, - sql_examplars: List[Mapping[str, str]], - example_nums: int, - is_shortcut: bool, -): +def text2sql_agent_setting_update(llm_host:str, llm_port:str, + sql_examplars:List[Mapping[str, str]], example_nums:int): - url = f"http://{llm_parser_host}:{llm_parser_port}/query2sql_setting_update/" - logger.info("url: {}", url) - payload = { - "sqlExamplars": sql_examplars, - "exampleNums": example_nums, - "isShortcut": is_shortcut, - } - headers = {"content-type": "application/json"} + url = f"http://{llm_host}:{llm_port}/text2sql_agent_setting_update/" + payload = {"sqlExamplars":sql_examplars, "exampleNums":example_nums} + headers = {'content-type': 'application/json'} response = requests.post(url, data=json.dumps(payload), headers=headers) logger.info(response.text) +def text2dsl_agent_cs_setting_update(llm_host:str, llm_port:str, + sql_examplars:List[Mapping[str, str]], example_nums:int, fewshot_nums:int, self_consistency_nums:int): + + url = f"http://{llm_host}:{llm_port}/texg2sqt_cs_agent_setting_update/" + payload = {"sqlExamplars":sql_examplars, + "exampleNums":example_nums, "fewshotNums":fewshot_nums, "selfConsistencyNums":self_consistency_nums} + headers = {'content-type': 'application/json'} + response = requests.post(url, data=json.dumps(payload), headers=headers) + logger.info(response.text) + + +def text2dsl_agent_wrapper_setting_update(llm_host:str, llm_port:str, + is_shortcut:bool, is_self_consistency:bool, + sql_examplars:List[Mapping[str, str]], example_nums:int, fewshot_nums:int, self_consistency_nums:int): + + url = f"http://{llm_host}:{llm_port}/query2sql_setting_update/" + payload = {"isShortcut":is_shortcut, "isSelfConsistency":is_self_consistency, + "sqlExamplars":sql_examplars, + "exampleNums":example_nums, "fewshotNums":fewshot_nums, "selfConsistencyNums":self_consistency_nums} + headers = {'content-type': 'application/json'} + response = requests.post(url, data=json.dumps(payload), headers=headers) + logger.info(response.text) + if __name__ == "__main__": - text2dsl_setting_update( - LLMPARSER_HOST, - LLMPARSER_PORT, - sql_examplars, - TEXT2DSL_FEW_SHOTS_EXAMPLE_NUM, - TEXT2DSL_IS_SHORTCUT, - ) + text2dsl_agent_wrapper_setting_update(LLMPARSER_HOST,LLMPARSER_PORT, + TEXT2DSL_IS_SHORTCUT, TEXT2DSL_IS_SELF_CONSISTENCY, + sql_examplars, TEXT2DSL_EXAMPLE_NUM, TEXT2DSL_FEWSHOTS_NUM, TEXT2DSL_SELF_CONSISTENCY_NUM) + + diff --git a/chat/core/src/main/python/services/sql/output_parser.py b/chat/core/src/main/python/services/sql/output_parser.py index ee07132eb..0200d3227 100644 --- a/chat/core/src/main/python/services/sql/output_parser.py +++ b/chat/core/src/main/python/services/sql/output_parser.py @@ -8,16 +8,14 @@ sys.path.append(os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath( sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) sys.path.append(os.path.dirname(os.path.abspath(__file__))) -from util.logging_utils import logger +from instances.logging_instance import logger def schema_link_parse(schema_link_output): try: schema_link_output = schema_link_output.strip() pattern = r"Schema_links:(.*)" - schema_link_output = re.findall(pattern, schema_link_output, re.DOTALL)[ - 0 - ].strip() + schema_link_output = re.findall(pattern, schema_link_output, re.DOTALL)[0].strip() except Exception as e: logger.exception(e) schema_link_output = None diff --git a/chat/core/src/main/python/services/sql/prompt_maker.py b/chat/core/src/main/python/services/sql/prompt_maker.py deleted file mode 100644 index 3ed8d84ad..000000000 --- a/chat/core/src/main/python/services/sql/prompt_maker.py +++ /dev/null @@ -1,166 +0,0 @@ -# -*- coding:utf-8 -*- -import os -import sys -from typing import List, Mapping - -sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) -sys.path.append(os.path.dirname(os.path.abspath(__file__))) - -from util.logging_utils import logger - -from langchain.prompts import PromptTemplate -from langchain.prompts.few_shot import FewShotPromptTemplate -from langchain.prompts.example_selector import SemanticSimilarityExampleSelector - - -def schema_linking_exampler( - user_query: str, - domain_name: str, - fields_list: List[str], - prior_schema_links: Mapping[str, str], - example_selector: SemanticSimilarityExampleSelector, -) -> str: - - prior_schema_links_str = ( - "[" - + ",".join(["""'{}'->{}""".format(k, v) for k, v in prior_schema_links.items()]) - + "]" - ) - - example_prompt_template = PromptTemplate( - input_variables=[ - "table_name", - "fields_list", - "prior_schema_links", - "question", - "analysis", - "schema_links", - ], - template="Table {table_name}, columns = {fields_list}, prior_schema_links = {prior_schema_links}\n问题:{question}\n分析:{analysis} 所以Schema_links是:\nSchema_links:{schema_links}", - ) - - instruction = "# 根据数据库的表结构,参考先验信息,找出为每个问题生成SQL查询语句的schema_links" - - schema_linking_prompt = "Table {table_name}, columns = {fields_list}, prior_schema_links = {prior_schema_links}\n问题:{question}\n分析: 让我们一步一步地思考。" - - schema_linking_example_prompt_template = FewShotPromptTemplate( - example_selector=example_selector, - example_prompt=example_prompt_template, - example_separator="\n\n", - prefix=instruction, - input_variables=["table_name", "fields_list", "prior_schema_links", "question"], - suffix=schema_linking_prompt, - ) - - schema_linking_example_prompt = schema_linking_example_prompt_template.format( - table_name=domain_name, - fields_list=fields_list, - prior_schema_links=prior_schema_links_str, - question=user_query, - ) - - return schema_linking_example_prompt - - -def sql_exampler( - user_query: str, - domain_name: str, - schema_link_str: str, - data_date: str, - example_selector: SemanticSimilarityExampleSelector, -) -> str: - - instruction = "# 根据schema_links为每个问题生成SQL查询语句" - - sql_example_prompt_template = PromptTemplate( - input_variables=[ - "question", - "current_date", - "table_name", - "schema_links", - "sql", - ], - template="问题:{question}\nCurrent_date:{current_date}\nTable {table_name}\nSchema_links:{schema_links}\nSQL:{sql}", - ) - - sql_prompt = "问题:{question}\nCurrent_date:{current_date}\nTable {table_name}\nSchema_links:{schema_links}\nSQL:" - - sql_example_prompt_template = FewShotPromptTemplate( - example_selector=example_selector, - example_prompt=sql_example_prompt_template, - example_separator="\n\n", - prefix=instruction, - input_variables=["question", "current_date", "table_name", "schema_links"], - suffix=sql_prompt, - ) - - sql_example_prompt = sql_example_prompt_template.format( - question=user_query, - current_date=data_date, - table_name=domain_name, - schema_links=schema_link_str, - ) - - return sql_example_prompt - - -def schema_linking_sql_combo_examplar( - user_query: str, - domain_name: str, - data_date: str, - fields_list: List[str], - prior_schema_links: Mapping[str, str], - example_selector: SemanticSimilarityExampleSelector, -) -> str: - - prior_schema_links_str = ( - "[" - + ",".join(["""'{}'->{}""".format(k, v) for k, v in prior_schema_links.items()]) - + "]" - ) - - example_prompt_template = PromptTemplate( - input_variables=[ - "table_name", - "fields_list", - "prior_schema_links", - "current_date", - "question", - "analysis", - "schema_links", - "sql", - ], - template="Table {table_name}, columns = {fields_list}, prior_schema_links = {prior_schema_links}\nCurrent_date:{current_date}\n问题:{question}\n分析:{analysis} 所以Schema_links是:\nSchema_links:{schema_links}\nSQL:{sql}", - ) - - instruction = ( - "# 根据数据库的表结构,参考先验信息,找出为每个问题生成SQL查询语句的schema_links,再根据schema_links为每个问题生成SQL查询语句" - ) - - schema_linking_sql_combo_prompt = "Table {table_name}, columns = {fields_list}, prior_schema_links = {prior_schema_links}\nCurrent_date:{current_date}\n问题:{question}\n分析: 让我们一步一步地思考。" - - schema_linking_sql_combo_example_prompt_template = FewShotPromptTemplate( - example_selector=example_selector, - example_prompt=example_prompt_template, - example_separator="\n\n", - prefix=instruction, - input_variables=[ - "table_name", - "fields_list", - "prior_schema_links", - "current_date", - "question", - ], - suffix=schema_linking_sql_combo_prompt, - ) - - schema_linking_sql_combo_example_prompt = ( - schema_linking_sql_combo_example_prompt_template.format( - table_name=domain_name, - fields_list=fields_list, - prior_schema_links=prior_schema_links_str, - current_date=data_date, - question=user_query, - ) - ) - return schema_linking_sql_combo_example_prompt diff --git a/chat/core/src/main/python/services/sql/run.py b/chat/core/src/main/python/services/sql/run.py index 623c8cfdd..9cb6c96d8 100644 --- a/chat/core/src/main/python/services/sql/run.py +++ b/chat/core/src/main/python/services/sql/run.py @@ -1,188 +1,56 @@ +# -*- coding:utf-8 -*- + +import asyncio + import os import sys -from typing import List, Union, Mapping +sys.path.append(os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))) sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) sys.path.append(os.path.dirname(os.path.abspath(__file__))) -from util.logging_utils import logger +from sql.constructor import FewShotPromptTemplate2 +from sql.sql_agent import Text2DSLAgent, Text2DSLAgentConsistency, Text2DSLAgentWrapper -from sql.prompt_maker import ( - schema_linking_exampler, - sql_exampler, - schema_linking_sql_combo_examplar, -) -from sql.constructor import ( - sql_examples_vectorstore, - sql_example_selector, - reload_sql_example_collection, -) -from sql.output_parser import ( - schema_link_parse, - combo_schema_link_parse, - combo_sql_parse, -) +from instances.llm_instance import llm +from instances.text2vec import Text2VecEmbeddingFunction +from instances.chromadb_instance import client +from instances.logging_instance import logger -from util.llm_instance import llm -from config.config_parse import TEXT2DSL_IS_SHORTCUT +from few_shot_example.sql_exampler import examplars as sql_examplars +from config.config_parse import (TEXT2DSLAGENT_COLLECTION_NAME, TEXT2DSLAGENTCS_COLLECTION_NAME, + TEXT2DSL_EXAMPLE_NUM, TEXT2DSL_FEWSHOTS_NUM, TEXT2DSL_SELF_CONSISTENCY_NUM, + TEXT2DSL_IS_SHORTCUT, TEXT2DSL_IS_SELF_CONSISTENCY) -class Text2DSLAgent(object): - def __init__(self): - self.schema_linking_exampler = schema_linking_exampler - self.sql_exampler = sql_exampler +emb_func = Text2VecEmbeddingFunction() +text2dsl_agent_collection = client.get_or_create_collection(name=TEXT2DSLAGENT_COLLECTION_NAME, + embedding_function=emb_func, + metadata={"hnsw:space": "cosine"}) +text2dsl_agentcs_collection = client.get_or_create_collection(name=TEXT2DSLAGENTCS_COLLECTION_NAME, + embedding_function=emb_func, + metadata={"hnsw:space": "cosine"}) - self.schema_linking_sql_combo_exampler = schema_linking_sql_combo_examplar +text2dsl_agent_example_prompter = FewShotPromptTemplate2(collection=text2dsl_agent_collection, + few_shot_examples=sql_examplars, + retrieval_key="question", + few_shot_seperator='\n\n') - self.sql_examples_vectorstore = sql_examples_vectorstore - self.sql_example_selector = sql_example_selector +text2dsl_agentcs_example_prompter = FewShotPromptTemplate2(collection=text2dsl_agentcs_collection, + few_shot_examples=sql_examplars, + retrieval_key="question", + few_shot_seperator='\n\n') - self.schema_link_parse = schema_link_parse - self.combo_schema_link_parse = combo_schema_link_parse - self.combo_sql_parse = combo_sql_parse +text2sql_agent = Text2DSLAgent(num_fewshots=TEXT2DSL_EXAMPLE_NUM, + sql_example_prompter=text2dsl_agent_example_prompter, llm=llm) - self.llm = llm +text2sql_cs_agent = Text2DSLAgentConsistency(num_fewshots=TEXT2DSL_FEWSHOTS_NUM, num_examples=TEXT2DSL_EXAMPLE_NUM, num_self_consistency=TEXT2DSL_SELF_CONSISTENCY_NUM, + sql_example_prompter=text2dsl_agentcs_example_prompter, llm=llm) - self.is_shortcut = TEXT2DSL_IS_SHORTCUT +text2sql_agent.update_examples(sql_examplars, TEXT2DSL_EXAMPLE_NUM) - def update_examples(self, sql_examples, example_nums, is_shortcut): - ( - self.sql_examples_vectorstore, - self.sql_example_selector, - ) = reload_sql_example_collection( - self.sql_examples_vectorstore, - sql_examples, - self.sql_example_selector, - example_nums, - ) - self.is_shortcut = is_shortcut - - def query2sql( - self, - query_text: str, - schema: Union[dict, None] = None, - current_date: str = None, - linking: Union[List[Mapping[str, str]], None] = None, - ): - - logger.info("query_text: {}".format(query_text)) - logger.info("schema: {}".format(schema)) - logger.info("current_date: {}".format(current_date)) - logger.info("prior_schema_links: {}".format(linking)) - - if linking is not None: - prior_schema_links = { - item["fieldValue"]: item["fieldName"] for item in linking - } - else: - prior_schema_links = {} - - model_name = schema["modelName"] - fields_list = schema["fieldNameList"] - - schema_linking_prompt = self.schema_linking_exampler( - query_text, - model_name, - fields_list, - prior_schema_links, - self.sql_example_selector, - ) - logger.info("schema_linking_prompt-> {}".format(schema_linking_prompt)) - schema_link_output = self.llm(schema_linking_prompt) - schema_link_str = self.schema_link_parse(schema_link_output) - - sql_prompt = self.sql_exampler( - query_text, - model_name, - schema_link_str, - current_date, - self.sql_example_selector, - ) - logger.info("sql_prompt-> {}".format(sql_prompt)) - sql_output = self.llm(sql_prompt) - - resp = dict() - resp["query"] = query_text - resp["model"] = model_name - resp["fields"] = fields_list - resp["priorSchemaLinking"] = linking - resp["dataDate"] = current_date - - resp["analysisOutput"] = schema_link_output - resp["schemaLinkStr"] = schema_link_str - - resp["sqlOutput"] = sql_output - - logger.info("resp: {}".format(resp)) - - return resp - - def query2sqlcombo( - self, - query_text: str, - schema: Union[dict, None] = None, - current_date: str = None, - linking: Union[List[Mapping[str, str]], None] = None, - ): - - logger.info("query_text: {}".format(query_text)) - logger.info("schema: {}".format(schema)) - logger.info("current_date: {}".format(current_date)) - logger.info("prior_schema_links: {}".format(linking)) - - if linking is not None: - prior_schema_links = { - item["fieldValue"]: item["fieldName"] for item in linking - } - else: - prior_schema_links = {} - - model_name = schema["modelName"] - fields_list = schema["fieldNameList"] - - schema_linking_sql_combo_prompt = self.schema_linking_sql_combo_exampler( - query_text, - model_name, - current_date, - fields_list, - prior_schema_links, - self.sql_example_selector, - ) - logger.info("schema_linking_sql_combo_prompt-> {}".format(schema_linking_sql_combo_prompt)) - schema_linking_sql_combo_output = self.llm(schema_linking_sql_combo_prompt) - - schema_linking_str = self.combo_schema_link_parse( - schema_linking_sql_combo_output - ) - sql_str = self.combo_sql_parse(schema_linking_sql_combo_output) - - resp = dict() - resp["query"] = query_text - resp["model"] = model_name - resp["fields"] = fields_list - resp["priorSchemaLinking"] = prior_schema_links - resp["dataDate"] = current_date - - resp["analysisOutput"] = schema_linking_sql_combo_output - resp["schemaLinkStr"] = schema_linking_str - resp["sqlOutput"] = sql_str - - logger.info("resp: {}".format(resp)) - - return resp - - def query2sql_run( - self, - query_text: str, - schema: Union[dict, None] = None, - current_date: str = None, - linking: Union[List[Mapping[str, str]], None] = None, - ): - - if self.is_shortcut: - return self.query2sqlcombo(query_text, schema, current_date, linking) - else: - return self.query2sql(query_text, schema, current_date, linking) +text2sql_cs_agent.update_examples(sql_examplars, TEXT2DSL_EXAMPLE_NUM, TEXT2DSL_FEWSHOTS_NUM, TEXT2DSL_SELF_CONSISTENCY_NUM) -text2sql_agent = Text2DSLAgent() +text2sql_agent_router = Text2DSLAgentWrapper(sql_agent=text2sql_agent, sql_agent_cs=text2sql_cs_agent, + is_shortcut=TEXT2DSL_IS_SHORTCUT, is_self_consistency=TEXT2DSL_IS_SELF_CONSISTENCY) \ No newline at end of file diff --git a/chat/core/src/main/python/services/sql/sql_agent.py b/chat/core/src/main/python/services/sql/sql_agent.py new file mode 100644 index 000000000..0325cf25b --- /dev/null +++ b/chat/core/src/main/python/services/sql/sql_agent.py @@ -0,0 +1,380 @@ +import os +import sys +from typing import List, Union, Mapping, Any +from collections import Counter +import random +import asyncio +from langchain.llms.base import BaseLLM + +sys.path.append(os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))) +sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) +sys.path.append(os.path.dirname(os.path.abspath(__file__))) + +from instances.logging_instance import logger + + +from sql.constructor import FewShotPromptTemplate2 +from sql.output_parser import schema_link_parse, combo_schema_link_parse, combo_sql_parse + + +class Text2DSLAgent(object): + def __init__(self, num_fewshots:int, + sql_example_prompter:FewShotPromptTemplate2, + llm: BaseLLM): + self.num_fewshots = num_fewshots + self.sql_example_prompter = sql_example_prompter + self.llm = llm + + def update_examples(self, sql_examplars, num_fewshots): + self.num_fewshots = num_fewshots + self.sql_example_prompter.reload_few_shot_example(sql_examplars) + + def get_fewshot_examples(self, query_text: str)->List[Mapping[str, str]]: + few_shot_example_meta_list = self.sql_example_prompter.retrieve_few_shot_example(query_text, self.num_fewshots) + + return few_shot_example_meta_list + + def generate_schema_linking_prompt(self, user_query: str, domain_name: str, fields_list: List[str], + prior_schema_links: Mapping[str,str], fewshot_example_list:List[Mapping[str, str]])-> str: + + prior_schema_links_str = '['+ ','.join(["""'{}'->{}""".format(k,v) for k,v in prior_schema_links.items()]) + ']' + + instruction = "# 根据数据库的表结构,参考先验信息,找出为每个问题生成SQL查询语句的schema_links" + + schema_linking_example_keys = ["table_name", "fields_list", "prior_schema_links", "question", "analysis", "schema_links"] + schema_linking_example_template = "Table {table_name}, columns = {fields_list}, prior_schema_links = {prior_schema_links}\n问题:{question}\n分析:{analysis} 所以Schema_links是:\nSchema_links:{schema_links}" + schema_linking_fewshot_prompt = self.sql_example_prompter.make_few_shot_example_prompt(few_shot_template=schema_linking_example_template, + example_keys=schema_linking_example_keys, + few_shot_example_meta_list=fewshot_example_list) + + new_case_template = "Table {table_name}, columns = {fields_list}, prior_schema_links = {prior_schema_links}\n问题:{question}\n分析: 让我们一步一步地思考。" + new_case_prompt = new_case_template.format(table_name=domain_name, fields_list=fields_list, prior_schema_links=prior_schema_links_str, question=user_query) + + schema_linking_prompt = instruction + '\n\n' + schema_linking_fewshot_prompt + '\n\n' + new_case_prompt + return schema_linking_prompt + + def generate_sql_prompt(self, user_query: str, domain_name: str, + schema_link_str: str, data_date: str, + fewshot_example_list:List[Mapping[str, str]])-> str: + instruction = "# 根据schema_links为每个问题生成SQL查询语句" + sql_example_keys = ["question", "current_date", "table_name", "schema_links", "sql"] + sql_example_template = "问题:{question}\nCurrent_date:{current_date}\nTable {table_name}\nSchema_links:{schema_links}\nSQL:{sql}" + + + sql_example_fewshot_prompt = self.sql_example_prompter.make_few_shot_example_prompt(few_shot_template=sql_example_template, + example_keys=sql_example_keys, + few_shot_example_meta_list=fewshot_example_list) + + new_case_template = "问题:{question}\nCurrent_date:{current_date}\nTable {table_name}\nSchema_links:{schema_links}\nSQL:" + new_case_prompt = new_case_template.format(question=user_query, current_date=data_date, table_name=domain_name, schema_links=schema_link_str) + + sql_example_prompt = instruction + '\n\n' + sql_example_fewshot_prompt + '\n\n' + new_case_prompt + + return sql_example_prompt + + def generate_schema_linking_sql_prompt(self, user_query: str, + domain_name: str, + data_date : str, + fields_list: List[str], + prior_schema_links: Mapping[str,str], + fewshot_example_list:List[Mapping[str, str]]): + + prior_schema_links_str = '['+ ','.join(["""'{}'->{}""".format(k,v) for k,v in prior_schema_links.items()]) + ']' + + instruction = "# 根据数据库的表结构,参考先验信息,找出为每个问题生成SQL查询语句的schema_links,再根据schema_links为每个问题生成SQL查询语句" + + example_keys = ["table_name", "fields_list", "prior_schema_links", "current_date", "question", "analysis", "schema_links", "sql"] + example_template = "Table {table_name}, columns = {fields_list}, prior_schema_links = {prior_schema_links}\nCurrent_date:{current_date}\n问题:{question}\n分析:{analysis} 所以Schema_links是:\nSchema_links:{schema_links}\nSQL:{sql}" + fewshot_prompt = self.sql_example_prompter.make_few_shot_example_prompt(few_shot_template=example_template, + example_keys=example_keys, + few_shot_example_meta_list=fewshot_example_list) + + new_case_template = "Table {table_name}, columns = {fields_list}, prior_schema_links = {prior_schema_links}\nCurrent_date:{current_date}\n问题:{question}\n分析: 让我们一步一步地思考。" + new_case_prompt = new_case_template.format(table_name=domain_name, fields_list=fields_list, prior_schema_links=prior_schema_links_str, current_date=data_date, question=user_query) + + prompt = instruction + '\n\n' + fewshot_prompt + '\n\n' + new_case_prompt + + return prompt + + async def async_query2sql(self, query_text: str, + model_name: str, fields_list: List[str], + data_date: str, prior_schema_links: Mapping[str,str], prior_exts: str): + logger.info("query_text: {}".format(query_text)) + logger.info("model_name: {}".format(model_name)) + logger.info("fields_list: {}".format(fields_list)) + logger.info("data_date: {}".format(data_date)) + logger.info("prior_schema_links: {}".format(prior_schema_links)) + logger.info("prior_exts: {}".format(prior_exts)) + + query_text = query_text + ' 备注:'+prior_exts + logger.info("query_text_prior_exts: {}".format(query_text)) + + fewshot_example_meta_list = self.get_fewshot_examples(query_text) + schema_linking_prompt = self.generate_schema_linking_prompt(query_text, model_name, fields_list, prior_schema_links, fewshot_example_meta_list) + logger.debug("schema_linking_prompt->{}".format(schema_linking_prompt)) + schema_link_output = await self.llm._call_async(schema_linking_prompt) + + schema_link_str = schema_link_parse(schema_link_output) + + sql_prompt = self.generate_sql_prompt(query_text, model_name, schema_link_str, data_date, fewshot_example_meta_list) + logger.debug("sql_prompt->{}".format(sql_prompt)) + sql_output = await self.llm._call_async(sql_prompt) + + resp = dict() + resp['query'] = query_text + resp['model'] = model_name + resp['fields'] = fields_list + resp['priorSchemaLinking'] = prior_schema_links + resp['dataDate'] = data_date + + resp['schemaLinkingOutput'] = schema_link_output + resp['schemaLinkStr'] = schema_link_str + + resp['sqlOutput'] = sql_output + + logger.info("resp: {}".format(resp)) + + return resp + + async def async_query2sql_shortcut(self, query_text: str, + model_name: str, fields_list: List[str], + data_date: str, prior_schema_links: Mapping[str,str], prior_exts: str): + logger.info("query_text: {}".format(query_text)) + logger.info("model_name: {}".format(model_name)) + logger.info("fields_list: {}".format(fields_list)) + logger.info("data_date: {}".format(data_date)) + logger.info("prior_schema_links: {}".format(prior_schema_links)) + logger.info("prior_exts: {}".format(prior_exts)) + + query_text = query_text + ' 备注:'+prior_exts + logger.info("query_text_prior_exts: {}".format(query_text)) + + fewshot_example_meta_list = self.get_fewshot_examples(query_text) + schema_linking_sql_shortcut_prompt = self.generate_schema_linking_sql_prompt(query_text, model_name, data_date, fields_list, prior_schema_links, fewshot_example_meta_list) + logger.debug("schema_linking_sql_shortcut_prompt->{}".format(schema_linking_sql_shortcut_prompt)) + schema_linking_sql_shortcut_output = await self.llm._call_async(schema_linking_sql_shortcut_prompt) + + schema_linking_str = combo_schema_link_parse(schema_linking_sql_shortcut_output) + sql_str = combo_sql_parse(schema_linking_sql_shortcut_output) + + resp = dict() + resp['query'] = query_text + resp['model'] = model_name + resp['fields'] = fields_list + resp['priorSchemaLinking'] = prior_schema_links + resp['dataDate'] = data_date + + resp['schemaLinkingComboOutput'] = schema_linking_sql_shortcut_output + resp['schemaLinkStr'] = schema_linking_str + resp['sqlOutput'] = sql_str + + logger.info("resp: {}".format(resp)) + + return resp + +class Text2DSLAgentConsistency(object): + def __init__(self, num_fewshots:int, num_examples:int, num_self_consistency:int, + sql_example_prompter:FewShotPromptTemplate2, llm: BaseLLM) -> None: + self.num_fewshots = num_fewshots + self.num_examples = num_examples + assert self.num_fewshots <= self.num_examples + self.num_self_consistency = num_self_consistency + + self.llm = llm + self.sql_example_prompter = sql_example_prompter + + def update_examples(self, sql_examplars, num_examples, num_fewshots, num_self_consistency): + self.num_fewshots = num_fewshots + self.num_examples = num_examples + assert self.num_fewshots <= self.num_examples + self.num_self_consistency = num_self_consistency + assert self.num_self_consistency >= 1 + self.sql_example_prompter.reload_few_shot_example(sql_examplars) + + def get_examples_candidates(self, query_text: str)->List[Mapping[str, str]]: + few_shot_example_meta_list = self.sql_example_prompter.retrieve_few_shot_example(query_text, self.num_examples) + + return few_shot_example_meta_list + + def get_fewshot_example_combos(self, example_meta_list:List[Mapping[str, str]])-> List[List[Mapping[str, str]]]: + fewshot_example_list = [] + for i in range(0, self.num_self_consistency): + random.shuffle(example_meta_list) + fewshot_example_list.append(example_meta_list[:self.num_fewshots]) + + return fewshot_example_list + + def generate_schema_linking_prompt(self, user_query: str, domain_name: str, fields_list: List[str], + prior_schema_links: Mapping[str,str], fewshot_example_list:List[Mapping[str, str]])-> str: + + prior_schema_links_str = '['+ ','.join(["""'{}'->{}""".format(k,v) for k,v in prior_schema_links.items()]) + ']' + + instruction = "# 根据数据库的表结构,参考先验信息,找出为每个问题生成SQL查询语句的schema_links" + + schema_linking_example_keys = ["table_name", "fields_list", "prior_schema_links", "question", "analysis", "schema_links"] + schema_linking_example_template = "Table {table_name}, columns = {fields_list}, prior_schema_links = {prior_schema_links}\n问题:{question}\n分析:{analysis} 所以Schema_links是:\nSchema_links:{schema_links}" + schema_linking_fewshot_prompt = self.sql_example_prompter.make_few_shot_example_prompt(few_shot_template=schema_linking_example_template, + example_keys=schema_linking_example_keys, + few_shot_example_meta_list=fewshot_example_list) + + new_case_template = "Table {table_name}, columns = {fields_list}, prior_schema_links = {prior_schema_links}\n问题:{question}\n分析: 让我们一步一步地思考。" + new_case_prompt = new_case_template.format(table_name=domain_name, fields_list=fields_list, prior_schema_links=prior_schema_links_str, question=user_query) + + schema_linking_prompt = instruction + '\n\n' + schema_linking_fewshot_prompt + '\n\n' + new_case_prompt + return schema_linking_prompt + + def generate_schema_linking_prompt_pool(self, user_query: str, domain_name: str, fields_list: List[str], + prior_schema_links: Mapping[str,str], fewshot_example_list_pool:List[List[Mapping[str, str]]])-> List[str]: + schema_linking_prompt_pool = [] + for fewshot_example_list in fewshot_example_list_pool: + schema_linking_prompt = self.generate_schema_linking_prompt(user_query, domain_name, fields_list, prior_schema_links, fewshot_example_list) + schema_linking_prompt_pool.append(schema_linking_prompt) + + return schema_linking_prompt_pool + + def generate_sql_prompt(self, user_query: str, domain_name: str, + schema_link_str: str, data_date: str, + fewshot_example_list:List[Mapping[str, str]])-> str: + instruction = "# 根据schema_links为每个问题生成SQL查询语句" + sql_example_keys = ["question", "current_date", "table_name", "schema_links", "sql"] + sql_example_template = "问题:{question}\nCurrent_date:{current_date}\nTable {table_name}\nSchema_links:{schema_links}\nSQL:{sql}" + + + sql_example_fewshot_prompt = self.sql_example_prompter.make_few_shot_example_prompt(few_shot_template=sql_example_template, + example_keys=sql_example_keys, + few_shot_example_meta_list=fewshot_example_list) + + new_case_template = "问题:{question}\nCurrent_date:{current_date}\nTable {table_name}\nSchema_links:{schema_links}\nSQL:" + new_case_prompt = new_case_template.format(question=user_query, current_date=data_date, table_name=domain_name, schema_links=schema_link_str) + + sql_example_prompt = instruction + '\n\n' + sql_example_fewshot_prompt + '\n\n' + new_case_prompt + + return sql_example_prompt + + def generate_sql_prompt_pool(self, user_query: str, domain_name: str, data_date: str, + schema_link_str_pool: List[str], fewshot_example_list_pool:List[List[Mapping[str, str]]])-> List[str]: + sql_prompt_pool = [] + for schema_link_str, fewshot_example_list in zip(schema_link_str_pool, fewshot_example_list_pool): + sql_prompt = self.generate_sql_prompt(user_query, domain_name, schema_link_str, data_date, fewshot_example_list) + sql_prompt_pool.append(sql_prompt) + + return sql_prompt_pool + + def self_consistency_vote(self, output_res_pool:List[str]): + output_res_counts = Counter(output_res_pool) + output_res_max = output_res_counts.most_common(1)[0][0] + total_output_num = len(output_res_pool) + + vote_percentage = {k: (v/total_output_num) for k,v in output_res_counts.items()} + + return output_res_max, vote_percentage + + def schema_linking_list_str_unify(self, schema_linking_list: List[str])-> List[str]: + schema_linking_list_unify = [] + for schema_linking_str in schema_linking_list: + schema_linking_str_unify = ','.join(sorted([item.strip() for item in schema_linking_str.strip('[]').split(',')])) + schema_linking_str_unify = f'[{schema_linking_str_unify}]' + schema_linking_list_unify.append(schema_linking_str_unify) + + return schema_linking_list_unify + + + async def generate_schema_linking_tasks(self, user_query: str, domain_name: str, + fields_list: List[str], prior_schema_links: Mapping[str,str], + fewshot_example_list_combo:List[List[Mapping[str, str]]]): + + schema_linking_prompt_pool = self.generate_schema_linking_prompt_pool(user_query, domain_name, + fields_list, prior_schema_links, + fewshot_example_list_combo) + schema_linking_output_task_pool = [self.llm._call_async(schema_linking_prompt) for schema_linking_prompt in schema_linking_prompt_pool] + schema_linking_output_res_pool = await asyncio.gather(*schema_linking_output_task_pool) + logger.debug(f'schema_linking_output_res_pool:{schema_linking_output_res_pool}') + + return schema_linking_output_res_pool + + async def generate_sql_tasks(self, user_query: str, domain_name: str, data_date: str, + schema_link_str_pool: List[str], fewshot_example_list_combo:List[List[Mapping[str, str]]]): + + sql_prompt_pool = self.generate_sql_prompt_pool(user_query, domain_name, schema_link_str_pool, data_date, fewshot_example_list_combo) + sql_output_task_pool = [self.llm._call_async(sql_prompt) for sql_prompt in sql_prompt_pool] + sql_output_res_pool = await asyncio.gather(*sql_output_task_pool) + logger.debug(f'sql_output_res_pool:{sql_output_res_pool}') + + return sql_output_res_pool + + async def tasks_run(self, user_query: str, domain_name: str, fields_list: List[str], prior_schema_links: Mapping[str,str], data_date: str, prior_exts: str): + logger.info("user_query: {}".format(user_query)) + logger.info("domain_name: {}".format(domain_name)) + logger.info("fields_list: {}".format(fields_list)) + logger.info("current_date: {}".format(data_date)) + logger.info("prior_schema_links: {}".format(prior_schema_links)) + logger.info("prior_exts: {}".format(prior_exts)) + + user_query = user_query + ' 备注:'+prior_exts + logger.info("user_query_prior_exts: {}".format(user_query)) + + fewshot_example_meta_list = self.get_examples_candidates(user_query) + fewshot_example_list_combo = self.get_fewshot_example_combos(fewshot_example_meta_list) + + schema_linking_output_candidates = await self.generate_schema_linking_tasks(user_query, domain_name, fields_list, prior_schema_links, fewshot_example_list_combo) + schema_linking_candidate_list = [schema_link_parse(schema_linking_output_candidate) for schema_linking_output_candidate in schema_linking_output_candidates] + logger.debug(f'schema_linking_candidate_list:{schema_linking_candidate_list}') + schema_linking_candidate_sorted_list = self.schema_linking_list_str_unify(schema_linking_candidate_list) + logger.debug(f'schema_linking_candidate_sorted_list:{schema_linking_candidate_sorted_list}') + + schema_linking_output_max, schema_linking_output_vote_percentage = self.self_consistency_vote(schema_linking_candidate_sorted_list) + + sql_output_candicates = await self.generate_sql_tasks(user_query, domain_name, data_date, schema_linking_candidate_list,fewshot_example_list_combo) + logger.debug(f'sql_output_candicates:{sql_output_candicates}') + sql_output_max, sql_output_vote_percentage = self.self_consistency_vote(sql_output_candicates) + + resp = dict() + resp['query'] = user_query + resp['model'] = domain_name + resp['fields'] = fields_list + resp['priorSchemaLinking'] = prior_schema_links + resp['dataDate'] = data_date + + resp['schemaLinkStr'] = schema_linking_output_max + resp['schemaLinkingWeight'] = schema_linking_output_vote_percentage + + resp['sqlOutput'] = sql_output_max + resp['sqlWeight'] = sql_output_vote_percentage + + logger.info("resp: {}".format(resp)) + + return resp + +class Text2DSLAgentWrapper(object): + def __init__(self, sql_agent:Text2DSLAgent, sql_agent_cs:Text2DSLAgentConsistency, + is_shortcut:bool, is_self_consistency:bool): + self.sql_agent = sql_agent + self.sql_agent_cs = sql_agent_cs + + self.is_shortcut = is_shortcut + self.is_self_consistency = is_self_consistency + + async def async_query2sql(self, query_text: str, + model_name: str, fields_list: List[str], + data_date: str, prior_schema_links: Mapping[str,str], prior_exts: str): + if self.is_self_consistency: + logger.info("sql wrapper: self_consistency") + resp = await self.sql_agent_cs.tasks_run(user_query=query_text, domain_name=model_name, fields_list=fields_list, prior_schema_links=prior_schema_links, data_date=data_date, prior_exts=prior_exts) + return resp + elif self.is_shortcut: + logger.info("sql wrapper: shortcut") + resp = await self.sql_agent.async_query2sql_shortcut(query_text=query_text, model_name=model_name, fields_list=fields_list, data_date=data_date, prior_schema_links=prior_schema_links, prior_exts=prior_exts) + return resp + else: + logger.info("sql wrapper: normal") + resp = await self.sql_agent.async_query2sql(query_text=query_text, model_name=model_name, fields_list=fields_list, data_date=data_date, prior_schema_links=prior_schema_links, prior_exts=prior_exts) + return resp + + def update_configs(self, is_shortcut, is_self_consistency, + sql_examplars, num_examples, num_fewshots, num_self_consistency): + self.is_shortcut = is_shortcut + self.is_self_consistency = is_self_consistency + + self.sql_agent.update_examples(sql_examplars=sql_examplars, num_fewshots=num_examples) + self.sql_agent_cs.update_examples(sql_examplars=sql_examplars, num_examples=num_examples, num_fewshots=num_fewshots, num_self_consistency=num_self_consistency) + diff --git a/chat/core/src/main/python/services_router/query2sql_service.py b/chat/core/src/main/python/services_router/query2sql_service.py index 556e9f3cc..682457f69 100644 --- a/chat/core/src/main/python/services_router/query2sql_service.py +++ b/chat/core/src/main/python/services_router/query2sql_service.py @@ -8,59 +8,81 @@ sys.path.append(os.path.dirname(os.path.abspath(__file__))) from fastapi import APIRouter, Depends, HTTPException -from services.sql.run import text2sql_agent +from services.sql.run import text2sql_agent_router router = APIRouter() -@router.post("/query2sql/") -def din_query2sql(query_body: Mapping[str, Any]): - if "queryText" not in query_body: +@router.post("/query2sql") +async def query2sql(query_body: Mapping[str, Any]): + if 'queryText' not in query_body: raise HTTPException(status_code=400, detail="query_text is not in query_body") else: - query_text = query_body["queryText"] + query_text = query_body['queryText'] - if "schema" not in query_body: + if 'schema' not in query_body: raise HTTPException(status_code=400, detail="schema is not in query_body") else: - schema = query_body["schema"] - - if "currentDate" not in query_body: + schema = query_body['schema'] + + if 'currentDate' not in query_body: raise HTTPException(status_code=400, detail="currentDate is not in query_body") else: - current_date = query_body["currentDate"] + current_date = query_body['currentDate'] - if "linking" not in query_body: - linking = None + if 'linking' not in query_body: + raise HTTPException(status_code=400, detail="linking is not in query_body") else: - linking = query_body["linking"] + linking = query_body['linking'] - resp = text2sql_agent.query2sql_run( - query_text=query_text, schema=schema, current_date=current_date, linking=linking - ) + if 'priorExts' not in query_body: + raise HTTPException(status_code=400, detail="prior_exts is not in query_body") + else: + prior_exts = query_body['priorExts'] + + model_name = schema['modelName'] + fields_list = schema['fieldNameList'] + prior_schema_links = {item['fieldValue']:item['fieldName'] for item in linking} + + resp = await text2sql_agent_router.async_query2sql(query_text=query_text, model_name=model_name, fields_list=fields_list, + data_date=current_date, prior_schema_links=prior_schema_links, prior_exts=prior_exts) return resp -@router.post("/query2sql_setting_update/") +@router.post("/query2sql_setting_update") def query2sql_setting_update(query_body: Mapping[str, Any]): - if "sqlExamplars" not in query_body: + if 'sqlExamplars' not in query_body: raise HTTPException(status_code=400, detail="sqlExamplars is not in query_body") else: - sql_examplars = query_body["sqlExamplars"] + sql_examplars = query_body['sqlExamplars'] - if "exampleNums" not in query_body: + if 'exampleNums' not in query_body: raise HTTPException(status_code=400, detail="exampleNums is not in query_body") else: - example_nums = query_body["exampleNums"] + example_nums = query_body['exampleNums'] - if "isShortcut" not in query_body: + if 'fewshotNums' not in query_body: + raise HTTPException(status_code=400, detail="fewshotNums is not in query_body") + else: + fewshot_nums = query_body['fewshotNums'] + + if 'selfConsistencyNums' not in query_body: + raise HTTPException(status_code=400, detail="selfConsistencyNums is not in query_body") + else: + self_consistency_nums = query_body['selfConsistencyNums'] + + if 'isShortcut' not in query_body: raise HTTPException(status_code=400, detail="isShortcut is not in query_body") else: - is_shortcut = query_body["isShortcut"] + is_shortcut = query_body['isShortcut'] - text2sql_agent.update_examples( - sql_examples=sql_examplars, example_nums=example_nums, is_shortcut=is_shortcut - ) + if 'isSelfConsistency' not in query_body: + raise HTTPException(status_code=400, detail="isSelfConsistency is not in query_body") + else: + is_self_consistency = query_body['isSelfConsistency'] + + text2sql_agent_router.update_configs(is_shortcut=is_shortcut, is_self_consistency=is_self_consistency, sql_examplars=sql_examplars, + num_examples=example_nums, num_fewshots=fewshot_nums, num_self_consistency=self_consistency_nums) return "success" diff --git a/chat/core/src/main/python/util/logging_utils.py b/chat/core/src/main/python/util/logging_utils.py deleted file mode 100644 index 359ab8e30..000000000 --- a/chat/core/src/main/python/util/logging_utils.py +++ /dev/null @@ -1,4 +0,0 @@ -from loguru import logger -import sys - -# logger.add(sys.stdout, format="{time:YYYY-MM-DD at HH:mm:ss} | {level} | {message}") \ No newline at end of file diff --git a/chat/core/src/main/python/util/chromadb_instance.py b/chat/core/src/main/python/utils/chromadb_utils.py similarity index 70% rename from chat/core/src/main/python/util/chromadb_instance.py rename to chat/core/src/main/python/utils/chromadb_utils.py index 4c2f6824b..fe9475029 100644 --- a/chat/core/src/main/python/util/chromadb_instance.py +++ b/chat/core/src/main/python/utils/chromadb_utils.py @@ -4,22 +4,14 @@ from typing import Any, List, Mapping, Optional, Union import chromadb from chromadb.api import Collection from chromadb.config import Settings +from chromadb.api import Collection, Documents, Embeddings import os import sys sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) sys.path.append(os.path.dirname(os.path.abspath(__file__))) -from util.logging_utils import logger - -from config.config_parse import CHROMA_DB_PERSIST_PATH - -client = chromadb.Client( - Settings( - chroma_db_impl="duckdb+parquet", - persist_directory=CHROMA_DB_PERSIST_PATH, # Optional, defaults to .chromadb/ in the current directory - ) -) +from instances.logging_instance import logger def empty_chroma_collection_2(collection:Collection): @@ -48,26 +40,30 @@ def empty_chroma_collection(collection:Collection) -> None: def add_chroma_collection(collection:Collection, queries:List[str], query_ids:List[str], - metadatas:List[Mapping[str, str]]=None + metadatas:List[Mapping[str, str]]=None, + embeddings:Embeddings=None ) -> None: collection.add(documents=queries, ids=query_ids, - metadatas=metadatas) + metadatas=metadatas, + embeddings=embeddings) def update_chroma_collection(collection:Collection, queries:List[str], query_ids:List[str], - metadatas:List[Mapping[str, str]]=None + metadatas:List[Mapping[str, str]]=None, + embeddings:Embeddings=None ) -> None: collection.update(documents=queries, ids=query_ids, - metadatas=metadatas) + metadatas=metadatas, + embeddings=embeddings) -def query_chroma_collection(collection:Collection, query_texts:List[str], +def query_chroma_collection(collection:Collection, query_texts:List[str]=None, query_embeddings:Embeddings=None, filter_condition:Mapping[str,str]=None, n_results:int=10): outer_opt = '$and' inner_opt = '$eq' @@ -81,8 +77,10 @@ def query_chroma_collection(collection:Collection, query_texts:List[str], else: outer_filter = None - print('outer_filter: ', outer_filter) - res = collection.query(query_texts=query_texts, n_results=n_results, where=outer_filter) + logger.info('outer_filter: {}'.format(outer_filter)) + + res = collection.query(query_texts=query_texts, query_embeddings=query_embeddings, + n_results=n_results, where=outer_filter) return res @@ -115,16 +113,32 @@ def parse_retrieval_chroma_collection_query(res:List[Mapping[str, Any]]): return parsed_res -def chroma_collection_query_retrieval_format(query_list:List[str], retrieval_list:List[Mapping[str, Any]]): +def chroma_collection_query_retrieval_format(query_list:List[str], query_embeddings:Embeddings ,retrieval_list:List[Mapping[str, Any]]): res = [] - for query_idx in range(0, len(query_list)): - query = query_list[query_idx] - retrieval = retrieval_list[query_idx] - res.append({ - 'query': query, - 'retrieval': retrieval - }) + if query_list is not None and query_embeddings is not None: + raise Exception("query_list and query_embeddings are not None") + if query_list is None and query_embeddings is None: + raise Exception("query_list and query_embeddings are None") + + if query_list is not None: + for query_idx in range(0, len(query_list)): + query = query_list[query_idx] + retrieval = retrieval_list[query_idx] + + res.append({ + 'query': query, + 'retrieval': retrieval + }) + else: + for query_idx in range(0, len(query_embeddings)): + query_embedding = query_embeddings[query_idx] + retrieval = retrieval_list[query_idx] + + res.append({ + 'query_embedding': query_embedding, + 'retrieval': retrieval + }) return res