add self-consistency feature for text2sql (#303)

This commit is contained in:
codescracker
2023-10-31 20:02:20 +08:00
committed by GitHub
parent ae9aa1ba0f
commit 438e8463f5
22 changed files with 764 additions and 727 deletions

View File

@@ -1,7 +1,14 @@
# -*- coding:utf-8 -*-
import os
import configparser
from util.logging_utils import logger
import os
import sys
sys.path.append(os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))))
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
sys.path.append(os.path.dirname(os.path.abspath(__file__)))
from instances.logging_instance import logger
def type_convert(input_str: str):
@@ -11,10 +18,12 @@ def type_convert(input_str: str):
return input_str
PROJECT_DIR_PATH = os.path.dirname(os.path.abspath(__file__))
PROJECT_DIR_PATH = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
config_dir = "config"
CONFIG_DIR_PATH = os.path.join(PROJECT_DIR_PATH, config_dir)
config_file = "run_config.ini"
config_path = os.path.join(PROJECT_DIR_PATH, config_file)
config_path = os.path.join(CONFIG_DIR_PATH, config_file)
config = configparser.ConfigParser()
config.read(config_path)
@@ -26,9 +35,13 @@ chroma_db_section_name = "ChromaDB"
CHROMA_DB_PERSIST_DIR = config.get(chroma_db_section_name, 'CHROMA_DB_PERSIST_DIR')
PRESET_QUERY_COLLECTION_NAME = config.get(chroma_db_section_name, 'PRESET_QUERY_COLLECTION_NAME')
SOLVED_QUERY_COLLECTION_NAME = config.get(chroma_db_section_name, 'SOLVED_QUERY_COLLECTION_NAME')
TEXT2DSL_COLLECTION_NAME = config.get(chroma_db_section_name, 'TEXT2DSL_COLLECTION_NAME')
TEXT2DSL_FEW_SHOTS_EXAMPLE_NUM = int(config.get(chroma_db_section_name, 'TEXT2DSL_FEW_SHOTS_EXAMPLE_NUM'))
TEXT2DSLAGENT_COLLECTION_NAME = config.get(chroma_db_section_name, 'TEXT2DSLAGENT_COLLECTION_NAME')
TEXT2DSLAGENTCS_COLLECTION_NAME = config.get(chroma_db_section_name, 'TEXT2DSLAGENTCS_COLLECTION_NAME')
TEXT2DSL_EXAMPLE_NUM = int(config.get(chroma_db_section_name, 'TEXT2DSL_EXAMPLE_NUM'))
TEXT2DSL_FEWSHOTS_NUM = int(config.get(chroma_db_section_name, 'TEXT2DSL_FEWSHOTS_NUM'))
TEXT2DSL_SELF_CONSISTENCY_NUM = int(config.get(chroma_db_section_name, 'TEXT2DSL_SELF_CONSISTENCY_NUM'))
TEXT2DSL_IS_SHORTCUT = eval(config.get(chroma_db_section_name, 'TEXT2DSL_IS_SHORTCUT'))
TEXT2DSL_IS_SELF_CONSISTENCY = eval(config.get(chroma_db_section_name, 'TEXT2DSL_IS_SELF_CONSISTENCY'))
CHROMA_DB_PERSIST_PATH = os.path.join(PROJECT_DIR_PATH, CHROMA_DB_PERSIST_DIR)
text2vec_section_name = "Text2Vec"
@@ -44,10 +57,14 @@ for option in config.options(llm_model_section_name):
if __name__ == "__main__":
logger.info("PROJECT_DIR_PATH: ", PROJECT_DIR_PATH)
logger.info("EMB_MODEL_PATH: ", HF_TEXT2VEC_MODEL_NAME)
logger.info("CHROMA_DB_PERSIST_PATH: ", CHROMA_DB_PERSIST_PATH)
logger.info("LLMPARSER_HOST: ", LLMPARSER_HOST)
logger.info("LLMPARSER_PORT: ", LLMPARSER_PORT)
logger.info("llm_config_dict: ", llm_config_dict)
logger.info("is_shortcut: ", TEXT2DSL_IS_SHORTCUT)
logger.info(f"PROJECT_DIR_PATH: {PROJECT_DIR_PATH}")
logger.info(f"EMB_MODEL_PATH: {HF_TEXT2VEC_MODEL_NAME}")
logger.info(f"CHROMA_DB_PERSIST_PATH: {CHROMA_DB_PERSIST_PATH}")
logger.info(f"LLMPARSER_HOST: {LLMPARSER_HOST}")
logger.info(f"LLMPARSER_PORT: {LLMPARSER_PORT}")
logger.info(f"llm_config_dict: {llm_config_dict}")
logger.info(f"TEXT2DSL_EXAMPLE_NUM: {TEXT2DSL_EXAMPLE_NUM}")
logger.info(f"TEXT2DSL_FEWSHOTS_NUM: {TEXT2DSL_FEWSHOTS_NUM}")
logger.info(f"TEXT2DSL_SELF_CONSISTENCY_NUM: {TEXT2DSL_SELF_CONSISTENCY_NUM}")
logger.info(f"TEXT2DSL_IS_SHORTCUT: {TEXT2DSL_IS_SHORTCUT}")
logger.info(f"TEXT2DSL_IS_SELF_CONSISTENCY: {TEXT2DSL_IS_SELF_CONSISTENCY}")

View File

@@ -6,9 +6,13 @@ LLMPARSER_PORT = 9092
CHROMA_DB_PERSIST_DIR = chm_db
PRESET_QUERY_COLLECTION_NAME = preset_query_collection
SOLVED_QUERY_COLLECTION_NAME = solved_query_collection
TEXT2DSL_COLLECTION_NAME = text2dsl_collection
TEXT2DSL_FEW_SHOTS_EXAMPLE_NUM = 15
TEXT2DSLAGENT_COLLECTION_NAME = text2dsl_agent_collection
TEXT2DSLAGENTCS_COLLECTION_NAME = text2dsl_agent_cs_collection
TEXT2DSL_EXAMPLE_NUM = 15
TEXT2DSL_FEWSHOTS_NUM = 10
TEXT2DSL_SELF_CONSISTENCY_NUM = 5
TEXT2DSL_IS_SHORTCUT = False
TEXT2DSL_IS_SELF_CONSISTENCY = False
[Text2Vec]
HF_TEXT2VEC_MODEL_NAME = GanymedeNil/text2vec-large-chinese

View File

@@ -228,10 +228,11 @@ examplars= [
"question":"邓梓琦在2023年1月5日之后发布的歌曲中有哪些播放量大于500W的",
"prior_schema_links":"""['2312311'->MPPM歌手ID]""",
"analysis": """让我们一步一步地思考。在问题“邓梓琦在2023年1月5日之后发布的歌曲中有哪些播放量大于500W的“中我们被问
“歌曲中有哪些”所以我们需要column=[歌曲名]
“播放量大于500W的”所以我们需要column=[结算播放量], cell values = [5000000],所以有[结算播放量:(5000000)]
”邓梓琦在2023年1月5日之后发布的“所以我们需要column=[发布时间], cell values = ['2023-01-05'],所以有[发布时间:('2023-01-05')]
”邓梓琦“所以我们需要column=[歌手名], cell values = ['邓梓琦'],所以有[歌手名:('邓梓琦')]""",
"schema_links":"""["结算播放量":(5000000), "发布时间":("'2023-01-05'"), "歌手名":("'邓梓琦'")]""",
"schema_links":"""["歌曲名", "结算播放量":(5000000), "发布时间":("'2023-01-05'"), "歌手名":("'邓梓琦'")]""",
"sql":"""select 歌曲名 from 歌曲库 where 发布时间 >= '2023-01-05' and 歌手名 = '邓梓琦' and 结算播放量 > 5000000"""
},
{ "current_date":"2023-09-17",

View File

@@ -0,0 +1,21 @@
# -*- coding:utf-8 -*-
from typing import Any, List, Mapping, Optional, Union
import chromadb
from chromadb.api import Collection
from chromadb.config import Settings
import os
import sys
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
sys.path.append(os.path.dirname(os.path.abspath(__file__)))
from config.config_parse import CHROMA_DB_PERSIST_PATH
client = chromadb.Client(
Settings(
chroma_db_impl="duckdb+parquet",
persist_directory=CHROMA_DB_PERSIST_PATH, # Optional, defaults to .chromadb/ in the current directory
)
)

View File

@@ -0,0 +1,6 @@
from loguru import logger
import sys
logger.remove() #remove the old handler. Else, the old one will work along with the new one you've added below'
logger.add(sys.stdout, format="{time:YYYY-MM-DD at HH:mm:ss} | {level} | {message}", level="INFO")

View File

@@ -5,7 +5,7 @@ import re
import sys
from typing import Any, List, Mapping, Union
from util.logging_utils import logger
from instances.logging_instance import logger
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
sys.path.append(os.path.dirname(os.path.abspath(__file__)))

View File

@@ -12,7 +12,7 @@ from plugin_call.prompt_construct import (
construct_task_prompt,
plugin_selection_output_parse,
)
from util.llm_instance import llm
from instances.llm_instance import llm
def plugin_selection_run(

View File

@@ -1,111 +0,0 @@
# -*- coding:utf-8 -*-
import os
import sys
import uuid
from typing import Any, List, Mapping
from chromadb.api import Collection
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
sys.path.append(os.path.dirname(os.path.abspath(__file__)))
def get_ids(documents: List[str]) -> List[str]:
ids = []
for doc in documents:
ids.append(str(uuid.uuid5(uuid.NAMESPACE_URL, doc)))
return ids
def add2preset_query_collection(
collection: Collection, preset_queries: List[str], preset_query_ids: List[str]
) -> None:
collection.add(documents=preset_queries, ids=preset_query_ids)
def update_preset_query_collection(
collection: Collection, preset_queries: List[str], preset_query_ids: List[str]
) -> None:
collection.update(documents=preset_queries, ids=preset_query_ids)
def query2preset_query_collection(
collection: Collection, query_texts: List[str], n_results: int = 10
):
collection_cnt = collection.count()
min_n_results = 10
min_n_results = min(collection_cnt, min_n_results)
if n_results > min_n_results:
res = collection.query(query_texts=query_texts, n_results=n_results)
return res
else:
res = collection.query(query_texts=query_texts, n_results=min_n_results)
for _key in res.keys():
if res[_key] is None:
continue
for _idx in range(0, len(query_texts)):
res[_key][_idx] = res[_key][_idx][:n_results]
return res
def parse_retrieval_preset_query(res: List[Mapping[str, Any]]):
parsed_res = [[] for _ in range(0, len(res["ids"]))]
retrieval_ids = res["ids"]
retrieval_distances = res["distances"]
retrieval_sentences = res["documents"]
for query_idx in range(0, len(retrieval_ids)):
id_ls = retrieval_ids[query_idx]
distance_ls = retrieval_distances[query_idx]
sentence_ls = retrieval_sentences[query_idx]
for idx in range(0, len(id_ls)):
id = id_ls[idx]
distance = distance_ls[idx]
sentence = sentence_ls[idx]
parsed_res[query_idx].append(
{"id": id, "distance": distance, "presetQuery": sentence}
)
return parsed_res
def preset_query_retrieval_format(
query_list: List[str], retrieval_list: List[Mapping[str, Any]]
):
res = []
for query_idx in range(0, len(query_list)):
query = query_list[query_idx]
retrieval = retrieval_list[query_idx]
res.append({"query": query, "retrieval": retrieval})
return res
def empty_preset_query_collection(collection: Collection) -> None:
collection.delete()
def delete_preset_query_by_ids(
collection: Collection, preset_query_ids: List[str]
) -> None:
collection.delete(ids=preset_query_ids)
def get_preset_query_by_ids(collection: Collection, preset_query_ids: List[str]):
res = collection.get(ids=preset_query_ids)
return res
def preset_query_collection_size(collection: Collection) -> int:
return collection.count()

View File

@@ -1,51 +0,0 @@
# -*- coding:utf-8 -*-
import os
import sys
from typing import List
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
sys.path.append(os.path.dirname(os.path.abspath(__file__)))
from util.logging_utils import logger
from chromadb.api import Collection
from preset_query_db import (
query2preset_query_collection,
parse_retrieval_preset_query,
preset_query_retrieval_format,
preset_query_collection_size,
)
from util.text2vec import Text2VecEmbeddingFunction
from config.config_parse import PRESET_QUERY_COLLECTION_NAME
from util.chromadb_instance import client
emb_func = Text2VecEmbeddingFunction()
collection = client.get_or_create_collection(
name=PRESET_QUERY_COLLECTION_NAME,
embedding_function=emb_func,
metadata={"hnsw:space": "cosine"},
) # Get a collection object from an existing collection, by name. If it doesn't exist, create it.
logger.info("init_preset_query_collection_size: {}", preset_query_collection_size(collection))
def preset_query_retrieval_run(
collection: Collection, query_texts_list: List[str], n_results: int = 5
):
retrieval_res = query2preset_query_collection(
collection=collection, query_texts=query_texts_list, n_results=n_results
)
parsed_retrieval_res = parse_retrieval_preset_query(retrieval_res)
parsed_retrieval_res_format = preset_query_retrieval_format(
query_texts_list, parsed_retrieval_res
)
logger.info("parsed_retrieval_res_format: {}", parsed_retrieval_res_format)
return parsed_retrieval_res_format

View File

@@ -0,0 +1,98 @@
# -*- coding:utf-8 -*-
import os
import sys
import uuid
from typing import Any, List, Mapping, Optional, Union
import chromadb
from chromadb import Client
from chromadb.config import Settings
from chromadb.api import Collection, Documents, Embeddings
from chromadb.api.types import CollectionMetadata
sys.path.append(os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))))
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
sys.path.append(os.path.dirname(os.path.abspath(__file__)))
from instances.logging_instance import logger
from utils.chromadb_utils import (get_chroma_collection_size, query_chroma_collection,
parse_retrieval_chroma_collection_query, chroma_collection_query_retrieval_format,
get_chroma_collection_by_ids, get_chroma_collection_size,
add_chroma_collection, update_chroma_collection, delete_chroma_collection_by_ids,
empty_chroma_collection_2)
from instances.text2vec import Text2VecEmbeddingFunction
class ChromaCollectionRetriever(object):
def __init__(self, collection:Collection):
self.collection = collection
def retrieval_query_run(self, query_texts_list:List[str]=None, query_embeddings:Embeddings=None,
filter_condition:Mapping[str,str]=None, n_results:int=5):
retrieval_res = query_chroma_collection(self.collection, query_texts_list, query_embeddings,
filter_condition, n_results)
parsed_retrieval_res = parse_retrieval_chroma_collection_query(retrieval_res)
logger.debug('parsed_retrieval_res: {}', parsed_retrieval_res)
parsed_retrieval_res_format = chroma_collection_query_retrieval_format(query_texts_list, query_embeddings, parsed_retrieval_res)
logger.debug('parsed_retrieval_res_format: {}', parsed_retrieval_res_format)
return parsed_retrieval_res_format
def get_query_by_ids(self, query_ids:List[str]):
queries = get_chroma_collection_by_ids(self.collection, query_ids)
return queries
def get_query_size(self):
return get_chroma_collection_size(self.collection)
def add_queries(self, query_text_list:List[str],
query_id_list:List[str],
metadatas:List[Mapping[str, str]]=None,
embeddings:Embeddings=None):
add_chroma_collection(self.collection, query_text_list, query_id_list, metadatas, embeddings)
return True
def update_queries(self, query_text_list:List[str],
query_id_list:List[str],
metadatas:List[Mapping[str, str]]=None,
embeddings:Embeddings=None):
update_chroma_collection(self.collection, query_text_list, query_id_list, metadatas, embeddings)
return True
def delete_queries_by_ids(self, query_ids:List[str]):
delete_chroma_collection_by_ids(self.collection, query_ids)
return True
def empty_query_collection(self):
self.collection = empty_chroma_collection_2(self.collection)
return True
class CollectionManager(object):
def __init__(self, chroma_client:Client, embedding_func: Text2VecEmbeddingFunction, collection_meta: Optional[CollectionMetadata] = None):
self.chroma_client = chroma_client
self.embedding_func = embedding_func
self.collection_meta = collection_meta
def list_collections(self):
collection_list = self.chroma_client.list_collections()
return collection_list
def get_collection(self, collection_name:str):
collection = self.chroma_client.get_collection(name=collection_name, embedding_function=self.embedding_func)
return collection
def create_collection(self, collection_name:str):
collection = self.chroma_client.create_collection(name=collection_name, embedding_function=self.embedding_func, metadata=self.collection_meta)
return collection
def get_or_create_collection(self, collection_name:str):
collection = self.chroma_client.get_or_create_collection(name=collection_name, embedding_function=self.embedding_func, metadata=self.collection_meta)
return collection
def delete_collection(self, collection_name:str):
self.chroma_client.delete_collection(collection_name)
return True

View File

@@ -8,80 +8,30 @@ from typing import Any, List, Mapping, Optional, Union
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
sys.path.append(os.path.dirname(os.path.abspath(__file__)))
from util.logging_utils import logger
from instances.logging_instance import logger
import chromadb
from chromadb.config import Settings
from chromadb.api import Collection, Documents, Embeddings
from util.text2vec import Text2VecEmbeddingFunction
from instances.text2vec import Text2VecEmbeddingFunction
from instances.chromadb_instance import client
from config.config_parse import SOLVED_QUERY_COLLECTION_NAME, PRESET_QUERY_COLLECTION_NAME
from util.chromadb_instance import (client,
get_chroma_collection_size, query_chroma_collection,
parse_retrieval_chroma_collection_query, chroma_collection_query_retrieval_format,
get_chroma_collection_by_ids, get_chroma_collection_size,
add_chroma_collection, update_chroma_collection, delete_chroma_collection_by_ids,
empty_chroma_collection_2)
from retriever import ChromaCollectionRetriever, CollectionManager
emb_func = Text2VecEmbeddingFunction()
solved_query_collection = client.get_or_create_collection(name=SOLVED_QUERY_COLLECTION_NAME,
embedding_function=emb_func,
metadata={"hnsw:space": "cosine"}
) # Get a collection object from an existing collection, by name. If it doesn't exist, create it.
logger.info("init_solved_query_collection_size: {}", get_chroma_collection_size(solved_query_collection))
collection_manager = CollectionManager(chroma_client=client, embedding_func=emb_func
,collection_meta={"hnsw:space": "cosine"})
solved_query_collection = collection_manager.get_or_create_collection(collection_name=SOLVED_QUERY_COLLECTION_NAME)
preset_query_collection = collection_manager.get_or_create_collection(collection_name=PRESET_QUERY_COLLECTION_NAME)
preset_query_collection = client.get_or_create_collection(name=PRESET_QUERY_COLLECTION_NAME,
embedding_function=emb_func,
metadata={"hnsw:space": "cosine"}
)
logger.info("init_preset_query_collection_size: {}", get_chroma_collection_size(preset_query_collection))
class ChromaCollectionRetriever(object):
def __init__(self, collection:Collection):
self.collection = collection
def retrieval_query_run(self, query_texts_list:List[str],
filter_condition:Mapping[str,str]=None, n_results:int=5):
retrieval_res = query_chroma_collection(self.collection, query_texts_list,
filter_condition, n_results)
parsed_retrieval_res = parse_retrieval_chroma_collection_query(retrieval_res)
parsed_retrieval_res_format = chroma_collection_query_retrieval_format(query_texts_list, parsed_retrieval_res)
logger.info('parsed_retrieval_res_format: {}', parsed_retrieval_res_format)
return parsed_retrieval_res_format
def get_query_by_ids(self, query_ids:List[str]):
queries = get_chroma_collection_by_ids(self.collection, query_ids)
return queries
def get_query_size(self):
return get_chroma_collection_size(self.collection)
def add_queries(self, query_text_list:List[str],
query_id_list:List[str], metadatas:List[Mapping[str, str]]=None):
add_chroma_collection(self.collection, query_text_list, query_id_list, metadatas)
return True
def update_queries(self, query_text_list:List[str],
query_id_list:List[str], metadatas:List[Mapping[str, str]]=None):
update_chroma_collection(self.collection, query_text_list, query_id_list, metadatas)
return True
def delete_queries_by_ids(self, query_ids:List[str]):
delete_chroma_collection_by_ids(self.collection, query_ids)
return True
def empty_query_collection(self):
self.collection = empty_chroma_collection_2(self.collection)
return True
solved_query_retriever = ChromaCollectionRetriever(solved_query_collection)
preset_query_retriever = ChromaCollectionRetriever(preset_query_collection)
logger.info("init_solved_query_collection_size: {}".format(solved_query_retriever.get_query_size()))
logger.info("init_preset_query_collection_size: {}".format(preset_query_retriever.get_query_size()))

View File

@@ -2,87 +2,64 @@
import os
import sys
from typing import List, Mapping
from chromadb.api import Collection
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
sys.path.append(os.path.dirname(os.path.abspath(__file__)))
from util.logging_utils import logger
from instances.logging_instance import logger
from services.query_retrieval.retriever import ChromaCollectionRetriever
from langchain.vectorstores import Chroma
from langchain.prompts.example_selector import SemanticSimilarityExampleSelector
class FewShotPromptTemplate2(object):
def __init__(self, collection:Collection, few_shot_examples:List[Mapping[str, str]],
retrieval_key:str, few_shot_seperator:str = "\n\n") -> None:
self.collection = collection
self.few_shot_retriever = ChromaCollectionRetriever(self.collection)
from few_shot_example.sql_exampler import examplars as sql_examplars
from util.text2vec import hg_embedding
from util.chromadb_instance import client as chromadb_client, empty_chroma_collection_2
from config.config_parse import TEXT2DSL_COLLECTION_NAME, TEXT2DSL_FEW_SHOTS_EXAMPLE_NUM
self.few_shot_examples = few_shot_examples
self.retrieval_key = retrieval_key
self.few_shot_seperator = few_shot_seperator
def reload_sql_example_collection(
vectorstore: Chroma,
sql_examplars: List[Mapping[str, str]],
sql_example_selector: SemanticSimilarityExampleSelector,
example_nums: int,
):
logger.info("original sql_examples_collection size: {}", vectorstore._collection.count())
new_collection = empty_chroma_collection_2(collection=vectorstore._collection)
vectorstore._collection = new_collection
def add_few_shot_example(self, examples: List[Mapping[str, str]])-> None:
query_text_list = []
query_id_list = []
for idx, example in enumerate(examples):
query_text_list.append(example[self.retrieval_key])
query_id_list.append(str(idx))
logger.info("emptied sql_examples_collection size: {}", vectorstore._collection.count())
self.few_shot_retriever.add_queries(query_text_list=query_text_list, query_id_list=query_id_list, metadatas=examples)
sql_example_selector = SemanticSimilarityExampleSelector(
vectorstore=sql_examples_vectorstore,
k=example_nums,
input_keys=["question"],
example_keys=[
"table_name",
"fields_list",
"prior_schema_links",
"question",
"analysis",
"schema_links",
"current_date",
"sql",
],
)
def reload_few_shot_example(self, examples: List[Mapping[str, str]])-> None:
logger.info(f"original sql_examples_collection size: {self.few_shot_retriever.get_query_size()}")
for example in sql_examplars:
sql_example_selector.add_example(example)
self.few_shot_retriever.empty_query_collection()
logger.info(f"emptied sql_examples_collection size: {self.few_shot_retriever.get_query_size()}")
logger.info("reloaded sql_examples_collection size: {}", vectorstore._collection.count())
self.add_few_shot_example(examples=examples)
logger.info(f"reloaded sql_examples_collection size: {self.few_shot_retriever.get_query_size()}")
return vectorstore, sql_example_selector
def _sub_dict(self, d:Mapping[str, str], keys:List[str])-> Mapping[str, str]:
return {k:d[k] for k in keys if k in d}
def retrieve_few_shot_example(self, query_text: str, retrieval_num: int)-> List[Mapping[str, str]]:
query_text_list = [query_text]
retrieval_res_list = self.few_shot_retriever.retrieval_query_run(query_texts_list=query_text_list,
filter_condition=None, n_results=retrieval_num)
retrieval_res_unit_list = retrieval_res_list[0]['retrieval']
sql_examples_vectorstore = Chroma(
collection_name=TEXT2DSL_COLLECTION_NAME,
embedding_function=hg_embedding,
client=chromadb_client,
)
return retrieval_res_unit_list
example_nums = TEXT2DSL_FEW_SHOTS_EXAMPLE_NUM
def make_few_shot_example_prompt(self, few_shot_template: str, example_keys: List[str],
few_shot_example_meta_list: List[Mapping[str, str]])-> str:
few_shot_example_str_unit_list = []
sql_example_selector = SemanticSimilarityExampleSelector(
vectorstore=sql_examples_vectorstore,
k=example_nums,
input_keys=["question"],
example_keys=[
"table_name",
"fields_list",
"prior_schema_links",
"question",
"analysis",
"schema_links",
"current_date",
"sql",
],
)
retrieval_metas_list = [self._sub_dict(few_shot_example_meta['metadata'], example_keys) for few_shot_example_meta in few_shot_example_meta_list]
if sql_examples_vectorstore._collection.count() > 0:
logger.info("examples already in sql_vectorstore")
logger.info("init sql_vectorstore size: {}", sql_examples_vectorstore._collection.count())
for meta in retrieval_metas_list:
few_shot_example_str_unit_list.append(few_shot_template.format(**meta))
few_shot_example_str = self.few_shot_seperator.join(few_shot_example_str_unit_list)
return few_shot_example_str
logger.info("sql_examplars size: {}", len(sql_examplars))
sql_examples_vectorstore, sql_example_selector = reload_sql_example_collection(
sql_examples_vectorstore, sql_examplars, sql_example_selector, example_nums
)
logger.info("added sql_vectorstore size: {}", sql_examples_vectorstore._collection.count())

View File

@@ -6,41 +6,54 @@ from typing import List, Mapping
import requests
sys.path.append(os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))))
sys.path.append(os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))))
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
sys.path.append(os.path.dirname(os.path.abspath(__file__)))
from config.config_parse import (TEXT2DSL_FEW_SHOTS_EXAMPLE_NUM, TEXT2DSL_IS_SHORTCUT,
LLMPARSER_HOST, LLMPARSER_PORT)
from instances.logging_instance import logger
from config.config_parse import (TEXT2DSL_EXAMPLE_NUM, TEXT2DSL_FEWSHOTS_NUM, TEXT2DSL_SELF_CONSISTENCY_NUM,
LLMPARSER_HOST, LLMPARSER_PORT, TEXT2DSL_IS_SHORTCUT, TEXT2DSL_IS_SELF_CONSISTENCY)
from few_shot_example.sql_exampler import examplars as sql_examplars
from util.logging_utils import logger
def text2dsl_setting_update(
llm_parser_host: str,
llm_parser_port: str,
sql_examplars: List[Mapping[str, str]],
example_nums: int,
is_shortcut: bool,
):
def text2sql_agent_setting_update(llm_host:str, llm_port:str,
sql_examplars:List[Mapping[str, str]], example_nums:int):
url = f"http://{llm_parser_host}:{llm_parser_port}/query2sql_setting_update/"
logger.info("url: {}", url)
payload = {
"sqlExamplars": sql_examplars,
"exampleNums": example_nums,
"isShortcut": is_shortcut,
}
headers = {"content-type": "application/json"}
url = f"http://{llm_host}:{llm_port}/text2sql_agent_setting_update/"
payload = {"sqlExamplars":sql_examplars, "exampleNums":example_nums}
headers = {'content-type': 'application/json'}
response = requests.post(url, data=json.dumps(payload), headers=headers)
logger.info(response.text)
def text2dsl_agent_cs_setting_update(llm_host:str, llm_port:str,
sql_examplars:List[Mapping[str, str]], example_nums:int, fewshot_nums:int, self_consistency_nums:int):
url = f"http://{llm_host}:{llm_port}/texg2sqt_cs_agent_setting_update/"
payload = {"sqlExamplars":sql_examplars,
"exampleNums":example_nums, "fewshotNums":fewshot_nums, "selfConsistencyNums":self_consistency_nums}
headers = {'content-type': 'application/json'}
response = requests.post(url, data=json.dumps(payload), headers=headers)
logger.info(response.text)
def text2dsl_agent_wrapper_setting_update(llm_host:str, llm_port:str,
is_shortcut:bool, is_self_consistency:bool,
sql_examplars:List[Mapping[str, str]], example_nums:int, fewshot_nums:int, self_consistency_nums:int):
url = f"http://{llm_host}:{llm_port}/query2sql_setting_update/"
payload = {"isShortcut":is_shortcut, "isSelfConsistency":is_self_consistency,
"sqlExamplars":sql_examplars,
"exampleNums":example_nums, "fewshotNums":fewshot_nums, "selfConsistencyNums":self_consistency_nums}
headers = {'content-type': 'application/json'}
response = requests.post(url, data=json.dumps(payload), headers=headers)
logger.info(response.text)
if __name__ == "__main__":
text2dsl_setting_update(
LLMPARSER_HOST,
LLMPARSER_PORT,
sql_examplars,
TEXT2DSL_FEW_SHOTS_EXAMPLE_NUM,
TEXT2DSL_IS_SHORTCUT,
)
text2dsl_agent_wrapper_setting_update(LLMPARSER_HOST,LLMPARSER_PORT,
TEXT2DSL_IS_SHORTCUT, TEXT2DSL_IS_SELF_CONSISTENCY,
sql_examplars, TEXT2DSL_EXAMPLE_NUM, TEXT2DSL_FEWSHOTS_NUM, TEXT2DSL_SELF_CONSISTENCY_NUM)

View File

@@ -8,16 +8,14 @@ sys.path.append(os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
sys.path.append(os.path.dirname(os.path.abspath(__file__)))
from util.logging_utils import logger
from instances.logging_instance import logger
def schema_link_parse(schema_link_output):
try:
schema_link_output = schema_link_output.strip()
pattern = r"Schema_links:(.*)"
schema_link_output = re.findall(pattern, schema_link_output, re.DOTALL)[
0
].strip()
schema_link_output = re.findall(pattern, schema_link_output, re.DOTALL)[0].strip()
except Exception as e:
logger.exception(e)
schema_link_output = None

View File

@@ -1,166 +0,0 @@
# -*- coding:utf-8 -*-
import os
import sys
from typing import List, Mapping
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
sys.path.append(os.path.dirname(os.path.abspath(__file__)))
from util.logging_utils import logger
from langchain.prompts import PromptTemplate
from langchain.prompts.few_shot import FewShotPromptTemplate
from langchain.prompts.example_selector import SemanticSimilarityExampleSelector
def schema_linking_exampler(
user_query: str,
domain_name: str,
fields_list: List[str],
prior_schema_links: Mapping[str, str],
example_selector: SemanticSimilarityExampleSelector,
) -> str:
prior_schema_links_str = (
"["
+ ",".join(["""'{}'->{}""".format(k, v) for k, v in prior_schema_links.items()])
+ "]"
)
example_prompt_template = PromptTemplate(
input_variables=[
"table_name",
"fields_list",
"prior_schema_links",
"question",
"analysis",
"schema_links",
],
template="Table {table_name}, columns = {fields_list}, prior_schema_links = {prior_schema_links}\n问题:{question}\n分析:{analysis} 所以Schema_links是:\nSchema_links:{schema_links}",
)
instruction = "# 根据数据库的表结构,参考先验信息,找出为每个问题生成SQL查询语句的schema_links"
schema_linking_prompt = "Table {table_name}, columns = {fields_list}, prior_schema_links = {prior_schema_links}\n问题:{question}\n分析: 让我们一步一步地思考。"
schema_linking_example_prompt_template = FewShotPromptTemplate(
example_selector=example_selector,
example_prompt=example_prompt_template,
example_separator="\n\n",
prefix=instruction,
input_variables=["table_name", "fields_list", "prior_schema_links", "question"],
suffix=schema_linking_prompt,
)
schema_linking_example_prompt = schema_linking_example_prompt_template.format(
table_name=domain_name,
fields_list=fields_list,
prior_schema_links=prior_schema_links_str,
question=user_query,
)
return schema_linking_example_prompt
def sql_exampler(
user_query: str,
domain_name: str,
schema_link_str: str,
data_date: str,
example_selector: SemanticSimilarityExampleSelector,
) -> str:
instruction = "# 根据schema_links为每个问题生成SQL查询语句"
sql_example_prompt_template = PromptTemplate(
input_variables=[
"question",
"current_date",
"table_name",
"schema_links",
"sql",
],
template="问题:{question}\nCurrent_date:{current_date}\nTable {table_name}\nSchema_links:{schema_links}\nSQL:{sql}",
)
sql_prompt = "问题:{question}\nCurrent_date:{current_date}\nTable {table_name}\nSchema_links:{schema_links}\nSQL:"
sql_example_prompt_template = FewShotPromptTemplate(
example_selector=example_selector,
example_prompt=sql_example_prompt_template,
example_separator="\n\n",
prefix=instruction,
input_variables=["question", "current_date", "table_name", "schema_links"],
suffix=sql_prompt,
)
sql_example_prompt = sql_example_prompt_template.format(
question=user_query,
current_date=data_date,
table_name=domain_name,
schema_links=schema_link_str,
)
return sql_example_prompt
def schema_linking_sql_combo_examplar(
user_query: str,
domain_name: str,
data_date: str,
fields_list: List[str],
prior_schema_links: Mapping[str, str],
example_selector: SemanticSimilarityExampleSelector,
) -> str:
prior_schema_links_str = (
"["
+ ",".join(["""'{}'->{}""".format(k, v) for k, v in prior_schema_links.items()])
+ "]"
)
example_prompt_template = PromptTemplate(
input_variables=[
"table_name",
"fields_list",
"prior_schema_links",
"current_date",
"question",
"analysis",
"schema_links",
"sql",
],
template="Table {table_name}, columns = {fields_list}, prior_schema_links = {prior_schema_links}\nCurrent_date:{current_date}\n问题:{question}\n分析:{analysis} 所以Schema_links是:\nSchema_links:{schema_links}\nSQL:{sql}",
)
instruction = (
"# 根据数据库的表结构,参考先验信息,找出为每个问题生成SQL查询语句的schema_links,再根据schema_links为每个问题生成SQL查询语句"
)
schema_linking_sql_combo_prompt = "Table {table_name}, columns = {fields_list}, prior_schema_links = {prior_schema_links}\nCurrent_date:{current_date}\n问题:{question}\n分析: 让我们一步一步地思考。"
schema_linking_sql_combo_example_prompt_template = FewShotPromptTemplate(
example_selector=example_selector,
example_prompt=example_prompt_template,
example_separator="\n\n",
prefix=instruction,
input_variables=[
"table_name",
"fields_list",
"prior_schema_links",
"current_date",
"question",
],
suffix=schema_linking_sql_combo_prompt,
)
schema_linking_sql_combo_example_prompt = (
schema_linking_sql_combo_example_prompt_template.format(
table_name=domain_name,
fields_list=fields_list,
prior_schema_links=prior_schema_links_str,
current_date=data_date,
question=user_query,
)
)
return schema_linking_sql_combo_example_prompt

View File

@@ -1,188 +1,56 @@
# -*- coding:utf-8 -*-
import asyncio
import os
import sys
from typing import List, Union, Mapping
sys.path.append(os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))))
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
sys.path.append(os.path.dirname(os.path.abspath(__file__)))
from util.logging_utils import logger
from sql.constructor import FewShotPromptTemplate2
from sql.sql_agent import Text2DSLAgent, Text2DSLAgentConsistency, Text2DSLAgentWrapper
from sql.prompt_maker import (
schema_linking_exampler,
sql_exampler,
schema_linking_sql_combo_examplar,
)
from sql.constructor import (
sql_examples_vectorstore,
sql_example_selector,
reload_sql_example_collection,
)
from sql.output_parser import (
schema_link_parse,
combo_schema_link_parse,
combo_sql_parse,
)
from instances.llm_instance import llm
from instances.text2vec import Text2VecEmbeddingFunction
from instances.chromadb_instance import client
from instances.logging_instance import logger
from util.llm_instance import llm
from config.config_parse import TEXT2DSL_IS_SHORTCUT
from few_shot_example.sql_exampler import examplars as sql_examplars
from config.config_parse import (TEXT2DSLAGENT_COLLECTION_NAME, TEXT2DSLAGENTCS_COLLECTION_NAME,
TEXT2DSL_EXAMPLE_NUM, TEXT2DSL_FEWSHOTS_NUM, TEXT2DSL_SELF_CONSISTENCY_NUM,
TEXT2DSL_IS_SHORTCUT, TEXT2DSL_IS_SELF_CONSISTENCY)
class Text2DSLAgent(object):
def __init__(self):
self.schema_linking_exampler = schema_linking_exampler
self.sql_exampler = sql_exampler
emb_func = Text2VecEmbeddingFunction()
text2dsl_agent_collection = client.get_or_create_collection(name=TEXT2DSLAGENT_COLLECTION_NAME,
embedding_function=emb_func,
metadata={"hnsw:space": "cosine"})
text2dsl_agentcs_collection = client.get_or_create_collection(name=TEXT2DSLAGENTCS_COLLECTION_NAME,
embedding_function=emb_func,
metadata={"hnsw:space": "cosine"})
self.schema_linking_sql_combo_exampler = schema_linking_sql_combo_examplar
text2dsl_agent_example_prompter = FewShotPromptTemplate2(collection=text2dsl_agent_collection,
few_shot_examples=sql_examplars,
retrieval_key="question",
few_shot_seperator='\n\n')
self.sql_examples_vectorstore = sql_examples_vectorstore
self.sql_example_selector = sql_example_selector
text2dsl_agentcs_example_prompter = FewShotPromptTemplate2(collection=text2dsl_agentcs_collection,
few_shot_examples=sql_examplars,
retrieval_key="question",
few_shot_seperator='\n\n')
self.schema_link_parse = schema_link_parse
self.combo_schema_link_parse = combo_schema_link_parse
self.combo_sql_parse = combo_sql_parse
text2sql_agent = Text2DSLAgent(num_fewshots=TEXT2DSL_EXAMPLE_NUM,
sql_example_prompter=text2dsl_agent_example_prompter, llm=llm)
self.llm = llm
text2sql_cs_agent = Text2DSLAgentConsistency(num_fewshots=TEXT2DSL_FEWSHOTS_NUM, num_examples=TEXT2DSL_EXAMPLE_NUM, num_self_consistency=TEXT2DSL_SELF_CONSISTENCY_NUM,
sql_example_prompter=text2dsl_agentcs_example_prompter, llm=llm)
self.is_shortcut = TEXT2DSL_IS_SHORTCUT
text2sql_agent.update_examples(sql_examplars, TEXT2DSL_EXAMPLE_NUM)
def update_examples(self, sql_examples, example_nums, is_shortcut):
(
self.sql_examples_vectorstore,
self.sql_example_selector,
) = reload_sql_example_collection(
self.sql_examples_vectorstore,
sql_examples,
self.sql_example_selector,
example_nums,
)
self.is_shortcut = is_shortcut
def query2sql(
self,
query_text: str,
schema: Union[dict, None] = None,
current_date: str = None,
linking: Union[List[Mapping[str, str]], None] = None,
):
logger.info("query_text: {}".format(query_text))
logger.info("schema: {}".format(schema))
logger.info("current_date: {}".format(current_date))
logger.info("prior_schema_links: {}".format(linking))
if linking is not None:
prior_schema_links = {
item["fieldValue"]: item["fieldName"] for item in linking
}
else:
prior_schema_links = {}
model_name = schema["modelName"]
fields_list = schema["fieldNameList"]
schema_linking_prompt = self.schema_linking_exampler(
query_text,
model_name,
fields_list,
prior_schema_links,
self.sql_example_selector,
)
logger.info("schema_linking_prompt-> {}".format(schema_linking_prompt))
schema_link_output = self.llm(schema_linking_prompt)
schema_link_str = self.schema_link_parse(schema_link_output)
sql_prompt = self.sql_exampler(
query_text,
model_name,
schema_link_str,
current_date,
self.sql_example_selector,
)
logger.info("sql_prompt-> {}".format(sql_prompt))
sql_output = self.llm(sql_prompt)
resp = dict()
resp["query"] = query_text
resp["model"] = model_name
resp["fields"] = fields_list
resp["priorSchemaLinking"] = linking
resp["dataDate"] = current_date
resp["analysisOutput"] = schema_link_output
resp["schemaLinkStr"] = schema_link_str
resp["sqlOutput"] = sql_output
logger.info("resp: {}".format(resp))
return resp
def query2sqlcombo(
self,
query_text: str,
schema: Union[dict, None] = None,
current_date: str = None,
linking: Union[List[Mapping[str, str]], None] = None,
):
logger.info("query_text: {}".format(query_text))
logger.info("schema: {}".format(schema))
logger.info("current_date: {}".format(current_date))
logger.info("prior_schema_links: {}".format(linking))
if linking is not None:
prior_schema_links = {
item["fieldValue"]: item["fieldName"] for item in linking
}
else:
prior_schema_links = {}
model_name = schema["modelName"]
fields_list = schema["fieldNameList"]
schema_linking_sql_combo_prompt = self.schema_linking_sql_combo_exampler(
query_text,
model_name,
current_date,
fields_list,
prior_schema_links,
self.sql_example_selector,
)
logger.info("schema_linking_sql_combo_prompt-> {}".format(schema_linking_sql_combo_prompt))
schema_linking_sql_combo_output = self.llm(schema_linking_sql_combo_prompt)
schema_linking_str = self.combo_schema_link_parse(
schema_linking_sql_combo_output
)
sql_str = self.combo_sql_parse(schema_linking_sql_combo_output)
resp = dict()
resp["query"] = query_text
resp["model"] = model_name
resp["fields"] = fields_list
resp["priorSchemaLinking"] = prior_schema_links
resp["dataDate"] = current_date
resp["analysisOutput"] = schema_linking_sql_combo_output
resp["schemaLinkStr"] = schema_linking_str
resp["sqlOutput"] = sql_str
logger.info("resp: {}".format(resp))
return resp
def query2sql_run(
self,
query_text: str,
schema: Union[dict, None] = None,
current_date: str = None,
linking: Union[List[Mapping[str, str]], None] = None,
):
if self.is_shortcut:
return self.query2sqlcombo(query_text, schema, current_date, linking)
else:
return self.query2sql(query_text, schema, current_date, linking)
text2sql_cs_agent.update_examples(sql_examplars, TEXT2DSL_EXAMPLE_NUM, TEXT2DSL_FEWSHOTS_NUM, TEXT2DSL_SELF_CONSISTENCY_NUM)
text2sql_agent = Text2DSLAgent()
text2sql_agent_router = Text2DSLAgentWrapper(sql_agent=text2sql_agent, sql_agent_cs=text2sql_cs_agent,
is_shortcut=TEXT2DSL_IS_SHORTCUT, is_self_consistency=TEXT2DSL_IS_SELF_CONSISTENCY)

View File

@@ -0,0 +1,380 @@
import os
import sys
from typing import List, Union, Mapping, Any
from collections import Counter
import random
import asyncio
from langchain.llms.base import BaseLLM
sys.path.append(os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))))
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
sys.path.append(os.path.dirname(os.path.abspath(__file__)))
from instances.logging_instance import logger
from sql.constructor import FewShotPromptTemplate2
from sql.output_parser import schema_link_parse, combo_schema_link_parse, combo_sql_parse
class Text2DSLAgent(object):
def __init__(self, num_fewshots:int,
sql_example_prompter:FewShotPromptTemplate2,
llm: BaseLLM):
self.num_fewshots = num_fewshots
self.sql_example_prompter = sql_example_prompter
self.llm = llm
def update_examples(self, sql_examplars, num_fewshots):
self.num_fewshots = num_fewshots
self.sql_example_prompter.reload_few_shot_example(sql_examplars)
def get_fewshot_examples(self, query_text: str)->List[Mapping[str, str]]:
few_shot_example_meta_list = self.sql_example_prompter.retrieve_few_shot_example(query_text, self.num_fewshots)
return few_shot_example_meta_list
def generate_schema_linking_prompt(self, user_query: str, domain_name: str, fields_list: List[str],
prior_schema_links: Mapping[str,str], fewshot_example_list:List[Mapping[str, str]])-> str:
prior_schema_links_str = '['+ ','.join(["""'{}'->{}""".format(k,v) for k,v in prior_schema_links.items()]) + ']'
instruction = "# 根据数据库的表结构,参考先验信息,找出为每个问题生成SQL查询语句的schema_links"
schema_linking_example_keys = ["table_name", "fields_list", "prior_schema_links", "question", "analysis", "schema_links"]
schema_linking_example_template = "Table {table_name}, columns = {fields_list}, prior_schema_links = {prior_schema_links}\n问题:{question}\n分析:{analysis} 所以Schema_links是:\nSchema_links:{schema_links}"
schema_linking_fewshot_prompt = self.sql_example_prompter.make_few_shot_example_prompt(few_shot_template=schema_linking_example_template,
example_keys=schema_linking_example_keys,
few_shot_example_meta_list=fewshot_example_list)
new_case_template = "Table {table_name}, columns = {fields_list}, prior_schema_links = {prior_schema_links}\n问题:{question}\n分析: 让我们一步一步地思考。"
new_case_prompt = new_case_template.format(table_name=domain_name, fields_list=fields_list, prior_schema_links=prior_schema_links_str, question=user_query)
schema_linking_prompt = instruction + '\n\n' + schema_linking_fewshot_prompt + '\n\n' + new_case_prompt
return schema_linking_prompt
def generate_sql_prompt(self, user_query: str, domain_name: str,
schema_link_str: str, data_date: str,
fewshot_example_list:List[Mapping[str, str]])-> str:
instruction = "# 根据schema_links为每个问题生成SQL查询语句"
sql_example_keys = ["question", "current_date", "table_name", "schema_links", "sql"]
sql_example_template = "问题:{question}\nCurrent_date:{current_date}\nTable {table_name}\nSchema_links:{schema_links}\nSQL:{sql}"
sql_example_fewshot_prompt = self.sql_example_prompter.make_few_shot_example_prompt(few_shot_template=sql_example_template,
example_keys=sql_example_keys,
few_shot_example_meta_list=fewshot_example_list)
new_case_template = "问题:{question}\nCurrent_date:{current_date}\nTable {table_name}\nSchema_links:{schema_links}\nSQL:"
new_case_prompt = new_case_template.format(question=user_query, current_date=data_date, table_name=domain_name, schema_links=schema_link_str)
sql_example_prompt = instruction + '\n\n' + sql_example_fewshot_prompt + '\n\n' + new_case_prompt
return sql_example_prompt
def generate_schema_linking_sql_prompt(self, user_query: str,
domain_name: str,
data_date : str,
fields_list: List[str],
prior_schema_links: Mapping[str,str],
fewshot_example_list:List[Mapping[str, str]]):
prior_schema_links_str = '['+ ','.join(["""'{}'->{}""".format(k,v) for k,v in prior_schema_links.items()]) + ']'
instruction = "# 根据数据库的表结构,参考先验信息,找出为每个问题生成SQL查询语句的schema_links,再根据schema_links为每个问题生成SQL查询语句"
example_keys = ["table_name", "fields_list", "prior_schema_links", "current_date", "question", "analysis", "schema_links", "sql"]
example_template = "Table {table_name}, columns = {fields_list}, prior_schema_links = {prior_schema_links}\nCurrent_date:{current_date}\n问题:{question}\n分析:{analysis} 所以Schema_links是:\nSchema_links:{schema_links}\nSQL:{sql}"
fewshot_prompt = self.sql_example_prompter.make_few_shot_example_prompt(few_shot_template=example_template,
example_keys=example_keys,
few_shot_example_meta_list=fewshot_example_list)
new_case_template = "Table {table_name}, columns = {fields_list}, prior_schema_links = {prior_schema_links}\nCurrent_date:{current_date}\n问题:{question}\n分析: 让我们一步一步地思考。"
new_case_prompt = new_case_template.format(table_name=domain_name, fields_list=fields_list, prior_schema_links=prior_schema_links_str, current_date=data_date, question=user_query)
prompt = instruction + '\n\n' + fewshot_prompt + '\n\n' + new_case_prompt
return prompt
async def async_query2sql(self, query_text: str,
model_name: str, fields_list: List[str],
data_date: str, prior_schema_links: Mapping[str,str], prior_exts: str):
logger.info("query_text: {}".format(query_text))
logger.info("model_name: {}".format(model_name))
logger.info("fields_list: {}".format(fields_list))
logger.info("data_date: {}".format(data_date))
logger.info("prior_schema_links: {}".format(prior_schema_links))
logger.info("prior_exts: {}".format(prior_exts))
query_text = query_text + ' 备注:'+prior_exts
logger.info("query_text_prior_exts: {}".format(query_text))
fewshot_example_meta_list = self.get_fewshot_examples(query_text)
schema_linking_prompt = self.generate_schema_linking_prompt(query_text, model_name, fields_list, prior_schema_links, fewshot_example_meta_list)
logger.debug("schema_linking_prompt->{}".format(schema_linking_prompt))
schema_link_output = await self.llm._call_async(schema_linking_prompt)
schema_link_str = schema_link_parse(schema_link_output)
sql_prompt = self.generate_sql_prompt(query_text, model_name, schema_link_str, data_date, fewshot_example_meta_list)
logger.debug("sql_prompt->{}".format(sql_prompt))
sql_output = await self.llm._call_async(sql_prompt)
resp = dict()
resp['query'] = query_text
resp['model'] = model_name
resp['fields'] = fields_list
resp['priorSchemaLinking'] = prior_schema_links
resp['dataDate'] = data_date
resp['schemaLinkingOutput'] = schema_link_output
resp['schemaLinkStr'] = schema_link_str
resp['sqlOutput'] = sql_output
logger.info("resp: {}".format(resp))
return resp
async def async_query2sql_shortcut(self, query_text: str,
model_name: str, fields_list: List[str],
data_date: str, prior_schema_links: Mapping[str,str], prior_exts: str):
logger.info("query_text: {}".format(query_text))
logger.info("model_name: {}".format(model_name))
logger.info("fields_list: {}".format(fields_list))
logger.info("data_date: {}".format(data_date))
logger.info("prior_schema_links: {}".format(prior_schema_links))
logger.info("prior_exts: {}".format(prior_exts))
query_text = query_text + ' 备注:'+prior_exts
logger.info("query_text_prior_exts: {}".format(query_text))
fewshot_example_meta_list = self.get_fewshot_examples(query_text)
schema_linking_sql_shortcut_prompt = self.generate_schema_linking_sql_prompt(query_text, model_name, data_date, fields_list, prior_schema_links, fewshot_example_meta_list)
logger.debug("schema_linking_sql_shortcut_prompt->{}".format(schema_linking_sql_shortcut_prompt))
schema_linking_sql_shortcut_output = await self.llm._call_async(schema_linking_sql_shortcut_prompt)
schema_linking_str = combo_schema_link_parse(schema_linking_sql_shortcut_output)
sql_str = combo_sql_parse(schema_linking_sql_shortcut_output)
resp = dict()
resp['query'] = query_text
resp['model'] = model_name
resp['fields'] = fields_list
resp['priorSchemaLinking'] = prior_schema_links
resp['dataDate'] = data_date
resp['schemaLinkingComboOutput'] = schema_linking_sql_shortcut_output
resp['schemaLinkStr'] = schema_linking_str
resp['sqlOutput'] = sql_str
logger.info("resp: {}".format(resp))
return resp
class Text2DSLAgentConsistency(object):
def __init__(self, num_fewshots:int, num_examples:int, num_self_consistency:int,
sql_example_prompter:FewShotPromptTemplate2, llm: BaseLLM) -> None:
self.num_fewshots = num_fewshots
self.num_examples = num_examples
assert self.num_fewshots <= self.num_examples
self.num_self_consistency = num_self_consistency
self.llm = llm
self.sql_example_prompter = sql_example_prompter
def update_examples(self, sql_examplars, num_examples, num_fewshots, num_self_consistency):
self.num_fewshots = num_fewshots
self.num_examples = num_examples
assert self.num_fewshots <= self.num_examples
self.num_self_consistency = num_self_consistency
assert self.num_self_consistency >= 1
self.sql_example_prompter.reload_few_shot_example(sql_examplars)
def get_examples_candidates(self, query_text: str)->List[Mapping[str, str]]:
few_shot_example_meta_list = self.sql_example_prompter.retrieve_few_shot_example(query_text, self.num_examples)
return few_shot_example_meta_list
def get_fewshot_example_combos(self, example_meta_list:List[Mapping[str, str]])-> List[List[Mapping[str, str]]]:
fewshot_example_list = []
for i in range(0, self.num_self_consistency):
random.shuffle(example_meta_list)
fewshot_example_list.append(example_meta_list[:self.num_fewshots])
return fewshot_example_list
def generate_schema_linking_prompt(self, user_query: str, domain_name: str, fields_list: List[str],
prior_schema_links: Mapping[str,str], fewshot_example_list:List[Mapping[str, str]])-> str:
prior_schema_links_str = '['+ ','.join(["""'{}'->{}""".format(k,v) for k,v in prior_schema_links.items()]) + ']'
instruction = "# 根据数据库的表结构,参考先验信息,找出为每个问题生成SQL查询语句的schema_links"
schema_linking_example_keys = ["table_name", "fields_list", "prior_schema_links", "question", "analysis", "schema_links"]
schema_linking_example_template = "Table {table_name}, columns = {fields_list}, prior_schema_links = {prior_schema_links}\n问题:{question}\n分析:{analysis} 所以Schema_links是:\nSchema_links:{schema_links}"
schema_linking_fewshot_prompt = self.sql_example_prompter.make_few_shot_example_prompt(few_shot_template=schema_linking_example_template,
example_keys=schema_linking_example_keys,
few_shot_example_meta_list=fewshot_example_list)
new_case_template = "Table {table_name}, columns = {fields_list}, prior_schema_links = {prior_schema_links}\n问题:{question}\n分析: 让我们一步一步地思考。"
new_case_prompt = new_case_template.format(table_name=domain_name, fields_list=fields_list, prior_schema_links=prior_schema_links_str, question=user_query)
schema_linking_prompt = instruction + '\n\n' + schema_linking_fewshot_prompt + '\n\n' + new_case_prompt
return schema_linking_prompt
def generate_schema_linking_prompt_pool(self, user_query: str, domain_name: str, fields_list: List[str],
prior_schema_links: Mapping[str,str], fewshot_example_list_pool:List[List[Mapping[str, str]]])-> List[str]:
schema_linking_prompt_pool = []
for fewshot_example_list in fewshot_example_list_pool:
schema_linking_prompt = self.generate_schema_linking_prompt(user_query, domain_name, fields_list, prior_schema_links, fewshot_example_list)
schema_linking_prompt_pool.append(schema_linking_prompt)
return schema_linking_prompt_pool
def generate_sql_prompt(self, user_query: str, domain_name: str,
schema_link_str: str, data_date: str,
fewshot_example_list:List[Mapping[str, str]])-> str:
instruction = "# 根据schema_links为每个问题生成SQL查询语句"
sql_example_keys = ["question", "current_date", "table_name", "schema_links", "sql"]
sql_example_template = "问题:{question}\nCurrent_date:{current_date}\nTable {table_name}\nSchema_links:{schema_links}\nSQL:{sql}"
sql_example_fewshot_prompt = self.sql_example_prompter.make_few_shot_example_prompt(few_shot_template=sql_example_template,
example_keys=sql_example_keys,
few_shot_example_meta_list=fewshot_example_list)
new_case_template = "问题:{question}\nCurrent_date:{current_date}\nTable {table_name}\nSchema_links:{schema_links}\nSQL:"
new_case_prompt = new_case_template.format(question=user_query, current_date=data_date, table_name=domain_name, schema_links=schema_link_str)
sql_example_prompt = instruction + '\n\n' + sql_example_fewshot_prompt + '\n\n' + new_case_prompt
return sql_example_prompt
def generate_sql_prompt_pool(self, user_query: str, domain_name: str, data_date: str,
schema_link_str_pool: List[str], fewshot_example_list_pool:List[List[Mapping[str, str]]])-> List[str]:
sql_prompt_pool = []
for schema_link_str, fewshot_example_list in zip(schema_link_str_pool, fewshot_example_list_pool):
sql_prompt = self.generate_sql_prompt(user_query, domain_name, schema_link_str, data_date, fewshot_example_list)
sql_prompt_pool.append(sql_prompt)
return sql_prompt_pool
def self_consistency_vote(self, output_res_pool:List[str]):
output_res_counts = Counter(output_res_pool)
output_res_max = output_res_counts.most_common(1)[0][0]
total_output_num = len(output_res_pool)
vote_percentage = {k: (v/total_output_num) for k,v in output_res_counts.items()}
return output_res_max, vote_percentage
def schema_linking_list_str_unify(self, schema_linking_list: List[str])-> List[str]:
schema_linking_list_unify = []
for schema_linking_str in schema_linking_list:
schema_linking_str_unify = ','.join(sorted([item.strip() for item in schema_linking_str.strip('[]').split(',')]))
schema_linking_str_unify = f'[{schema_linking_str_unify}]'
schema_linking_list_unify.append(schema_linking_str_unify)
return schema_linking_list_unify
async def generate_schema_linking_tasks(self, user_query: str, domain_name: str,
fields_list: List[str], prior_schema_links: Mapping[str,str],
fewshot_example_list_combo:List[List[Mapping[str, str]]]):
schema_linking_prompt_pool = self.generate_schema_linking_prompt_pool(user_query, domain_name,
fields_list, prior_schema_links,
fewshot_example_list_combo)
schema_linking_output_task_pool = [self.llm._call_async(schema_linking_prompt) for schema_linking_prompt in schema_linking_prompt_pool]
schema_linking_output_res_pool = await asyncio.gather(*schema_linking_output_task_pool)
logger.debug(f'schema_linking_output_res_pool:{schema_linking_output_res_pool}')
return schema_linking_output_res_pool
async def generate_sql_tasks(self, user_query: str, domain_name: str, data_date: str,
schema_link_str_pool: List[str], fewshot_example_list_combo:List[List[Mapping[str, str]]]):
sql_prompt_pool = self.generate_sql_prompt_pool(user_query, domain_name, schema_link_str_pool, data_date, fewshot_example_list_combo)
sql_output_task_pool = [self.llm._call_async(sql_prompt) for sql_prompt in sql_prompt_pool]
sql_output_res_pool = await asyncio.gather(*sql_output_task_pool)
logger.debug(f'sql_output_res_pool:{sql_output_res_pool}')
return sql_output_res_pool
async def tasks_run(self, user_query: str, domain_name: str, fields_list: List[str], prior_schema_links: Mapping[str,str], data_date: str, prior_exts: str):
logger.info("user_query: {}".format(user_query))
logger.info("domain_name: {}".format(domain_name))
logger.info("fields_list: {}".format(fields_list))
logger.info("current_date: {}".format(data_date))
logger.info("prior_schema_links: {}".format(prior_schema_links))
logger.info("prior_exts: {}".format(prior_exts))
user_query = user_query + ' 备注:'+prior_exts
logger.info("user_query_prior_exts: {}".format(user_query))
fewshot_example_meta_list = self.get_examples_candidates(user_query)
fewshot_example_list_combo = self.get_fewshot_example_combos(fewshot_example_meta_list)
schema_linking_output_candidates = await self.generate_schema_linking_tasks(user_query, domain_name, fields_list, prior_schema_links, fewshot_example_list_combo)
schema_linking_candidate_list = [schema_link_parse(schema_linking_output_candidate) for schema_linking_output_candidate in schema_linking_output_candidates]
logger.debug(f'schema_linking_candidate_list:{schema_linking_candidate_list}')
schema_linking_candidate_sorted_list = self.schema_linking_list_str_unify(schema_linking_candidate_list)
logger.debug(f'schema_linking_candidate_sorted_list:{schema_linking_candidate_sorted_list}')
schema_linking_output_max, schema_linking_output_vote_percentage = self.self_consistency_vote(schema_linking_candidate_sorted_list)
sql_output_candicates = await self.generate_sql_tasks(user_query, domain_name, data_date, schema_linking_candidate_list,fewshot_example_list_combo)
logger.debug(f'sql_output_candicates:{sql_output_candicates}')
sql_output_max, sql_output_vote_percentage = self.self_consistency_vote(sql_output_candicates)
resp = dict()
resp['query'] = user_query
resp['model'] = domain_name
resp['fields'] = fields_list
resp['priorSchemaLinking'] = prior_schema_links
resp['dataDate'] = data_date
resp['schemaLinkStr'] = schema_linking_output_max
resp['schemaLinkingWeight'] = schema_linking_output_vote_percentage
resp['sqlOutput'] = sql_output_max
resp['sqlWeight'] = sql_output_vote_percentage
logger.info("resp: {}".format(resp))
return resp
class Text2DSLAgentWrapper(object):
def __init__(self, sql_agent:Text2DSLAgent, sql_agent_cs:Text2DSLAgentConsistency,
is_shortcut:bool, is_self_consistency:bool):
self.sql_agent = sql_agent
self.sql_agent_cs = sql_agent_cs
self.is_shortcut = is_shortcut
self.is_self_consistency = is_self_consistency
async def async_query2sql(self, query_text: str,
model_name: str, fields_list: List[str],
data_date: str, prior_schema_links: Mapping[str,str], prior_exts: str):
if self.is_self_consistency:
logger.info("sql wrapper: self_consistency")
resp = await self.sql_agent_cs.tasks_run(user_query=query_text, domain_name=model_name, fields_list=fields_list, prior_schema_links=prior_schema_links, data_date=data_date, prior_exts=prior_exts)
return resp
elif self.is_shortcut:
logger.info("sql wrapper: shortcut")
resp = await self.sql_agent.async_query2sql_shortcut(query_text=query_text, model_name=model_name, fields_list=fields_list, data_date=data_date, prior_schema_links=prior_schema_links, prior_exts=prior_exts)
return resp
else:
logger.info("sql wrapper: normal")
resp = await self.sql_agent.async_query2sql(query_text=query_text, model_name=model_name, fields_list=fields_list, data_date=data_date, prior_schema_links=prior_schema_links, prior_exts=prior_exts)
return resp
def update_configs(self, is_shortcut, is_self_consistency,
sql_examplars, num_examples, num_fewshots, num_self_consistency):
self.is_shortcut = is_shortcut
self.is_self_consistency = is_self_consistency
self.sql_agent.update_examples(sql_examplars=sql_examplars, num_fewshots=num_examples)
self.sql_agent_cs.update_examples(sql_examplars=sql_examplars, num_examples=num_examples, num_fewshots=num_fewshots, num_self_consistency=num_self_consistency)

View File

@@ -8,59 +8,81 @@ sys.path.append(os.path.dirname(os.path.abspath(__file__)))
from fastapi import APIRouter, Depends, HTTPException
from services.sql.run import text2sql_agent
from services.sql.run import text2sql_agent_router
router = APIRouter()
@router.post("/query2sql/")
def din_query2sql(query_body: Mapping[str, Any]):
if "queryText" not in query_body:
@router.post("/query2sql")
async def query2sql(query_body: Mapping[str, Any]):
if 'queryText' not in query_body:
raise HTTPException(status_code=400, detail="query_text is not in query_body")
else:
query_text = query_body["queryText"]
query_text = query_body['queryText']
if "schema" not in query_body:
if 'schema' not in query_body:
raise HTTPException(status_code=400, detail="schema is not in query_body")
else:
schema = query_body["schema"]
if "currentDate" not in query_body:
schema = query_body['schema']
if 'currentDate' not in query_body:
raise HTTPException(status_code=400, detail="currentDate is not in query_body")
else:
current_date = query_body["currentDate"]
current_date = query_body['currentDate']
if "linking" not in query_body:
linking = None
if 'linking' not in query_body:
raise HTTPException(status_code=400, detail="linking is not in query_body")
else:
linking = query_body["linking"]
linking = query_body['linking']
resp = text2sql_agent.query2sql_run(
query_text=query_text, schema=schema, current_date=current_date, linking=linking
)
if 'priorExts' not in query_body:
raise HTTPException(status_code=400, detail="prior_exts is not in query_body")
else:
prior_exts = query_body['priorExts']
model_name = schema['modelName']
fields_list = schema['fieldNameList']
prior_schema_links = {item['fieldValue']:item['fieldName'] for item in linking}
resp = await text2sql_agent_router.async_query2sql(query_text=query_text, model_name=model_name, fields_list=fields_list,
data_date=current_date, prior_schema_links=prior_schema_links, prior_exts=prior_exts)
return resp
@router.post("/query2sql_setting_update/")
@router.post("/query2sql_setting_update")
def query2sql_setting_update(query_body: Mapping[str, Any]):
if "sqlExamplars" not in query_body:
if 'sqlExamplars' not in query_body:
raise HTTPException(status_code=400, detail="sqlExamplars is not in query_body")
else:
sql_examplars = query_body["sqlExamplars"]
sql_examplars = query_body['sqlExamplars']
if "exampleNums" not in query_body:
if 'exampleNums' not in query_body:
raise HTTPException(status_code=400, detail="exampleNums is not in query_body")
else:
example_nums = query_body["exampleNums"]
example_nums = query_body['exampleNums']
if "isShortcut" not in query_body:
if 'fewshotNums' not in query_body:
raise HTTPException(status_code=400, detail="fewshotNums is not in query_body")
else:
fewshot_nums = query_body['fewshotNums']
if 'selfConsistencyNums' not in query_body:
raise HTTPException(status_code=400, detail="selfConsistencyNums is not in query_body")
else:
self_consistency_nums = query_body['selfConsistencyNums']
if 'isShortcut' not in query_body:
raise HTTPException(status_code=400, detail="isShortcut is not in query_body")
else:
is_shortcut = query_body["isShortcut"]
is_shortcut = query_body['isShortcut']
text2sql_agent.update_examples(
sql_examples=sql_examplars, example_nums=example_nums, is_shortcut=is_shortcut
)
if 'isSelfConsistency' not in query_body:
raise HTTPException(status_code=400, detail="isSelfConsistency is not in query_body")
else:
is_self_consistency = query_body['isSelfConsistency']
text2sql_agent_router.update_configs(is_shortcut=is_shortcut, is_self_consistency=is_self_consistency, sql_examplars=sql_examplars,
num_examples=example_nums, num_fewshots=fewshot_nums, num_self_consistency=self_consistency_nums)
return "success"

View File

@@ -1,4 +0,0 @@
from loguru import logger
import sys
# logger.add(sys.stdout, format="{time:YYYY-MM-DD at HH:mm:ss} | {level} | {message}")

View File

@@ -4,22 +4,14 @@ from typing import Any, List, Mapping, Optional, Union
import chromadb
from chromadb.api import Collection
from chromadb.config import Settings
from chromadb.api import Collection, Documents, Embeddings
import os
import sys
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
sys.path.append(os.path.dirname(os.path.abspath(__file__)))
from util.logging_utils import logger
from config.config_parse import CHROMA_DB_PERSIST_PATH
client = chromadb.Client(
Settings(
chroma_db_impl="duckdb+parquet",
persist_directory=CHROMA_DB_PERSIST_PATH, # Optional, defaults to .chromadb/ in the current directory
)
)
from instances.logging_instance import logger
def empty_chroma_collection_2(collection:Collection):
@@ -48,26 +40,30 @@ def empty_chroma_collection(collection:Collection) -> None:
def add_chroma_collection(collection:Collection,
queries:List[str],
query_ids:List[str],
metadatas:List[Mapping[str, str]]=None
metadatas:List[Mapping[str, str]]=None,
embeddings:Embeddings=None
) -> None:
collection.add(documents=queries,
ids=query_ids,
metadatas=metadatas)
metadatas=metadatas,
embeddings=embeddings)
def update_chroma_collection(collection:Collection,
queries:List[str],
query_ids:List[str],
metadatas:List[Mapping[str, str]]=None
metadatas:List[Mapping[str, str]]=None,
embeddings:Embeddings=None
) -> None:
collection.update(documents=queries,
ids=query_ids,
metadatas=metadatas)
metadatas=metadatas,
embeddings=embeddings)
def query_chroma_collection(collection:Collection, query_texts:List[str],
def query_chroma_collection(collection:Collection, query_texts:List[str]=None, query_embeddings:Embeddings=None,
filter_condition:Mapping[str,str]=None, n_results:int=10):
outer_opt = '$and'
inner_opt = '$eq'
@@ -81,8 +77,10 @@ def query_chroma_collection(collection:Collection, query_texts:List[str],
else:
outer_filter = None
print('outer_filter: ', outer_filter)
res = collection.query(query_texts=query_texts, n_results=n_results, where=outer_filter)
logger.info('outer_filter: {}'.format(outer_filter))
res = collection.query(query_texts=query_texts, query_embeddings=query_embeddings,
n_results=n_results, where=outer_filter)
return res
@@ -115,16 +113,32 @@ def parse_retrieval_chroma_collection_query(res:List[Mapping[str, Any]]):
return parsed_res
def chroma_collection_query_retrieval_format(query_list:List[str], retrieval_list:List[Mapping[str, Any]]):
def chroma_collection_query_retrieval_format(query_list:List[str], query_embeddings:Embeddings ,retrieval_list:List[Mapping[str, Any]]):
res = []
for query_idx in range(0, len(query_list)):
query = query_list[query_idx]
retrieval = retrieval_list[query_idx]
res.append({
'query': query,
'retrieval': retrieval
})
if query_list is not None and query_embeddings is not None:
raise Exception("query_list and query_embeddings are not None")
if query_list is None and query_embeddings is None:
raise Exception("query_list and query_embeddings are None")
if query_list is not None:
for query_idx in range(0, len(query_list)):
query = query_list[query_idx]
retrieval = retrieval_list[query_idx]
res.append({
'query': query,
'retrieval': retrieval
})
else:
for query_idx in range(0, len(query_embeddings)):
query_embedding = query_embeddings[query_idx]
retrieval = retrieval_list[query_idx]
res.append({
'query_embedding': query_embedding,
'retrieval': retrieval
})
return res