mirror of
https://github.com/tencentmusic/supersonic.git
synced 2025-12-12 20:51:48 +00:00
add auto-CoT feature (#483)
* 1.refactor the retrieval module. 2.refactor the http service module. 3.upgrade text2sql output format the parse for absolute time related expression in query. * fix bug. * upgrade the config module, now support config llm suppoted by langchain. * fix conflicts. * update text2sql config reload to be compitable with new config format. * modify default config. * 1.add self-consistency feature for text2sql. 2.upgrade llm api call from sync to async. 3.refactor text2sql module. 4. refactor semantical retriever modules. * merege with upstream master * add general retrieve service. * add api service for sql_agent for crud opereations of few-shots examples. * modify requirements * add auto-cot feature --------- Co-authored-by: shaoweigong <shaoweigong@tencent.com>
This commit is contained in:
@@ -22,7 +22,7 @@ from utils.chromadb_utils import (get_chroma_collection_size, query_chroma_colle
|
||||
add_chroma_collection, update_chroma_collection, delete_chroma_collection_by_ids,
|
||||
empty_chroma_collection_2)
|
||||
|
||||
from instances.text2vec import Text2VecEmbeddingFunction
|
||||
from utils.text2vec import Text2VecEmbeddingFunction
|
||||
|
||||
class ChromaCollectionRetriever(object):
|
||||
def __init__(self, collection:Collection):
|
||||
|
||||
@@ -14,7 +14,7 @@ import chromadb
|
||||
from chromadb.config import Settings
|
||||
from chromadb.api import Collection, Documents, Embeddings
|
||||
|
||||
from instances.text2vec import Text2VecEmbeddingFunction
|
||||
from utils.text2vec import Text2VecEmbeddingFunction
|
||||
from instances.chromadb_instance import client
|
||||
|
||||
from config.config_parse import SOLVED_QUERY_COLLECTION_NAME, PRESET_QUERY_COLLECTION_NAME
|
||||
|
||||
167
chat/python/services/s2ql/auto_cot.py
Normal file
167
chat/python/services/s2ql/auto_cot.py
Normal file
@@ -0,0 +1,167 @@
|
||||
# -*- coding:utf-8 -*-
|
||||
from typing import Any, List, Mapping, Optional, Union, Tuple
|
||||
|
||||
import os
|
||||
import sys
|
||||
|
||||
sys.path.append(os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))))
|
||||
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||||
sys.path.append(os.path.dirname(os.path.abspath(__file__)))
|
||||
|
||||
from instances.logging_instance import logger
|
||||
from instances.text2vec_instance import emb_func
|
||||
|
||||
from sqlglot import parse_one, exp
|
||||
import numpy as np
|
||||
|
||||
def sql2schema_linking(sql: str):
|
||||
sql_ast = parse_one(sql)
|
||||
|
||||
fields_raw = []
|
||||
table_alias_map = dict()
|
||||
|
||||
literals = []
|
||||
fields = []
|
||||
|
||||
for literal in sql_ast.find_all(exp.Literal):
|
||||
literals.append(literal.output_name)
|
||||
|
||||
for column in sql_ast.find_all(exp.Column):
|
||||
fields_raw.append({
|
||||
'column_table_alias': column.table,
|
||||
'column_name': column.name,
|
||||
})
|
||||
|
||||
for table in sql_ast.find_all(exp.Table):
|
||||
if table.alias not in table_alias_map:
|
||||
table_alias_map[table.alias] = table.name
|
||||
|
||||
logger.debug(f'literals: {literals}')
|
||||
logger.debug(f'fields_raw: {fields_raw}')
|
||||
logger.debug(f'table_alias_map: {table_alias_map}')
|
||||
|
||||
for field in fields_raw:
|
||||
column_table_alias = field['column_table_alias']
|
||||
column_name = field['column_name']
|
||||
|
||||
if column_table_alias.strip() == '':
|
||||
column_table = ''
|
||||
fields.append((column_table, column_name))
|
||||
elif column_table_alias in table_alias_map:
|
||||
column_table = table_alias_map[column_table_alias]
|
||||
fields.append((column_table, column_name))
|
||||
elif column_table_alias in table_alias_map.values():
|
||||
column_table = column_table_alias
|
||||
fields.append((column_table, column_name))
|
||||
else:
|
||||
logger.error(f'column_table_alias: {column_table_alias} not in table_alias_map: {table_alias_map}')
|
||||
raise Exception(f'column_table_alias: {column_table_alias} not in table_alias_map: {table_alias_map}')
|
||||
|
||||
return {
|
||||
'fields': list(set(fields)),
|
||||
'literals': literals
|
||||
}
|
||||
|
||||
|
||||
def get_question_slices(question: str, min_window_size: int, max_window_size: int):
|
||||
assert min_window_size <= max_window_size
|
||||
assert min_window_size > 1
|
||||
assert max_window_size < len(question)+1
|
||||
|
||||
question_slices = []
|
||||
for i in range(len(question)):
|
||||
for j in range(i+1, len(question)+1):
|
||||
if j-i >= min_window_size and j-i <= max_window_size:
|
||||
question_slices.append(question[i:j])
|
||||
|
||||
return question_slices
|
||||
|
||||
|
||||
def schema_linking_match(fields: List[Tuple[str,str]], question: str, min_window_size: int, max_window_size: int):
|
||||
question_slices = get_question_slices(question, min_window_size, max_window_size)
|
||||
assert len(question_slices) > 0
|
||||
logger.debug('question_slices_len:{}'.format(len(question_slices)))
|
||||
logger.debug(f'question_slices: {question_slices}')
|
||||
|
||||
question_slices_embeddings = emb_func(question_slices)
|
||||
fields_embeddings = emb_func([field[1] for field in fields])
|
||||
|
||||
fields_embeddings = np.array(fields_embeddings) # (n_fields, 768)
|
||||
question_slices_embeddings = np.array(question_slices_embeddings) # (n_question_slices, 768)
|
||||
|
||||
question_slices_embeddings_norm = question_slices_embeddings / np.linalg.norm(question_slices_embeddings, axis=1, keepdims=True) # (n_question_slices, 768)
|
||||
question_slices_embeddings_norm_transpose = question_slices_embeddings_norm.T # (768, n_question_slices)
|
||||
|
||||
if len(fields) > 0:
|
||||
fields_embeddings_norm = fields_embeddings / np.linalg.norm(fields_embeddings, axis=1, keepdims=True) # (n_fields, 768)
|
||||
fields_question_slices_similarity = np.matmul(fields_embeddings_norm, question_slices_embeddings_norm_transpose) # (n_fields, n_question_slices)
|
||||
logger.debug('fields_question_slices_similarity_max:{}'.format(np.max(fields_question_slices_similarity, axis=1)))
|
||||
fields_question_slices_argmax = np.argmax(fields_question_slices_similarity, axis=1) # (n_fields, )
|
||||
logger.debug('fields_question_slices_argmax:{}'.format(fields_question_slices_argmax))
|
||||
|
||||
fields_question_slices_pair = []
|
||||
for i in range(len(fields)):
|
||||
if fields[i][0]!="":
|
||||
fields_question_slices_pair.append((fields[i][0]+'.'+fields[i][1], question_slices[fields_question_slices_argmax[i]]))
|
||||
else:
|
||||
fields_question_slices_pair.append((fields[i][1], question_slices[fields_question_slices_argmax[i]]))
|
||||
|
||||
logger.debug(f'fields_question_slices_pair: {fields_question_slices_pair}')
|
||||
else:
|
||||
fields_question_slices_pair = []
|
||||
|
||||
return fields_question_slices_pair
|
||||
|
||||
|
||||
def construct_schema_linking_cot(question:str, fields_question_slices_pair:List[Tuple[str,str]], literals_list:List[str]):
|
||||
cot_intro= """Let’s think step by step. In the question "{question}", we are asked:""".format(question=question)
|
||||
|
||||
schema_linkings_list = []
|
||||
|
||||
fields_cot_template = """"{question_slice}" so we need column = [{field}]"""
|
||||
fields_cot_list = []
|
||||
for field, question_slice in fields_question_slices_pair:
|
||||
fields_cot_list.append(fields_cot_template.format(question_slice=question_slice, field=field))
|
||||
schema_linkings_list.append(field)
|
||||
fields_cot = '\n'.join(fields_cot_list)
|
||||
|
||||
literals_cot_template = """Based on the tables, columns, and Foreign_keys, The set of possible cell values are = [{literals}]. So the Schema_links are:"""
|
||||
literals_cot = literals_cot_template.format(literals=",".join(literals_list))
|
||||
|
||||
schema_linkings_list += literals_list
|
||||
schema_linking_str = '[' + ",".join(schema_linkings_list) + ']'
|
||||
schema_linkings = 'Schema_links: '+ schema_linking_str
|
||||
|
||||
cot = """{cot_intro}""".format(cot_intro=cot_intro)
|
||||
if len(fields_cot_list) > 0:
|
||||
cot += '\n' + fields_cot
|
||||
|
||||
cot += '\n' + literals_cot
|
||||
cot += '\n' + schema_linkings
|
||||
|
||||
return cot, schema_linking_str
|
||||
|
||||
def auto_cot_run(question, sql, min_window_size, max_window_size):
|
||||
sql_entity = sql2schema_linking(sql)
|
||||
logger.debug(f'sql_entity: {sql_entity}')
|
||||
|
||||
fields = sql_entity['fields']
|
||||
literals = sql_entity['literals']
|
||||
|
||||
field_linked_pairs = schema_linking_match(fields, question, min_window_size, max_window_size)
|
||||
logger.debug(f'field_linked_pairs: {field_linked_pairs}')
|
||||
|
||||
auto_schema_linking_cot, auto_schema_linkings = construct_schema_linking_cot(question, field_linked_pairs, literals)
|
||||
logger.debug(f'auto_schema_linking_cot: {auto_schema_linking_cot}')
|
||||
logger.debug(f'auto_schema_linkings: {auto_schema_linkings}')
|
||||
|
||||
return auto_schema_linking_cot, auto_schema_linkings
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
question = "没有获得过奖项的高校有哪几所?"
|
||||
sql = "select 名称 from 高校 where 词条id not in ( select 高校id from 奖项 )"
|
||||
min_window_size = 6
|
||||
max_window_size = 10
|
||||
|
||||
generated_schema_linking_cot, generated_schema_linkings = auto_cot_run(question, sql, min_window_size, max_window_size)
|
||||
82
chat/python/services/s2ql/auto_cot_run.py
Normal file
82
chat/python/services/s2ql/auto_cot_run.py
Normal file
@@ -0,0 +1,82 @@
|
||||
# -*- coding:utf-8 -*-
|
||||
|
||||
import os
|
||||
import sys
|
||||
from typing import Any, List, Union, Mapping
|
||||
|
||||
sys.path.append(os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))))
|
||||
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||||
sys.path.append(os.path.dirname(os.path.abspath(__file__)))
|
||||
|
||||
from instances.logging_instance import logger
|
||||
|
||||
from auto_cot import auto_cot_run
|
||||
|
||||
|
||||
|
||||
def transform_sql_example(question:str, current_date:str, table_name:str, field_list: Union[str, List[str]], prior_linkings: Union[str, Mapping[str,str]], prior_exts:str, sql:str=None):
|
||||
db_schema = f"Table: {table_name}, Columns = {field_list}\nForeign_keys: []"
|
||||
|
||||
prior_linkings_pairs = []
|
||||
if isinstance(prior_linkings, str):
|
||||
prior_linkings = prior_linkings.strip('[]')
|
||||
if prior_linkings.strip() == '':
|
||||
prior_linkings = []
|
||||
else:
|
||||
prior_linkings = prior_linkings.split(',')
|
||||
logger.debug(f'prior_linkings: {prior_linkings}')
|
||||
|
||||
for prior_linking in prior_linkings:
|
||||
logger.debug(f'prior_linking: {prior_linking}')
|
||||
entity_value, entity_type = prior_linking.split('->')
|
||||
entity_linking = """’{}‘是一个’{}‘""".format(entity_value, entity_type)
|
||||
prior_linkings_pairs.append(entity_linking)
|
||||
elif isinstance(prior_linkings, Mapping):
|
||||
for entity_value, entity_type in prior_linkings.items():
|
||||
entity_linking = """’{}‘是一个’{}‘""".format(entity_value, entity_type)
|
||||
prior_linkings_pairs.append(entity_linking)
|
||||
|
||||
prior_linkings_str = ','.join(prior_linkings_pairs)
|
||||
|
||||
current_data_str = """当前的日期是{}""".format(current_date)
|
||||
|
||||
question_augmented = """{question} (补充信息:{prior_linking}。{current_date}) (备注: {prior_exts})""".format(question=question, prior_linking=prior_linkings_str, prior_exts=prior_exts, current_date=current_data_str)
|
||||
|
||||
return question_augmented, db_schema, sql
|
||||
|
||||
|
||||
def transform_sql_example_autoCoT_run(examplar_list, min_window_size, max_window_size):
|
||||
transformed_sql_examplar_list = []
|
||||
|
||||
for examplar in examplar_list:
|
||||
question = examplar['question']
|
||||
current_date = examplar['currentDate']
|
||||
table_name = examplar['tableName']
|
||||
field_list = examplar['fieldsList']
|
||||
prior_linkings = examplar['priorSchemaLinks']
|
||||
sql = examplar['sql']
|
||||
if 'priorExts' not in examplar:
|
||||
prior_exts = ''
|
||||
else:
|
||||
prior_exts = examplar['priorExts']
|
||||
|
||||
question_augmented, db_schema, sql = transform_sql_example(question=question, current_date=current_date, table_name=table_name, field_list=field_list, prior_linkings=prior_linkings, prior_exts=prior_exts, sql=sql)
|
||||
logger.debug(f'question_augmented: {question_augmented}')
|
||||
logger.debug(f'db_schema: {db_schema}')
|
||||
logger.debug(f'sql: {sql}')
|
||||
|
||||
generated_schema_linking_cot, generated_schema_linkings = auto_cot_run(question_augmented, sql, min_window_size, max_window_size)
|
||||
|
||||
transformed_sql_examplar = dict()
|
||||
transformed_sql_examplar['question'] = question
|
||||
transformed_sql_examplar['questionAugmented'] = question_augmented
|
||||
transformed_sql_examplar['dbSchema'] = db_schema
|
||||
transformed_sql_examplar['sql'] = sql
|
||||
transformed_sql_examplar['generatedSchemaLinkingCoT'] = generated_schema_linking_cot
|
||||
transformed_sql_examplar['generatedSchemaLinkings'] = generated_schema_linkings
|
||||
|
||||
logger.debug(f'transformed_sql_examplar: {transformed_sql_examplar}')
|
||||
|
||||
transformed_sql_examplar_list.append(transformed_sql_examplar)
|
||||
|
||||
return transformed_sql_examplar_list
|
||||
@@ -36,7 +36,10 @@ class FewShotPromptTemplate2(object):
|
||||
self.few_shot_retriever.update_queries(query_text_list=query_text_list, query_id_list=example_ids, metadatas=example_units)
|
||||
|
||||
def delete_few_shot_example(self, example_ids: List[str])-> None:
|
||||
self.few_shot_retriever.delete_queries_by_ids(query_ids=example_ids)
|
||||
self.few_shot_retriever.delete_queries_by_ids(query_ids=example_ids)
|
||||
|
||||
def get_few_shot_example(self, example_ids: List[str]):
|
||||
return self.few_shot_retriever.get_query_by_ids(query_ids=example_ids)
|
||||
|
||||
def count_few_shot_example(self)-> int:
|
||||
return self.few_shot_retriever.get_query_size()
|
||||
@@ -73,3 +76,4 @@ class FewShotPromptTemplate2(object):
|
||||
few_shot_example_str = self.few_shot_seperator.join(few_shot_example_str_unit_list)
|
||||
|
||||
return few_shot_example_str
|
||||
|
||||
40
chat/python/services/s2ql/examples_reload_run.py
Normal file
40
chat/python/services/s2ql/examples_reload_run.py
Normal file
@@ -0,0 +1,40 @@
|
||||
# -*- coding:utf-8 -*-
|
||||
import json
|
||||
import os
|
||||
import sys
|
||||
from typing import List, Mapping
|
||||
|
||||
import requests
|
||||
|
||||
sys.path.append(os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))))
|
||||
sys.path.append(os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))))
|
||||
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||||
sys.path.append(os.path.dirname(os.path.abspath(__file__)))
|
||||
|
||||
from instances.logging_instance import logger
|
||||
from config.config_parse import (
|
||||
TEXT2DSL_EXAMPLE_NUM, TEXT2DSL_FEWSHOTS_NUM, TEXT2DSL_SELF_CONSISTENCY_NUM,
|
||||
LLMPARSER_HOST, LLMPARSER_PORT,)
|
||||
from few_shot_example.s2ql_examplar import examplars as sql_examplars
|
||||
|
||||
|
||||
def text2dsl_agent_wrapper_setting_update(llm_host:str, llm_port:str,
|
||||
sql_examplars:List[Mapping[str, str]],
|
||||
example_nums:int, fewshot_nums:int, self_consistency_nums:int):
|
||||
|
||||
sql_ids = [str(i) for i in range(0, len(sql_examplars))]
|
||||
|
||||
url = f"http://{llm_host}:{llm_port}/query2sql_setting_update"
|
||||
payload = {
|
||||
"sqlExamplars":sql_examplars, "sqlIds": sql_ids,
|
||||
"exampleNums":example_nums, "fewshotNums":fewshot_nums, "selfConsistencyNums":self_consistency_nums
|
||||
}
|
||||
headers = {'content-type': 'application/json'}
|
||||
response = requests.post(url, data=json.dumps(payload), headers=headers)
|
||||
logger.info(response.text)
|
||||
|
||||
if __name__ == "__main__":
|
||||
text2dsl_agent_wrapper_setting_update(LLMPARSER_HOST,LLMPARSER_PORT,
|
||||
sql_examplars, TEXT2DSL_EXAMPLE_NUM, TEXT2DSL_FEWSHOTS_NUM, TEXT2DSL_SELF_CONSISTENCY_NUM)
|
||||
|
||||
|
||||
@@ -11,10 +11,10 @@ sys.path.append(os.path.dirname(os.path.abspath(__file__)))
|
||||
from instances.logging_instance import logger
|
||||
|
||||
|
||||
def schema_link_parse(schema_link_output):
|
||||
def schema_link_parse(schema_link_output: str):
|
||||
try:
|
||||
schema_link_output = schema_link_output.strip()
|
||||
pattern = r"Schema_links:(.*)"
|
||||
pattern = r'Schema_links:(.*)'
|
||||
schema_link_output = re.findall(pattern, schema_link_output, re.DOTALL)[0].strip()
|
||||
except Exception as e:
|
||||
logger.exception(e)
|
||||
@@ -22,28 +22,29 @@ def schema_link_parse(schema_link_output):
|
||||
|
||||
return schema_link_output
|
||||
|
||||
|
||||
def combo_schema_link_parse(schema_linking_sql_combo_output: str):
|
||||
try:
|
||||
schema_linking_sql_combo_output = schema_linking_sql_combo_output.strip()
|
||||
pattern = r"Schema_links:(\[.*?\])"
|
||||
pattern = r'Schema_links:(\[.*?\])|Schema_links: (\[.*?\])'
|
||||
schema_links_match = re.search(pattern, schema_linking_sql_combo_output)
|
||||
|
||||
if schema_links_match:
|
||||
if schema_links_match.group(1):
|
||||
schema_links = schema_links_match.group(1)
|
||||
elif schema_links_match.group(2):
|
||||
schema_links = schema_links_match.group(2)
|
||||
else:
|
||||
schema_links = None
|
||||
|
||||
except Exception as e:
|
||||
logger.info(e)
|
||||
logger.exception(e)
|
||||
schema_links = None
|
||||
|
||||
return schema_links
|
||||
|
||||
|
||||
def combo_sql_parse(schema_linking_sql_combo_output: str):
|
||||
try:
|
||||
schema_linking_sql_combo_output = schema_linking_sql_combo_output.strip()
|
||||
pattern = r"SQL:(.*)"
|
||||
pattern = r'SQL:(.*)'
|
||||
sql_match = re.search(pattern, schema_linking_sql_combo_output)
|
||||
|
||||
if sql_match:
|
||||
@@ -55,3 +56,4 @@ def combo_sql_parse(schema_linking_sql_combo_output: str):
|
||||
sql = None
|
||||
|
||||
return sql
|
||||
|
||||
63
chat/python/services/s2ql/run.py
Normal file
63
chat/python/services/s2ql/run.py
Normal file
@@ -0,0 +1,63 @@
|
||||
# -*- coding:utf-8 -*-
|
||||
|
||||
import asyncio
|
||||
|
||||
import os
|
||||
import sys
|
||||
|
||||
sys.path.append(os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))))
|
||||
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||||
sys.path.append(os.path.dirname(os.path.abspath(__file__)))
|
||||
|
||||
import json
|
||||
|
||||
from s2ql.constructor import FewShotPromptTemplate2
|
||||
from s2ql.sql_agent import Text2DSLAgent, Text2DSLAgentAutoCoT, Text2DSLAgentWrapper
|
||||
|
||||
from instances.llm_instance import llm
|
||||
from instances.chromadb_instance import client as chromadb_client
|
||||
from instances.logging_instance import logger
|
||||
from instances.text2vec_instance import emb_func
|
||||
|
||||
from few_shot_example.s2ql_examplar import examplars as sql_examplars
|
||||
from config.config_parse import (TEXT2DSLAGENT_COLLECTION_NAME, TEXT2DSLAGENTACT_COLLECTION_NAME,
|
||||
TEXT2DSL_EXAMPLE_NUM, TEXT2DSL_FEWSHOTS_NUM, TEXT2DSL_SELF_CONSISTENCY_NUM,
|
||||
ACT_MIN_WINDOWN_SIZE, ACT_MAX_WINDOWN_SIZE)
|
||||
|
||||
|
||||
text2dsl_agent_collection = chromadb_client.get_or_create_collection(name=TEXT2DSLAGENT_COLLECTION_NAME,
|
||||
embedding_function=emb_func,
|
||||
metadata={"hnsw:space": "cosine"})
|
||||
text2dsl_agent_act_collection = chromadb_client.get_or_create_collection(name=TEXT2DSLAGENTACT_COLLECTION_NAME,
|
||||
embedding_function=emb_func,
|
||||
metadata={"hnsw:space": "cosine"})
|
||||
|
||||
text2dsl_agent_example_prompter = FewShotPromptTemplate2(collection=text2dsl_agent_collection,
|
||||
retrieval_key="question",
|
||||
few_shot_seperator='\n\n')
|
||||
text2dsl_agent_act_example_prompter = FewShotPromptTemplate2(collection=text2dsl_agent_act_collection,
|
||||
retrieval_key="question",
|
||||
few_shot_seperator='\n\n')
|
||||
|
||||
text2sql_agent = Text2DSLAgent(num_fewshots=TEXT2DSL_FEWSHOTS_NUM, num_examples=TEXT2DSL_EXAMPLE_NUM, num_self_consistency=TEXT2DSL_SELF_CONSISTENCY_NUM,
|
||||
sql_example_prompter=text2dsl_agent_example_prompter, llm=llm)
|
||||
text2sql_agent_autoCoT = Text2DSLAgentAutoCoT(num_fewshots=TEXT2DSL_FEWSHOTS_NUM, num_examples=TEXT2DSL_EXAMPLE_NUM, num_self_consistency=TEXT2DSL_SELF_CONSISTENCY_NUM,
|
||||
sql_example_prompter=text2dsl_agent_act_example_prompter, llm=llm,
|
||||
auto_cot_min_window_size=ACT_MIN_WINDOWN_SIZE, auto_cot_max_window_size=ACT_MAX_WINDOWN_SIZE)
|
||||
|
||||
sql_ids = [str(i) for i in range(0, len(sql_examplars))]
|
||||
text2sql_agent.reload_setting(sql_ids, sql_examplars, TEXT2DSL_EXAMPLE_NUM, TEXT2DSL_FEWSHOTS_NUM, TEXT2DSL_SELF_CONSISTENCY_NUM)
|
||||
|
||||
if text2sql_agent_autoCoT.count_examples()==0:
|
||||
source_dir_path = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||||
example_dir_path = os.path.join(source_dir_path, 'few_shot_example')
|
||||
example_json_file = os.path.join(example_dir_path, 's2ql_examplar3_transformed.json')
|
||||
with open(example_json_file, 'r') as f:
|
||||
transformed_sql_examplar_list = json.load(f)
|
||||
|
||||
transformed_sql_examplar_ids = [str(i) for i in range(0, len(transformed_sql_examplar_list))]
|
||||
text2sql_agent_autoCoT.reload_setting_autoCoT(transformed_sql_examplar_ids, transformed_sql_examplar_list, TEXT2DSL_EXAMPLE_NUM, TEXT2DSL_FEWSHOTS_NUM, TEXT2DSL_SELF_CONSISTENCY_NUM)
|
||||
|
||||
|
||||
text2sql_agent_router = Text2DSLAgentWrapper(sql_agent_act=text2sql_agent_autoCoT)
|
||||
|
||||
778
chat/python/services/s2ql/sql_agent.py
Normal file
778
chat/python/services/s2ql/sql_agent.py
Normal file
@@ -0,0 +1,778 @@
|
||||
import os
|
||||
import sys
|
||||
from typing import List, Union, Mapping, Any
|
||||
from collections import Counter
|
||||
import random
|
||||
import asyncio
|
||||
from enum import Enum
|
||||
|
||||
from langchain.llms.base import BaseLLM
|
||||
|
||||
sys.path.append(os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))))
|
||||
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||||
sys.path.append(os.path.dirname(os.path.abspath(__file__)))
|
||||
|
||||
from instances.logging_instance import logger
|
||||
|
||||
from s2ql.constructor import FewShotPromptTemplate2
|
||||
from s2ql.output_parser import schema_link_parse, combo_schema_link_parse, combo_sql_parse
|
||||
from s2ql.auto_cot_run import transform_sql_example, transform_sql_example_autoCoT_run
|
||||
|
||||
|
||||
class Text2DSLAgentBase(object):
|
||||
def __init__(self, num_fewshots:int, num_examples:int, num_self_consistency:int,
|
||||
sql_example_prompter:FewShotPromptTemplate2, llm: BaseLLM) -> None:
|
||||
self.num_fewshots = num_fewshots
|
||||
self.num_examples = num_examples
|
||||
assert self.num_fewshots <= self.num_examples
|
||||
self.num_self_consistency = num_self_consistency
|
||||
|
||||
self.llm = llm
|
||||
self.sql_example_prompter = sql_example_prompter
|
||||
|
||||
def get_examples_candidates(self, question: str, filter_condition: Mapping[str, str], num_examples: int)->List[Mapping[str, str]]:
|
||||
few_shot_example_meta_list = self.sql_example_prompter.retrieve_few_shot_example(question, num_examples, filter_condition)
|
||||
|
||||
return few_shot_example_meta_list
|
||||
|
||||
def get_fewshot_example_combos(self, example_meta_list:List[Mapping[str, str]], num_fewshots:int)-> List[List[Mapping[str, str]]]:
|
||||
fewshot_example_list = []
|
||||
for i in range(0, self.num_self_consistency):
|
||||
random.shuffle(example_meta_list)
|
||||
fewshot_example_list.append(example_meta_list[:num_fewshots])
|
||||
|
||||
return fewshot_example_list
|
||||
|
||||
def self_consistency_vote(self, output_res_pool:List[str]):
|
||||
output_res_counts = Counter(output_res_pool)
|
||||
output_res_max = output_res_counts.most_common(1)[0][0]
|
||||
total_output_num = len(output_res_pool)
|
||||
|
||||
vote_percentage = {k: (v/total_output_num) for k,v in output_res_counts.items()}
|
||||
|
||||
return output_res_max, vote_percentage
|
||||
|
||||
def schema_linking_list_str_unify(self, schema_linking_list: List[str])-> List[str]:
|
||||
schema_linking_list_unify = []
|
||||
for schema_linking_str in schema_linking_list:
|
||||
schema_linking_str_unify = ','.join(sorted([item.strip() for item in schema_linking_str.strip('[]').split(',')]))
|
||||
schema_linking_str_unify = f'[{schema_linking_str_unify}]'
|
||||
schema_linking_list_unify.append(schema_linking_str_unify)
|
||||
|
||||
return schema_linking_list_unify
|
||||
|
||||
class Text2DSLAgentAutoCoT(Text2DSLAgentBase):
|
||||
def __init__(self, num_fewshots:int, num_examples:int, num_self_consistency:int,
|
||||
sql_example_prompter:FewShotPromptTemplate2, llm: BaseLLM,
|
||||
auto_cot_min_window_size: int, auto_cot_max_window_size: int):
|
||||
super().__init__(num_fewshots, num_examples, num_self_consistency, sql_example_prompter, llm)
|
||||
|
||||
assert auto_cot_min_window_size <= auto_cot_max_window_size
|
||||
self.auto_cot_min_window_size = auto_cot_min_window_size
|
||||
self.auto_cot_max_window_size = auto_cot_max_window_size
|
||||
|
||||
def reload_setting(self, sql_example_ids: List[str], sql_example_units: List[Mapping[str,str]], num_examples:int, num_fewshots:int, num_self_consistency:int):
|
||||
self.num_fewshots = num_fewshots
|
||||
self.num_examples = num_examples
|
||||
assert self.num_fewshots <= self.num_examples
|
||||
self.num_self_consistency = num_self_consistency
|
||||
assert self.num_self_consistency >= 1
|
||||
|
||||
new_sql_example_unit_list = transform_sql_example_autoCoT_run(sql_example_units, self.auto_cot_min_window_size, self.auto_cot_max_window_size)
|
||||
self.sql_example_prompter.reload_few_shot_example(sql_example_ids, new_sql_example_unit_list)
|
||||
|
||||
def reload_setting_autoCoT(self, sql_example_ids: List[str], auto_cot_sql_example_units: List[Mapping[str,str]], num_examples:int, num_fewshots:int, num_self_consistency:int):
|
||||
self.num_fewshots = num_fewshots
|
||||
self.num_examples = num_examples
|
||||
assert self.num_fewshots <= self.num_examples
|
||||
self.num_self_consistency = num_self_consistency
|
||||
assert self.num_self_consistency >= 1
|
||||
|
||||
self.sql_example_prompter.reload_few_shot_example(sql_example_ids, auto_cot_sql_example_units)
|
||||
|
||||
def add_examples(self, sql_example_ids: List[str], sql_example_units: List[Mapping[str,str]]):
|
||||
new_sql_example_unit_list = transform_sql_example_autoCoT_run(sql_example_units, self.auto_cot_min_window_size, self.auto_cot_max_window_size)
|
||||
self.sql_example_prompter.add_few_shot_example(sql_example_ids, new_sql_example_unit_list)
|
||||
|
||||
def update_examples(self, sql_example_ids: List[str], sql_example_units: List[Mapping[str,str]]):
|
||||
new_sql_example_unit_list = transform_sql_example_autoCoT_run(sql_example_units, self.auto_cot_min_window_size, self.auto_cot_max_window_size)
|
||||
self.sql_example_prompter.update_few_shot_example(sql_example_ids, new_sql_example_unit_list)
|
||||
|
||||
def delete_examples(self, sql_example_ids: List[str]):
|
||||
self.sql_example_prompter.delete_few_shot_example(sql_example_ids)
|
||||
|
||||
def count_examples(self):
|
||||
return self.sql_example_prompter.count_few_shot_example()
|
||||
|
||||
def get_examples(self, sql_example_ids: List[str]):
|
||||
return self.sql_example_prompter.get_few_shot_example(sql_example_ids)
|
||||
|
||||
def generate_schema_linking_prompt(self, question: str, current_date:str, domain_name: str, fields_list: List[str],
|
||||
prior_schema_links: Mapping[str,str], prior_exts:str, fewshot_example_list:List[Mapping[str, str]])-> str:
|
||||
|
||||
instruction = "# Find the schema_links for generating SQL queries for each question based on the database schema and Foreign keys."
|
||||
|
||||
schema_linking_example_keys = ["questionAugmented", "dbSchema", "generatedSchemaLinkingCoT"]
|
||||
schema_linking_example_template = "{dbSchema}\nQ: {questionAugmented}\nA: {generatedSchemaLinkingCoT}"
|
||||
schema_linking_fewshot_prompt = self.sql_example_prompter.make_few_shot_example_prompt(few_shot_template=schema_linking_example_template,
|
||||
example_keys=schema_linking_example_keys,
|
||||
few_shot_example_meta_list=fewshot_example_list)
|
||||
|
||||
question_augmented, db_schema, _ = transform_sql_example(question, current_date, domain_name, fields_list, prior_schema_links, prior_exts)
|
||||
new_case_template = """{dbSchema}\nQ: {questionAugmented1}\nA: Let’s think step by step. In the question "{questionAugmented2}", we are asked:"""
|
||||
new_case_prompt = new_case_template.format(dbSchema=db_schema, questionAugmented1=question_augmented, questionAugmented2=question_augmented)
|
||||
|
||||
schema_linking_prompt = instruction + '\n\n' + schema_linking_fewshot_prompt + '\n\n' + new_case_prompt
|
||||
|
||||
return schema_linking_prompt
|
||||
|
||||
|
||||
def generate_schema_linking_prompt_pool(self, question: str, current_date:str, domain_name: str, fields_list: List[str],
|
||||
prior_schema_links: Mapping[str,str], prior_exts:str, fewshot_example_list_pool:List[List[Mapping[str, str]]])-> List[str]:
|
||||
schema_linking_prompt_pool = []
|
||||
for fewshot_example_list in fewshot_example_list_pool:
|
||||
schema_linking_prompt = self.generate_schema_linking_prompt(question, current_date, domain_name, fields_list, prior_schema_links, prior_exts, fewshot_example_list)
|
||||
schema_linking_prompt_pool.append(schema_linking_prompt)
|
||||
|
||||
return schema_linking_prompt_pool
|
||||
|
||||
def generate_sql_prompt(self, question: str, domain_name: str,fields_list: List[str],
|
||||
schema_link_str: str, current_date: str, prior_schema_links: Mapping[str,str], prior_exts:str,
|
||||
fewshot_example_list:List[Mapping[str, str]])-> str:
|
||||
|
||||
instruction = "# Use the the schema links to generate the SQL queries for each of the questions."
|
||||
sql_example_keys = ["questionAugmented", "dbSchema", "generatedSchemaLinkings", "sql"]
|
||||
sql_example_template = "{dbSchema}\nQ: {questionAugmented}\nSchema_links: {generatedSchemaLinkings}\nSQL: {sql}"
|
||||
|
||||
sql_example_fewshot_prompt = self.sql_example_prompter.make_few_shot_example_prompt(few_shot_template=sql_example_template,
|
||||
example_keys=sql_example_keys,
|
||||
few_shot_example_meta_list=fewshot_example_list)
|
||||
|
||||
question_augmented, db_schema, _ = transform_sql_example(question, current_date, domain_name, fields_list, prior_schema_links, prior_exts)
|
||||
new_case_template = "{dbSchema}\nQ: {questionAugmented}\nSchema_links: {schemaLinkings}\nSQL: "
|
||||
new_case_prompt = new_case_template.format(dbSchema=db_schema, questionAugmented=question_augmented, schemaLinkings=schema_link_str)
|
||||
|
||||
sql_example_prompt = instruction + '\n\n' + sql_example_fewshot_prompt + '\n\n' + new_case_prompt
|
||||
|
||||
return sql_example_prompt
|
||||
|
||||
def generate_sql_prompt_pool(self, question: str, domain_name: str,fields_list: List[str],
|
||||
schema_link_str_pool: List[str], current_date: str, prior_schema_links: Mapping[str,str], prior_exts:str,
|
||||
fewshot_example_list_pool:List[List[Mapping[str, str]]])-> List[str]:
|
||||
sql_prompt_pool = []
|
||||
for schema_link_str, fewshot_example_list in zip(schema_link_str_pool, fewshot_example_list_pool):
|
||||
sql_prompt = self.generate_sql_prompt(question, domain_name, fields_list, schema_link_str, current_date, prior_schema_links, prior_exts, fewshot_example_list)
|
||||
sql_prompt_pool.append(sql_prompt)
|
||||
|
||||
return sql_prompt_pool
|
||||
|
||||
def generate_schema_linking_sql_prompt(self, question: str, current_date:str, domain_name: str, fields_list: List[str],
|
||||
prior_schema_links: Mapping[str,str], prior_exts:str, fewshot_example_list:List[Mapping[str, str]]):
|
||||
|
||||
instruction = "# Find the schema_links for generating SQL queries for each question based on the database schema and Foreign keys. Then use the the schema links to generate the SQL queries for each of the questions."
|
||||
|
||||
example_keys = ["questionAugmented", "dbSchema", "generatedSchemaLinkingCoT","sql"]
|
||||
example_template = "{dbSchema}\nQ: {questionAugmented}\nA: {generatedSchemaLinkingCoT}\nSQL: {sql}"
|
||||
fewshot_prompt = self.sql_example_prompter.make_few_shot_example_prompt(few_shot_template=example_template,
|
||||
example_keys=example_keys,
|
||||
few_shot_example_meta_list=fewshot_example_list)
|
||||
|
||||
question_augmented, db_schema, _ = transform_sql_example(question, current_date, domain_name, fields_list, prior_schema_links, prior_exts)
|
||||
new_case_template = """{dbSchema}\nQ: {questionAugmented1}\nA: Let’s think step by step. In the question "{questionAugmented2}", we are asked:"""
|
||||
new_case_prompt = new_case_template.format(dbSchema=db_schema, questionAugmented1=question_augmented, questionAugmented2=question_augmented)
|
||||
|
||||
prompt = instruction + '\n\n' + fewshot_prompt + '\n\n' + new_case_prompt
|
||||
|
||||
return prompt
|
||||
|
||||
def generate_schema_linking_sql_prompt_pool(self, question: str, current_date:str, domain_name: str, fields_list: List[str],
|
||||
prior_schema_links: Mapping[str,str], prior_exts:str, fewshot_example_list_pool:List[List[Mapping[str, str]]])-> List[str]:
|
||||
schema_linking_sql_prompt_pool = []
|
||||
for fewshot_example_list in fewshot_example_list_pool:
|
||||
schema_linking_sql_prompt = self.generate_schema_linking_sql_prompt(question, current_date, domain_name, fields_list, prior_schema_links, prior_exts, fewshot_example_list)
|
||||
schema_linking_sql_prompt_pool.append(schema_linking_sql_prompt)
|
||||
|
||||
return schema_linking_sql_prompt_pool
|
||||
|
||||
async def async_query2sql(self, question: str, filter_condition: Mapping[str,str],
|
||||
model_name: str, fields_list: List[str],
|
||||
current_date: str, prior_schema_links: Mapping[str,str], prior_exts: str):
|
||||
logger.info("question: {}".format(question))
|
||||
logger.info("filter_condition: {}".format(filter_condition))
|
||||
logger.info("model_name: {}".format(model_name))
|
||||
logger.info("fields_list: {}".format(fields_list))
|
||||
logger.info("current_date: {}".format(current_date))
|
||||
logger.info("prior_schema_links: {}".format(prior_schema_links))
|
||||
logger.info("prior_exts: {}".format(prior_exts))
|
||||
|
||||
fewshot_example_meta_list = self.get_examples_candidates(question, filter_condition, self.num_examples)
|
||||
schema_linking_prompt = self.generate_schema_linking_prompt(question, current_date, model_name, fields_list, prior_schema_links, prior_exts, fewshot_example_meta_list)
|
||||
logger.debug("schema_linking_prompt->{}".format(schema_linking_prompt))
|
||||
schema_link_output = await self.llm._call_async(schema_linking_prompt)
|
||||
logger.debug("schema_link_output->{}".format(schema_link_output))
|
||||
|
||||
schema_link_str = schema_link_parse(schema_link_output)
|
||||
logger.debug("schema_link_str->{}".format(schema_link_str))
|
||||
|
||||
sql_prompt = self.generate_sql_prompt(question, model_name, fields_list, schema_link_str, current_date, prior_schema_links, prior_exts, fewshot_example_meta_list)
|
||||
logger.debug("sql_prompt->{}".format(sql_prompt))
|
||||
sql_output = await self.llm._call_async(sql_prompt)
|
||||
|
||||
resp = dict()
|
||||
resp['question'] = question
|
||||
resp['model'] = model_name
|
||||
resp['fields'] = fields_list
|
||||
resp['priorSchemaLinking'] = prior_schema_links
|
||||
resp['currentDate'] = current_date
|
||||
|
||||
resp['schemaLinkingOutput'] = schema_link_output
|
||||
resp['schemaLinkStr'] = schema_link_str
|
||||
|
||||
resp['sqlOutput'] = sql_output
|
||||
|
||||
logger.info("resp: {}".format(resp))
|
||||
|
||||
return resp
|
||||
|
||||
async def async_query2sql_shortcut(self, question: str, filter_condition: Mapping[str,str],
|
||||
model_name: str, fields_list: List[str],
|
||||
current_date: str, prior_schema_links: Mapping[str,str], prior_exts: str):
|
||||
logger.info("question: {}".format(question))
|
||||
logger.info("filter_condition: {}".format(filter_condition))
|
||||
logger.info("model_name: {}".format(model_name))
|
||||
logger.info("fields_list: {}".format(fields_list))
|
||||
logger.info("current_date: {}".format(current_date))
|
||||
logger.info("prior_schema_links: {}".format(prior_schema_links))
|
||||
logger.info("prior_exts: {}".format(prior_exts))
|
||||
|
||||
fewshot_example_meta_list = self.get_examples_candidates(question, filter_condition, self.num_examples)
|
||||
schema_linking_sql_shortcut_prompt = self.generate_schema_linking_sql_prompt(question, current_date, model_name, fields_list, prior_schema_links, prior_exts, fewshot_example_meta_list)
|
||||
logger.debug("schema_linking_sql_shortcut_prompt->{}".format(schema_linking_sql_shortcut_prompt))
|
||||
schema_linking_sql_shortcut_output = await self.llm._call_async(schema_linking_sql_shortcut_prompt)
|
||||
logger.debug("schema_linking_sql_shortcut_output->{}".format(schema_linking_sql_shortcut_output))
|
||||
|
||||
schema_linking_str = combo_schema_link_parse(schema_linking_sql_shortcut_output)
|
||||
sql_str = combo_sql_parse(schema_linking_sql_shortcut_output)
|
||||
|
||||
resp = dict()
|
||||
resp['question'] = question
|
||||
resp['model'] = model_name
|
||||
resp['fields'] = fields_list
|
||||
resp['priorSchemaLinking'] = prior_schema_links
|
||||
resp['currentDate'] = current_date
|
||||
|
||||
resp['schemaLinkingComboOutput'] = schema_linking_sql_shortcut_output
|
||||
resp['schemaLinkStr'] = schema_linking_str
|
||||
resp['sqlOutput'] = sql_str
|
||||
|
||||
logger.info("resp: {}".format(resp))
|
||||
|
||||
return resp
|
||||
|
||||
async def generate_schema_linking_tasks(self, question: str, model_name: str, fields_list: List[str],
|
||||
current_date: str, prior_schema_links: Mapping[str,str], prior_exts: str, fewshot_example_list_combo:List[List[Mapping[str, str]]]):
|
||||
|
||||
schema_linking_prompt_pool = self.generate_schema_linking_prompt_pool(question, current_date, model_name, fields_list, prior_schema_links, prior_exts, fewshot_example_list_combo)
|
||||
logger.debug("schema_linking_prompt_pool->{}".format(schema_linking_prompt_pool))
|
||||
schema_linking_output_pool = await asyncio.gather(*[self.llm._call_async(schema_linking_prompt) for schema_linking_prompt in schema_linking_prompt_pool])
|
||||
logger.debug("schema_linking_output_pool->{}".format(schema_linking_output_pool))
|
||||
|
||||
schema_linking_str_pool = [schema_link_parse(schema_linking_output) for schema_linking_output in schema_linking_output_pool]
|
||||
|
||||
return schema_linking_str_pool
|
||||
|
||||
async def generate_sql_tasks(self, question: str, model_name: str, fields_list: List[str], schema_link_str_pool: List[str],
|
||||
current_date: str, prior_schema_links: Mapping[str,str], prior_exts: str, fewshot_example_list_combo:List[List[Mapping[str, str]]]):
|
||||
|
||||
sql_prompt_pool = self.generate_sql_prompt_pool(question, model_name, fields_list, schema_link_str_pool, current_date, prior_schema_links, prior_exts, fewshot_example_list_combo)
|
||||
logger.debug("sql_prompt_pool->{}".format(sql_prompt_pool))
|
||||
sql_output_pool = await asyncio.gather(*[self.llm._call_async(sql_prompt) for sql_prompt in sql_prompt_pool])
|
||||
logger.debug("sql_output_pool->{}".format(sql_output_pool))
|
||||
|
||||
return sql_output_pool
|
||||
|
||||
async def generate_schema_linking_sql_tasks(self, question: str, model_name: str, fields_list: List[str],
|
||||
current_date: str, prior_schema_links: Mapping[str,str], prior_exts: str, fewshot_example_list_combo:List[List[Mapping[str, str]]]):
|
||||
schema_linking_sql_prompt_pool = self.generate_schema_linking_sql_prompt_pool(question, current_date, model_name, fields_list, prior_schema_links, prior_exts, fewshot_example_list_combo)
|
||||
schema_linking_sql_output_task_pool = [self.llm._call_async(schema_linking_sql_prompt) for schema_linking_sql_prompt in schema_linking_sql_prompt_pool]
|
||||
schema_linking_sql_output_res_pool = await asyncio.gather(*schema_linking_sql_output_task_pool)
|
||||
logger.debug("schema_linking_sql_output_res_pool->{}".format(schema_linking_sql_output_res_pool))
|
||||
|
||||
return schema_linking_sql_output_res_pool
|
||||
|
||||
async def tasks_run(self, question: str, filter_condition: Mapping[str,str],
|
||||
model_name: str, fields_list: List[str],
|
||||
current_date: str, prior_schema_links: Mapping[str,str], prior_exts: str):
|
||||
logger.info("question: {}".format(question))
|
||||
logger.info("filter_condition: {}".format(filter_condition))
|
||||
logger.info("model_name: {}".format(model_name))
|
||||
logger.info("fields_list: {}".format(fields_list))
|
||||
logger.info("current_date: {}".format(current_date))
|
||||
logger.info("prior_schema_links: {}".format(prior_schema_links))
|
||||
logger.info("prior_exts: {}".format(prior_exts))
|
||||
|
||||
fewshot_example_meta_list = self.get_examples_candidates(question, filter_condition, self.num_examples)
|
||||
fewshot_example_list_combo = self.get_fewshot_example_combos(fewshot_example_meta_list, self.num_fewshots)
|
||||
|
||||
schema_linking_candidate_list = await self.generate_schema_linking_tasks(question, model_name, fields_list, current_date, prior_schema_links, prior_exts, fewshot_example_list_combo)
|
||||
logger.debug(f'schema_linking_candidate_list:{schema_linking_candidate_list}')
|
||||
schema_linking_candidate_sorted_list = self.schema_linking_list_str_unify(schema_linking_candidate_list)
|
||||
logger.debug(f'schema_linking_candidate_sorted_list:{schema_linking_candidate_sorted_list}')
|
||||
|
||||
schema_linking_output_max, schema_linking_output_vote_percentage = self.self_consistency_vote(schema_linking_candidate_sorted_list)
|
||||
|
||||
sql_output_candicates = await self.generate_sql_tasks(question, model_name, fields_list, schema_linking_candidate_list, current_date, prior_schema_links, prior_exts, fewshot_example_list_combo)
|
||||
logger.debug(f'sql_output_candicates:{sql_output_candicates}')
|
||||
sql_output_max, sql_output_vote_percentage = self.self_consistency_vote(sql_output_candicates)
|
||||
|
||||
resp = dict()
|
||||
resp['question'] = question
|
||||
resp['model'] = model_name
|
||||
resp['fields'] = fields_list
|
||||
resp['priorSchemaLinking'] = prior_schema_links
|
||||
resp['currentDate'] = current_date
|
||||
|
||||
resp['schemaLinkStr'] = schema_linking_output_max
|
||||
resp['schemaLinkingWeight'] = schema_linking_output_vote_percentage
|
||||
|
||||
resp['sqlOutput'] = sql_output_max
|
||||
resp['sqlWeight'] = sql_output_vote_percentage
|
||||
|
||||
logger.info("resp: {}".format(resp))
|
||||
|
||||
return resp
|
||||
|
||||
async def tasks_run_shortcut(self, question: str, filter_condition: Mapping[str,str], model_name: str, fields_list: List[str],
|
||||
current_date: str, prior_schema_links: Mapping[str,str], prior_exts: str):
|
||||
logger.info("question: {}".format(question))
|
||||
logger.info("filter_condition: {}".format(filter_condition))
|
||||
logger.info("model_name: {}".format(model_name))
|
||||
logger.info("fields_list: {}".format(fields_list))
|
||||
logger.info("current_date: {}".format(current_date))
|
||||
logger.info("prior_schema_links: {}".format(prior_schema_links))
|
||||
logger.info("prior_exts: {}".format(prior_exts))
|
||||
|
||||
fewshot_example_meta_list = self.get_examples_candidates(question, filter_condition, self.num_examples)
|
||||
fewshot_example_list_combo = self.get_fewshot_example_combos(fewshot_example_meta_list, self.num_fewshots)
|
||||
|
||||
schema_linking_sql_output_candidates = await self.generate_schema_linking_sql_tasks(question, model_name, fields_list, current_date, prior_schema_links, prior_exts, fewshot_example_list_combo)
|
||||
logger.debug(f'schema_linking_sql_output_candidates:{schema_linking_sql_output_candidates}')
|
||||
schema_linking_output_candidate_list = [combo_schema_link_parse(schema_linking_sql_output_candidate) for schema_linking_sql_output_candidate in schema_linking_sql_output_candidates]
|
||||
logger.debug(f'schema_linking_sql_output_candidate_list:{schema_linking_output_candidate_list}')
|
||||
schema_linking_output_candidate_sorted_list = self.schema_linking_list_str_unify(schema_linking_output_candidate_list)
|
||||
|
||||
schema_linking_output_max, schema_linking_output_vote_percentage = self.self_consistency_vote(schema_linking_output_candidate_sorted_list)
|
||||
|
||||
sql_output_candidate_list = [combo_sql_parse(schema_linking_sql_output_candidate) for schema_linking_sql_output_candidate in schema_linking_sql_output_candidates]
|
||||
logger.debug(f'sql_output_candidate_list:{sql_output_candidate_list}')
|
||||
sql_output_max, sql_output_vote_percentage = self.self_consistency_vote(sql_output_candidate_list)
|
||||
|
||||
resp = dict()
|
||||
resp['question'] = question
|
||||
resp['model'] = model_name
|
||||
resp['fields'] = fields_list
|
||||
resp['priorSchemaLinking'] = prior_schema_links
|
||||
resp['currentDate'] = current_date
|
||||
|
||||
resp['schemaLinkStr'] = schema_linking_output_max
|
||||
resp['schemaLinkingWeight'] = schema_linking_output_vote_percentage
|
||||
|
||||
resp['sqlOutput'] = sql_output_max
|
||||
resp['sqlWeight'] = sql_output_vote_percentage
|
||||
|
||||
logger.info("resp: {}".format(resp))
|
||||
|
||||
return resp
|
||||
|
||||
class Text2DSLAgent(Text2DSLAgentBase):
|
||||
def __init__(self, num_fewshots:int, num_examples:int, num_self_consistency:int,
|
||||
sql_example_prompter:FewShotPromptTemplate2, llm: BaseLLM,) -> None:
|
||||
super().__init__(num_fewshots, num_examples, num_self_consistency, sql_example_prompter, llm)
|
||||
|
||||
def reload_setting(self, sql_example_ids:List[str], sql_example_units: List[Mapping[str, str]], num_examples:int, num_fewshots:int, num_self_consistency:int):
|
||||
self.num_fewshots = num_fewshots
|
||||
self.num_examples = num_examples
|
||||
assert self.num_fewshots <= self.num_examples
|
||||
self.num_self_consistency = num_self_consistency
|
||||
assert self.num_self_consistency >= 1
|
||||
self.sql_example_prompter.reload_few_shot_example(sql_example_ids, sql_example_units)
|
||||
|
||||
def add_examples(self, sql_example_ids:List[str], sql_example_units: List[Mapping[str, str]]):
|
||||
self.sql_example_prompter.add_few_shot_example(sql_example_ids, sql_example_units)
|
||||
|
||||
def update_examples(self, sql_example_ids:List[str], sql_example_units: List[Mapping[str, str]]):
|
||||
self.sql_example_prompter.update_few_shot_example(sql_example_ids, sql_example_units)
|
||||
|
||||
def delete_examples(self, sql_example_ids:List[str]):
|
||||
self.sql_example_prompter.delete_few_shot_example(sql_example_ids)
|
||||
|
||||
def get_examples(self, sql_example_ids: List[str]):
|
||||
return self.sql_example_prompter.get_few_shot_example(sql_example_ids)
|
||||
|
||||
def count_examples(self):
|
||||
return self.sql_example_prompter.count_few_shot_example()
|
||||
|
||||
def generate_schema_linking_prompt(self, question: str, domain_name: str, fields_list: List[str],
|
||||
prior_schema_links: Mapping[str,str], fewshot_example_list:List[Mapping[str, str]])-> str:
|
||||
|
||||
prior_schema_links_str = '['+ ','.join(["""'{}'->{}""".format(k,v) for k,v in prior_schema_links.items()]) + ']'
|
||||
|
||||
instruction = "# 根据数据库的表结构,参考先验信息,找出为每个问题生成SQL查询语句的schema_links"
|
||||
|
||||
schema_linking_example_keys = ["tableName", "fieldsList", "priorSchemaLinks", "question", "analysis", "schemaLinks"]
|
||||
schema_linking_example_template = "Table {tableName}, columns = {fieldsList}, prior_schema_links = {priorSchemaLinks}\n问题:{question}\n分析:{analysis} 所以Schema_links是:\nSchema_links:{schemaLinks}"
|
||||
schema_linking_fewshot_prompt = self.sql_example_prompter.make_few_shot_example_prompt(few_shot_template=schema_linking_example_template,
|
||||
example_keys=schema_linking_example_keys,
|
||||
few_shot_example_meta_list=fewshot_example_list)
|
||||
|
||||
new_case_template = "Table {tableName}, columns = {fieldsList}, prior_schema_links = {priorSchemaLinks}\n问题:{question}\n分析: 让我们一步一步地思考。"
|
||||
new_case_prompt = new_case_template.format(tableName=domain_name, fieldsList=fields_list, priorSchemaLinks=prior_schema_links_str, question=question)
|
||||
|
||||
schema_linking_prompt = instruction + '\n\n' + schema_linking_fewshot_prompt + '\n\n' + new_case_prompt
|
||||
return schema_linking_prompt
|
||||
|
||||
def generate_schema_linking_prompt_pool(self, question: str, domain_name: str, fields_list: List[str],
|
||||
prior_schema_links: Mapping[str,str], fewshot_example_list_pool:List[List[Mapping[str, str]]])-> List[str]:
|
||||
schema_linking_prompt_pool = []
|
||||
for fewshot_example_list in fewshot_example_list_pool:
|
||||
schema_linking_prompt = self.generate_schema_linking_prompt(question, domain_name, fields_list, prior_schema_links, fewshot_example_list)
|
||||
schema_linking_prompt_pool.append(schema_linking_prompt)
|
||||
|
||||
return schema_linking_prompt_pool
|
||||
|
||||
def generate_sql_prompt(self, question: str, domain_name: str,
|
||||
schema_link_str: str, data_date: str,
|
||||
fewshot_example_list:List[Mapping[str, str]])-> str:
|
||||
instruction = "# 根据schema_links为每个问题生成SQL查询语句"
|
||||
sql_example_keys = ["question", "currentDate", "tableName", "schemaLinks", "sql"]
|
||||
sql_example_template = "问题:{question}\nCurrent_date:{currentDate}\nTable {tableName}\nSchema_links:{schemaLinks}\nSQL:{sql}"
|
||||
|
||||
sql_example_fewshot_prompt = self.sql_example_prompter.make_few_shot_example_prompt(few_shot_template=sql_example_template,
|
||||
example_keys=sql_example_keys,
|
||||
few_shot_example_meta_list=fewshot_example_list)
|
||||
|
||||
new_case_template = "问题:{question}\nCurrent_date:{currentDate}\nTable {tableName}\nSchema_links:{schemaLinks}\nSQL:"
|
||||
new_case_prompt = new_case_template.format(question=question, currentDate=data_date, tableName=domain_name, schemaLinks=schema_link_str)
|
||||
|
||||
sql_example_prompt = instruction + '\n\n' + sql_example_fewshot_prompt + '\n\n' + new_case_prompt
|
||||
|
||||
return sql_example_prompt
|
||||
|
||||
def generate_sql_prompt_pool(self, question: str, domain_name: str, data_date: str,
|
||||
schema_link_str_pool: List[str], fewshot_example_list_pool:List[List[Mapping[str, str]]])-> List[str]:
|
||||
sql_prompt_pool = []
|
||||
for schema_link_str, fewshot_example_list in zip(schema_link_str_pool, fewshot_example_list_pool):
|
||||
sql_prompt = self.generate_sql_prompt(question, domain_name, schema_link_str, data_date, fewshot_example_list)
|
||||
sql_prompt_pool.append(sql_prompt)
|
||||
|
||||
return sql_prompt_pool
|
||||
|
||||
def generate_schema_linking_sql_prompt(self, question: str,
|
||||
domain_name: str,
|
||||
data_date : str,
|
||||
fields_list: List[str],
|
||||
prior_schema_links: Mapping[str,str],
|
||||
fewshot_example_list:List[Mapping[str, str]]):
|
||||
|
||||
prior_schema_links_str = '['+ ','.join(["""'{}'->{}""".format(k,v) for k,v in prior_schema_links.items()]) + ']'
|
||||
|
||||
instruction = "# 根据数据库的表结构,参考先验信息,找出为每个问题生成SQL查询语句的schema_links,再根据schema_links为每个问题生成SQL查询语句"
|
||||
|
||||
example_keys = ["tableName", "fieldsList", "priorSchemaLinks", "currentDate", "question", "analysis", "schemaLinks", "sql"]
|
||||
example_template = "Table {tableName}, columns = {fieldsList}, prior_schema_links = {priorSchemaLinks}\nCurrent_date:{currentDate}\n问题:{question}\n分析:{analysis} 所以Schema_links是:\nSchema_links:{schemaLinks}\nSQL:{sql}"
|
||||
fewshot_prompt = self.sql_example_prompter.make_few_shot_example_prompt(few_shot_template=example_template,
|
||||
example_keys=example_keys,
|
||||
few_shot_example_meta_list=fewshot_example_list)
|
||||
|
||||
new_case_template = "Table {tableName}, columns = {fieldsList}, prior_schema_links = {priorSchemaLinks}\nCurrent_date:{currentDate}\n问题:{question}\n分析: 让我们一步一步地思考。"
|
||||
new_case_prompt = new_case_template.format(tableName=domain_name, fieldsList=fields_list, priorSchemaLinks=prior_schema_links_str, currentDate=data_date, question=question)
|
||||
|
||||
prompt = instruction + '\n\n' + fewshot_prompt + '\n\n' + new_case_prompt
|
||||
|
||||
return prompt
|
||||
|
||||
def generate_schema_linking_sql_prompt_pool(self, question: str, domain_name: str, fields_list: List[str], data_date: str,
|
||||
prior_schema_links: Mapping[str,str], fewshot_example_list_pool:List[List[Mapping[str, str]]])-> List[str]:
|
||||
schema_linking_sql_prompt_pool = []
|
||||
for fewshot_example_list in fewshot_example_list_pool:
|
||||
schema_linking_sql_prompt = self.generate_schema_linking_sql_prompt(question, domain_name, data_date, fields_list, prior_schema_links, fewshot_example_list)
|
||||
schema_linking_sql_prompt_pool.append(schema_linking_sql_prompt)
|
||||
|
||||
return schema_linking_sql_prompt_pool
|
||||
|
||||
def self_consistency_vote(self, output_res_pool:List[str]):
|
||||
output_res_counts = Counter(output_res_pool)
|
||||
output_res_max = output_res_counts.most_common(1)[0][0]
|
||||
total_output_num = len(output_res_pool)
|
||||
|
||||
vote_percentage = {k: (v/total_output_num) for k,v in output_res_counts.items()}
|
||||
|
||||
return output_res_max, vote_percentage
|
||||
|
||||
def schema_linking_list_str_unify(self, schema_linking_list: List[str])-> List[str]:
|
||||
schema_linking_list_unify = []
|
||||
for schema_linking_str in schema_linking_list:
|
||||
schema_linking_str_unify = ','.join(sorted([item.strip() for item in schema_linking_str.strip('[]').split(',')]))
|
||||
schema_linking_str_unify = f'[{schema_linking_str_unify}]'
|
||||
schema_linking_list_unify.append(schema_linking_str_unify)
|
||||
|
||||
return schema_linking_list_unify
|
||||
|
||||
async def generate_schema_linking_tasks(self, question: str, domain_name: str,
|
||||
fields_list: List[str], prior_schema_links: Mapping[str,str],
|
||||
fewshot_example_list_combo:List[List[Mapping[str, str]]]):
|
||||
|
||||
schema_linking_prompt_pool = self.generate_schema_linking_prompt_pool(question, domain_name,
|
||||
fields_list, prior_schema_links,
|
||||
fewshot_example_list_combo)
|
||||
schema_linking_output_task_pool = [self.llm._call_async(schema_linking_prompt) for schema_linking_prompt in schema_linking_prompt_pool]
|
||||
schema_linking_output_pool = await asyncio.gather(*schema_linking_output_task_pool)
|
||||
logger.debug(f'schema_linking_output_pool:{schema_linking_output_pool}')
|
||||
|
||||
schema_linking_str_pool = [schema_link_parse(schema_linking_output) for schema_linking_output in schema_linking_output_pool]
|
||||
|
||||
return schema_linking_str_pool
|
||||
|
||||
async def generate_sql_tasks(self, question: str, domain_name: str, data_date: str,
|
||||
schema_link_str_pool: List[str], fewshot_example_list_combo:List[List[Mapping[str, str]]]):
|
||||
|
||||
sql_prompt_pool = self.generate_sql_prompt_pool(question, domain_name, schema_link_str_pool, data_date, fewshot_example_list_combo)
|
||||
sql_output_task_pool = [self.llm._call_async(sql_prompt) for sql_prompt in sql_prompt_pool]
|
||||
sql_output_res_pool = await asyncio.gather(*sql_output_task_pool)
|
||||
logger.debug(f'sql_output_res_pool:{sql_output_res_pool}')
|
||||
|
||||
return sql_output_res_pool
|
||||
|
||||
async def generate_schema_linking_sql_tasks(self, question: str, domain_name: str, fields_list: List[str], data_date: str,
|
||||
prior_schema_links: Mapping[str,str], fewshot_example_list_combo:List[List[Mapping[str, str]]]):
|
||||
schema_linking_sql_prompt_pool = self.generate_schema_linking_sql_prompt_pool(question, domain_name, fields_list, data_date, prior_schema_links, fewshot_example_list_combo)
|
||||
schema_linking_sql_output_task_pool = [self.llm._call_async(schema_linking_sql_prompt) for schema_linking_sql_prompt in schema_linking_sql_prompt_pool]
|
||||
schema_linking_sql_output_res_pool = await asyncio.gather(*schema_linking_sql_output_task_pool)
|
||||
logger.debug(f'schema_linking_sql_output_res_pool:{schema_linking_sql_output_res_pool}')
|
||||
|
||||
return schema_linking_sql_output_res_pool
|
||||
|
||||
async def tasks_run(self, question: str, filter_condition: Mapping[str, str], domain_name: str, fields_list: List[str], prior_schema_links: Mapping[str,str], data_date: str, prior_exts: str):
|
||||
logger.info("question: {}".format(question))
|
||||
logger.info("domain_name: {}".format(domain_name))
|
||||
logger.info("fields_list: {}".format(fields_list))
|
||||
logger.info("current_date: {}".format(data_date))
|
||||
logger.info("prior_schema_links: {}".format(prior_schema_links))
|
||||
logger.info("prior_exts: {}".format(prior_exts))
|
||||
|
||||
if prior_exts != '':
|
||||
question = question + ' 备注:'+prior_exts
|
||||
logger.info("question_prior_exts: {}".format(question))
|
||||
|
||||
fewshot_example_meta_list = self.get_examples_candidates(question, filter_condition, self.num_examples)
|
||||
fewshot_example_list_combo = self.get_fewshot_example_combos(fewshot_example_meta_list, self.num_fewshots)
|
||||
|
||||
schema_linking_candidate_list = await self.generate_schema_linking_tasks(question, domain_name, fields_list, prior_schema_links, fewshot_example_list_combo)
|
||||
logger.debug(f'schema_linking_candidate_list:{schema_linking_candidate_list}')
|
||||
schema_linking_candidate_sorted_list = self.schema_linking_list_str_unify(schema_linking_candidate_list)
|
||||
logger.debug(f'schema_linking_candidate_sorted_list:{schema_linking_candidate_sorted_list}')
|
||||
|
||||
schema_linking_output_max, schema_linking_output_vote_percentage = self.self_consistency_vote(schema_linking_candidate_sorted_list)
|
||||
|
||||
sql_output_candicates = await self.generate_sql_tasks(question, domain_name, data_date, schema_linking_candidate_list,fewshot_example_list_combo)
|
||||
logger.debug(f'sql_output_candicates:{sql_output_candicates}')
|
||||
sql_output_max, sql_output_vote_percentage = self.self_consistency_vote(sql_output_candicates)
|
||||
|
||||
resp = dict()
|
||||
resp['question'] = question
|
||||
resp['model'] = domain_name
|
||||
resp['fields'] = fields_list
|
||||
resp['priorSchemaLinking'] = prior_schema_links
|
||||
resp['dataDate'] = data_date
|
||||
|
||||
resp['schemaLinkStr'] = schema_linking_output_max
|
||||
resp['schemaLinkingWeight'] = schema_linking_output_vote_percentage
|
||||
|
||||
resp['sqlOutput'] = sql_output_max
|
||||
resp['sqlWeight'] = sql_output_vote_percentage
|
||||
|
||||
logger.info("resp: {}".format(resp))
|
||||
|
||||
return resp
|
||||
|
||||
async def tasks_run_shortcut(self, question: str, filter_condition: Mapping[str, str], domain_name: str, fields_list: List[str], prior_schema_links: Mapping[str,str], data_date: str, prior_exts: str):
|
||||
logger.info("question: {}".format(question))
|
||||
logger.info("domain_name: {}".format(domain_name))
|
||||
logger.info("fields_list: {}".format(fields_list))
|
||||
logger.info("current_date: {}".format(data_date))
|
||||
logger.info("prior_schema_links: {}".format(prior_schema_links))
|
||||
logger.info("prior_exts: {}".format(prior_exts))
|
||||
|
||||
if prior_exts != '':
|
||||
question = question + ' 备注:'+prior_exts
|
||||
logger.info("question_prior_exts: {}".format(question))
|
||||
|
||||
fewshot_example_meta_list = self.get_examples_candidates(question, filter_condition, self.num_examples)
|
||||
fewshot_example_list_combo = self.get_fewshot_example_combos(fewshot_example_meta_list, self.num_fewshots)
|
||||
|
||||
schema_linking_sql_output_candidates = await self.generate_schema_linking_sql_tasks(question, domain_name, fields_list, data_date, prior_schema_links, fewshot_example_list_combo)
|
||||
logger.debug(f'schema_linking_sql_output_candidates:{schema_linking_sql_output_candidates}')
|
||||
schema_linking_output_candidate_list = [combo_schema_link_parse(schema_linking_sql_output_candidate) for schema_linking_sql_output_candidate in schema_linking_sql_output_candidates]
|
||||
logger.debug(f'schema_linking_sql_output_candidate_list:{schema_linking_output_candidate_list}')
|
||||
schema_linking_output_candidate_sorted_list = self.schema_linking_list_str_unify(schema_linking_output_candidate_list)
|
||||
|
||||
schema_linking_output_max, schema_linking_output_vote_percentage = self.self_consistency_vote(schema_linking_output_candidate_sorted_list)
|
||||
|
||||
sql_output_candidate_list = [combo_sql_parse(schema_linking_sql_output_candidate) for schema_linking_sql_output_candidate in schema_linking_sql_output_candidates]
|
||||
logger.debug(f'sql_output_candidate_list:{sql_output_candidate_list}')
|
||||
sql_output_max, sql_output_vote_percentage = self.self_consistency_vote(sql_output_candidate_list)
|
||||
|
||||
resp = dict()
|
||||
resp['question'] = question
|
||||
resp['model'] = domain_name
|
||||
resp['fields'] = fields_list
|
||||
resp['priorSchemaLinking'] = prior_schema_links
|
||||
resp['dataDate'] = data_date
|
||||
|
||||
resp['schemaLinkStr'] = schema_linking_output_max
|
||||
resp['schemaLinkingWeight'] = schema_linking_output_vote_percentage
|
||||
|
||||
resp['sqlOutput'] = sql_output_max
|
||||
resp['sqlWeight'] = sql_output_vote_percentage
|
||||
|
||||
logger.info("resp: {}".format(resp))
|
||||
|
||||
return resp
|
||||
|
||||
async def async_query2sql(self, question: str, filter_condition: Mapping[str,str],
|
||||
model_name: str, fields_list: List[str],
|
||||
data_date: str, prior_schema_links: Mapping[str,str], prior_exts: str):
|
||||
logger.info("question: {}".format(question))
|
||||
logger.info("model_name: {}".format(model_name))
|
||||
logger.info("fields_list: {}".format(fields_list))
|
||||
logger.info("data_date: {}".format(data_date))
|
||||
logger.info("prior_schema_links: {}".format(prior_schema_links))
|
||||
logger.info("prior_exts: {}".format(prior_exts))
|
||||
|
||||
if prior_exts != '':
|
||||
question = question + ' 备注:'+prior_exts
|
||||
logger.info("question_prior_exts: {}".format(question))
|
||||
|
||||
fewshot_example_meta_list = self.get_examples_candidates(question, filter_condition, self.num_examples)
|
||||
schema_linking_prompt = self.generate_schema_linking_prompt(question, model_name, fields_list, prior_schema_links, fewshot_example_meta_list)
|
||||
logger.debug("schema_linking_prompt->{}".format(schema_linking_prompt))
|
||||
schema_link_output = await self.llm._call_async(schema_linking_prompt)
|
||||
|
||||
schema_link_str = schema_link_parse(schema_link_output)
|
||||
|
||||
sql_prompt = self.generate_sql_prompt(question, model_name, schema_link_str, data_date, fewshot_example_meta_list)
|
||||
logger.debug("sql_prompt->{}".format(sql_prompt))
|
||||
sql_output = await self.llm._call_async(sql_prompt)
|
||||
|
||||
resp = dict()
|
||||
resp['question'] = question
|
||||
resp['model'] = model_name
|
||||
resp['fields'] = fields_list
|
||||
resp['priorSchemaLinking'] = prior_schema_links
|
||||
resp['dataDate'] = data_date
|
||||
|
||||
resp['schemaLinkingOutput'] = schema_link_output
|
||||
resp['schemaLinkStr'] = schema_link_str
|
||||
|
||||
resp['sqlOutput'] = sql_output
|
||||
|
||||
logger.info("resp: {}".format(resp))
|
||||
|
||||
return resp
|
||||
|
||||
async def async_query2sql_shortcut(self, question: str, filter_condition: Mapping[str,str],
|
||||
model_name: str, fields_list: List[str],
|
||||
data_date: str, prior_schema_links: Mapping[str,str], prior_exts: str):
|
||||
|
||||
logger.info("question: {}".format(question))
|
||||
logger.info("model_name: {}".format(model_name))
|
||||
logger.info("fields_list: {}".format(fields_list))
|
||||
logger.info("data_date: {}".format(data_date))
|
||||
logger.info("prior_schema_links: {}".format(prior_schema_links))
|
||||
logger.info("prior_exts: {}".format(prior_exts))
|
||||
|
||||
if prior_exts != '':
|
||||
question = question + ' 备注:'+prior_exts
|
||||
logger.info("question_prior_exts: {}".format(question))
|
||||
|
||||
fewshot_example_meta_list = self.get_examples_candidates(question, filter_condition, self.num_examples)
|
||||
schema_linking_sql_shortcut_prompt = self.generate_schema_linking_sql_prompt(question, model_name, data_date, fields_list, prior_schema_links, fewshot_example_meta_list)
|
||||
logger.debug("schema_linking_sql_shortcut_prompt->{}".format(schema_linking_sql_shortcut_prompt))
|
||||
schema_linking_sql_shortcut_output = await self.llm._call_async(schema_linking_sql_shortcut_prompt)
|
||||
|
||||
schema_linking_str = combo_schema_link_parse(schema_linking_sql_shortcut_output)
|
||||
sql_str = combo_sql_parse(schema_linking_sql_shortcut_output)
|
||||
|
||||
resp = dict()
|
||||
resp['question'] = question
|
||||
resp['model'] = model_name
|
||||
resp['fields'] = fields_list
|
||||
resp['priorSchemaLinking'] = prior_schema_links
|
||||
resp['dataDate'] = data_date
|
||||
|
||||
resp['schemaLinkingComboOutput'] = schema_linking_sql_shortcut_output
|
||||
resp['schemaLinkStr'] = schema_linking_str
|
||||
resp['sqlOutput'] = sql_str
|
||||
|
||||
logger.info("resp: {}".format(resp))
|
||||
|
||||
return resp
|
||||
|
||||
class SqlModeEnum(Enum):
|
||||
VALUE5 = '1_pass_auto_cot'
|
||||
VALUE6 = '1_pass_auto_cot_self_consistency'
|
||||
VALUE7 = '2_pass_auto_cot'
|
||||
VALUE8 = '2_pass_auto_cot_self_consistency'
|
||||
|
||||
class Text2DSLAgentWrapper(object):
|
||||
def __init__(self, sql_agent_act:Text2DSLAgentAutoCoT):
|
||||
self.sql_agent_act = sql_agent_act
|
||||
|
||||
async def async_query2sql(self, question: str, filter_condition: Mapping[str,str],
|
||||
model_name: str, fields_list: List[str],
|
||||
data_date: str, prior_schema_links: Mapping[str,str], prior_exts: str, sql_generation_mode: str):
|
||||
|
||||
if sql_generation_mode not in (sql_mode.value for sql_mode in SqlModeEnum):
|
||||
raise ValueError(f"sql_generation_mode: {sql_generation_mode} is not in SqlModeEnum")
|
||||
|
||||
if sql_generation_mode == '1_pass_auto_cot':
|
||||
logger.info(f"sql wrapper: {sql_generation_mode}")
|
||||
resp = await self.sql_agent_act.async_query2sql_shortcut(question=question, filter_condition=filter_condition, model_name=model_name, fields_list=fields_list, current_date=data_date, prior_schema_links=prior_schema_links, prior_exts=prior_exts)
|
||||
return resp
|
||||
elif sql_generation_mode == '1_pass_auto_cot_self_consistency':
|
||||
logger.info(f"sql wrapper: {sql_generation_mode}")
|
||||
resp = await self.sql_agent_act.tasks_run_shortcut(question=question, filter_condition=filter_condition, model_name=model_name, fields_list=fields_list, current_date=data_date, prior_schema_links=prior_schema_links, prior_exts=prior_exts)
|
||||
return resp
|
||||
elif sql_generation_mode == '2_pass_auto_cot':
|
||||
logger.info(f"sql wrapper: {sql_generation_mode}")
|
||||
resp = await self.sql_agent_act.async_query2sql(question=question, filter_condition=filter_condition, model_name=model_name, fields_list=fields_list, current_date=data_date, prior_schema_links=prior_schema_links, prior_exts=prior_exts)
|
||||
return resp
|
||||
elif sql_generation_mode == '2_pass_auto_cot_self_consistency':
|
||||
logger.info(f"sql wrapper: {sql_generation_mode}")
|
||||
resp = await self.sql_agent_act.tasks_run(question=question, filter_condition=filter_condition, model_name=model_name, fields_list=fields_list, current_date=data_date, prior_schema_links=prior_schema_links, prior_exts=prior_exts)
|
||||
return resp
|
||||
else:
|
||||
raise ValueError(f'sql_generation_mode:{sql_generation_mode} is not in SqlModeEnum')
|
||||
|
||||
def update_configs(self, sql_example_ids:List[str], sql_example_units: List[Mapping[str, str]],
|
||||
num_examples: int, num_fewshots: int, num_self_consistency: int):
|
||||
self.sql_agent_act.reload_setting(sql_example_ids=sql_example_ids, sql_example_units=sql_example_units, num_examples=num_examples, num_fewshots=num_fewshots, num_self_consistency=num_self_consistency)
|
||||
|
||||
def add_examples(self, sql_example_ids:List[str], sql_example_units: List[Mapping[str, str]]):
|
||||
self.sql_agent_act.add_examples(sql_example_ids=sql_example_ids, sql_example_units=sql_example_units)
|
||||
|
||||
def update_examples(self, sql_example_ids:List[str], sql_example_units: List[Mapping[str, str]]):
|
||||
self.sql_agent_act.update_examples(sql_example_ids=sql_example_ids, sql_example_units=sql_example_units)
|
||||
|
||||
def delete_examples(self, sql_example_ids:List[str]):
|
||||
self.sql_agent_act.delete_examples(sql_example_ids=sql_example_ids)
|
||||
|
||||
def get_examples(self, sql_example_ids: List[str]):
|
||||
sql_agent_act_examples = self.sql_agent_act.get_examples(sql_example_ids=sql_example_ids)
|
||||
|
||||
return sql_agent_act_examples
|
||||
|
||||
def count_examples(self):
|
||||
sql_agent_examples_act_cnt = self.sql_agent_act.count_examples()
|
||||
|
||||
return sql_agent_examples_act_cnt
|
||||
@@ -1,61 +0,0 @@
|
||||
# -*- coding:utf-8 -*-
|
||||
import json
|
||||
import os
|
||||
import sys
|
||||
from typing import List, Mapping
|
||||
|
||||
import requests
|
||||
|
||||
sys.path.append(os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))))
|
||||
sys.path.append(os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))))
|
||||
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||||
sys.path.append(os.path.dirname(os.path.abspath(__file__)))
|
||||
|
||||
from instances.logging_instance import logger
|
||||
|
||||
from config.config_parse import (TEXT2DSL_EXAMPLE_NUM, TEXT2DSL_FEWSHOTS_NUM, TEXT2DSL_SELF_CONSISTENCY_NUM,
|
||||
LLMPARSER_HOST, LLMPARSER_PORT, TEXT2DSL_IS_SHORTCUT, TEXT2DSL_IS_SELF_CONSISTENCY)
|
||||
from few_shot_example.sql_examplar import examplars as sql_examplars
|
||||
|
||||
|
||||
def text2sql_agent_setting_update(llm_host:str, llm_port:str,
|
||||
sql_examplars:List[Mapping[str, str]], example_nums:int):
|
||||
|
||||
url = f"http://{llm_host}:{llm_port}/text2sql_agent_setting_update/"
|
||||
payload = {"sqlExamplars":sql_examplars, "exampleNums":example_nums}
|
||||
headers = {'content-type': 'application/json'}
|
||||
response = requests.post(url, data=json.dumps(payload), headers=headers)
|
||||
logger.info(response.text)
|
||||
|
||||
|
||||
def text2dsl_agent_cs_setting_update(llm_host:str, llm_port:str,
|
||||
sql_examplars:List[Mapping[str, str]], example_nums:int, fewshot_nums:int, self_consistency_nums:int):
|
||||
|
||||
url = f"http://{llm_host}:{llm_port}/texg2sqt_cs_agent_setting_update/"
|
||||
payload = {"sqlExamplars":sql_examplars,
|
||||
"exampleNums":example_nums, "fewshotNums":fewshot_nums, "selfConsistencyNums":self_consistency_nums}
|
||||
headers = {'content-type': 'application/json'}
|
||||
response = requests.post(url, data=json.dumps(payload), headers=headers)
|
||||
logger.info(response.text)
|
||||
|
||||
|
||||
def text2dsl_agent_wrapper_setting_update(llm_host:str, llm_port:str,
|
||||
is_shortcut:bool, is_self_consistency:bool,
|
||||
sql_examplars:List[Mapping[str, str]], example_nums:int, fewshot_nums:int, self_consistency_nums:int):
|
||||
|
||||
sql_ids = [str(i) for i in range(0, len(sql_examplars))]
|
||||
|
||||
url = f"http://{llm_host}:{llm_port}/query2sql_setting_update/"
|
||||
payload = {"isShortcut":is_shortcut, "isSelfConsistency":is_self_consistency,
|
||||
"sqlExamplars":sql_examplars, "sqlIds": sql_ids,
|
||||
"exampleNums":example_nums, "fewshotNums":fewshot_nums, "selfConsistencyNums":self_consistency_nums}
|
||||
headers = {'content-type': 'application/json'}
|
||||
response = requests.post(url, data=json.dumps(payload), headers=headers)
|
||||
logger.info(response.text)
|
||||
|
||||
if __name__ == "__main__":
|
||||
text2dsl_agent_wrapper_setting_update(LLMPARSER_HOST,LLMPARSER_PORT,
|
||||
TEXT2DSL_IS_SHORTCUT, TEXT2DSL_IS_SELF_CONSISTENCY,
|
||||
sql_examplars, TEXT2DSL_EXAMPLE_NUM, TEXT2DSL_FEWSHOTS_NUM, TEXT2DSL_SELF_CONSISTENCY_NUM)
|
||||
|
||||
|
||||
@@ -1,54 +0,0 @@
|
||||
# -*- coding:utf-8 -*-
|
||||
|
||||
import asyncio
|
||||
|
||||
import os
|
||||
import sys
|
||||
|
||||
sys.path.append(os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))))
|
||||
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||||
sys.path.append(os.path.dirname(os.path.abspath(__file__)))
|
||||
|
||||
from sql.constructor import FewShotPromptTemplate2
|
||||
from sql.sql_agent import Text2DSLAgent, Text2DSLAgentConsistency, Text2DSLAgentWrapper
|
||||
|
||||
from instances.llm_instance import llm
|
||||
from instances.text2vec import Text2VecEmbeddingFunction
|
||||
from instances.chromadb_instance import client
|
||||
from instances.logging_instance import logger
|
||||
|
||||
from few_shot_example.sql_examplar import examplars as sql_examplars
|
||||
from config.config_parse import (TEXT2DSLAGENT_COLLECTION_NAME, TEXT2DSLAGENTCS_COLLECTION_NAME,
|
||||
TEXT2DSL_EXAMPLE_NUM, TEXT2DSL_FEWSHOTS_NUM, TEXT2DSL_SELF_CONSISTENCY_NUM,
|
||||
TEXT2DSL_IS_SHORTCUT, TEXT2DSL_IS_SELF_CONSISTENCY)
|
||||
|
||||
|
||||
emb_func = Text2VecEmbeddingFunction()
|
||||
text2dsl_agent_collection = client.get_or_create_collection(name=TEXT2DSLAGENT_COLLECTION_NAME,
|
||||
embedding_function=emb_func,
|
||||
metadata={"hnsw:space": "cosine"})
|
||||
text2dsl_agentcs_collection = client.get_or_create_collection(name=TEXT2DSLAGENTCS_COLLECTION_NAME,
|
||||
embedding_function=emb_func,
|
||||
metadata={"hnsw:space": "cosine"})
|
||||
|
||||
text2dsl_agent_example_prompter = FewShotPromptTemplate2(collection=text2dsl_agent_collection,
|
||||
retrieval_key="question",
|
||||
few_shot_seperator='\n\n')
|
||||
|
||||
text2dsl_agentcs_example_prompter = FewShotPromptTemplate2(collection=text2dsl_agentcs_collection,
|
||||
retrieval_key="question",
|
||||
few_shot_seperator='\n\n')
|
||||
|
||||
text2sql_agent = Text2DSLAgent(num_fewshots=TEXT2DSL_EXAMPLE_NUM,
|
||||
sql_example_prompter=text2dsl_agent_example_prompter, llm=llm)
|
||||
|
||||
text2sql_cs_agent = Text2DSLAgentConsistency(num_fewshots=TEXT2DSL_FEWSHOTS_NUM, num_examples=TEXT2DSL_EXAMPLE_NUM, num_self_consistency=TEXT2DSL_SELF_CONSISTENCY_NUM,
|
||||
sql_example_prompter=text2dsl_agentcs_example_prompter, llm=llm)
|
||||
|
||||
sql_ids = [str(i) for i in range(0, len(sql_examplars))]
|
||||
text2sql_agent.reload_setting(sql_ids, sql_examplars, TEXT2DSL_EXAMPLE_NUM)
|
||||
text2sql_cs_agent.reload_setting(sql_ids, sql_examplars, TEXT2DSL_EXAMPLE_NUM, TEXT2DSL_FEWSHOTS_NUM, TEXT2DSL_SELF_CONSISTENCY_NUM)
|
||||
|
||||
|
||||
text2sql_agent_router = Text2DSLAgentWrapper(sql_agent=text2sql_agent, sql_agent_cs=text2sql_cs_agent,
|
||||
is_shortcut=TEXT2DSL_IS_SHORTCUT, is_self_consistency=TEXT2DSL_IS_SELF_CONSISTENCY)
|
||||
@@ -1,405 +0,0 @@
|
||||
import os
|
||||
import sys
|
||||
from typing import List, Union, Mapping, Any
|
||||
from collections import Counter
|
||||
import random
|
||||
import asyncio
|
||||
from langchain.llms.base import BaseLLM
|
||||
|
||||
sys.path.append(os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))))
|
||||
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||||
sys.path.append(os.path.dirname(os.path.abspath(__file__)))
|
||||
|
||||
from instances.logging_instance import logger
|
||||
|
||||
from sql.constructor import FewShotPromptTemplate2
|
||||
from sql.output_parser import schema_link_parse, combo_schema_link_parse, combo_sql_parse
|
||||
|
||||
|
||||
class Text2DSLAgent(object):
|
||||
def __init__(self, num_fewshots:int,
|
||||
sql_example_prompter:FewShotPromptTemplate2,
|
||||
llm: BaseLLM):
|
||||
self.num_fewshots = num_fewshots
|
||||
self.sql_example_prompter = sql_example_prompter
|
||||
self.llm = llm
|
||||
|
||||
def reload_setting(self, sql_example_ids: List[str], sql_example_units: List[Mapping[str,str]], num_fewshots: int):
|
||||
self.num_fewshots = num_fewshots
|
||||
|
||||
self.sql_example_prompter.reload_few_shot_example(sql_example_ids, sql_example_units)
|
||||
|
||||
def add_examples(self, sql_example_ids: List[str], sql_example_units: List[Mapping[str,str]]):
|
||||
self.sql_example_prompter.add_few_shot_example(sql_example_ids, sql_example_units)
|
||||
|
||||
def update_examples(self, sql_example_ids: List[str], sql_example_units: List[Mapping[str,str]]):
|
||||
self.sql_example_prompter.update_few_shot_example(sql_example_ids, sql_example_units)
|
||||
|
||||
def delete_examples(self, sql_example_ids: List[str]):
|
||||
self.sql_example_prompter.delete_few_shot_example(sql_example_ids)
|
||||
|
||||
def count_examples(self):
|
||||
return self.sql_example_prompter.count_few_shot_example()
|
||||
|
||||
def get_fewshot_examples(self, query_text: str, filter_condition: Mapping[str,str])->List[Mapping[str, str]]:
|
||||
few_shot_example_meta_list = self.sql_example_prompter.retrieve_few_shot_example(query_text, self.num_fewshots, filter_condition)
|
||||
|
||||
return few_shot_example_meta_list
|
||||
|
||||
def generate_schema_linking_prompt(self, user_query: str, domain_name: str, fields_list: List[str],
|
||||
prior_schema_links: Mapping[str,str], fewshot_example_list:List[Mapping[str, str]])-> str:
|
||||
|
||||
prior_schema_links_str = '['+ ','.join(["""'{}'->{}""".format(k,v) for k,v in prior_schema_links.items()]) + ']'
|
||||
|
||||
instruction = "# 根据数据库的表结构,参考先验信息,找出为每个问题生成SQL查询语句的schema_links"
|
||||
|
||||
schema_linking_example_keys = ["tableName", "fieldsList", "priorSchemaLinks", "question", "analysis", "schemaLinks"]
|
||||
schema_linking_example_template = "Table {tableName}, columns = {fieldsList}, prior_schema_links = {priorSchemaLinks}\n问题:{question}\n分析:{analysis} 所以Schema_links是:\nSchema_links:{schemaLinks}"
|
||||
schema_linking_fewshot_prompt = self.sql_example_prompter.make_few_shot_example_prompt(few_shot_template=schema_linking_example_template,
|
||||
example_keys=schema_linking_example_keys,
|
||||
few_shot_example_meta_list=fewshot_example_list)
|
||||
|
||||
new_case_template = "Table {tableName}, columns = {fieldsList}, prior_schema_links = {priorSchemaLinks}\n问题:{question}\n分析: 让我们一步一步地思考。"
|
||||
new_case_prompt = new_case_template.format(tableName=domain_name, fieldsList=fields_list, priorSchemaLinks=prior_schema_links_str, question=user_query)
|
||||
|
||||
schema_linking_prompt = instruction + '\n\n' + schema_linking_fewshot_prompt + '\n\n' + new_case_prompt
|
||||
return schema_linking_prompt
|
||||
|
||||
def generate_sql_prompt(self, user_query: str, domain_name: str,
|
||||
schema_link_str: str, data_date: str,
|
||||
fewshot_example_list:List[Mapping[str, str]])-> str:
|
||||
instruction = "# 根据schema_links为每个问题生成SQL查询语句"
|
||||
sql_example_keys = ["question", "currentDate", "tableName", "schemaLinks", "sql"]
|
||||
sql_example_template = "问题:{question}\nCurrent_date:{currentDate}\nTable {tableName}\nSchema_links:{schemaLinks}\nSQL:{sql}"
|
||||
|
||||
sql_example_fewshot_prompt = self.sql_example_prompter.make_few_shot_example_prompt(few_shot_template=sql_example_template,
|
||||
example_keys=sql_example_keys,
|
||||
few_shot_example_meta_list=fewshot_example_list)
|
||||
|
||||
new_case_template = "问题:{question}\nCurrent_date:{currentDate}\nTable {tableName}\nSchema_links:{schemaLinks}\nSQL:"
|
||||
new_case_prompt = new_case_template.format(question=user_query, currentDate=data_date, tableName=domain_name, schemaLinks=schema_link_str)
|
||||
|
||||
sql_example_prompt = instruction + '\n\n' + sql_example_fewshot_prompt + '\n\n' + new_case_prompt
|
||||
|
||||
return sql_example_prompt
|
||||
|
||||
def generate_schema_linking_sql_prompt(self, user_query: str,
|
||||
domain_name: str,
|
||||
data_date : str,
|
||||
fields_list: List[str],
|
||||
prior_schema_links: Mapping[str,str],
|
||||
fewshot_example_list:List[Mapping[str, str]]):
|
||||
|
||||
prior_schema_links_str = '['+ ','.join(["""'{}'->{}""".format(k,v) for k,v in prior_schema_links.items()]) + ']'
|
||||
|
||||
instruction = "# 根据数据库的表结构,参考先验信息,找出为每个问题生成SQL查询语句的schema_links,再根据schema_links为每个问题生成SQL查询语句"
|
||||
|
||||
example_keys = ["tableName", "fieldsList", "priorSchemaLinks", "currentDate", "question", "analysis", "schemaLinks", "sql"]
|
||||
example_template = "Table {tableName}, columns = {fieldsList}, prior_schema_links = {priorSchemaLinks}\nCurrent_date:{currentDate}\n问题:{question}\n分析:{analysis} 所以Schema_links是:\nSchema_links:{schemaLinks}\nSQL:{sql}"
|
||||
fewshot_prompt = self.sql_example_prompter.make_few_shot_example_prompt(few_shot_template=example_template,
|
||||
example_keys=example_keys,
|
||||
few_shot_example_meta_list=fewshot_example_list)
|
||||
|
||||
new_case_template = "Table {tableName}, columns = {fieldsList}, prior_schema_links = {priorSchemaLinks}\nCurrent_date:{currentDate}\n问题:{question}\n分析: 让我们一步一步地思考。"
|
||||
new_case_prompt = new_case_template.format(tableName=domain_name, fieldsList=fields_list, priorSchemaLinks=prior_schema_links_str, currentDate=data_date, question=user_query)
|
||||
|
||||
prompt = instruction + '\n\n' + fewshot_prompt + '\n\n' + new_case_prompt
|
||||
|
||||
return prompt
|
||||
|
||||
async def async_query2sql(self, query_text: str, filter_condition: Mapping[str,str],
|
||||
model_name: str, fields_list: List[str],
|
||||
data_date: str, prior_schema_links: Mapping[str,str], prior_exts: str):
|
||||
logger.info("query_text: {}".format(query_text))
|
||||
logger.info("model_name: {}".format(model_name))
|
||||
logger.info("fields_list: {}".format(fields_list))
|
||||
logger.info("data_date: {}".format(data_date))
|
||||
logger.info("prior_schema_links: {}".format(prior_schema_links))
|
||||
logger.info("prior_exts: {}".format(prior_exts))
|
||||
|
||||
if prior_exts != '':
|
||||
query_text = query_text + ' 备注:'+prior_exts
|
||||
logger.info("query_text_prior_exts: {}".format(query_text))
|
||||
|
||||
fewshot_example_meta_list = self.get_fewshot_examples(query_text, filter_condition)
|
||||
schema_linking_prompt = self.generate_schema_linking_prompt(query_text, model_name, fields_list, prior_schema_links, fewshot_example_meta_list)
|
||||
logger.debug("schema_linking_prompt->{}".format(schema_linking_prompt))
|
||||
schema_link_output = await self.llm._call_async(schema_linking_prompt)
|
||||
|
||||
schema_link_str = schema_link_parse(schema_link_output)
|
||||
|
||||
sql_prompt = self.generate_sql_prompt(query_text, model_name, schema_link_str, data_date, fewshot_example_meta_list)
|
||||
logger.debug("sql_prompt->{}".format(sql_prompt))
|
||||
sql_output = await self.llm._call_async(sql_prompt)
|
||||
|
||||
resp = dict()
|
||||
resp['query'] = query_text
|
||||
resp['model'] = model_name
|
||||
resp['fields'] = fields_list
|
||||
resp['priorSchemaLinking'] = prior_schema_links
|
||||
resp['dataDate'] = data_date
|
||||
|
||||
resp['schemaLinkingOutput'] = schema_link_output
|
||||
resp['schemaLinkStr'] = schema_link_str
|
||||
|
||||
resp['sqlOutput'] = sql_output
|
||||
|
||||
logger.info("resp: {}".format(resp))
|
||||
|
||||
return resp
|
||||
|
||||
async def async_query2sql_shortcut(self, query_text: str, filter_condition: Mapping[str,str],
|
||||
model_name: str, fields_list: List[str],
|
||||
data_date: str, prior_schema_links: Mapping[str,str], prior_exts: str):
|
||||
logger.info("query_text: {}".format(query_text))
|
||||
logger.info("model_name: {}".format(model_name))
|
||||
logger.info("fields_list: {}".format(fields_list))
|
||||
logger.info("data_date: {}".format(data_date))
|
||||
logger.info("prior_schema_links: {}".format(prior_schema_links))
|
||||
logger.info("prior_exts: {}".format(prior_exts))
|
||||
|
||||
if prior_exts != '':
|
||||
query_text = query_text + ' 备注:'+prior_exts
|
||||
logger.info("query_text_prior_exts: {}".format(query_text))
|
||||
|
||||
fewshot_example_meta_list = self.get_fewshot_examples(query_text, filter_condition)
|
||||
schema_linking_sql_shortcut_prompt = self.generate_schema_linking_sql_prompt(query_text, model_name, data_date, fields_list, prior_schema_links, fewshot_example_meta_list)
|
||||
logger.debug("schema_linking_sql_shortcut_prompt->{}".format(schema_linking_sql_shortcut_prompt))
|
||||
schema_linking_sql_shortcut_output = await self.llm._call_async(schema_linking_sql_shortcut_prompt)
|
||||
|
||||
schema_linking_str = combo_schema_link_parse(schema_linking_sql_shortcut_output)
|
||||
sql_str = combo_sql_parse(schema_linking_sql_shortcut_output)
|
||||
|
||||
resp = dict()
|
||||
resp['query'] = query_text
|
||||
resp['model'] = model_name
|
||||
resp['fields'] = fields_list
|
||||
resp['priorSchemaLinking'] = prior_schema_links
|
||||
resp['dataDate'] = data_date
|
||||
|
||||
resp['schemaLinkingComboOutput'] = schema_linking_sql_shortcut_output
|
||||
resp['schemaLinkStr'] = schema_linking_str
|
||||
resp['sqlOutput'] = sql_str
|
||||
|
||||
logger.info("resp: {}".format(resp))
|
||||
|
||||
return resp
|
||||
|
||||
class Text2DSLAgentConsistency(object):
|
||||
def __init__(self, num_fewshots:int, num_examples:int, num_self_consistency:int,
|
||||
sql_example_prompter:FewShotPromptTemplate2, llm: BaseLLM) -> None:
|
||||
self.num_fewshots = num_fewshots
|
||||
self.num_examples = num_examples
|
||||
assert self.num_fewshots <= self.num_examples
|
||||
self.num_self_consistency = num_self_consistency
|
||||
|
||||
self.llm = llm
|
||||
self.sql_example_prompter = sql_example_prompter
|
||||
|
||||
def reload_setting(self, sql_example_ids:List[str], sql_example_units: List[Mapping[str, str]], num_examples:int, num_fewshots:int, num_self_consistency:int):
|
||||
self.num_fewshots = num_fewshots
|
||||
self.num_examples = num_examples
|
||||
assert self.num_fewshots <= self.num_examples
|
||||
self.num_self_consistency = num_self_consistency
|
||||
assert self.num_self_consistency >= 1
|
||||
self.sql_example_prompter.reload_few_shot_example(sql_example_ids, sql_example_units)
|
||||
|
||||
def add_examples(self, sql_example_ids:List[str], sql_example_units: List[Mapping[str, str]]):
|
||||
self.sql_example_prompter.add_few_shot_example(sql_example_ids, sql_example_units)
|
||||
|
||||
def update_examples(self, sql_example_ids:List[str], sql_example_units: List[Mapping[str, str]]):
|
||||
self.sql_example_prompter.update_few_shot_example(sql_example_ids, sql_example_units)
|
||||
|
||||
def delete_examples(self, sql_example_ids:List[str]):
|
||||
self.sql_example_prompter.delete_few_shot_example(sql_example_ids)
|
||||
|
||||
def count_examples(self):
|
||||
return self.sql_example_prompter.count_few_shot_example()
|
||||
|
||||
def get_examples_candidates(self, query_text: str, filter_condition: Mapping[str, str])->List[Mapping[str, str]]:
|
||||
few_shot_example_meta_list = self.sql_example_prompter.retrieve_few_shot_example(query_text, self.num_examples, filter_condition)
|
||||
|
||||
return few_shot_example_meta_list
|
||||
|
||||
def get_fewshot_example_combos(self, example_meta_list:List[Mapping[str, str]])-> List[List[Mapping[str, str]]]:
|
||||
fewshot_example_list = []
|
||||
for i in range(0, self.num_self_consistency):
|
||||
random.shuffle(example_meta_list)
|
||||
fewshot_example_list.append(example_meta_list[:self.num_fewshots])
|
||||
|
||||
return fewshot_example_list
|
||||
|
||||
def generate_schema_linking_prompt(self, user_query: str, domain_name: str, fields_list: List[str],
|
||||
prior_schema_links: Mapping[str,str], fewshot_example_list:List[Mapping[str, str]])-> str:
|
||||
|
||||
prior_schema_links_str = '['+ ','.join(["""'{}'->{}""".format(k,v) for k,v in prior_schema_links.items()]) + ']'
|
||||
|
||||
instruction = "# 根据数据库的表结构,参考先验信息,找出为每个问题生成SQL查询语句的schema_links"
|
||||
|
||||
schema_linking_example_keys = ["tableName", "fieldsList", "priorSchemaLinks", "question", "analysis", "schemaLinks"]
|
||||
schema_linking_example_template = "Table {tableName}, columns = {fieldsList}, prior_schema_links = {priorSchemaLinks}\n问题:{question}\n分析:{analysis} 所以Schema_links是:\nSchema_links:{schemaLinks}"
|
||||
schema_linking_fewshot_prompt = self.sql_example_prompter.make_few_shot_example_prompt(few_shot_template=schema_linking_example_template,
|
||||
example_keys=schema_linking_example_keys,
|
||||
few_shot_example_meta_list=fewshot_example_list)
|
||||
|
||||
new_case_template = "Table {tableName}, columns = {fieldsList}, prior_schema_links = {priorSchemaLinks}\n问题:{question}\n分析: 让我们一步一步地思考。"
|
||||
new_case_prompt = new_case_template.format(tableName=domain_name, fieldsList=fields_list, priorSchemaLinks=prior_schema_links_str, question=user_query)
|
||||
|
||||
schema_linking_prompt = instruction + '\n\n' + schema_linking_fewshot_prompt + '\n\n' + new_case_prompt
|
||||
return schema_linking_prompt
|
||||
|
||||
def generate_schema_linking_prompt_pool(self, user_query: str, domain_name: str, fields_list: List[str],
|
||||
prior_schema_links: Mapping[str,str], fewshot_example_list_pool:List[List[Mapping[str, str]]])-> List[str]:
|
||||
schema_linking_prompt_pool = []
|
||||
for fewshot_example_list in fewshot_example_list_pool:
|
||||
schema_linking_prompt = self.generate_schema_linking_prompt(user_query, domain_name, fields_list, prior_schema_links, fewshot_example_list)
|
||||
schema_linking_prompt_pool.append(schema_linking_prompt)
|
||||
|
||||
return schema_linking_prompt_pool
|
||||
|
||||
def generate_sql_prompt(self, user_query: str, domain_name: str,
|
||||
schema_link_str: str, data_date: str,
|
||||
fewshot_example_list:List[Mapping[str, str]])-> str:
|
||||
instruction = "# 根据schema_links为每个问题生成SQL查询语句"
|
||||
sql_example_keys = ["question", "currentDate", "tableName", "schemaLinks", "sql"]
|
||||
sql_example_template = "问题:{question}\nCurrent_date:{currentDate}\nTable {tableName}\nSchema_links:{schemaLinks}\nSQL:{sql}"
|
||||
|
||||
sql_example_fewshot_prompt = self.sql_example_prompter.make_few_shot_example_prompt(few_shot_template=sql_example_template,
|
||||
example_keys=sql_example_keys,
|
||||
few_shot_example_meta_list=fewshot_example_list)
|
||||
|
||||
new_case_template = "问题:{question}\nCurrent_date:{currentDate}\nTable {tableName}\nSchema_links:{schemaLinks}\nSQL:"
|
||||
new_case_prompt = new_case_template.format(question=user_query, currentDate=data_date, tableName=domain_name, schemaLinks=schema_link_str)
|
||||
|
||||
sql_example_prompt = instruction + '\n\n' + sql_example_fewshot_prompt + '\n\n' + new_case_prompt
|
||||
|
||||
return sql_example_prompt
|
||||
|
||||
def generate_sql_prompt_pool(self, user_query: str, domain_name: str, data_date: str,
|
||||
schema_link_str_pool: List[str], fewshot_example_list_pool:List[List[Mapping[str, str]]])-> List[str]:
|
||||
sql_prompt_pool = []
|
||||
for schema_link_str, fewshot_example_list in zip(schema_link_str_pool, fewshot_example_list_pool):
|
||||
sql_prompt = self.generate_sql_prompt(user_query, domain_name, schema_link_str, data_date, fewshot_example_list)
|
||||
sql_prompt_pool.append(sql_prompt)
|
||||
|
||||
return sql_prompt_pool
|
||||
|
||||
def self_consistency_vote(self, output_res_pool:List[str]):
|
||||
output_res_counts = Counter(output_res_pool)
|
||||
output_res_max = output_res_counts.most_common(1)[0][0]
|
||||
total_output_num = len(output_res_pool)
|
||||
|
||||
vote_percentage = {k: (v/total_output_num) for k,v in output_res_counts.items()}
|
||||
|
||||
return output_res_max, vote_percentage
|
||||
|
||||
def schema_linking_list_str_unify(self, schema_linking_list: List[str])-> List[str]:
|
||||
schema_linking_list_unify = []
|
||||
for schema_linking_str in schema_linking_list:
|
||||
schema_linking_str_unify = ','.join(sorted([item.strip() for item in schema_linking_str.strip('[]').split(',')]))
|
||||
schema_linking_str_unify = f'[{schema_linking_str_unify}]'
|
||||
schema_linking_list_unify.append(schema_linking_str_unify)
|
||||
|
||||
return schema_linking_list_unify
|
||||
|
||||
|
||||
async def generate_schema_linking_tasks(self, user_query: str, domain_name: str,
|
||||
fields_list: List[str], prior_schema_links: Mapping[str,str],
|
||||
fewshot_example_list_combo:List[List[Mapping[str, str]]]):
|
||||
|
||||
schema_linking_prompt_pool = self.generate_schema_linking_prompt_pool(user_query, domain_name,
|
||||
fields_list, prior_schema_links,
|
||||
fewshot_example_list_combo)
|
||||
schema_linking_output_task_pool = [self.llm._call_async(schema_linking_prompt) for schema_linking_prompt in schema_linking_prompt_pool]
|
||||
schema_linking_output_res_pool = await asyncio.gather(*schema_linking_output_task_pool)
|
||||
logger.debug(f'schema_linking_output_res_pool:{schema_linking_output_res_pool}')
|
||||
|
||||
return schema_linking_output_res_pool
|
||||
|
||||
async def generate_sql_tasks(self, user_query: str, domain_name: str, data_date: str,
|
||||
schema_link_str_pool: List[str], fewshot_example_list_combo:List[List[Mapping[str, str]]]):
|
||||
|
||||
sql_prompt_pool = self.generate_sql_prompt_pool(user_query, domain_name, schema_link_str_pool, data_date, fewshot_example_list_combo)
|
||||
sql_output_task_pool = [self.llm._call_async(sql_prompt) for sql_prompt in sql_prompt_pool]
|
||||
sql_output_res_pool = await asyncio.gather(*sql_output_task_pool)
|
||||
logger.debug(f'sql_output_res_pool:{sql_output_res_pool}')
|
||||
|
||||
return sql_output_res_pool
|
||||
|
||||
async def tasks_run(self, user_query: str, filter_condition: Mapping[str, str], domain_name: str, fields_list: List[str], prior_schema_links: Mapping[str,str], data_date: str, prior_exts: str):
|
||||
logger.info("user_query: {}".format(user_query))
|
||||
logger.info("domain_name: {}".format(domain_name))
|
||||
logger.info("fields_list: {}".format(fields_list))
|
||||
logger.info("current_date: {}".format(data_date))
|
||||
logger.info("prior_schema_links: {}".format(prior_schema_links))
|
||||
logger.info("prior_exts: {}".format(prior_exts))
|
||||
|
||||
if prior_exts != '':
|
||||
user_query = user_query + ' 备注:'+prior_exts
|
||||
logger.info("user_query_prior_exts: {}".format(user_query))
|
||||
|
||||
fewshot_example_meta_list = self.get_examples_candidates(user_query, filter_condition)
|
||||
fewshot_example_list_combo = self.get_fewshot_example_combos(fewshot_example_meta_list)
|
||||
|
||||
schema_linking_output_candidates = await self.generate_schema_linking_tasks(user_query, domain_name, fields_list, prior_schema_links, fewshot_example_list_combo)
|
||||
schema_linking_candidate_list = [schema_link_parse(schema_linking_output_candidate) for schema_linking_output_candidate in schema_linking_output_candidates]
|
||||
logger.debug(f'schema_linking_candidate_list:{schema_linking_candidate_list}')
|
||||
schema_linking_candidate_sorted_list = self.schema_linking_list_str_unify(schema_linking_candidate_list)
|
||||
logger.debug(f'schema_linking_candidate_sorted_list:{schema_linking_candidate_sorted_list}')
|
||||
|
||||
schema_linking_output_max, schema_linking_output_vote_percentage = self.self_consistency_vote(schema_linking_candidate_sorted_list)
|
||||
|
||||
sql_output_candicates = await self.generate_sql_tasks(user_query, domain_name, data_date, schema_linking_candidate_list,fewshot_example_list_combo)
|
||||
logger.debug(f'sql_output_candicates:{sql_output_candicates}')
|
||||
sql_output_max, sql_output_vote_percentage = self.self_consistency_vote(sql_output_candicates)
|
||||
|
||||
resp = dict()
|
||||
resp['query'] = user_query
|
||||
resp['model'] = domain_name
|
||||
resp['fields'] = fields_list
|
||||
resp['priorSchemaLinking'] = prior_schema_links
|
||||
resp['dataDate'] = data_date
|
||||
|
||||
resp['schemaLinkStr'] = schema_linking_output_max
|
||||
resp['schemaLinkingWeight'] = schema_linking_output_vote_percentage
|
||||
|
||||
resp['sqlOutput'] = sql_output_max
|
||||
resp['sqlWeight'] = sql_output_vote_percentage
|
||||
|
||||
logger.info("resp: {}".format(resp))
|
||||
|
||||
return resp
|
||||
|
||||
class Text2DSLAgentWrapper(object):
|
||||
def __init__(self, sql_agent:Text2DSLAgent, sql_agent_cs:Text2DSLAgentConsistency,
|
||||
is_shortcut:bool, is_self_consistency:bool):
|
||||
self.sql_agent = sql_agent
|
||||
self.sql_agent_cs = sql_agent_cs
|
||||
|
||||
self.is_shortcut = is_shortcut
|
||||
self.is_self_consistency = is_self_consistency
|
||||
|
||||
async def async_query2sql(self, query_text: str, filter_condition: Mapping[str,str],
|
||||
model_name: str, fields_list: List[str],
|
||||
data_date: str, prior_schema_links: Mapping[str,str], prior_exts: str):
|
||||
if self.is_self_consistency:
|
||||
logger.info("sql wrapper: self_consistency")
|
||||
resp = await self.sql_agent_cs.tasks_run(user_query=query_text, filter_condition=filter_condition, domain_name=model_name, fields_list=fields_list, prior_schema_links=prior_schema_links, data_date=data_date, prior_exts=prior_exts)
|
||||
return resp
|
||||
elif self.is_shortcut:
|
||||
logger.info("sql wrapper: shortcut")
|
||||
resp = await self.sql_agent.async_query2sql_shortcut(query_text=query_text, filter_condition=filter_condition, model_name=model_name, fields_list=fields_list, data_date=data_date, prior_schema_links=prior_schema_links, prior_exts=prior_exts)
|
||||
return resp
|
||||
else:
|
||||
logger.info("sql wrapper: normal")
|
||||
resp = await self.sql_agent.async_query2sql(query_text=query_text, filter_condition=filter_condition, model_name=model_name, fields_list=fields_list, data_date=data_date, prior_schema_links=prior_schema_links, prior_exts=prior_exts)
|
||||
return resp
|
||||
|
||||
def update_configs(self, is_shortcut, is_self_consistency,
|
||||
sql_examplars, num_examples, num_fewshots, num_self_consistency):
|
||||
self.is_shortcut = is_shortcut
|
||||
self.is_self_consistency = is_self_consistency
|
||||
|
||||
self.sql_agent.update_examples(sql_examplars=sql_examplars, num_fewshots=num_examples)
|
||||
self.sql_agent_cs.update_examples(sql_examplars=sql_examplars, num_examples=num_examples, num_fewshots=num_fewshots, num_self_consistency=num_self_consistency)
|
||||
|
||||
Reference in New Issue
Block a user