mirror of
https://github.com/tencentmusic/supersonic.git
synced 2025-12-11 03:58:14 +00:00
add self-consistency feature for text2sql (#303)
This commit is contained in:
@@ -1,7 +1,14 @@
|
||||
# -*- coding:utf-8 -*-
|
||||
import os
|
||||
import configparser
|
||||
from util.logging_utils import logger
|
||||
|
||||
import os
|
||||
import sys
|
||||
sys.path.append(os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))))
|
||||
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||||
sys.path.append(os.path.dirname(os.path.abspath(__file__)))
|
||||
|
||||
from instances.logging_instance import logger
|
||||
|
||||
|
||||
def type_convert(input_str: str):
|
||||
@@ -11,10 +18,12 @@ def type_convert(input_str: str):
|
||||
return input_str
|
||||
|
||||
|
||||
PROJECT_DIR_PATH = os.path.dirname(os.path.abspath(__file__))
|
||||
|
||||
PROJECT_DIR_PATH = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
|
||||
config_dir = "config"
|
||||
CONFIG_DIR_PATH = os.path.join(PROJECT_DIR_PATH, config_dir)
|
||||
config_file = "run_config.ini"
|
||||
config_path = os.path.join(PROJECT_DIR_PATH, config_file)
|
||||
config_path = os.path.join(CONFIG_DIR_PATH, config_file)
|
||||
|
||||
config = configparser.ConfigParser()
|
||||
config.read(config_path)
|
||||
|
||||
@@ -26,9 +35,13 @@ chroma_db_section_name = "ChromaDB"
|
||||
CHROMA_DB_PERSIST_DIR = config.get(chroma_db_section_name, 'CHROMA_DB_PERSIST_DIR')
|
||||
PRESET_QUERY_COLLECTION_NAME = config.get(chroma_db_section_name, 'PRESET_QUERY_COLLECTION_NAME')
|
||||
SOLVED_QUERY_COLLECTION_NAME = config.get(chroma_db_section_name, 'SOLVED_QUERY_COLLECTION_NAME')
|
||||
TEXT2DSL_COLLECTION_NAME = config.get(chroma_db_section_name, 'TEXT2DSL_COLLECTION_NAME')
|
||||
TEXT2DSL_FEW_SHOTS_EXAMPLE_NUM = int(config.get(chroma_db_section_name, 'TEXT2DSL_FEW_SHOTS_EXAMPLE_NUM'))
|
||||
TEXT2DSLAGENT_COLLECTION_NAME = config.get(chroma_db_section_name, 'TEXT2DSLAGENT_COLLECTION_NAME')
|
||||
TEXT2DSLAGENTCS_COLLECTION_NAME = config.get(chroma_db_section_name, 'TEXT2DSLAGENTCS_COLLECTION_NAME')
|
||||
TEXT2DSL_EXAMPLE_NUM = int(config.get(chroma_db_section_name, 'TEXT2DSL_EXAMPLE_NUM'))
|
||||
TEXT2DSL_FEWSHOTS_NUM = int(config.get(chroma_db_section_name, 'TEXT2DSL_FEWSHOTS_NUM'))
|
||||
TEXT2DSL_SELF_CONSISTENCY_NUM = int(config.get(chroma_db_section_name, 'TEXT2DSL_SELF_CONSISTENCY_NUM'))
|
||||
TEXT2DSL_IS_SHORTCUT = eval(config.get(chroma_db_section_name, 'TEXT2DSL_IS_SHORTCUT'))
|
||||
TEXT2DSL_IS_SELF_CONSISTENCY = eval(config.get(chroma_db_section_name, 'TEXT2DSL_IS_SELF_CONSISTENCY'))
|
||||
CHROMA_DB_PERSIST_PATH = os.path.join(PROJECT_DIR_PATH, CHROMA_DB_PERSIST_DIR)
|
||||
|
||||
text2vec_section_name = "Text2Vec"
|
||||
@@ -44,10 +57,14 @@ for option in config.options(llm_model_section_name):
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
logger.info("PROJECT_DIR_PATH: ", PROJECT_DIR_PATH)
|
||||
logger.info("EMB_MODEL_PATH: ", HF_TEXT2VEC_MODEL_NAME)
|
||||
logger.info("CHROMA_DB_PERSIST_PATH: ", CHROMA_DB_PERSIST_PATH)
|
||||
logger.info("LLMPARSER_HOST: ", LLMPARSER_HOST)
|
||||
logger.info("LLMPARSER_PORT: ", LLMPARSER_PORT)
|
||||
logger.info("llm_config_dict: ", llm_config_dict)
|
||||
logger.info("is_shortcut: ", TEXT2DSL_IS_SHORTCUT)
|
||||
logger.info(f"PROJECT_DIR_PATH: {PROJECT_DIR_PATH}")
|
||||
logger.info(f"EMB_MODEL_PATH: {HF_TEXT2VEC_MODEL_NAME}")
|
||||
logger.info(f"CHROMA_DB_PERSIST_PATH: {CHROMA_DB_PERSIST_PATH}")
|
||||
logger.info(f"LLMPARSER_HOST: {LLMPARSER_HOST}")
|
||||
logger.info(f"LLMPARSER_PORT: {LLMPARSER_PORT}")
|
||||
logger.info(f"llm_config_dict: {llm_config_dict}")
|
||||
logger.info(f"TEXT2DSL_EXAMPLE_NUM: {TEXT2DSL_EXAMPLE_NUM}")
|
||||
logger.info(f"TEXT2DSL_FEWSHOTS_NUM: {TEXT2DSL_FEWSHOTS_NUM}")
|
||||
logger.info(f"TEXT2DSL_SELF_CONSISTENCY_NUM: {TEXT2DSL_SELF_CONSISTENCY_NUM}")
|
||||
logger.info(f"TEXT2DSL_IS_SHORTCUT: {TEXT2DSL_IS_SHORTCUT}")
|
||||
logger.info(f"TEXT2DSL_IS_SELF_CONSISTENCY: {TEXT2DSL_IS_SELF_CONSISTENCY}")
|
||||
|
||||
@@ -6,9 +6,13 @@ LLMPARSER_PORT = 9092
|
||||
CHROMA_DB_PERSIST_DIR = chm_db
|
||||
PRESET_QUERY_COLLECTION_NAME = preset_query_collection
|
||||
SOLVED_QUERY_COLLECTION_NAME = solved_query_collection
|
||||
TEXT2DSL_COLLECTION_NAME = text2dsl_collection
|
||||
TEXT2DSL_FEW_SHOTS_EXAMPLE_NUM = 15
|
||||
TEXT2DSLAGENT_COLLECTION_NAME = text2dsl_agent_collection
|
||||
TEXT2DSLAGENTCS_COLLECTION_NAME = text2dsl_agent_cs_collection
|
||||
TEXT2DSL_EXAMPLE_NUM = 15
|
||||
TEXT2DSL_FEWSHOTS_NUM = 10
|
||||
TEXT2DSL_SELF_CONSISTENCY_NUM = 5
|
||||
TEXT2DSL_IS_SHORTCUT = False
|
||||
TEXT2DSL_IS_SELF_CONSISTENCY = False
|
||||
|
||||
[Text2Vec]
|
||||
HF_TEXT2VEC_MODEL_NAME = GanymedeNil/text2vec-large-chinese
|
||||
|
||||
@@ -228,10 +228,11 @@ examplars= [
|
||||
"question":"邓梓琦在2023年1月5日之后发布的歌曲中,有哪些播放量大于500W的?",
|
||||
"prior_schema_links":"""['2312311'->MPPM歌手ID]""",
|
||||
"analysis": """让我们一步一步地思考。在问题“邓梓琦在2023年1月5日之后发布的歌曲中,有哪些播放量大于500W的?“中,我们被问:
|
||||
“歌曲中,有哪些”,所以我们需要column=[歌曲名]
|
||||
“播放量大于500W的”,所以我们需要column=[结算播放量], cell values = [5000000],所以有[结算播放量:(5000000)]
|
||||
”邓梓琦在2023年1月5日之后发布的“,所以我们需要column=[发布时间], cell values = ['2023-01-05'],所以有[发布时间:('2023-01-05')]
|
||||
”邓梓琦“,所以我们需要column=[歌手名], cell values = ['邓梓琦'],所以有[歌手名:('邓梓琦')]""",
|
||||
"schema_links":"""["结算播放量":(5000000), "发布时间":("'2023-01-05'"), "歌手名":("'邓梓琦'")]""",
|
||||
"schema_links":"""["歌曲名", "结算播放量":(5000000), "发布时间":("'2023-01-05'"), "歌手名":("'邓梓琦'")]""",
|
||||
"sql":"""select 歌曲名 from 歌曲库 where 发布时间 >= '2023-01-05' and 歌手名 = '邓梓琦' and 结算播放量 > 5000000"""
|
||||
},
|
||||
{ "current_date":"2023-09-17",
|
||||
|
||||
21
chat/core/src/main/python/instances/chromadb_instance.py
Normal file
21
chat/core/src/main/python/instances/chromadb_instance.py
Normal file
@@ -0,0 +1,21 @@
|
||||
# -*- coding:utf-8 -*-
|
||||
from typing import Any, List, Mapping, Optional, Union
|
||||
|
||||
import chromadb
|
||||
from chromadb.api import Collection
|
||||
from chromadb.config import Settings
|
||||
|
||||
import os
|
||||
import sys
|
||||
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||||
sys.path.append(os.path.dirname(os.path.abspath(__file__)))
|
||||
|
||||
from config.config_parse import CHROMA_DB_PERSIST_PATH
|
||||
|
||||
|
||||
client = chromadb.Client(
|
||||
Settings(
|
||||
chroma_db_impl="duckdb+parquet",
|
||||
persist_directory=CHROMA_DB_PERSIST_PATH, # Optional, defaults to .chromadb/ in the current directory
|
||||
)
|
||||
)
|
||||
6
chat/core/src/main/python/instances/logging_instance.py
Normal file
6
chat/core/src/main/python/instances/logging_instance.py
Normal file
@@ -0,0 +1,6 @@
|
||||
from loguru import logger
|
||||
import sys
|
||||
|
||||
logger.remove() #remove the old handler. Else, the old one will work along with the new one you've added below'
|
||||
logger.add(sys.stdout, format="{time:YYYY-MM-DD at HH:mm:ss} | {level} | {message}", level="INFO")
|
||||
|
||||
@@ -5,7 +5,7 @@ import re
|
||||
import sys
|
||||
from typing import Any, List, Mapping, Union
|
||||
|
||||
from util.logging_utils import logger
|
||||
from instances.logging_instance import logger
|
||||
|
||||
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||||
sys.path.append(os.path.dirname(os.path.abspath(__file__)))
|
||||
|
||||
@@ -12,7 +12,7 @@ from plugin_call.prompt_construct import (
|
||||
construct_task_prompt,
|
||||
plugin_selection_output_parse,
|
||||
)
|
||||
from util.llm_instance import llm
|
||||
from instances.llm_instance import llm
|
||||
|
||||
|
||||
def plugin_selection_run(
|
||||
|
||||
@@ -1,111 +0,0 @@
|
||||
# -*- coding:utf-8 -*-
|
||||
import os
|
||||
import sys
|
||||
import uuid
|
||||
from typing import Any, List, Mapping
|
||||
|
||||
from chromadb.api import Collection
|
||||
|
||||
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||||
sys.path.append(os.path.dirname(os.path.abspath(__file__)))
|
||||
|
||||
|
||||
def get_ids(documents: List[str]) -> List[str]:
|
||||
ids = []
|
||||
for doc in documents:
|
||||
ids.append(str(uuid.uuid5(uuid.NAMESPACE_URL, doc)))
|
||||
|
||||
return ids
|
||||
|
||||
|
||||
def add2preset_query_collection(
|
||||
collection: Collection, preset_queries: List[str], preset_query_ids: List[str]
|
||||
) -> None:
|
||||
|
||||
collection.add(documents=preset_queries, ids=preset_query_ids)
|
||||
|
||||
|
||||
def update_preset_query_collection(
|
||||
collection: Collection, preset_queries: List[str], preset_query_ids: List[str]
|
||||
) -> None:
|
||||
|
||||
collection.update(documents=preset_queries, ids=preset_query_ids)
|
||||
|
||||
|
||||
def query2preset_query_collection(
|
||||
collection: Collection, query_texts: List[str], n_results: int = 10
|
||||
):
|
||||
collection_cnt = collection.count()
|
||||
min_n_results = 10
|
||||
min_n_results = min(collection_cnt, min_n_results)
|
||||
|
||||
if n_results > min_n_results:
|
||||
res = collection.query(query_texts=query_texts, n_results=n_results)
|
||||
return res
|
||||
else:
|
||||
res = collection.query(query_texts=query_texts, n_results=min_n_results)
|
||||
|
||||
for _key in res.keys():
|
||||
if res[_key] is None:
|
||||
continue
|
||||
for _idx in range(0, len(query_texts)):
|
||||
res[_key][_idx] = res[_key][_idx][:n_results]
|
||||
|
||||
return res
|
||||
|
||||
|
||||
def parse_retrieval_preset_query(res: List[Mapping[str, Any]]):
|
||||
parsed_res = [[] for _ in range(0, len(res["ids"]))]
|
||||
|
||||
retrieval_ids = res["ids"]
|
||||
retrieval_distances = res["distances"]
|
||||
retrieval_sentences = res["documents"]
|
||||
|
||||
for query_idx in range(0, len(retrieval_ids)):
|
||||
id_ls = retrieval_ids[query_idx]
|
||||
distance_ls = retrieval_distances[query_idx]
|
||||
sentence_ls = retrieval_sentences[query_idx]
|
||||
|
||||
for idx in range(0, len(id_ls)):
|
||||
id = id_ls[idx]
|
||||
distance = distance_ls[idx]
|
||||
sentence = sentence_ls[idx]
|
||||
|
||||
parsed_res[query_idx].append(
|
||||
{"id": id, "distance": distance, "presetQuery": sentence}
|
||||
)
|
||||
|
||||
return parsed_res
|
||||
|
||||
|
||||
def preset_query_retrieval_format(
|
||||
query_list: List[str], retrieval_list: List[Mapping[str, Any]]
|
||||
):
|
||||
res = []
|
||||
for query_idx in range(0, len(query_list)):
|
||||
query = query_list[query_idx]
|
||||
retrieval = retrieval_list[query_idx]
|
||||
|
||||
res.append({"query": query, "retrieval": retrieval})
|
||||
|
||||
return res
|
||||
|
||||
|
||||
def empty_preset_query_collection(collection: Collection) -> None:
|
||||
collection.delete()
|
||||
|
||||
|
||||
def delete_preset_query_by_ids(
|
||||
collection: Collection, preset_query_ids: List[str]
|
||||
) -> None:
|
||||
collection.delete(ids=preset_query_ids)
|
||||
|
||||
|
||||
def get_preset_query_by_ids(collection: Collection, preset_query_ids: List[str]):
|
||||
res = collection.get(ids=preset_query_ids)
|
||||
|
||||
return res
|
||||
|
||||
|
||||
def preset_query_collection_size(collection: Collection) -> int:
|
||||
return collection.count()
|
||||
@@ -1,51 +0,0 @@
|
||||
# -*- coding:utf-8 -*-
|
||||
|
||||
import os
|
||||
import sys
|
||||
from typing import List
|
||||
|
||||
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||||
sys.path.append(os.path.dirname(os.path.abspath(__file__)))
|
||||
|
||||
from util.logging_utils import logger
|
||||
from chromadb.api import Collection
|
||||
|
||||
from preset_query_db import (
|
||||
query2preset_query_collection,
|
||||
parse_retrieval_preset_query,
|
||||
preset_query_retrieval_format,
|
||||
preset_query_collection_size,
|
||||
)
|
||||
|
||||
from util.text2vec import Text2VecEmbeddingFunction
|
||||
|
||||
from config.config_parse import PRESET_QUERY_COLLECTION_NAME
|
||||
from util.chromadb_instance import client
|
||||
|
||||
|
||||
emb_func = Text2VecEmbeddingFunction()
|
||||
|
||||
collection = client.get_or_create_collection(
|
||||
name=PRESET_QUERY_COLLECTION_NAME,
|
||||
embedding_function=emb_func,
|
||||
metadata={"hnsw:space": "cosine"},
|
||||
) # Get a collection object from an existing collection, by name. If it doesn't exist, create it.
|
||||
|
||||
logger.info("init_preset_query_collection_size: {}", preset_query_collection_size(collection))
|
||||
|
||||
|
||||
def preset_query_retrieval_run(
|
||||
collection: Collection, query_texts_list: List[str], n_results: int = 5
|
||||
):
|
||||
retrieval_res = query2preset_query_collection(
|
||||
collection=collection, query_texts=query_texts_list, n_results=n_results
|
||||
)
|
||||
|
||||
parsed_retrieval_res = parse_retrieval_preset_query(retrieval_res)
|
||||
parsed_retrieval_res_format = preset_query_retrieval_format(
|
||||
query_texts_list, parsed_retrieval_res
|
||||
)
|
||||
|
||||
logger.info("parsed_retrieval_res_format: {}", parsed_retrieval_res_format)
|
||||
|
||||
return parsed_retrieval_res_format
|
||||
@@ -0,0 +1,98 @@
|
||||
# -*- coding:utf-8 -*-
|
||||
|
||||
import os
|
||||
import sys
|
||||
import uuid
|
||||
from typing import Any, List, Mapping, Optional, Union
|
||||
|
||||
import chromadb
|
||||
from chromadb import Client
|
||||
from chromadb.config import Settings
|
||||
from chromadb.api import Collection, Documents, Embeddings
|
||||
from chromadb.api.types import CollectionMetadata
|
||||
|
||||
sys.path.append(os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))))
|
||||
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||||
sys.path.append(os.path.dirname(os.path.abspath(__file__)))
|
||||
|
||||
from instances.logging_instance import logger
|
||||
from utils.chromadb_utils import (get_chroma_collection_size, query_chroma_collection,
|
||||
parse_retrieval_chroma_collection_query, chroma_collection_query_retrieval_format,
|
||||
get_chroma_collection_by_ids, get_chroma_collection_size,
|
||||
add_chroma_collection, update_chroma_collection, delete_chroma_collection_by_ids,
|
||||
empty_chroma_collection_2)
|
||||
|
||||
from instances.text2vec import Text2VecEmbeddingFunction
|
||||
|
||||
class ChromaCollectionRetriever(object):
|
||||
def __init__(self, collection:Collection):
|
||||
self.collection = collection
|
||||
|
||||
def retrieval_query_run(self, query_texts_list:List[str]=None, query_embeddings:Embeddings=None,
|
||||
filter_condition:Mapping[str,str]=None, n_results:int=5):
|
||||
|
||||
retrieval_res = query_chroma_collection(self.collection, query_texts_list, query_embeddings,
|
||||
filter_condition, n_results)
|
||||
|
||||
parsed_retrieval_res = parse_retrieval_chroma_collection_query(retrieval_res)
|
||||
logger.debug('parsed_retrieval_res: {}', parsed_retrieval_res)
|
||||
parsed_retrieval_res_format = chroma_collection_query_retrieval_format(query_texts_list, query_embeddings, parsed_retrieval_res)
|
||||
logger.debug('parsed_retrieval_res_format: {}', parsed_retrieval_res_format)
|
||||
|
||||
return parsed_retrieval_res_format
|
||||
|
||||
def get_query_by_ids(self, query_ids:List[str]):
|
||||
queries = get_chroma_collection_by_ids(self.collection, query_ids)
|
||||
return queries
|
||||
|
||||
def get_query_size(self):
|
||||
return get_chroma_collection_size(self.collection)
|
||||
|
||||
def add_queries(self, query_text_list:List[str],
|
||||
query_id_list:List[str],
|
||||
metadatas:List[Mapping[str, str]]=None,
|
||||
embeddings:Embeddings=None):
|
||||
add_chroma_collection(self.collection, query_text_list, query_id_list, metadatas, embeddings)
|
||||
return True
|
||||
|
||||
def update_queries(self, query_text_list:List[str],
|
||||
query_id_list:List[str],
|
||||
metadatas:List[Mapping[str, str]]=None,
|
||||
embeddings:Embeddings=None):
|
||||
update_chroma_collection(self.collection, query_text_list, query_id_list, metadatas, embeddings)
|
||||
return True
|
||||
|
||||
def delete_queries_by_ids(self, query_ids:List[str]):
|
||||
delete_chroma_collection_by_ids(self.collection, query_ids)
|
||||
return True
|
||||
|
||||
def empty_query_collection(self):
|
||||
self.collection = empty_chroma_collection_2(self.collection)
|
||||
|
||||
return True
|
||||
|
||||
class CollectionManager(object):
|
||||
def __init__(self, chroma_client:Client, embedding_func: Text2VecEmbeddingFunction, collection_meta: Optional[CollectionMetadata] = None):
|
||||
self.chroma_client = chroma_client
|
||||
self.embedding_func = embedding_func
|
||||
self.collection_meta = collection_meta
|
||||
|
||||
def list_collections(self):
|
||||
collection_list = self.chroma_client.list_collections()
|
||||
return collection_list
|
||||
|
||||
def get_collection(self, collection_name:str):
|
||||
collection = self.chroma_client.get_collection(name=collection_name, embedding_function=self.embedding_func)
|
||||
return collection
|
||||
|
||||
def create_collection(self, collection_name:str):
|
||||
collection = self.chroma_client.create_collection(name=collection_name, embedding_function=self.embedding_func, metadata=self.collection_meta)
|
||||
return collection
|
||||
|
||||
def get_or_create_collection(self, collection_name:str):
|
||||
collection = self.chroma_client.get_or_create_collection(name=collection_name, embedding_function=self.embedding_func, metadata=self.collection_meta)
|
||||
return collection
|
||||
|
||||
def delete_collection(self, collection_name:str):
|
||||
self.chroma_client.delete_collection(collection_name)
|
||||
return True
|
||||
@@ -8,80 +8,30 @@ from typing import Any, List, Mapping, Optional, Union
|
||||
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||||
sys.path.append(os.path.dirname(os.path.abspath(__file__)))
|
||||
|
||||
from util.logging_utils import logger
|
||||
from instances.logging_instance import logger
|
||||
|
||||
import chromadb
|
||||
from chromadb.config import Settings
|
||||
from chromadb.api import Collection, Documents, Embeddings
|
||||
|
||||
from util.text2vec import Text2VecEmbeddingFunction
|
||||
from instances.text2vec import Text2VecEmbeddingFunction
|
||||
from instances.chromadb_instance import client
|
||||
|
||||
from config.config_parse import SOLVED_QUERY_COLLECTION_NAME, PRESET_QUERY_COLLECTION_NAME
|
||||
from util.chromadb_instance import (client,
|
||||
get_chroma_collection_size, query_chroma_collection,
|
||||
parse_retrieval_chroma_collection_query, chroma_collection_query_retrieval_format,
|
||||
get_chroma_collection_by_ids, get_chroma_collection_size,
|
||||
add_chroma_collection, update_chroma_collection, delete_chroma_collection_by_ids,
|
||||
empty_chroma_collection_2)
|
||||
from retriever import ChromaCollectionRetriever, CollectionManager
|
||||
|
||||
|
||||
emb_func = Text2VecEmbeddingFunction()
|
||||
|
||||
solved_query_collection = client.get_or_create_collection(name=SOLVED_QUERY_COLLECTION_NAME,
|
||||
embedding_function=emb_func,
|
||||
metadata={"hnsw:space": "cosine"}
|
||||
) # Get a collection object from an existing collection, by name. If it doesn't exist, create it.
|
||||
logger.info("init_solved_query_collection_size: {}", get_chroma_collection_size(solved_query_collection))
|
||||
collection_manager = CollectionManager(chroma_client=client, embedding_func=emb_func
|
||||
,collection_meta={"hnsw:space": "cosine"})
|
||||
|
||||
solved_query_collection = collection_manager.get_or_create_collection(collection_name=SOLVED_QUERY_COLLECTION_NAME)
|
||||
preset_query_collection = collection_manager.get_or_create_collection(collection_name=PRESET_QUERY_COLLECTION_NAME)
|
||||
|
||||
preset_query_collection = client.get_or_create_collection(name=PRESET_QUERY_COLLECTION_NAME,
|
||||
embedding_function=emb_func,
|
||||
metadata={"hnsw:space": "cosine"}
|
||||
)
|
||||
logger.info("init_preset_query_collection_size: {}", get_chroma_collection_size(preset_query_collection))
|
||||
|
||||
class ChromaCollectionRetriever(object):
|
||||
def __init__(self, collection:Collection):
|
||||
self.collection = collection
|
||||
|
||||
def retrieval_query_run(self, query_texts_list:List[str],
|
||||
filter_condition:Mapping[str,str]=None, n_results:int=5):
|
||||
|
||||
retrieval_res = query_chroma_collection(self.collection, query_texts_list,
|
||||
filter_condition, n_results)
|
||||
|
||||
parsed_retrieval_res = parse_retrieval_chroma_collection_query(retrieval_res)
|
||||
parsed_retrieval_res_format = chroma_collection_query_retrieval_format(query_texts_list, parsed_retrieval_res)
|
||||
|
||||
logger.info('parsed_retrieval_res_format: {}', parsed_retrieval_res_format)
|
||||
|
||||
return parsed_retrieval_res_format
|
||||
|
||||
def get_query_by_ids(self, query_ids:List[str]):
|
||||
queries = get_chroma_collection_by_ids(self.collection, query_ids)
|
||||
return queries
|
||||
|
||||
def get_query_size(self):
|
||||
return get_chroma_collection_size(self.collection)
|
||||
|
||||
def add_queries(self, query_text_list:List[str],
|
||||
query_id_list:List[str], metadatas:List[Mapping[str, str]]=None):
|
||||
add_chroma_collection(self.collection, query_text_list, query_id_list, metadatas)
|
||||
return True
|
||||
|
||||
def update_queries(self, query_text_list:List[str],
|
||||
query_id_list:List[str], metadatas:List[Mapping[str, str]]=None):
|
||||
update_chroma_collection(self.collection, query_text_list, query_id_list, metadatas)
|
||||
return True
|
||||
|
||||
def delete_queries_by_ids(self, query_ids:List[str]):
|
||||
delete_chroma_collection_by_ids(self.collection, query_ids)
|
||||
return True
|
||||
|
||||
def empty_query_collection(self):
|
||||
self.collection = empty_chroma_collection_2(self.collection)
|
||||
|
||||
return True
|
||||
|
||||
|
||||
solved_query_retriever = ChromaCollectionRetriever(solved_query_collection)
|
||||
preset_query_retriever = ChromaCollectionRetriever(preset_query_collection)
|
||||
|
||||
logger.info("init_solved_query_collection_size: {}".format(solved_query_retriever.get_query_size()))
|
||||
logger.info("init_preset_query_collection_size: {}".format(preset_query_retriever.get_query_size()))
|
||||
|
||||
@@ -2,87 +2,64 @@
|
||||
import os
|
||||
import sys
|
||||
from typing import List, Mapping
|
||||
from chromadb.api import Collection
|
||||
|
||||
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||||
sys.path.append(os.path.dirname(os.path.abspath(__file__)))
|
||||
|
||||
from util.logging_utils import logger
|
||||
from instances.logging_instance import logger
|
||||
from services.query_retrieval.retriever import ChromaCollectionRetriever
|
||||
|
||||
from langchain.vectorstores import Chroma
|
||||
from langchain.prompts.example_selector import SemanticSimilarityExampleSelector
|
||||
class FewShotPromptTemplate2(object):
|
||||
def __init__(self, collection:Collection, few_shot_examples:List[Mapping[str, str]],
|
||||
retrieval_key:str, few_shot_seperator:str = "\n\n") -> None:
|
||||
self.collection = collection
|
||||
self.few_shot_retriever = ChromaCollectionRetriever(self.collection)
|
||||
|
||||
from few_shot_example.sql_exampler import examplars as sql_examplars
|
||||
from util.text2vec import hg_embedding
|
||||
from util.chromadb_instance import client as chromadb_client, empty_chroma_collection_2
|
||||
from config.config_parse import TEXT2DSL_COLLECTION_NAME, TEXT2DSL_FEW_SHOTS_EXAMPLE_NUM
|
||||
self.few_shot_examples = few_shot_examples
|
||||
self.retrieval_key = retrieval_key
|
||||
|
||||
self.few_shot_seperator = few_shot_seperator
|
||||
|
||||
def reload_sql_example_collection(
|
||||
vectorstore: Chroma,
|
||||
sql_examplars: List[Mapping[str, str]],
|
||||
sql_example_selector: SemanticSimilarityExampleSelector,
|
||||
example_nums: int,
|
||||
):
|
||||
logger.info("original sql_examples_collection size: {}", vectorstore._collection.count())
|
||||
new_collection = empty_chroma_collection_2(collection=vectorstore._collection)
|
||||
vectorstore._collection = new_collection
|
||||
def add_few_shot_example(self, examples: List[Mapping[str, str]])-> None:
|
||||
query_text_list = []
|
||||
query_id_list = []
|
||||
for idx, example in enumerate(examples):
|
||||
query_text_list.append(example[self.retrieval_key])
|
||||
query_id_list.append(str(idx))
|
||||
|
||||
logger.info("emptied sql_examples_collection size: {}", vectorstore._collection.count())
|
||||
self.few_shot_retriever.add_queries(query_text_list=query_text_list, query_id_list=query_id_list, metadatas=examples)
|
||||
|
||||
sql_example_selector = SemanticSimilarityExampleSelector(
|
||||
vectorstore=sql_examples_vectorstore,
|
||||
k=example_nums,
|
||||
input_keys=["question"],
|
||||
example_keys=[
|
||||
"table_name",
|
||||
"fields_list",
|
||||
"prior_schema_links",
|
||||
"question",
|
||||
"analysis",
|
||||
"schema_links",
|
||||
"current_date",
|
||||
"sql",
|
||||
],
|
||||
)
|
||||
def reload_few_shot_example(self, examples: List[Mapping[str, str]])-> None:
|
||||
logger.info(f"original sql_examples_collection size: {self.few_shot_retriever.get_query_size()}")
|
||||
|
||||
for example in sql_examplars:
|
||||
sql_example_selector.add_example(example)
|
||||
self.few_shot_retriever.empty_query_collection()
|
||||
logger.info(f"emptied sql_examples_collection size: {self.few_shot_retriever.get_query_size()}")
|
||||
|
||||
logger.info("reloaded sql_examples_collection size: {}", vectorstore._collection.count())
|
||||
self.add_few_shot_example(examples=examples)
|
||||
logger.info(f"reloaded sql_examples_collection size: {self.few_shot_retriever.get_query_size()}")
|
||||
|
||||
return vectorstore, sql_example_selector
|
||||
def _sub_dict(self, d:Mapping[str, str], keys:List[str])-> Mapping[str, str]:
|
||||
return {k:d[k] for k in keys if k in d}
|
||||
|
||||
def retrieve_few_shot_example(self, query_text: str, retrieval_num: int)-> List[Mapping[str, str]]:
|
||||
query_text_list = [query_text]
|
||||
retrieval_res_list = self.few_shot_retriever.retrieval_query_run(query_texts_list=query_text_list,
|
||||
filter_condition=None, n_results=retrieval_num)
|
||||
retrieval_res_unit_list = retrieval_res_list[0]['retrieval']
|
||||
|
||||
sql_examples_vectorstore = Chroma(
|
||||
collection_name=TEXT2DSL_COLLECTION_NAME,
|
||||
embedding_function=hg_embedding,
|
||||
client=chromadb_client,
|
||||
)
|
||||
return retrieval_res_unit_list
|
||||
|
||||
example_nums = TEXT2DSL_FEW_SHOTS_EXAMPLE_NUM
|
||||
def make_few_shot_example_prompt(self, few_shot_template: str, example_keys: List[str],
|
||||
few_shot_example_meta_list: List[Mapping[str, str]])-> str:
|
||||
few_shot_example_str_unit_list = []
|
||||
|
||||
sql_example_selector = SemanticSimilarityExampleSelector(
|
||||
vectorstore=sql_examples_vectorstore,
|
||||
k=example_nums,
|
||||
input_keys=["question"],
|
||||
example_keys=[
|
||||
"table_name",
|
||||
"fields_list",
|
||||
"prior_schema_links",
|
||||
"question",
|
||||
"analysis",
|
||||
"schema_links",
|
||||
"current_date",
|
||||
"sql",
|
||||
],
|
||||
)
|
||||
retrieval_metas_list = [self._sub_dict(few_shot_example_meta['metadata'], example_keys) for few_shot_example_meta in few_shot_example_meta_list]
|
||||
|
||||
if sql_examples_vectorstore._collection.count() > 0:
|
||||
logger.info("examples already in sql_vectorstore")
|
||||
logger.info("init sql_vectorstore size: {}", sql_examples_vectorstore._collection.count())
|
||||
for meta in retrieval_metas_list:
|
||||
few_shot_example_str_unit_list.append(few_shot_template.format(**meta))
|
||||
|
||||
few_shot_example_str = self.few_shot_seperator.join(few_shot_example_str_unit_list)
|
||||
|
||||
return few_shot_example_str
|
||||
|
||||
logger.info("sql_examplars size: {}", len(sql_examplars))
|
||||
sql_examples_vectorstore, sql_example_selector = reload_sql_example_collection(
|
||||
sql_examples_vectorstore, sql_examplars, sql_example_selector, example_nums
|
||||
)
|
||||
logger.info("added sql_vectorstore size: {}", sql_examples_vectorstore._collection.count())
|
||||
|
||||
@@ -6,41 +6,54 @@ from typing import List, Mapping
|
||||
|
||||
import requests
|
||||
|
||||
sys.path.append(os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))))
|
||||
sys.path.append(os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))))
|
||||
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||||
sys.path.append(os.path.dirname(os.path.abspath(__file__)))
|
||||
|
||||
from config.config_parse import (TEXT2DSL_FEW_SHOTS_EXAMPLE_NUM, TEXT2DSL_IS_SHORTCUT,
|
||||
LLMPARSER_HOST, LLMPARSER_PORT)
|
||||
from instances.logging_instance import logger
|
||||
|
||||
from config.config_parse import (TEXT2DSL_EXAMPLE_NUM, TEXT2DSL_FEWSHOTS_NUM, TEXT2DSL_SELF_CONSISTENCY_NUM,
|
||||
LLMPARSER_HOST, LLMPARSER_PORT, TEXT2DSL_IS_SHORTCUT, TEXT2DSL_IS_SELF_CONSISTENCY)
|
||||
from few_shot_example.sql_exampler import examplars as sql_examplars
|
||||
from util.logging_utils import logger
|
||||
|
||||
|
||||
def text2dsl_setting_update(
|
||||
llm_parser_host: str,
|
||||
llm_parser_port: str,
|
||||
sql_examplars: List[Mapping[str, str]],
|
||||
example_nums: int,
|
||||
is_shortcut: bool,
|
||||
):
|
||||
def text2sql_agent_setting_update(llm_host:str, llm_port:str,
|
||||
sql_examplars:List[Mapping[str, str]], example_nums:int):
|
||||
|
||||
url = f"http://{llm_parser_host}:{llm_parser_port}/query2sql_setting_update/"
|
||||
logger.info("url: {}", url)
|
||||
payload = {
|
||||
"sqlExamplars": sql_examplars,
|
||||
"exampleNums": example_nums,
|
||||
"isShortcut": is_shortcut,
|
||||
}
|
||||
headers = {"content-type": "application/json"}
|
||||
url = f"http://{llm_host}:{llm_port}/text2sql_agent_setting_update/"
|
||||
payload = {"sqlExamplars":sql_examplars, "exampleNums":example_nums}
|
||||
headers = {'content-type': 'application/json'}
|
||||
response = requests.post(url, data=json.dumps(payload), headers=headers)
|
||||
logger.info(response.text)
|
||||
|
||||
|
||||
def text2dsl_agent_cs_setting_update(llm_host:str, llm_port:str,
|
||||
sql_examplars:List[Mapping[str, str]], example_nums:int, fewshot_nums:int, self_consistency_nums:int):
|
||||
|
||||
url = f"http://{llm_host}:{llm_port}/texg2sqt_cs_agent_setting_update/"
|
||||
payload = {"sqlExamplars":sql_examplars,
|
||||
"exampleNums":example_nums, "fewshotNums":fewshot_nums, "selfConsistencyNums":self_consistency_nums}
|
||||
headers = {'content-type': 'application/json'}
|
||||
response = requests.post(url, data=json.dumps(payload), headers=headers)
|
||||
logger.info(response.text)
|
||||
|
||||
|
||||
def text2dsl_agent_wrapper_setting_update(llm_host:str, llm_port:str,
|
||||
is_shortcut:bool, is_self_consistency:bool,
|
||||
sql_examplars:List[Mapping[str, str]], example_nums:int, fewshot_nums:int, self_consistency_nums:int):
|
||||
|
||||
url = f"http://{llm_host}:{llm_port}/query2sql_setting_update/"
|
||||
payload = {"isShortcut":is_shortcut, "isSelfConsistency":is_self_consistency,
|
||||
"sqlExamplars":sql_examplars,
|
||||
"exampleNums":example_nums, "fewshotNums":fewshot_nums, "selfConsistencyNums":self_consistency_nums}
|
||||
headers = {'content-type': 'application/json'}
|
||||
response = requests.post(url, data=json.dumps(payload), headers=headers)
|
||||
logger.info(response.text)
|
||||
|
||||
if __name__ == "__main__":
|
||||
text2dsl_setting_update(
|
||||
LLMPARSER_HOST,
|
||||
LLMPARSER_PORT,
|
||||
sql_examplars,
|
||||
TEXT2DSL_FEW_SHOTS_EXAMPLE_NUM,
|
||||
TEXT2DSL_IS_SHORTCUT,
|
||||
)
|
||||
text2dsl_agent_wrapper_setting_update(LLMPARSER_HOST,LLMPARSER_PORT,
|
||||
TEXT2DSL_IS_SHORTCUT, TEXT2DSL_IS_SELF_CONSISTENCY,
|
||||
sql_examplars, TEXT2DSL_EXAMPLE_NUM, TEXT2DSL_FEWSHOTS_NUM, TEXT2DSL_SELF_CONSISTENCY_NUM)
|
||||
|
||||
|
||||
|
||||
@@ -8,16 +8,14 @@ sys.path.append(os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(
|
||||
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||||
sys.path.append(os.path.dirname(os.path.abspath(__file__)))
|
||||
|
||||
from util.logging_utils import logger
|
||||
from instances.logging_instance import logger
|
||||
|
||||
|
||||
def schema_link_parse(schema_link_output):
|
||||
try:
|
||||
schema_link_output = schema_link_output.strip()
|
||||
pattern = r"Schema_links:(.*)"
|
||||
schema_link_output = re.findall(pattern, schema_link_output, re.DOTALL)[
|
||||
0
|
||||
].strip()
|
||||
schema_link_output = re.findall(pattern, schema_link_output, re.DOTALL)[0].strip()
|
||||
except Exception as e:
|
||||
logger.exception(e)
|
||||
schema_link_output = None
|
||||
|
||||
@@ -1,166 +0,0 @@
|
||||
# -*- coding:utf-8 -*-
|
||||
import os
|
||||
import sys
|
||||
from typing import List, Mapping
|
||||
|
||||
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||||
sys.path.append(os.path.dirname(os.path.abspath(__file__)))
|
||||
|
||||
from util.logging_utils import logger
|
||||
|
||||
from langchain.prompts import PromptTemplate
|
||||
from langchain.prompts.few_shot import FewShotPromptTemplate
|
||||
from langchain.prompts.example_selector import SemanticSimilarityExampleSelector
|
||||
|
||||
|
||||
def schema_linking_exampler(
|
||||
user_query: str,
|
||||
domain_name: str,
|
||||
fields_list: List[str],
|
||||
prior_schema_links: Mapping[str, str],
|
||||
example_selector: SemanticSimilarityExampleSelector,
|
||||
) -> str:
|
||||
|
||||
prior_schema_links_str = (
|
||||
"["
|
||||
+ ",".join(["""'{}'->{}""".format(k, v) for k, v in prior_schema_links.items()])
|
||||
+ "]"
|
||||
)
|
||||
|
||||
example_prompt_template = PromptTemplate(
|
||||
input_variables=[
|
||||
"table_name",
|
||||
"fields_list",
|
||||
"prior_schema_links",
|
||||
"question",
|
||||
"analysis",
|
||||
"schema_links",
|
||||
],
|
||||
template="Table {table_name}, columns = {fields_list}, prior_schema_links = {prior_schema_links}\n问题:{question}\n分析:{analysis} 所以Schema_links是:\nSchema_links:{schema_links}",
|
||||
)
|
||||
|
||||
instruction = "# 根据数据库的表结构,参考先验信息,找出为每个问题生成SQL查询语句的schema_links"
|
||||
|
||||
schema_linking_prompt = "Table {table_name}, columns = {fields_list}, prior_schema_links = {prior_schema_links}\n问题:{question}\n分析: 让我们一步一步地思考。"
|
||||
|
||||
schema_linking_example_prompt_template = FewShotPromptTemplate(
|
||||
example_selector=example_selector,
|
||||
example_prompt=example_prompt_template,
|
||||
example_separator="\n\n",
|
||||
prefix=instruction,
|
||||
input_variables=["table_name", "fields_list", "prior_schema_links", "question"],
|
||||
suffix=schema_linking_prompt,
|
||||
)
|
||||
|
||||
schema_linking_example_prompt = schema_linking_example_prompt_template.format(
|
||||
table_name=domain_name,
|
||||
fields_list=fields_list,
|
||||
prior_schema_links=prior_schema_links_str,
|
||||
question=user_query,
|
||||
)
|
||||
|
||||
return schema_linking_example_prompt
|
||||
|
||||
|
||||
def sql_exampler(
|
||||
user_query: str,
|
||||
domain_name: str,
|
||||
schema_link_str: str,
|
||||
data_date: str,
|
||||
example_selector: SemanticSimilarityExampleSelector,
|
||||
) -> str:
|
||||
|
||||
instruction = "# 根据schema_links为每个问题生成SQL查询语句"
|
||||
|
||||
sql_example_prompt_template = PromptTemplate(
|
||||
input_variables=[
|
||||
"question",
|
||||
"current_date",
|
||||
"table_name",
|
||||
"schema_links",
|
||||
"sql",
|
||||
],
|
||||
template="问题:{question}\nCurrent_date:{current_date}\nTable {table_name}\nSchema_links:{schema_links}\nSQL:{sql}",
|
||||
)
|
||||
|
||||
sql_prompt = "问题:{question}\nCurrent_date:{current_date}\nTable {table_name}\nSchema_links:{schema_links}\nSQL:"
|
||||
|
||||
sql_example_prompt_template = FewShotPromptTemplate(
|
||||
example_selector=example_selector,
|
||||
example_prompt=sql_example_prompt_template,
|
||||
example_separator="\n\n",
|
||||
prefix=instruction,
|
||||
input_variables=["question", "current_date", "table_name", "schema_links"],
|
||||
suffix=sql_prompt,
|
||||
)
|
||||
|
||||
sql_example_prompt = sql_example_prompt_template.format(
|
||||
question=user_query,
|
||||
current_date=data_date,
|
||||
table_name=domain_name,
|
||||
schema_links=schema_link_str,
|
||||
)
|
||||
|
||||
return sql_example_prompt
|
||||
|
||||
|
||||
def schema_linking_sql_combo_examplar(
|
||||
user_query: str,
|
||||
domain_name: str,
|
||||
data_date: str,
|
||||
fields_list: List[str],
|
||||
prior_schema_links: Mapping[str, str],
|
||||
example_selector: SemanticSimilarityExampleSelector,
|
||||
) -> str:
|
||||
|
||||
prior_schema_links_str = (
|
||||
"["
|
||||
+ ",".join(["""'{}'->{}""".format(k, v) for k, v in prior_schema_links.items()])
|
||||
+ "]"
|
||||
)
|
||||
|
||||
example_prompt_template = PromptTemplate(
|
||||
input_variables=[
|
||||
"table_name",
|
||||
"fields_list",
|
||||
"prior_schema_links",
|
||||
"current_date",
|
||||
"question",
|
||||
"analysis",
|
||||
"schema_links",
|
||||
"sql",
|
||||
],
|
||||
template="Table {table_name}, columns = {fields_list}, prior_schema_links = {prior_schema_links}\nCurrent_date:{current_date}\n问题:{question}\n分析:{analysis} 所以Schema_links是:\nSchema_links:{schema_links}\nSQL:{sql}",
|
||||
)
|
||||
|
||||
instruction = (
|
||||
"# 根据数据库的表结构,参考先验信息,找出为每个问题生成SQL查询语句的schema_links,再根据schema_links为每个问题生成SQL查询语句"
|
||||
)
|
||||
|
||||
schema_linking_sql_combo_prompt = "Table {table_name}, columns = {fields_list}, prior_schema_links = {prior_schema_links}\nCurrent_date:{current_date}\n问题:{question}\n分析: 让我们一步一步地思考。"
|
||||
|
||||
schema_linking_sql_combo_example_prompt_template = FewShotPromptTemplate(
|
||||
example_selector=example_selector,
|
||||
example_prompt=example_prompt_template,
|
||||
example_separator="\n\n",
|
||||
prefix=instruction,
|
||||
input_variables=[
|
||||
"table_name",
|
||||
"fields_list",
|
||||
"prior_schema_links",
|
||||
"current_date",
|
||||
"question",
|
||||
],
|
||||
suffix=schema_linking_sql_combo_prompt,
|
||||
)
|
||||
|
||||
schema_linking_sql_combo_example_prompt = (
|
||||
schema_linking_sql_combo_example_prompt_template.format(
|
||||
table_name=domain_name,
|
||||
fields_list=fields_list,
|
||||
prior_schema_links=prior_schema_links_str,
|
||||
current_date=data_date,
|
||||
question=user_query,
|
||||
)
|
||||
)
|
||||
return schema_linking_sql_combo_example_prompt
|
||||
@@ -1,188 +1,56 @@
|
||||
# -*- coding:utf-8 -*-
|
||||
|
||||
import asyncio
|
||||
|
||||
import os
|
||||
import sys
|
||||
from typing import List, Union, Mapping
|
||||
|
||||
sys.path.append(os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))))
|
||||
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||||
sys.path.append(os.path.dirname(os.path.abspath(__file__)))
|
||||
|
||||
from util.logging_utils import logger
|
||||
from sql.constructor import FewShotPromptTemplate2
|
||||
from sql.sql_agent import Text2DSLAgent, Text2DSLAgentConsistency, Text2DSLAgentWrapper
|
||||
|
||||
from sql.prompt_maker import (
|
||||
schema_linking_exampler,
|
||||
sql_exampler,
|
||||
schema_linking_sql_combo_examplar,
|
||||
)
|
||||
from sql.constructor import (
|
||||
sql_examples_vectorstore,
|
||||
sql_example_selector,
|
||||
reload_sql_example_collection,
|
||||
)
|
||||
from sql.output_parser import (
|
||||
schema_link_parse,
|
||||
combo_schema_link_parse,
|
||||
combo_sql_parse,
|
||||
)
|
||||
from instances.llm_instance import llm
|
||||
from instances.text2vec import Text2VecEmbeddingFunction
|
||||
from instances.chromadb_instance import client
|
||||
from instances.logging_instance import logger
|
||||
|
||||
from util.llm_instance import llm
|
||||
from config.config_parse import TEXT2DSL_IS_SHORTCUT
|
||||
from few_shot_example.sql_exampler import examplars as sql_examplars
|
||||
from config.config_parse import (TEXT2DSLAGENT_COLLECTION_NAME, TEXT2DSLAGENTCS_COLLECTION_NAME,
|
||||
TEXT2DSL_EXAMPLE_NUM, TEXT2DSL_FEWSHOTS_NUM, TEXT2DSL_SELF_CONSISTENCY_NUM,
|
||||
TEXT2DSL_IS_SHORTCUT, TEXT2DSL_IS_SELF_CONSISTENCY)
|
||||
|
||||
|
||||
class Text2DSLAgent(object):
|
||||
def __init__(self):
|
||||
self.schema_linking_exampler = schema_linking_exampler
|
||||
self.sql_exampler = sql_exampler
|
||||
emb_func = Text2VecEmbeddingFunction()
|
||||
text2dsl_agent_collection = client.get_or_create_collection(name=TEXT2DSLAGENT_COLLECTION_NAME,
|
||||
embedding_function=emb_func,
|
||||
metadata={"hnsw:space": "cosine"})
|
||||
text2dsl_agentcs_collection = client.get_or_create_collection(name=TEXT2DSLAGENTCS_COLLECTION_NAME,
|
||||
embedding_function=emb_func,
|
||||
metadata={"hnsw:space": "cosine"})
|
||||
|
||||
self.schema_linking_sql_combo_exampler = schema_linking_sql_combo_examplar
|
||||
text2dsl_agent_example_prompter = FewShotPromptTemplate2(collection=text2dsl_agent_collection,
|
||||
few_shot_examples=sql_examplars,
|
||||
retrieval_key="question",
|
||||
few_shot_seperator='\n\n')
|
||||
|
||||
self.sql_examples_vectorstore = sql_examples_vectorstore
|
||||
self.sql_example_selector = sql_example_selector
|
||||
text2dsl_agentcs_example_prompter = FewShotPromptTemplate2(collection=text2dsl_agentcs_collection,
|
||||
few_shot_examples=sql_examplars,
|
||||
retrieval_key="question",
|
||||
few_shot_seperator='\n\n')
|
||||
|
||||
self.schema_link_parse = schema_link_parse
|
||||
self.combo_schema_link_parse = combo_schema_link_parse
|
||||
self.combo_sql_parse = combo_sql_parse
|
||||
text2sql_agent = Text2DSLAgent(num_fewshots=TEXT2DSL_EXAMPLE_NUM,
|
||||
sql_example_prompter=text2dsl_agent_example_prompter, llm=llm)
|
||||
|
||||
self.llm = llm
|
||||
text2sql_cs_agent = Text2DSLAgentConsistency(num_fewshots=TEXT2DSL_FEWSHOTS_NUM, num_examples=TEXT2DSL_EXAMPLE_NUM, num_self_consistency=TEXT2DSL_SELF_CONSISTENCY_NUM,
|
||||
sql_example_prompter=text2dsl_agentcs_example_prompter, llm=llm)
|
||||
|
||||
self.is_shortcut = TEXT2DSL_IS_SHORTCUT
|
||||
text2sql_agent.update_examples(sql_examplars, TEXT2DSL_EXAMPLE_NUM)
|
||||
|
||||
def update_examples(self, sql_examples, example_nums, is_shortcut):
|
||||
(
|
||||
self.sql_examples_vectorstore,
|
||||
self.sql_example_selector,
|
||||
) = reload_sql_example_collection(
|
||||
self.sql_examples_vectorstore,
|
||||
sql_examples,
|
||||
self.sql_example_selector,
|
||||
example_nums,
|
||||
)
|
||||
self.is_shortcut = is_shortcut
|
||||
|
||||
def query2sql(
|
||||
self,
|
||||
query_text: str,
|
||||
schema: Union[dict, None] = None,
|
||||
current_date: str = None,
|
||||
linking: Union[List[Mapping[str, str]], None] = None,
|
||||
):
|
||||
|
||||
logger.info("query_text: {}".format(query_text))
|
||||
logger.info("schema: {}".format(schema))
|
||||
logger.info("current_date: {}".format(current_date))
|
||||
logger.info("prior_schema_links: {}".format(linking))
|
||||
|
||||
if linking is not None:
|
||||
prior_schema_links = {
|
||||
item["fieldValue"]: item["fieldName"] for item in linking
|
||||
}
|
||||
else:
|
||||
prior_schema_links = {}
|
||||
|
||||
model_name = schema["modelName"]
|
||||
fields_list = schema["fieldNameList"]
|
||||
|
||||
schema_linking_prompt = self.schema_linking_exampler(
|
||||
query_text,
|
||||
model_name,
|
||||
fields_list,
|
||||
prior_schema_links,
|
||||
self.sql_example_selector,
|
||||
)
|
||||
logger.info("schema_linking_prompt-> {}".format(schema_linking_prompt))
|
||||
schema_link_output = self.llm(schema_linking_prompt)
|
||||
schema_link_str = self.schema_link_parse(schema_link_output)
|
||||
|
||||
sql_prompt = self.sql_exampler(
|
||||
query_text,
|
||||
model_name,
|
||||
schema_link_str,
|
||||
current_date,
|
||||
self.sql_example_selector,
|
||||
)
|
||||
logger.info("sql_prompt-> {}".format(sql_prompt))
|
||||
sql_output = self.llm(sql_prompt)
|
||||
|
||||
resp = dict()
|
||||
resp["query"] = query_text
|
||||
resp["model"] = model_name
|
||||
resp["fields"] = fields_list
|
||||
resp["priorSchemaLinking"] = linking
|
||||
resp["dataDate"] = current_date
|
||||
|
||||
resp["analysisOutput"] = schema_link_output
|
||||
resp["schemaLinkStr"] = schema_link_str
|
||||
|
||||
resp["sqlOutput"] = sql_output
|
||||
|
||||
logger.info("resp: {}".format(resp))
|
||||
|
||||
return resp
|
||||
|
||||
def query2sqlcombo(
|
||||
self,
|
||||
query_text: str,
|
||||
schema: Union[dict, None] = None,
|
||||
current_date: str = None,
|
||||
linking: Union[List[Mapping[str, str]], None] = None,
|
||||
):
|
||||
|
||||
logger.info("query_text: {}".format(query_text))
|
||||
logger.info("schema: {}".format(schema))
|
||||
logger.info("current_date: {}".format(current_date))
|
||||
logger.info("prior_schema_links: {}".format(linking))
|
||||
|
||||
if linking is not None:
|
||||
prior_schema_links = {
|
||||
item["fieldValue"]: item["fieldName"] for item in linking
|
||||
}
|
||||
else:
|
||||
prior_schema_links = {}
|
||||
|
||||
model_name = schema["modelName"]
|
||||
fields_list = schema["fieldNameList"]
|
||||
|
||||
schema_linking_sql_combo_prompt = self.schema_linking_sql_combo_exampler(
|
||||
query_text,
|
||||
model_name,
|
||||
current_date,
|
||||
fields_list,
|
||||
prior_schema_links,
|
||||
self.sql_example_selector,
|
||||
)
|
||||
logger.info("schema_linking_sql_combo_prompt-> {}".format(schema_linking_sql_combo_prompt))
|
||||
schema_linking_sql_combo_output = self.llm(schema_linking_sql_combo_prompt)
|
||||
|
||||
schema_linking_str = self.combo_schema_link_parse(
|
||||
schema_linking_sql_combo_output
|
||||
)
|
||||
sql_str = self.combo_sql_parse(schema_linking_sql_combo_output)
|
||||
|
||||
resp = dict()
|
||||
resp["query"] = query_text
|
||||
resp["model"] = model_name
|
||||
resp["fields"] = fields_list
|
||||
resp["priorSchemaLinking"] = prior_schema_links
|
||||
resp["dataDate"] = current_date
|
||||
|
||||
resp["analysisOutput"] = schema_linking_sql_combo_output
|
||||
resp["schemaLinkStr"] = schema_linking_str
|
||||
resp["sqlOutput"] = sql_str
|
||||
|
||||
logger.info("resp: {}".format(resp))
|
||||
|
||||
return resp
|
||||
|
||||
def query2sql_run(
|
||||
self,
|
||||
query_text: str,
|
||||
schema: Union[dict, None] = None,
|
||||
current_date: str = None,
|
||||
linking: Union[List[Mapping[str, str]], None] = None,
|
||||
):
|
||||
|
||||
if self.is_shortcut:
|
||||
return self.query2sqlcombo(query_text, schema, current_date, linking)
|
||||
else:
|
||||
return self.query2sql(query_text, schema, current_date, linking)
|
||||
text2sql_cs_agent.update_examples(sql_examplars, TEXT2DSL_EXAMPLE_NUM, TEXT2DSL_FEWSHOTS_NUM, TEXT2DSL_SELF_CONSISTENCY_NUM)
|
||||
|
||||
|
||||
text2sql_agent = Text2DSLAgent()
|
||||
text2sql_agent_router = Text2DSLAgentWrapper(sql_agent=text2sql_agent, sql_agent_cs=text2sql_cs_agent,
|
||||
is_shortcut=TEXT2DSL_IS_SHORTCUT, is_self_consistency=TEXT2DSL_IS_SELF_CONSISTENCY)
|
||||
380
chat/core/src/main/python/services/sql/sql_agent.py
Normal file
380
chat/core/src/main/python/services/sql/sql_agent.py
Normal file
@@ -0,0 +1,380 @@
|
||||
import os
|
||||
import sys
|
||||
from typing import List, Union, Mapping, Any
|
||||
from collections import Counter
|
||||
import random
|
||||
import asyncio
|
||||
from langchain.llms.base import BaseLLM
|
||||
|
||||
sys.path.append(os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))))
|
||||
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||||
sys.path.append(os.path.dirname(os.path.abspath(__file__)))
|
||||
|
||||
from instances.logging_instance import logger
|
||||
|
||||
|
||||
from sql.constructor import FewShotPromptTemplate2
|
||||
from sql.output_parser import schema_link_parse, combo_schema_link_parse, combo_sql_parse
|
||||
|
||||
|
||||
class Text2DSLAgent(object):
|
||||
def __init__(self, num_fewshots:int,
|
||||
sql_example_prompter:FewShotPromptTemplate2,
|
||||
llm: BaseLLM):
|
||||
self.num_fewshots = num_fewshots
|
||||
self.sql_example_prompter = sql_example_prompter
|
||||
self.llm = llm
|
||||
|
||||
def update_examples(self, sql_examplars, num_fewshots):
|
||||
self.num_fewshots = num_fewshots
|
||||
self.sql_example_prompter.reload_few_shot_example(sql_examplars)
|
||||
|
||||
def get_fewshot_examples(self, query_text: str)->List[Mapping[str, str]]:
|
||||
few_shot_example_meta_list = self.sql_example_prompter.retrieve_few_shot_example(query_text, self.num_fewshots)
|
||||
|
||||
return few_shot_example_meta_list
|
||||
|
||||
def generate_schema_linking_prompt(self, user_query: str, domain_name: str, fields_list: List[str],
|
||||
prior_schema_links: Mapping[str,str], fewshot_example_list:List[Mapping[str, str]])-> str:
|
||||
|
||||
prior_schema_links_str = '['+ ','.join(["""'{}'->{}""".format(k,v) for k,v in prior_schema_links.items()]) + ']'
|
||||
|
||||
instruction = "# 根据数据库的表结构,参考先验信息,找出为每个问题生成SQL查询语句的schema_links"
|
||||
|
||||
schema_linking_example_keys = ["table_name", "fields_list", "prior_schema_links", "question", "analysis", "schema_links"]
|
||||
schema_linking_example_template = "Table {table_name}, columns = {fields_list}, prior_schema_links = {prior_schema_links}\n问题:{question}\n分析:{analysis} 所以Schema_links是:\nSchema_links:{schema_links}"
|
||||
schema_linking_fewshot_prompt = self.sql_example_prompter.make_few_shot_example_prompt(few_shot_template=schema_linking_example_template,
|
||||
example_keys=schema_linking_example_keys,
|
||||
few_shot_example_meta_list=fewshot_example_list)
|
||||
|
||||
new_case_template = "Table {table_name}, columns = {fields_list}, prior_schema_links = {prior_schema_links}\n问题:{question}\n分析: 让我们一步一步地思考。"
|
||||
new_case_prompt = new_case_template.format(table_name=domain_name, fields_list=fields_list, prior_schema_links=prior_schema_links_str, question=user_query)
|
||||
|
||||
schema_linking_prompt = instruction + '\n\n' + schema_linking_fewshot_prompt + '\n\n' + new_case_prompt
|
||||
return schema_linking_prompt
|
||||
|
||||
def generate_sql_prompt(self, user_query: str, domain_name: str,
|
||||
schema_link_str: str, data_date: str,
|
||||
fewshot_example_list:List[Mapping[str, str]])-> str:
|
||||
instruction = "# 根据schema_links为每个问题生成SQL查询语句"
|
||||
sql_example_keys = ["question", "current_date", "table_name", "schema_links", "sql"]
|
||||
sql_example_template = "问题:{question}\nCurrent_date:{current_date}\nTable {table_name}\nSchema_links:{schema_links}\nSQL:{sql}"
|
||||
|
||||
|
||||
sql_example_fewshot_prompt = self.sql_example_prompter.make_few_shot_example_prompt(few_shot_template=sql_example_template,
|
||||
example_keys=sql_example_keys,
|
||||
few_shot_example_meta_list=fewshot_example_list)
|
||||
|
||||
new_case_template = "问题:{question}\nCurrent_date:{current_date}\nTable {table_name}\nSchema_links:{schema_links}\nSQL:"
|
||||
new_case_prompt = new_case_template.format(question=user_query, current_date=data_date, table_name=domain_name, schema_links=schema_link_str)
|
||||
|
||||
sql_example_prompt = instruction + '\n\n' + sql_example_fewshot_prompt + '\n\n' + new_case_prompt
|
||||
|
||||
return sql_example_prompt
|
||||
|
||||
def generate_schema_linking_sql_prompt(self, user_query: str,
|
||||
domain_name: str,
|
||||
data_date : str,
|
||||
fields_list: List[str],
|
||||
prior_schema_links: Mapping[str,str],
|
||||
fewshot_example_list:List[Mapping[str, str]]):
|
||||
|
||||
prior_schema_links_str = '['+ ','.join(["""'{}'->{}""".format(k,v) for k,v in prior_schema_links.items()]) + ']'
|
||||
|
||||
instruction = "# 根据数据库的表结构,参考先验信息,找出为每个问题生成SQL查询语句的schema_links,再根据schema_links为每个问题生成SQL查询语句"
|
||||
|
||||
example_keys = ["table_name", "fields_list", "prior_schema_links", "current_date", "question", "analysis", "schema_links", "sql"]
|
||||
example_template = "Table {table_name}, columns = {fields_list}, prior_schema_links = {prior_schema_links}\nCurrent_date:{current_date}\n问题:{question}\n分析:{analysis} 所以Schema_links是:\nSchema_links:{schema_links}\nSQL:{sql}"
|
||||
fewshot_prompt = self.sql_example_prompter.make_few_shot_example_prompt(few_shot_template=example_template,
|
||||
example_keys=example_keys,
|
||||
few_shot_example_meta_list=fewshot_example_list)
|
||||
|
||||
new_case_template = "Table {table_name}, columns = {fields_list}, prior_schema_links = {prior_schema_links}\nCurrent_date:{current_date}\n问题:{question}\n分析: 让我们一步一步地思考。"
|
||||
new_case_prompt = new_case_template.format(table_name=domain_name, fields_list=fields_list, prior_schema_links=prior_schema_links_str, current_date=data_date, question=user_query)
|
||||
|
||||
prompt = instruction + '\n\n' + fewshot_prompt + '\n\n' + new_case_prompt
|
||||
|
||||
return prompt
|
||||
|
||||
async def async_query2sql(self, query_text: str,
|
||||
model_name: str, fields_list: List[str],
|
||||
data_date: str, prior_schema_links: Mapping[str,str], prior_exts: str):
|
||||
logger.info("query_text: {}".format(query_text))
|
||||
logger.info("model_name: {}".format(model_name))
|
||||
logger.info("fields_list: {}".format(fields_list))
|
||||
logger.info("data_date: {}".format(data_date))
|
||||
logger.info("prior_schema_links: {}".format(prior_schema_links))
|
||||
logger.info("prior_exts: {}".format(prior_exts))
|
||||
|
||||
query_text = query_text + ' 备注:'+prior_exts
|
||||
logger.info("query_text_prior_exts: {}".format(query_text))
|
||||
|
||||
fewshot_example_meta_list = self.get_fewshot_examples(query_text)
|
||||
schema_linking_prompt = self.generate_schema_linking_prompt(query_text, model_name, fields_list, prior_schema_links, fewshot_example_meta_list)
|
||||
logger.debug("schema_linking_prompt->{}".format(schema_linking_prompt))
|
||||
schema_link_output = await self.llm._call_async(schema_linking_prompt)
|
||||
|
||||
schema_link_str = schema_link_parse(schema_link_output)
|
||||
|
||||
sql_prompt = self.generate_sql_prompt(query_text, model_name, schema_link_str, data_date, fewshot_example_meta_list)
|
||||
logger.debug("sql_prompt->{}".format(sql_prompt))
|
||||
sql_output = await self.llm._call_async(sql_prompt)
|
||||
|
||||
resp = dict()
|
||||
resp['query'] = query_text
|
||||
resp['model'] = model_name
|
||||
resp['fields'] = fields_list
|
||||
resp['priorSchemaLinking'] = prior_schema_links
|
||||
resp['dataDate'] = data_date
|
||||
|
||||
resp['schemaLinkingOutput'] = schema_link_output
|
||||
resp['schemaLinkStr'] = schema_link_str
|
||||
|
||||
resp['sqlOutput'] = sql_output
|
||||
|
||||
logger.info("resp: {}".format(resp))
|
||||
|
||||
return resp
|
||||
|
||||
async def async_query2sql_shortcut(self, query_text: str,
|
||||
model_name: str, fields_list: List[str],
|
||||
data_date: str, prior_schema_links: Mapping[str,str], prior_exts: str):
|
||||
logger.info("query_text: {}".format(query_text))
|
||||
logger.info("model_name: {}".format(model_name))
|
||||
logger.info("fields_list: {}".format(fields_list))
|
||||
logger.info("data_date: {}".format(data_date))
|
||||
logger.info("prior_schema_links: {}".format(prior_schema_links))
|
||||
logger.info("prior_exts: {}".format(prior_exts))
|
||||
|
||||
query_text = query_text + ' 备注:'+prior_exts
|
||||
logger.info("query_text_prior_exts: {}".format(query_text))
|
||||
|
||||
fewshot_example_meta_list = self.get_fewshot_examples(query_text)
|
||||
schema_linking_sql_shortcut_prompt = self.generate_schema_linking_sql_prompt(query_text, model_name, data_date, fields_list, prior_schema_links, fewshot_example_meta_list)
|
||||
logger.debug("schema_linking_sql_shortcut_prompt->{}".format(schema_linking_sql_shortcut_prompt))
|
||||
schema_linking_sql_shortcut_output = await self.llm._call_async(schema_linking_sql_shortcut_prompt)
|
||||
|
||||
schema_linking_str = combo_schema_link_parse(schema_linking_sql_shortcut_output)
|
||||
sql_str = combo_sql_parse(schema_linking_sql_shortcut_output)
|
||||
|
||||
resp = dict()
|
||||
resp['query'] = query_text
|
||||
resp['model'] = model_name
|
||||
resp['fields'] = fields_list
|
||||
resp['priorSchemaLinking'] = prior_schema_links
|
||||
resp['dataDate'] = data_date
|
||||
|
||||
resp['schemaLinkingComboOutput'] = schema_linking_sql_shortcut_output
|
||||
resp['schemaLinkStr'] = schema_linking_str
|
||||
resp['sqlOutput'] = sql_str
|
||||
|
||||
logger.info("resp: {}".format(resp))
|
||||
|
||||
return resp
|
||||
|
||||
class Text2DSLAgentConsistency(object):
|
||||
def __init__(self, num_fewshots:int, num_examples:int, num_self_consistency:int,
|
||||
sql_example_prompter:FewShotPromptTemplate2, llm: BaseLLM) -> None:
|
||||
self.num_fewshots = num_fewshots
|
||||
self.num_examples = num_examples
|
||||
assert self.num_fewshots <= self.num_examples
|
||||
self.num_self_consistency = num_self_consistency
|
||||
|
||||
self.llm = llm
|
||||
self.sql_example_prompter = sql_example_prompter
|
||||
|
||||
def update_examples(self, sql_examplars, num_examples, num_fewshots, num_self_consistency):
|
||||
self.num_fewshots = num_fewshots
|
||||
self.num_examples = num_examples
|
||||
assert self.num_fewshots <= self.num_examples
|
||||
self.num_self_consistency = num_self_consistency
|
||||
assert self.num_self_consistency >= 1
|
||||
self.sql_example_prompter.reload_few_shot_example(sql_examplars)
|
||||
|
||||
def get_examples_candidates(self, query_text: str)->List[Mapping[str, str]]:
|
||||
few_shot_example_meta_list = self.sql_example_prompter.retrieve_few_shot_example(query_text, self.num_examples)
|
||||
|
||||
return few_shot_example_meta_list
|
||||
|
||||
def get_fewshot_example_combos(self, example_meta_list:List[Mapping[str, str]])-> List[List[Mapping[str, str]]]:
|
||||
fewshot_example_list = []
|
||||
for i in range(0, self.num_self_consistency):
|
||||
random.shuffle(example_meta_list)
|
||||
fewshot_example_list.append(example_meta_list[:self.num_fewshots])
|
||||
|
||||
return fewshot_example_list
|
||||
|
||||
def generate_schema_linking_prompt(self, user_query: str, domain_name: str, fields_list: List[str],
|
||||
prior_schema_links: Mapping[str,str], fewshot_example_list:List[Mapping[str, str]])-> str:
|
||||
|
||||
prior_schema_links_str = '['+ ','.join(["""'{}'->{}""".format(k,v) for k,v in prior_schema_links.items()]) + ']'
|
||||
|
||||
instruction = "# 根据数据库的表结构,参考先验信息,找出为每个问题生成SQL查询语句的schema_links"
|
||||
|
||||
schema_linking_example_keys = ["table_name", "fields_list", "prior_schema_links", "question", "analysis", "schema_links"]
|
||||
schema_linking_example_template = "Table {table_name}, columns = {fields_list}, prior_schema_links = {prior_schema_links}\n问题:{question}\n分析:{analysis} 所以Schema_links是:\nSchema_links:{schema_links}"
|
||||
schema_linking_fewshot_prompt = self.sql_example_prompter.make_few_shot_example_prompt(few_shot_template=schema_linking_example_template,
|
||||
example_keys=schema_linking_example_keys,
|
||||
few_shot_example_meta_list=fewshot_example_list)
|
||||
|
||||
new_case_template = "Table {table_name}, columns = {fields_list}, prior_schema_links = {prior_schema_links}\n问题:{question}\n分析: 让我们一步一步地思考。"
|
||||
new_case_prompt = new_case_template.format(table_name=domain_name, fields_list=fields_list, prior_schema_links=prior_schema_links_str, question=user_query)
|
||||
|
||||
schema_linking_prompt = instruction + '\n\n' + schema_linking_fewshot_prompt + '\n\n' + new_case_prompt
|
||||
return schema_linking_prompt
|
||||
|
||||
def generate_schema_linking_prompt_pool(self, user_query: str, domain_name: str, fields_list: List[str],
|
||||
prior_schema_links: Mapping[str,str], fewshot_example_list_pool:List[List[Mapping[str, str]]])-> List[str]:
|
||||
schema_linking_prompt_pool = []
|
||||
for fewshot_example_list in fewshot_example_list_pool:
|
||||
schema_linking_prompt = self.generate_schema_linking_prompt(user_query, domain_name, fields_list, prior_schema_links, fewshot_example_list)
|
||||
schema_linking_prompt_pool.append(schema_linking_prompt)
|
||||
|
||||
return schema_linking_prompt_pool
|
||||
|
||||
def generate_sql_prompt(self, user_query: str, domain_name: str,
|
||||
schema_link_str: str, data_date: str,
|
||||
fewshot_example_list:List[Mapping[str, str]])-> str:
|
||||
instruction = "# 根据schema_links为每个问题生成SQL查询语句"
|
||||
sql_example_keys = ["question", "current_date", "table_name", "schema_links", "sql"]
|
||||
sql_example_template = "问题:{question}\nCurrent_date:{current_date}\nTable {table_name}\nSchema_links:{schema_links}\nSQL:{sql}"
|
||||
|
||||
|
||||
sql_example_fewshot_prompt = self.sql_example_prompter.make_few_shot_example_prompt(few_shot_template=sql_example_template,
|
||||
example_keys=sql_example_keys,
|
||||
few_shot_example_meta_list=fewshot_example_list)
|
||||
|
||||
new_case_template = "问题:{question}\nCurrent_date:{current_date}\nTable {table_name}\nSchema_links:{schema_links}\nSQL:"
|
||||
new_case_prompt = new_case_template.format(question=user_query, current_date=data_date, table_name=domain_name, schema_links=schema_link_str)
|
||||
|
||||
sql_example_prompt = instruction + '\n\n' + sql_example_fewshot_prompt + '\n\n' + new_case_prompt
|
||||
|
||||
return sql_example_prompt
|
||||
|
||||
def generate_sql_prompt_pool(self, user_query: str, domain_name: str, data_date: str,
|
||||
schema_link_str_pool: List[str], fewshot_example_list_pool:List[List[Mapping[str, str]]])-> List[str]:
|
||||
sql_prompt_pool = []
|
||||
for schema_link_str, fewshot_example_list in zip(schema_link_str_pool, fewshot_example_list_pool):
|
||||
sql_prompt = self.generate_sql_prompt(user_query, domain_name, schema_link_str, data_date, fewshot_example_list)
|
||||
sql_prompt_pool.append(sql_prompt)
|
||||
|
||||
return sql_prompt_pool
|
||||
|
||||
def self_consistency_vote(self, output_res_pool:List[str]):
|
||||
output_res_counts = Counter(output_res_pool)
|
||||
output_res_max = output_res_counts.most_common(1)[0][0]
|
||||
total_output_num = len(output_res_pool)
|
||||
|
||||
vote_percentage = {k: (v/total_output_num) for k,v in output_res_counts.items()}
|
||||
|
||||
return output_res_max, vote_percentage
|
||||
|
||||
def schema_linking_list_str_unify(self, schema_linking_list: List[str])-> List[str]:
|
||||
schema_linking_list_unify = []
|
||||
for schema_linking_str in schema_linking_list:
|
||||
schema_linking_str_unify = ','.join(sorted([item.strip() for item in schema_linking_str.strip('[]').split(',')]))
|
||||
schema_linking_str_unify = f'[{schema_linking_str_unify}]'
|
||||
schema_linking_list_unify.append(schema_linking_str_unify)
|
||||
|
||||
return schema_linking_list_unify
|
||||
|
||||
|
||||
async def generate_schema_linking_tasks(self, user_query: str, domain_name: str,
|
||||
fields_list: List[str], prior_schema_links: Mapping[str,str],
|
||||
fewshot_example_list_combo:List[List[Mapping[str, str]]]):
|
||||
|
||||
schema_linking_prompt_pool = self.generate_schema_linking_prompt_pool(user_query, domain_name,
|
||||
fields_list, prior_schema_links,
|
||||
fewshot_example_list_combo)
|
||||
schema_linking_output_task_pool = [self.llm._call_async(schema_linking_prompt) for schema_linking_prompt in schema_linking_prompt_pool]
|
||||
schema_linking_output_res_pool = await asyncio.gather(*schema_linking_output_task_pool)
|
||||
logger.debug(f'schema_linking_output_res_pool:{schema_linking_output_res_pool}')
|
||||
|
||||
return schema_linking_output_res_pool
|
||||
|
||||
async def generate_sql_tasks(self, user_query: str, domain_name: str, data_date: str,
|
||||
schema_link_str_pool: List[str], fewshot_example_list_combo:List[List[Mapping[str, str]]]):
|
||||
|
||||
sql_prompt_pool = self.generate_sql_prompt_pool(user_query, domain_name, schema_link_str_pool, data_date, fewshot_example_list_combo)
|
||||
sql_output_task_pool = [self.llm._call_async(sql_prompt) for sql_prompt in sql_prompt_pool]
|
||||
sql_output_res_pool = await asyncio.gather(*sql_output_task_pool)
|
||||
logger.debug(f'sql_output_res_pool:{sql_output_res_pool}')
|
||||
|
||||
return sql_output_res_pool
|
||||
|
||||
async def tasks_run(self, user_query: str, domain_name: str, fields_list: List[str], prior_schema_links: Mapping[str,str], data_date: str, prior_exts: str):
|
||||
logger.info("user_query: {}".format(user_query))
|
||||
logger.info("domain_name: {}".format(domain_name))
|
||||
logger.info("fields_list: {}".format(fields_list))
|
||||
logger.info("current_date: {}".format(data_date))
|
||||
logger.info("prior_schema_links: {}".format(prior_schema_links))
|
||||
logger.info("prior_exts: {}".format(prior_exts))
|
||||
|
||||
user_query = user_query + ' 备注:'+prior_exts
|
||||
logger.info("user_query_prior_exts: {}".format(user_query))
|
||||
|
||||
fewshot_example_meta_list = self.get_examples_candidates(user_query)
|
||||
fewshot_example_list_combo = self.get_fewshot_example_combos(fewshot_example_meta_list)
|
||||
|
||||
schema_linking_output_candidates = await self.generate_schema_linking_tasks(user_query, domain_name, fields_list, prior_schema_links, fewshot_example_list_combo)
|
||||
schema_linking_candidate_list = [schema_link_parse(schema_linking_output_candidate) for schema_linking_output_candidate in schema_linking_output_candidates]
|
||||
logger.debug(f'schema_linking_candidate_list:{schema_linking_candidate_list}')
|
||||
schema_linking_candidate_sorted_list = self.schema_linking_list_str_unify(schema_linking_candidate_list)
|
||||
logger.debug(f'schema_linking_candidate_sorted_list:{schema_linking_candidate_sorted_list}')
|
||||
|
||||
schema_linking_output_max, schema_linking_output_vote_percentage = self.self_consistency_vote(schema_linking_candidate_sorted_list)
|
||||
|
||||
sql_output_candicates = await self.generate_sql_tasks(user_query, domain_name, data_date, schema_linking_candidate_list,fewshot_example_list_combo)
|
||||
logger.debug(f'sql_output_candicates:{sql_output_candicates}')
|
||||
sql_output_max, sql_output_vote_percentage = self.self_consistency_vote(sql_output_candicates)
|
||||
|
||||
resp = dict()
|
||||
resp['query'] = user_query
|
||||
resp['model'] = domain_name
|
||||
resp['fields'] = fields_list
|
||||
resp['priorSchemaLinking'] = prior_schema_links
|
||||
resp['dataDate'] = data_date
|
||||
|
||||
resp['schemaLinkStr'] = schema_linking_output_max
|
||||
resp['schemaLinkingWeight'] = schema_linking_output_vote_percentage
|
||||
|
||||
resp['sqlOutput'] = sql_output_max
|
||||
resp['sqlWeight'] = sql_output_vote_percentage
|
||||
|
||||
logger.info("resp: {}".format(resp))
|
||||
|
||||
return resp
|
||||
|
||||
class Text2DSLAgentWrapper(object):
|
||||
def __init__(self, sql_agent:Text2DSLAgent, sql_agent_cs:Text2DSLAgentConsistency,
|
||||
is_shortcut:bool, is_self_consistency:bool):
|
||||
self.sql_agent = sql_agent
|
||||
self.sql_agent_cs = sql_agent_cs
|
||||
|
||||
self.is_shortcut = is_shortcut
|
||||
self.is_self_consistency = is_self_consistency
|
||||
|
||||
async def async_query2sql(self, query_text: str,
|
||||
model_name: str, fields_list: List[str],
|
||||
data_date: str, prior_schema_links: Mapping[str,str], prior_exts: str):
|
||||
if self.is_self_consistency:
|
||||
logger.info("sql wrapper: self_consistency")
|
||||
resp = await self.sql_agent_cs.tasks_run(user_query=query_text, domain_name=model_name, fields_list=fields_list, prior_schema_links=prior_schema_links, data_date=data_date, prior_exts=prior_exts)
|
||||
return resp
|
||||
elif self.is_shortcut:
|
||||
logger.info("sql wrapper: shortcut")
|
||||
resp = await self.sql_agent.async_query2sql_shortcut(query_text=query_text, model_name=model_name, fields_list=fields_list, data_date=data_date, prior_schema_links=prior_schema_links, prior_exts=prior_exts)
|
||||
return resp
|
||||
else:
|
||||
logger.info("sql wrapper: normal")
|
||||
resp = await self.sql_agent.async_query2sql(query_text=query_text, model_name=model_name, fields_list=fields_list, data_date=data_date, prior_schema_links=prior_schema_links, prior_exts=prior_exts)
|
||||
return resp
|
||||
|
||||
def update_configs(self, is_shortcut, is_self_consistency,
|
||||
sql_examplars, num_examples, num_fewshots, num_self_consistency):
|
||||
self.is_shortcut = is_shortcut
|
||||
self.is_self_consistency = is_self_consistency
|
||||
|
||||
self.sql_agent.update_examples(sql_examplars=sql_examplars, num_fewshots=num_examples)
|
||||
self.sql_agent_cs.update_examples(sql_examplars=sql_examplars, num_examples=num_examples, num_fewshots=num_fewshots, num_self_consistency=num_self_consistency)
|
||||
|
||||
@@ -8,59 +8,81 @@ sys.path.append(os.path.dirname(os.path.abspath(__file__)))
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException
|
||||
|
||||
from services.sql.run import text2sql_agent
|
||||
from services.sql.run import text2sql_agent_router
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
@router.post("/query2sql/")
|
||||
def din_query2sql(query_body: Mapping[str, Any]):
|
||||
if "queryText" not in query_body:
|
||||
@router.post("/query2sql")
|
||||
async def query2sql(query_body: Mapping[str, Any]):
|
||||
if 'queryText' not in query_body:
|
||||
raise HTTPException(status_code=400, detail="query_text is not in query_body")
|
||||
else:
|
||||
query_text = query_body["queryText"]
|
||||
query_text = query_body['queryText']
|
||||
|
||||
if "schema" not in query_body:
|
||||
if 'schema' not in query_body:
|
||||
raise HTTPException(status_code=400, detail="schema is not in query_body")
|
||||
else:
|
||||
schema = query_body["schema"]
|
||||
|
||||
if "currentDate" not in query_body:
|
||||
schema = query_body['schema']
|
||||
|
||||
if 'currentDate' not in query_body:
|
||||
raise HTTPException(status_code=400, detail="currentDate is not in query_body")
|
||||
else:
|
||||
current_date = query_body["currentDate"]
|
||||
current_date = query_body['currentDate']
|
||||
|
||||
if "linking" not in query_body:
|
||||
linking = None
|
||||
if 'linking' not in query_body:
|
||||
raise HTTPException(status_code=400, detail="linking is not in query_body")
|
||||
else:
|
||||
linking = query_body["linking"]
|
||||
linking = query_body['linking']
|
||||
|
||||
resp = text2sql_agent.query2sql_run(
|
||||
query_text=query_text, schema=schema, current_date=current_date, linking=linking
|
||||
)
|
||||
if 'priorExts' not in query_body:
|
||||
raise HTTPException(status_code=400, detail="prior_exts is not in query_body")
|
||||
else:
|
||||
prior_exts = query_body['priorExts']
|
||||
|
||||
model_name = schema['modelName']
|
||||
fields_list = schema['fieldNameList']
|
||||
prior_schema_links = {item['fieldValue']:item['fieldName'] for item in linking}
|
||||
|
||||
resp = await text2sql_agent_router.async_query2sql(query_text=query_text, model_name=model_name, fields_list=fields_list,
|
||||
data_date=current_date, prior_schema_links=prior_schema_links, prior_exts=prior_exts)
|
||||
|
||||
return resp
|
||||
|
||||
|
||||
@router.post("/query2sql_setting_update/")
|
||||
@router.post("/query2sql_setting_update")
|
||||
def query2sql_setting_update(query_body: Mapping[str, Any]):
|
||||
if "sqlExamplars" not in query_body:
|
||||
if 'sqlExamplars' not in query_body:
|
||||
raise HTTPException(status_code=400, detail="sqlExamplars is not in query_body")
|
||||
else:
|
||||
sql_examplars = query_body["sqlExamplars"]
|
||||
sql_examplars = query_body['sqlExamplars']
|
||||
|
||||
if "exampleNums" not in query_body:
|
||||
if 'exampleNums' not in query_body:
|
||||
raise HTTPException(status_code=400, detail="exampleNums is not in query_body")
|
||||
else:
|
||||
example_nums = query_body["exampleNums"]
|
||||
example_nums = query_body['exampleNums']
|
||||
|
||||
if "isShortcut" not in query_body:
|
||||
if 'fewshotNums' not in query_body:
|
||||
raise HTTPException(status_code=400, detail="fewshotNums is not in query_body")
|
||||
else:
|
||||
fewshot_nums = query_body['fewshotNums']
|
||||
|
||||
if 'selfConsistencyNums' not in query_body:
|
||||
raise HTTPException(status_code=400, detail="selfConsistencyNums is not in query_body")
|
||||
else:
|
||||
self_consistency_nums = query_body['selfConsistencyNums']
|
||||
|
||||
if 'isShortcut' not in query_body:
|
||||
raise HTTPException(status_code=400, detail="isShortcut is not in query_body")
|
||||
else:
|
||||
is_shortcut = query_body["isShortcut"]
|
||||
is_shortcut = query_body['isShortcut']
|
||||
|
||||
text2sql_agent.update_examples(
|
||||
sql_examples=sql_examplars, example_nums=example_nums, is_shortcut=is_shortcut
|
||||
)
|
||||
if 'isSelfConsistency' not in query_body:
|
||||
raise HTTPException(status_code=400, detail="isSelfConsistency is not in query_body")
|
||||
else:
|
||||
is_self_consistency = query_body['isSelfConsistency']
|
||||
|
||||
text2sql_agent_router.update_configs(is_shortcut=is_shortcut, is_self_consistency=is_self_consistency, sql_examplars=sql_examplars,
|
||||
num_examples=example_nums, num_fewshots=fewshot_nums, num_self_consistency=self_consistency_nums)
|
||||
|
||||
return "success"
|
||||
|
||||
@@ -1,4 +0,0 @@
|
||||
from loguru import logger
|
||||
import sys
|
||||
|
||||
# logger.add(sys.stdout, format="{time:YYYY-MM-DD at HH:mm:ss} | {level} | {message}")
|
||||
@@ -4,22 +4,14 @@ from typing import Any, List, Mapping, Optional, Union
|
||||
import chromadb
|
||||
from chromadb.api import Collection
|
||||
from chromadb.config import Settings
|
||||
from chromadb.api import Collection, Documents, Embeddings
|
||||
|
||||
import os
|
||||
import sys
|
||||
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||||
sys.path.append(os.path.dirname(os.path.abspath(__file__)))
|
||||
|
||||
from util.logging_utils import logger
|
||||
|
||||
from config.config_parse import CHROMA_DB_PERSIST_PATH
|
||||
|
||||
client = chromadb.Client(
|
||||
Settings(
|
||||
chroma_db_impl="duckdb+parquet",
|
||||
persist_directory=CHROMA_DB_PERSIST_PATH, # Optional, defaults to .chromadb/ in the current directory
|
||||
)
|
||||
)
|
||||
from instances.logging_instance import logger
|
||||
|
||||
|
||||
def empty_chroma_collection_2(collection:Collection):
|
||||
@@ -48,26 +40,30 @@ def empty_chroma_collection(collection:Collection) -> None:
|
||||
def add_chroma_collection(collection:Collection,
|
||||
queries:List[str],
|
||||
query_ids:List[str],
|
||||
metadatas:List[Mapping[str, str]]=None
|
||||
metadatas:List[Mapping[str, str]]=None,
|
||||
embeddings:Embeddings=None
|
||||
) -> None:
|
||||
|
||||
collection.add(documents=queries,
|
||||
ids=query_ids,
|
||||
metadatas=metadatas)
|
||||
metadatas=metadatas,
|
||||
embeddings=embeddings)
|
||||
|
||||
|
||||
def update_chroma_collection(collection:Collection,
|
||||
queries:List[str],
|
||||
query_ids:List[str],
|
||||
metadatas:List[Mapping[str, str]]=None
|
||||
metadatas:List[Mapping[str, str]]=None,
|
||||
embeddings:Embeddings=None
|
||||
) -> None:
|
||||
|
||||
collection.update(documents=queries,
|
||||
ids=query_ids,
|
||||
metadatas=metadatas)
|
||||
metadatas=metadatas,
|
||||
embeddings=embeddings)
|
||||
|
||||
|
||||
def query_chroma_collection(collection:Collection, query_texts:List[str],
|
||||
def query_chroma_collection(collection:Collection, query_texts:List[str]=None, query_embeddings:Embeddings=None,
|
||||
filter_condition:Mapping[str,str]=None, n_results:int=10):
|
||||
outer_opt = '$and'
|
||||
inner_opt = '$eq'
|
||||
@@ -81,8 +77,10 @@ def query_chroma_collection(collection:Collection, query_texts:List[str],
|
||||
else:
|
||||
outer_filter = None
|
||||
|
||||
print('outer_filter: ', outer_filter)
|
||||
res = collection.query(query_texts=query_texts, n_results=n_results, where=outer_filter)
|
||||
logger.info('outer_filter: {}'.format(outer_filter))
|
||||
|
||||
res = collection.query(query_texts=query_texts, query_embeddings=query_embeddings,
|
||||
n_results=n_results, where=outer_filter)
|
||||
return res
|
||||
|
||||
|
||||
@@ -115,16 +113,32 @@ def parse_retrieval_chroma_collection_query(res:List[Mapping[str, Any]]):
|
||||
|
||||
return parsed_res
|
||||
|
||||
def chroma_collection_query_retrieval_format(query_list:List[str], retrieval_list:List[Mapping[str, Any]]):
|
||||
def chroma_collection_query_retrieval_format(query_list:List[str], query_embeddings:Embeddings ,retrieval_list:List[Mapping[str, Any]]):
|
||||
res = []
|
||||
for query_idx in range(0, len(query_list)):
|
||||
query = query_list[query_idx]
|
||||
retrieval = retrieval_list[query_idx]
|
||||
|
||||
res.append({
|
||||
'query': query,
|
||||
'retrieval': retrieval
|
||||
})
|
||||
if query_list is not None and query_embeddings is not None:
|
||||
raise Exception("query_list and query_embeddings are not None")
|
||||
if query_list is None and query_embeddings is None:
|
||||
raise Exception("query_list and query_embeddings are None")
|
||||
|
||||
if query_list is not None:
|
||||
for query_idx in range(0, len(query_list)):
|
||||
query = query_list[query_idx]
|
||||
retrieval = retrieval_list[query_idx]
|
||||
|
||||
res.append({
|
||||
'query': query,
|
||||
'retrieval': retrieval
|
||||
})
|
||||
else:
|
||||
for query_idx in range(0, len(query_embeddings)):
|
||||
query_embedding = query_embeddings[query_idx]
|
||||
retrieval = retrieval_list[query_idx]
|
||||
|
||||
res.append({
|
||||
'query_embedding': query_embedding,
|
||||
'retrieval': retrieval
|
||||
})
|
||||
|
||||
return res
|
||||
|
||||
Reference in New Issue
Block a user