mirror of
https://github.com/tencentmusic/supersonic.git
synced 2025-12-14 13:47:09 +00:00
[improvement][chat] Move python code out of chat-core module
This commit is contained in:
75
chat/python/services/sql/constructor.py
Normal file
75
chat/python/services/sql/constructor.py
Normal file
@@ -0,0 +1,75 @@
|
||||
# -*- coding:utf-8 -*-
|
||||
import os
|
||||
import sys
|
||||
from typing import List, Mapping
|
||||
from chromadb.api import Collection
|
||||
|
||||
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||||
sys.path.append(os.path.dirname(os.path.abspath(__file__)))
|
||||
|
||||
from instances.logging_instance import logger
|
||||
from services.query_retrieval.retriever import ChromaCollectionRetriever
|
||||
|
||||
class FewShotPromptTemplate2(object):
|
||||
def __init__(self, collection:Collection, retrieval_key:str, few_shot_seperator:str = "\n\n") -> None:
|
||||
self.collection = collection
|
||||
self.few_shot_retriever = ChromaCollectionRetriever(self.collection)
|
||||
|
||||
self.retrieval_key = retrieval_key
|
||||
|
||||
self.few_shot_seperator = few_shot_seperator
|
||||
|
||||
def add_few_shot_example(self, example_ids: List[str] , example_units: List[Mapping[str, str]])-> None:
|
||||
query_text_list = []
|
||||
|
||||
for idx, example_unit in enumerate(example_units):
|
||||
query_text_list.append(example_unit[self.retrieval_key])
|
||||
|
||||
self.few_shot_retriever.add_queries(query_text_list=query_text_list, query_id_list=example_ids, metadatas=example_units)
|
||||
|
||||
def update_few_shot_example(self, example_ids: List[str] , example_units: List[Mapping[str, str]])-> None:
|
||||
query_text_list = []
|
||||
|
||||
for idx, example_unit in enumerate(example_units):
|
||||
query_text_list.append(example_unit[self.retrieval_key])
|
||||
|
||||
self.few_shot_retriever.update_queries(query_text_list=query_text_list, query_id_list=example_ids, metadatas=example_units)
|
||||
|
||||
def delete_few_shot_example(self, example_ids: List[str])-> None:
|
||||
self.few_shot_retriever.delete_queries_by_ids(query_ids=example_ids)
|
||||
|
||||
def count_few_shot_example(self)-> int:
|
||||
return self.few_shot_retriever.get_query_size()
|
||||
|
||||
def reload_few_shot_example(self, example_ids: List[str] , example_units: List[Mapping[str, str]])-> None:
|
||||
logger.info(f"original {self.collection.name} size: {self.few_shot_retriever.get_query_size()}")
|
||||
|
||||
self.few_shot_retriever.empty_query_collection()
|
||||
logger.info(f"emptied {self.collection.name} size: {self.few_shot_retriever.get_query_size()}")
|
||||
|
||||
self.add_few_shot_example(example_ids=example_ids, example_units=example_units)
|
||||
logger.info(f"reloaded {self.collection.name} size: {self.few_shot_retriever.get_query_size()}")
|
||||
|
||||
def _sub_dict(self, d:Mapping[str, str], keys:List[str])-> Mapping[str, str]:
|
||||
return {k:d[k] for k in keys if k in d}
|
||||
|
||||
def retrieve_few_shot_example(self, query_text: str, retrieval_num: int, filter_condition: Mapping[str,str] =None)-> List[Mapping[str, str]]:
|
||||
query_text_list = [query_text]
|
||||
retrieval_res_list = self.few_shot_retriever.retrieval_query_run(query_texts_list=query_text_list,
|
||||
filter_condition=filter_condition, n_results=retrieval_num)
|
||||
retrieval_res_unit_list = retrieval_res_list[0]['retrieval']
|
||||
|
||||
return retrieval_res_unit_list
|
||||
|
||||
def make_few_shot_example_prompt(self, few_shot_template: str, example_keys: List[str],
|
||||
few_shot_example_meta_list: List[Mapping[str, str]])-> str:
|
||||
few_shot_example_str_unit_list = []
|
||||
|
||||
retrieval_metas_list = [self._sub_dict(few_shot_example_meta['metadata'], example_keys) for few_shot_example_meta in few_shot_example_meta_list]
|
||||
|
||||
for meta in retrieval_metas_list:
|
||||
few_shot_example_str_unit_list.append(few_shot_template.format(**meta))
|
||||
|
||||
few_shot_example_str = self.few_shot_seperator.join(few_shot_example_str_unit_list)
|
||||
|
||||
return few_shot_example_str
|
||||
61
chat/python/services/sql/examples_reload_run.py
Normal file
61
chat/python/services/sql/examples_reload_run.py
Normal file
@@ -0,0 +1,61 @@
|
||||
# -*- coding:utf-8 -*-
|
||||
import json
|
||||
import os
|
||||
import sys
|
||||
from typing import List, Mapping
|
||||
|
||||
import requests
|
||||
|
||||
sys.path.append(os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))))
|
||||
sys.path.append(os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))))
|
||||
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||||
sys.path.append(os.path.dirname(os.path.abspath(__file__)))
|
||||
|
||||
from instances.logging_instance import logger
|
||||
|
||||
from config.config_parse import (TEXT2DSL_EXAMPLE_NUM, TEXT2DSL_FEWSHOTS_NUM, TEXT2DSL_SELF_CONSISTENCY_NUM,
|
||||
LLMPARSER_HOST, LLMPARSER_PORT, TEXT2DSL_IS_SHORTCUT, TEXT2DSL_IS_SELF_CONSISTENCY)
|
||||
from few_shot_example.sql_examplar import examplars as sql_examplars
|
||||
|
||||
|
||||
def text2sql_agent_setting_update(llm_host:str, llm_port:str,
|
||||
sql_examplars:List[Mapping[str, str]], example_nums:int):
|
||||
|
||||
url = f"http://{llm_host}:{llm_port}/text2sql_agent_setting_update/"
|
||||
payload = {"sqlExamplars":sql_examplars, "exampleNums":example_nums}
|
||||
headers = {'content-type': 'application/json'}
|
||||
response = requests.post(url, data=json.dumps(payload), headers=headers)
|
||||
logger.info(response.text)
|
||||
|
||||
|
||||
def text2dsl_agent_cs_setting_update(llm_host:str, llm_port:str,
|
||||
sql_examplars:List[Mapping[str, str]], example_nums:int, fewshot_nums:int, self_consistency_nums:int):
|
||||
|
||||
url = f"http://{llm_host}:{llm_port}/texg2sqt_cs_agent_setting_update/"
|
||||
payload = {"sqlExamplars":sql_examplars,
|
||||
"exampleNums":example_nums, "fewshotNums":fewshot_nums, "selfConsistencyNums":self_consistency_nums}
|
||||
headers = {'content-type': 'application/json'}
|
||||
response = requests.post(url, data=json.dumps(payload), headers=headers)
|
||||
logger.info(response.text)
|
||||
|
||||
|
||||
def text2dsl_agent_wrapper_setting_update(llm_host:str, llm_port:str,
|
||||
is_shortcut:bool, is_self_consistency:bool,
|
||||
sql_examplars:List[Mapping[str, str]], example_nums:int, fewshot_nums:int, self_consistency_nums:int):
|
||||
|
||||
sql_ids = [str(i) for i in range(0, len(sql_examplars))]
|
||||
|
||||
url = f"http://{llm_host}:{llm_port}/query2sql_setting_update/"
|
||||
payload = {"isShortcut":is_shortcut, "isSelfConsistency":is_self_consistency,
|
||||
"sqlExamplars":sql_examplars, "sqlIds": sql_ids,
|
||||
"exampleNums":example_nums, "fewshotNums":fewshot_nums, "selfConsistencyNums":self_consistency_nums}
|
||||
headers = {'content-type': 'application/json'}
|
||||
response = requests.post(url, data=json.dumps(payload), headers=headers)
|
||||
logger.info(response.text)
|
||||
|
||||
if __name__ == "__main__":
|
||||
text2dsl_agent_wrapper_setting_update(LLMPARSER_HOST,LLMPARSER_PORT,
|
||||
TEXT2DSL_IS_SHORTCUT, TEXT2DSL_IS_SELF_CONSISTENCY,
|
||||
sql_examplars, TEXT2DSL_EXAMPLE_NUM, TEXT2DSL_FEWSHOTS_NUM, TEXT2DSL_SELF_CONSISTENCY_NUM)
|
||||
|
||||
|
||||
57
chat/python/services/sql/output_parser.py
Normal file
57
chat/python/services/sql/output_parser.py
Normal file
@@ -0,0 +1,57 @@
|
||||
# -*- coding:utf-8 -*-
|
||||
import re
|
||||
|
||||
import os
|
||||
import sys
|
||||
|
||||
sys.path.append(os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))))
|
||||
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||||
sys.path.append(os.path.dirname(os.path.abspath(__file__)))
|
||||
|
||||
from instances.logging_instance import logger
|
||||
|
||||
|
||||
def schema_link_parse(schema_link_output):
|
||||
try:
|
||||
schema_link_output = schema_link_output.strip()
|
||||
pattern = r"Schema_links:(.*)"
|
||||
schema_link_output = re.findall(pattern, schema_link_output, re.DOTALL)[0].strip()
|
||||
except Exception as e:
|
||||
logger.exception(e)
|
||||
schema_link_output = None
|
||||
|
||||
return schema_link_output
|
||||
|
||||
|
||||
def combo_schema_link_parse(schema_linking_sql_combo_output: str):
|
||||
try:
|
||||
schema_linking_sql_combo_output = schema_linking_sql_combo_output.strip()
|
||||
pattern = r"Schema_links:(\[.*?\])"
|
||||
schema_links_match = re.search(pattern, schema_linking_sql_combo_output)
|
||||
|
||||
if schema_links_match:
|
||||
schema_links = schema_links_match.group(1)
|
||||
else:
|
||||
schema_links = None
|
||||
except Exception as e:
|
||||
logger.info(e)
|
||||
schema_links = None
|
||||
|
||||
return schema_links
|
||||
|
||||
|
||||
def combo_sql_parse(schema_linking_sql_combo_output: str):
|
||||
try:
|
||||
schema_linking_sql_combo_output = schema_linking_sql_combo_output.strip()
|
||||
pattern = r"SQL:(.*)"
|
||||
sql_match = re.search(pattern, schema_linking_sql_combo_output)
|
||||
|
||||
if sql_match:
|
||||
sql = sql_match.group(1)
|
||||
else:
|
||||
sql = None
|
||||
except Exception as e:
|
||||
logger.exception(e)
|
||||
sql = None
|
||||
|
||||
return sql
|
||||
54
chat/python/services/sql/run.py
Normal file
54
chat/python/services/sql/run.py
Normal file
@@ -0,0 +1,54 @@
|
||||
# -*- coding:utf-8 -*-
|
||||
|
||||
import asyncio
|
||||
|
||||
import os
|
||||
import sys
|
||||
|
||||
sys.path.append(os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))))
|
||||
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||||
sys.path.append(os.path.dirname(os.path.abspath(__file__)))
|
||||
|
||||
from sql.constructor import FewShotPromptTemplate2
|
||||
from sql.sql_agent import Text2DSLAgent, Text2DSLAgentConsistency, Text2DSLAgentWrapper
|
||||
|
||||
from instances.llm_instance import llm
|
||||
from instances.text2vec import Text2VecEmbeddingFunction
|
||||
from instances.chromadb_instance import client
|
||||
from instances.logging_instance import logger
|
||||
|
||||
from few_shot_example.sql_examplar import examplars as sql_examplars
|
||||
from config.config_parse import (TEXT2DSLAGENT_COLLECTION_NAME, TEXT2DSLAGENTCS_COLLECTION_NAME,
|
||||
TEXT2DSL_EXAMPLE_NUM, TEXT2DSL_FEWSHOTS_NUM, TEXT2DSL_SELF_CONSISTENCY_NUM,
|
||||
TEXT2DSL_IS_SHORTCUT, TEXT2DSL_IS_SELF_CONSISTENCY)
|
||||
|
||||
|
||||
emb_func = Text2VecEmbeddingFunction()
|
||||
text2dsl_agent_collection = client.get_or_create_collection(name=TEXT2DSLAGENT_COLLECTION_NAME,
|
||||
embedding_function=emb_func,
|
||||
metadata={"hnsw:space": "cosine"})
|
||||
text2dsl_agentcs_collection = client.get_or_create_collection(name=TEXT2DSLAGENTCS_COLLECTION_NAME,
|
||||
embedding_function=emb_func,
|
||||
metadata={"hnsw:space": "cosine"})
|
||||
|
||||
text2dsl_agent_example_prompter = FewShotPromptTemplate2(collection=text2dsl_agent_collection,
|
||||
retrieval_key="question",
|
||||
few_shot_seperator='\n\n')
|
||||
|
||||
text2dsl_agentcs_example_prompter = FewShotPromptTemplate2(collection=text2dsl_agentcs_collection,
|
||||
retrieval_key="question",
|
||||
few_shot_seperator='\n\n')
|
||||
|
||||
text2sql_agent = Text2DSLAgent(num_fewshots=TEXT2DSL_EXAMPLE_NUM,
|
||||
sql_example_prompter=text2dsl_agent_example_prompter, llm=llm)
|
||||
|
||||
text2sql_cs_agent = Text2DSLAgentConsistency(num_fewshots=TEXT2DSL_FEWSHOTS_NUM, num_examples=TEXT2DSL_EXAMPLE_NUM, num_self_consistency=TEXT2DSL_SELF_CONSISTENCY_NUM,
|
||||
sql_example_prompter=text2dsl_agentcs_example_prompter, llm=llm)
|
||||
|
||||
sql_ids = [str(i) for i in range(0, len(sql_examplars))]
|
||||
text2sql_agent.reload_setting(sql_ids, sql_examplars, TEXT2DSL_EXAMPLE_NUM)
|
||||
text2sql_cs_agent.reload_setting(sql_ids, sql_examplars, TEXT2DSL_EXAMPLE_NUM, TEXT2DSL_FEWSHOTS_NUM, TEXT2DSL_SELF_CONSISTENCY_NUM)
|
||||
|
||||
|
||||
text2sql_agent_router = Text2DSLAgentWrapper(sql_agent=text2sql_agent, sql_agent_cs=text2sql_cs_agent,
|
||||
is_shortcut=TEXT2DSL_IS_SHORTCUT, is_self_consistency=TEXT2DSL_IS_SELF_CONSISTENCY)
|
||||
405
chat/python/services/sql/sql_agent.py
Normal file
405
chat/python/services/sql/sql_agent.py
Normal file
@@ -0,0 +1,405 @@
|
||||
import os
|
||||
import sys
|
||||
from typing import List, Union, Mapping, Any
|
||||
from collections import Counter
|
||||
import random
|
||||
import asyncio
|
||||
from langchain.llms.base import BaseLLM
|
||||
|
||||
sys.path.append(os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))))
|
||||
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||||
sys.path.append(os.path.dirname(os.path.abspath(__file__)))
|
||||
|
||||
from instances.logging_instance import logger
|
||||
|
||||
from sql.constructor import FewShotPromptTemplate2
|
||||
from sql.output_parser import schema_link_parse, combo_schema_link_parse, combo_sql_parse
|
||||
|
||||
|
||||
class Text2DSLAgent(object):
|
||||
def __init__(self, num_fewshots:int,
|
||||
sql_example_prompter:FewShotPromptTemplate2,
|
||||
llm: BaseLLM):
|
||||
self.num_fewshots = num_fewshots
|
||||
self.sql_example_prompter = sql_example_prompter
|
||||
self.llm = llm
|
||||
|
||||
def reload_setting(self, sql_example_ids: List[str], sql_example_units: List[Mapping[str,str]], num_fewshots: int):
|
||||
self.num_fewshots = num_fewshots
|
||||
|
||||
self.sql_example_prompter.reload_few_shot_example(sql_example_ids, sql_example_units)
|
||||
|
||||
def add_examples(self, sql_example_ids: List[str], sql_example_units: List[Mapping[str,str]]):
|
||||
self.sql_example_prompter.add_few_shot_example(sql_example_ids, sql_example_units)
|
||||
|
||||
def update_examples(self, sql_example_ids: List[str], sql_example_units: List[Mapping[str,str]]):
|
||||
self.sql_example_prompter.update_few_shot_example(sql_example_ids, sql_example_units)
|
||||
|
||||
def delete_examples(self, sql_example_ids: List[str]):
|
||||
self.sql_example_prompter.delete_few_shot_example(sql_example_ids)
|
||||
|
||||
def count_examples(self):
|
||||
return self.sql_example_prompter.count_few_shot_example()
|
||||
|
||||
def get_fewshot_examples(self, query_text: str, filter_condition: Mapping[str,str])->List[Mapping[str, str]]:
|
||||
few_shot_example_meta_list = self.sql_example_prompter.retrieve_few_shot_example(query_text, self.num_fewshots, filter_condition)
|
||||
|
||||
return few_shot_example_meta_list
|
||||
|
||||
def generate_schema_linking_prompt(self, user_query: str, domain_name: str, fields_list: List[str],
|
||||
prior_schema_links: Mapping[str,str], fewshot_example_list:List[Mapping[str, str]])-> str:
|
||||
|
||||
prior_schema_links_str = '['+ ','.join(["""'{}'->{}""".format(k,v) for k,v in prior_schema_links.items()]) + ']'
|
||||
|
||||
instruction = "# 根据数据库的表结构,参考先验信息,找出为每个问题生成SQL查询语句的schema_links"
|
||||
|
||||
schema_linking_example_keys = ["tableName", "fieldsList", "priorSchemaLinks", "question", "analysis", "schemaLinks"]
|
||||
schema_linking_example_template = "Table {tableName}, columns = {fieldsList}, prior_schema_links = {priorSchemaLinks}\n问题:{question}\n分析:{analysis} 所以Schema_links是:\nSchema_links:{schemaLinks}"
|
||||
schema_linking_fewshot_prompt = self.sql_example_prompter.make_few_shot_example_prompt(few_shot_template=schema_linking_example_template,
|
||||
example_keys=schema_linking_example_keys,
|
||||
few_shot_example_meta_list=fewshot_example_list)
|
||||
|
||||
new_case_template = "Table {tableName}, columns = {fieldsList}, prior_schema_links = {priorSchemaLinks}\n问题:{question}\n分析: 让我们一步一步地思考。"
|
||||
new_case_prompt = new_case_template.format(tableName=domain_name, fieldsList=fields_list, priorSchemaLinks=prior_schema_links_str, question=user_query)
|
||||
|
||||
schema_linking_prompt = instruction + '\n\n' + schema_linking_fewshot_prompt + '\n\n' + new_case_prompt
|
||||
return schema_linking_prompt
|
||||
|
||||
def generate_sql_prompt(self, user_query: str, domain_name: str,
|
||||
schema_link_str: str, data_date: str,
|
||||
fewshot_example_list:List[Mapping[str, str]])-> str:
|
||||
instruction = "# 根据schema_links为每个问题生成SQL查询语句"
|
||||
sql_example_keys = ["question", "currentDate", "tableName", "schemaLinks", "sql"]
|
||||
sql_example_template = "问题:{question}\nCurrent_date:{currentDate}\nTable {tableName}\nSchema_links:{schemaLinks}\nSQL:{sql}"
|
||||
|
||||
sql_example_fewshot_prompt = self.sql_example_prompter.make_few_shot_example_prompt(few_shot_template=sql_example_template,
|
||||
example_keys=sql_example_keys,
|
||||
few_shot_example_meta_list=fewshot_example_list)
|
||||
|
||||
new_case_template = "问题:{question}\nCurrent_date:{currentDate}\nTable {tableName}\nSchema_links:{schemaLinks}\nSQL:"
|
||||
new_case_prompt = new_case_template.format(question=user_query, currentDate=data_date, tableName=domain_name, schemaLinks=schema_link_str)
|
||||
|
||||
sql_example_prompt = instruction + '\n\n' + sql_example_fewshot_prompt + '\n\n' + new_case_prompt
|
||||
|
||||
return sql_example_prompt
|
||||
|
||||
def generate_schema_linking_sql_prompt(self, user_query: str,
|
||||
domain_name: str,
|
||||
data_date : str,
|
||||
fields_list: List[str],
|
||||
prior_schema_links: Mapping[str,str],
|
||||
fewshot_example_list:List[Mapping[str, str]]):
|
||||
|
||||
prior_schema_links_str = '['+ ','.join(["""'{}'->{}""".format(k,v) for k,v in prior_schema_links.items()]) + ']'
|
||||
|
||||
instruction = "# 根据数据库的表结构,参考先验信息,找出为每个问题生成SQL查询语句的schema_links,再根据schema_links为每个问题生成SQL查询语句"
|
||||
|
||||
example_keys = ["tableName", "fieldsList", "priorSchemaLinks", "currentDate", "question", "analysis", "schemaLinks", "sql"]
|
||||
example_template = "Table {tableName}, columns = {fieldsList}, prior_schema_links = {priorSchemaLinks}\nCurrent_date:{currentDate}\n问题:{question}\n分析:{analysis} 所以Schema_links是:\nSchema_links:{schemaLinks}\nSQL:{sql}"
|
||||
fewshot_prompt = self.sql_example_prompter.make_few_shot_example_prompt(few_shot_template=example_template,
|
||||
example_keys=example_keys,
|
||||
few_shot_example_meta_list=fewshot_example_list)
|
||||
|
||||
new_case_template = "Table {tableName}, columns = {fieldsList}, prior_schema_links = {priorSchemaLinks}\nCurrent_date:{currentDate}\n问题:{question}\n分析: 让我们一步一步地思考。"
|
||||
new_case_prompt = new_case_template.format(tableName=domain_name, fieldsList=fields_list, priorSchemaLinks=prior_schema_links_str, currentDate=data_date, question=user_query)
|
||||
|
||||
prompt = instruction + '\n\n' + fewshot_prompt + '\n\n' + new_case_prompt
|
||||
|
||||
return prompt
|
||||
|
||||
async def async_query2sql(self, query_text: str, filter_condition: Mapping[str,str],
|
||||
model_name: str, fields_list: List[str],
|
||||
data_date: str, prior_schema_links: Mapping[str,str], prior_exts: str):
|
||||
logger.info("query_text: {}".format(query_text))
|
||||
logger.info("model_name: {}".format(model_name))
|
||||
logger.info("fields_list: {}".format(fields_list))
|
||||
logger.info("data_date: {}".format(data_date))
|
||||
logger.info("prior_schema_links: {}".format(prior_schema_links))
|
||||
logger.info("prior_exts: {}".format(prior_exts))
|
||||
|
||||
if prior_exts != '':
|
||||
query_text = query_text + ' 备注:'+prior_exts
|
||||
logger.info("query_text_prior_exts: {}".format(query_text))
|
||||
|
||||
fewshot_example_meta_list = self.get_fewshot_examples(query_text, filter_condition)
|
||||
schema_linking_prompt = self.generate_schema_linking_prompt(query_text, model_name, fields_list, prior_schema_links, fewshot_example_meta_list)
|
||||
logger.debug("schema_linking_prompt->{}".format(schema_linking_prompt))
|
||||
schema_link_output = await self.llm._call_async(schema_linking_prompt)
|
||||
|
||||
schema_link_str = schema_link_parse(schema_link_output)
|
||||
|
||||
sql_prompt = self.generate_sql_prompt(query_text, model_name, schema_link_str, data_date, fewshot_example_meta_list)
|
||||
logger.debug("sql_prompt->{}".format(sql_prompt))
|
||||
sql_output = await self.llm._call_async(sql_prompt)
|
||||
|
||||
resp = dict()
|
||||
resp['query'] = query_text
|
||||
resp['model'] = model_name
|
||||
resp['fields'] = fields_list
|
||||
resp['priorSchemaLinking'] = prior_schema_links
|
||||
resp['dataDate'] = data_date
|
||||
|
||||
resp['schemaLinkingOutput'] = schema_link_output
|
||||
resp['schemaLinkStr'] = schema_link_str
|
||||
|
||||
resp['sqlOutput'] = sql_output
|
||||
|
||||
logger.info("resp: {}".format(resp))
|
||||
|
||||
return resp
|
||||
|
||||
async def async_query2sql_shortcut(self, query_text: str, filter_condition: Mapping[str,str],
|
||||
model_name: str, fields_list: List[str],
|
||||
data_date: str, prior_schema_links: Mapping[str,str], prior_exts: str):
|
||||
logger.info("query_text: {}".format(query_text))
|
||||
logger.info("model_name: {}".format(model_name))
|
||||
logger.info("fields_list: {}".format(fields_list))
|
||||
logger.info("data_date: {}".format(data_date))
|
||||
logger.info("prior_schema_links: {}".format(prior_schema_links))
|
||||
logger.info("prior_exts: {}".format(prior_exts))
|
||||
|
||||
if prior_exts != '':
|
||||
query_text = query_text + ' 备注:'+prior_exts
|
||||
logger.info("query_text_prior_exts: {}".format(query_text))
|
||||
|
||||
fewshot_example_meta_list = self.get_fewshot_examples(query_text, filter_condition)
|
||||
schema_linking_sql_shortcut_prompt = self.generate_schema_linking_sql_prompt(query_text, model_name, data_date, fields_list, prior_schema_links, fewshot_example_meta_list)
|
||||
logger.debug("schema_linking_sql_shortcut_prompt->{}".format(schema_linking_sql_shortcut_prompt))
|
||||
schema_linking_sql_shortcut_output = await self.llm._call_async(schema_linking_sql_shortcut_prompt)
|
||||
|
||||
schema_linking_str = combo_schema_link_parse(schema_linking_sql_shortcut_output)
|
||||
sql_str = combo_sql_parse(schema_linking_sql_shortcut_output)
|
||||
|
||||
resp = dict()
|
||||
resp['query'] = query_text
|
||||
resp['model'] = model_name
|
||||
resp['fields'] = fields_list
|
||||
resp['priorSchemaLinking'] = prior_schema_links
|
||||
resp['dataDate'] = data_date
|
||||
|
||||
resp['schemaLinkingComboOutput'] = schema_linking_sql_shortcut_output
|
||||
resp['schemaLinkStr'] = schema_linking_str
|
||||
resp['sqlOutput'] = sql_str
|
||||
|
||||
logger.info("resp: {}".format(resp))
|
||||
|
||||
return resp
|
||||
|
||||
class Text2DSLAgentConsistency(object):
|
||||
def __init__(self, num_fewshots:int, num_examples:int, num_self_consistency:int,
|
||||
sql_example_prompter:FewShotPromptTemplate2, llm: BaseLLM) -> None:
|
||||
self.num_fewshots = num_fewshots
|
||||
self.num_examples = num_examples
|
||||
assert self.num_fewshots <= self.num_examples
|
||||
self.num_self_consistency = num_self_consistency
|
||||
|
||||
self.llm = llm
|
||||
self.sql_example_prompter = sql_example_prompter
|
||||
|
||||
def reload_setting(self, sql_example_ids:List[str], sql_example_units: List[Mapping[str, str]], num_examples:int, num_fewshots:int, num_self_consistency:int):
|
||||
self.num_fewshots = num_fewshots
|
||||
self.num_examples = num_examples
|
||||
assert self.num_fewshots <= self.num_examples
|
||||
self.num_self_consistency = num_self_consistency
|
||||
assert self.num_self_consistency >= 1
|
||||
self.sql_example_prompter.reload_few_shot_example(sql_example_ids, sql_example_units)
|
||||
|
||||
def add_examples(self, sql_example_ids:List[str], sql_example_units: List[Mapping[str, str]]):
|
||||
self.sql_example_prompter.add_few_shot_example(sql_example_ids, sql_example_units)
|
||||
|
||||
def update_examples(self, sql_example_ids:List[str], sql_example_units: List[Mapping[str, str]]):
|
||||
self.sql_example_prompter.update_few_shot_example(sql_example_ids, sql_example_units)
|
||||
|
||||
def delete_examples(self, sql_example_ids:List[str]):
|
||||
self.sql_example_prompter.delete_few_shot_example(sql_example_ids)
|
||||
|
||||
def count_examples(self):
|
||||
return self.sql_example_prompter.count_few_shot_example()
|
||||
|
||||
def get_examples_candidates(self, query_text: str, filter_condition: Mapping[str, str])->List[Mapping[str, str]]:
|
||||
few_shot_example_meta_list = self.sql_example_prompter.retrieve_few_shot_example(query_text, self.num_examples, filter_condition)
|
||||
|
||||
return few_shot_example_meta_list
|
||||
|
||||
def get_fewshot_example_combos(self, example_meta_list:List[Mapping[str, str]])-> List[List[Mapping[str, str]]]:
|
||||
fewshot_example_list = []
|
||||
for i in range(0, self.num_self_consistency):
|
||||
random.shuffle(example_meta_list)
|
||||
fewshot_example_list.append(example_meta_list[:self.num_fewshots])
|
||||
|
||||
return fewshot_example_list
|
||||
|
||||
def generate_schema_linking_prompt(self, user_query: str, domain_name: str, fields_list: List[str],
|
||||
prior_schema_links: Mapping[str,str], fewshot_example_list:List[Mapping[str, str]])-> str:
|
||||
|
||||
prior_schema_links_str = '['+ ','.join(["""'{}'->{}""".format(k,v) for k,v in prior_schema_links.items()]) + ']'
|
||||
|
||||
instruction = "# 根据数据库的表结构,参考先验信息,找出为每个问题生成SQL查询语句的schema_links"
|
||||
|
||||
schema_linking_example_keys = ["tableName", "fieldsList", "priorSchemaLinks", "question", "analysis", "schemaLinks"]
|
||||
schema_linking_example_template = "Table {tableName}, columns = {fieldsList}, prior_schema_links = {priorSchemaLinks}\n问题:{question}\n分析:{analysis} 所以Schema_links是:\nSchema_links:{schemaLinks}"
|
||||
schema_linking_fewshot_prompt = self.sql_example_prompter.make_few_shot_example_prompt(few_shot_template=schema_linking_example_template,
|
||||
example_keys=schema_linking_example_keys,
|
||||
few_shot_example_meta_list=fewshot_example_list)
|
||||
|
||||
new_case_template = "Table {tableName}, columns = {fieldsList}, prior_schema_links = {priorSchemaLinks}\n问题:{question}\n分析: 让我们一步一步地思考。"
|
||||
new_case_prompt = new_case_template.format(tableName=domain_name, fieldsList=fields_list, priorSchemaLinks=prior_schema_links_str, question=user_query)
|
||||
|
||||
schema_linking_prompt = instruction + '\n\n' + schema_linking_fewshot_prompt + '\n\n' + new_case_prompt
|
||||
return schema_linking_prompt
|
||||
|
||||
def generate_schema_linking_prompt_pool(self, user_query: str, domain_name: str, fields_list: List[str],
|
||||
prior_schema_links: Mapping[str,str], fewshot_example_list_pool:List[List[Mapping[str, str]]])-> List[str]:
|
||||
schema_linking_prompt_pool = []
|
||||
for fewshot_example_list in fewshot_example_list_pool:
|
||||
schema_linking_prompt = self.generate_schema_linking_prompt(user_query, domain_name, fields_list, prior_schema_links, fewshot_example_list)
|
||||
schema_linking_prompt_pool.append(schema_linking_prompt)
|
||||
|
||||
return schema_linking_prompt_pool
|
||||
|
||||
def generate_sql_prompt(self, user_query: str, domain_name: str,
|
||||
schema_link_str: str, data_date: str,
|
||||
fewshot_example_list:List[Mapping[str, str]])-> str:
|
||||
instruction = "# 根据schema_links为每个问题生成SQL查询语句"
|
||||
sql_example_keys = ["question", "currentDate", "tableName", "schemaLinks", "sql"]
|
||||
sql_example_template = "问题:{question}\nCurrent_date:{currentDate}\nTable {tableName}\nSchema_links:{schemaLinks}\nSQL:{sql}"
|
||||
|
||||
sql_example_fewshot_prompt = self.sql_example_prompter.make_few_shot_example_prompt(few_shot_template=sql_example_template,
|
||||
example_keys=sql_example_keys,
|
||||
few_shot_example_meta_list=fewshot_example_list)
|
||||
|
||||
new_case_template = "问题:{question}\nCurrent_date:{currentDate}\nTable {tableName}\nSchema_links:{schemaLinks}\nSQL:"
|
||||
new_case_prompt = new_case_template.format(question=user_query, currentDate=data_date, tableName=domain_name, schemaLinks=schema_link_str)
|
||||
|
||||
sql_example_prompt = instruction + '\n\n' + sql_example_fewshot_prompt + '\n\n' + new_case_prompt
|
||||
|
||||
return sql_example_prompt
|
||||
|
||||
def generate_sql_prompt_pool(self, user_query: str, domain_name: str, data_date: str,
|
||||
schema_link_str_pool: List[str], fewshot_example_list_pool:List[List[Mapping[str, str]]])-> List[str]:
|
||||
sql_prompt_pool = []
|
||||
for schema_link_str, fewshot_example_list in zip(schema_link_str_pool, fewshot_example_list_pool):
|
||||
sql_prompt = self.generate_sql_prompt(user_query, domain_name, schema_link_str, data_date, fewshot_example_list)
|
||||
sql_prompt_pool.append(sql_prompt)
|
||||
|
||||
return sql_prompt_pool
|
||||
|
||||
def self_consistency_vote(self, output_res_pool:List[str]):
|
||||
output_res_counts = Counter(output_res_pool)
|
||||
output_res_max = output_res_counts.most_common(1)[0][0]
|
||||
total_output_num = len(output_res_pool)
|
||||
|
||||
vote_percentage = {k: (v/total_output_num) for k,v in output_res_counts.items()}
|
||||
|
||||
return output_res_max, vote_percentage
|
||||
|
||||
def schema_linking_list_str_unify(self, schema_linking_list: List[str])-> List[str]:
|
||||
schema_linking_list_unify = []
|
||||
for schema_linking_str in schema_linking_list:
|
||||
schema_linking_str_unify = ','.join(sorted([item.strip() for item in schema_linking_str.strip('[]').split(',')]))
|
||||
schema_linking_str_unify = f'[{schema_linking_str_unify}]'
|
||||
schema_linking_list_unify.append(schema_linking_str_unify)
|
||||
|
||||
return schema_linking_list_unify
|
||||
|
||||
|
||||
async def generate_schema_linking_tasks(self, user_query: str, domain_name: str,
|
||||
fields_list: List[str], prior_schema_links: Mapping[str,str],
|
||||
fewshot_example_list_combo:List[List[Mapping[str, str]]]):
|
||||
|
||||
schema_linking_prompt_pool = self.generate_schema_linking_prompt_pool(user_query, domain_name,
|
||||
fields_list, prior_schema_links,
|
||||
fewshot_example_list_combo)
|
||||
schema_linking_output_task_pool = [self.llm._call_async(schema_linking_prompt) for schema_linking_prompt in schema_linking_prompt_pool]
|
||||
schema_linking_output_res_pool = await asyncio.gather(*schema_linking_output_task_pool)
|
||||
logger.debug(f'schema_linking_output_res_pool:{schema_linking_output_res_pool}')
|
||||
|
||||
return schema_linking_output_res_pool
|
||||
|
||||
async def generate_sql_tasks(self, user_query: str, domain_name: str, data_date: str,
|
||||
schema_link_str_pool: List[str], fewshot_example_list_combo:List[List[Mapping[str, str]]]):
|
||||
|
||||
sql_prompt_pool = self.generate_sql_prompt_pool(user_query, domain_name, schema_link_str_pool, data_date, fewshot_example_list_combo)
|
||||
sql_output_task_pool = [self.llm._call_async(sql_prompt) for sql_prompt in sql_prompt_pool]
|
||||
sql_output_res_pool = await asyncio.gather(*sql_output_task_pool)
|
||||
logger.debug(f'sql_output_res_pool:{sql_output_res_pool}')
|
||||
|
||||
return sql_output_res_pool
|
||||
|
||||
async def tasks_run(self, user_query: str, filter_condition: Mapping[str, str], domain_name: str, fields_list: List[str], prior_schema_links: Mapping[str,str], data_date: str, prior_exts: str):
|
||||
logger.info("user_query: {}".format(user_query))
|
||||
logger.info("domain_name: {}".format(domain_name))
|
||||
logger.info("fields_list: {}".format(fields_list))
|
||||
logger.info("current_date: {}".format(data_date))
|
||||
logger.info("prior_schema_links: {}".format(prior_schema_links))
|
||||
logger.info("prior_exts: {}".format(prior_exts))
|
||||
|
||||
if prior_exts != '':
|
||||
user_query = user_query + ' 备注:'+prior_exts
|
||||
logger.info("user_query_prior_exts: {}".format(user_query))
|
||||
|
||||
fewshot_example_meta_list = self.get_examples_candidates(user_query, filter_condition)
|
||||
fewshot_example_list_combo = self.get_fewshot_example_combos(fewshot_example_meta_list)
|
||||
|
||||
schema_linking_output_candidates = await self.generate_schema_linking_tasks(user_query, domain_name, fields_list, prior_schema_links, fewshot_example_list_combo)
|
||||
schema_linking_candidate_list = [schema_link_parse(schema_linking_output_candidate) for schema_linking_output_candidate in schema_linking_output_candidates]
|
||||
logger.debug(f'schema_linking_candidate_list:{schema_linking_candidate_list}')
|
||||
schema_linking_candidate_sorted_list = self.schema_linking_list_str_unify(schema_linking_candidate_list)
|
||||
logger.debug(f'schema_linking_candidate_sorted_list:{schema_linking_candidate_sorted_list}')
|
||||
|
||||
schema_linking_output_max, schema_linking_output_vote_percentage = self.self_consistency_vote(schema_linking_candidate_sorted_list)
|
||||
|
||||
sql_output_candicates = await self.generate_sql_tasks(user_query, domain_name, data_date, schema_linking_candidate_list,fewshot_example_list_combo)
|
||||
logger.debug(f'sql_output_candicates:{sql_output_candicates}')
|
||||
sql_output_max, sql_output_vote_percentage = self.self_consistency_vote(sql_output_candicates)
|
||||
|
||||
resp = dict()
|
||||
resp['query'] = user_query
|
||||
resp['model'] = domain_name
|
||||
resp['fields'] = fields_list
|
||||
resp['priorSchemaLinking'] = prior_schema_links
|
||||
resp['dataDate'] = data_date
|
||||
|
||||
resp['schemaLinkStr'] = schema_linking_output_max
|
||||
resp['schemaLinkingWeight'] = schema_linking_output_vote_percentage
|
||||
|
||||
resp['sqlOutput'] = sql_output_max
|
||||
resp['sqlWeight'] = sql_output_vote_percentage
|
||||
|
||||
logger.info("resp: {}".format(resp))
|
||||
|
||||
return resp
|
||||
|
||||
class Text2DSLAgentWrapper(object):
|
||||
def __init__(self, sql_agent:Text2DSLAgent, sql_agent_cs:Text2DSLAgentConsistency,
|
||||
is_shortcut:bool, is_self_consistency:bool):
|
||||
self.sql_agent = sql_agent
|
||||
self.sql_agent_cs = sql_agent_cs
|
||||
|
||||
self.is_shortcut = is_shortcut
|
||||
self.is_self_consistency = is_self_consistency
|
||||
|
||||
async def async_query2sql(self, query_text: str, filter_condition: Mapping[str,str],
|
||||
model_name: str, fields_list: List[str],
|
||||
data_date: str, prior_schema_links: Mapping[str,str], prior_exts: str):
|
||||
if self.is_self_consistency:
|
||||
logger.info("sql wrapper: self_consistency")
|
||||
resp = await self.sql_agent_cs.tasks_run(user_query=query_text, filter_condition=filter_condition, domain_name=model_name, fields_list=fields_list, prior_schema_links=prior_schema_links, data_date=data_date, prior_exts=prior_exts)
|
||||
return resp
|
||||
elif self.is_shortcut:
|
||||
logger.info("sql wrapper: shortcut")
|
||||
resp = await self.sql_agent.async_query2sql_shortcut(query_text=query_text, filter_condition=filter_condition, model_name=model_name, fields_list=fields_list, data_date=data_date, prior_schema_links=prior_schema_links, prior_exts=prior_exts)
|
||||
return resp
|
||||
else:
|
||||
logger.info("sql wrapper: normal")
|
||||
resp = await self.sql_agent.async_query2sql(query_text=query_text, filter_condition=filter_condition, model_name=model_name, fields_list=fields_list, data_date=data_date, prior_schema_links=prior_schema_links, prior_exts=prior_exts)
|
||||
return resp
|
||||
|
||||
def update_configs(self, is_shortcut, is_self_consistency,
|
||||
sql_examplars, num_examples, num_fewshots, num_self_consistency):
|
||||
self.is_shortcut = is_shortcut
|
||||
self.is_self_consistency = is_self_consistency
|
||||
|
||||
self.sql_agent.update_examples(sql_examplars=sql_examplars, num_fewshots=num_examples)
|
||||
self.sql_agent_cs.update_examples(sql_examplars=sql_examplars, num_examples=num_examples, num_fewshots=num_fewshots, num_self_consistency=num_self_consistency)
|
||||
|
||||
Reference in New Issue
Block a user