(improvement)(Chat) Move python module from Chat To Headless (#823)

Co-authored-by: jolunoluo
This commit is contained in:
LXW
2024-03-15 12:47:11 +08:00
committed by GitHub
parent 988a025cdf
commit 36136e4c15
30 changed files with 2 additions and 2 deletions

View File

@@ -0,0 +1,99 @@
# -*- coding:utf-8 -*-
import json
import os
import re
import sys
from typing import Any, List, Mapping, Union
from instances.logging_instance import logger
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
sys.path.append(os.path.dirname(os.path.abspath(__file__)))
def construct_plugin_prompt(tool_config):
tool_name = tool_config["name"]
tool_description = tool_config["description"]
tool_examples = tool_config["examples"]
prompt = "【工具名称】\n" + tool_name + "\n"
prompt += "【工具描述】\n" + tool_description + "\n"
prompt += "【工具适用问题示例】\n"
for example in tool_examples:
prompt += example + "\n"
return prompt
def construct_plugin_pool_prompt(tool_config_list):
tool_explain_list = []
for tool_config in tool_config_list:
tool_explain = construct_plugin_prompt(tool_config)
tool_explain_list.append(tool_explain)
tool_explain_list_str = "\n\n".join(tool_explain_list)
return tool_explain_list_str
def construct_task_prompt(query_text, tool_explain_list_str):
instruction = """问题为:{query_text}\n请根据问题和工具的描述选择对应的工具完成任务。请注意只能选择1个工具。请一步一步地分析选择工具的原因(每个工具的【工具适用问题示例】是选择的重要参考依据)并给出最终选择输出格式为json,key为分析过程, ’选择工具‘""".format(
query_text=query_text
)
prompt = "工具选择如下:\n\n{tool_explain_list_str}\n\n【任务说明】\n{instruction}".format(
instruction=instruction, tool_explain_list_str=tool_explain_list_str
)
return prompt
def plugin_selection_output_parse(llm_output: str) -> Union[Mapping[str, str], None]:
try:
pattern = r"\{[^{}]+\}"
find_result = re.findall(pattern, llm_output)
result = find_result[0].strip()
logger.info("result: {}", result)
result_dict = json.loads(result)
logger.info("result_dict: {}", result_dict)
key_mapping = {"分析过程": "analysis", "选择工具": "toolSelection"}
converted_result_dict = {
key_mapping[key]: value
for key, value in result_dict.items()
if key in key_mapping
}
except Exception as e:
logger.exception(e)
converted_result_dict = None
return converted_result_dict
def plugins_config_format_convert(
plugin_config_list: List[Mapping[str, Any]]
) -> List[Mapping[str, Any]]:
plugin_config_list_new = []
for plugin_config in plugin_config_list:
plugin_config_new = dict()
name = plugin_config["name"]
description = plugin_config["description"]
examples = plugin_config["examples"]
parameters = plugin_config["parameters"]
examples_str = "\n".join(examples)
description_new = """{plugin_desc}\n\n例如能够处理如下问题:\n{examples_str}""".format(
plugin_desc=description, examples_str=examples_str
)
plugin_config_new["name"] = name
plugin_config_new["description"] = description_new
plugin_config_new["parameters"] = parameters
plugin_config_list_new.append(plugin_config_new)
return plugin_config_list_new

View File

@@ -0,0 +1,28 @@
# -*- coding:utf-8 -*-
import os
import sys
from typing import Any, List, Mapping, Union
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
sys.path.append(os.path.dirname(os.path.abspath(__file__)))
from plugin_call.prompt_construct import (
construct_plugin_pool_prompt,
construct_task_prompt,
plugin_selection_output_parse,
)
from instances.llm_instance import llm
def plugin_selection_run(
query_text: str, plugin_configs: List[Mapping[str, Any]]
) -> Union[Mapping[str, str], None]:
tools_prompt = construct_plugin_pool_prompt(plugin_configs)
task_prompt = construct_task_prompt(query_text, tools_prompt)
llm_output = llm(task_prompt)
parsed_output = plugin_selection_output_parse(llm_output)
return parsed_output

View File

@@ -0,0 +1,98 @@
# -*- coding:utf-8 -*-
import os
import sys
import uuid
from typing import Any, List, Mapping, Optional, Union
import chromadb
from chromadb import Client
from chromadb.config import Settings
from chromadb.api import Collection, Documents, Embeddings
from chromadb.api.types import CollectionMetadata
sys.path.append(os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))))
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
sys.path.append(os.path.dirname(os.path.abspath(__file__)))
from instances.logging_instance import logger
from utils.chromadb_utils import (get_chroma_collection_size, query_chroma_collection,
parse_retrieval_chroma_collection_query, chroma_collection_query_retrieval_format,
get_chroma_collection_by_ids, get_chroma_collection_size,
add_chroma_collection, update_chroma_collection, delete_chroma_collection_by_ids,
empty_chroma_collection_2)
from utils.text2vec import Text2VecEmbeddingFunction
class ChromaCollectionRetriever(object):
def __init__(self, collection:Collection):
self.collection = collection
def retrieval_query_run(self, query_texts_list:List[str]=None, query_embeddings:Embeddings=None,
filter_condition:Mapping[str,str]=None, n_results:int=5):
retrieval_res = query_chroma_collection(self.collection, query_texts_list, query_embeddings,
filter_condition, n_results)
parsed_retrieval_res = parse_retrieval_chroma_collection_query(retrieval_res)
logger.debug('parsed_retrieval_res: {}', parsed_retrieval_res)
parsed_retrieval_res_format = chroma_collection_query_retrieval_format(query_texts_list, query_embeddings, parsed_retrieval_res)
logger.debug('parsed_retrieval_res_format: {}', parsed_retrieval_res_format)
return parsed_retrieval_res_format
def get_query_by_ids(self, query_ids:List[str]):
queries = get_chroma_collection_by_ids(self.collection, query_ids)
return queries
def get_query_size(self):
return get_chroma_collection_size(self.collection)
def add_queries(self, query_text_list:List[str],
query_id_list:List[str],
metadatas:List[Mapping[str, str]]=None,
embeddings:Embeddings=None):
add_chroma_collection(self.collection, query_text_list, query_id_list, metadatas, embeddings)
return True
def update_queries(self, query_text_list:List[str],
query_id_list:List[str],
metadatas:List[Mapping[str, str]]=None,
embeddings:Embeddings=None):
update_chroma_collection(self.collection, query_text_list, query_id_list, metadatas, embeddings)
return True
def delete_queries_by_ids(self, query_ids:List[str]):
delete_chroma_collection_by_ids(self.collection, query_ids)
return True
def empty_query_collection(self):
self.collection = empty_chroma_collection_2(self.collection)
return True
class CollectionManager(object):
def __init__(self, chroma_client:Client, embedding_func: Text2VecEmbeddingFunction, collection_meta: Optional[CollectionMetadata] = None):
self.chroma_client = chroma_client
self.embedding_func = embedding_func
self.collection_meta = collection_meta
def list_collections(self):
collection_list = self.chroma_client.list_collections()
return collection_list
def get_collection(self, collection_name:str):
collection = self.chroma_client.get_collection(name=collection_name, embedding_function=self.embedding_func)
return collection
def create_collection(self, collection_name:str):
collection = self.chroma_client.create_collection(name=collection_name, embedding_function=self.embedding_func, metadata=self.collection_meta)
return collection
def get_or_create_collection(self, collection_name:str):
collection = self.chroma_client.get_or_create_collection(name=collection_name, embedding_function=self.embedding_func, metadata=self.collection_meta)
return collection
def delete_collection(self, collection_name:str):
self.chroma_client.delete_collection(collection_name)
return True

View File

@@ -0,0 +1,37 @@
# -*- coding:utf-8 -*-
import os
import sys
import uuid
from typing import Any, List, Mapping, Optional, Union
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
sys.path.append(os.path.dirname(os.path.abspath(__file__)))
from instances.logging_instance import logger
import chromadb
from chromadb.config import Settings
from chromadb.api import Collection, Documents, Embeddings
from utils.text2vec import Text2VecEmbeddingFunction
from instances.chromadb_instance import client
from config.config_parse import SOLVED_QUERY_COLLECTION_NAME, PRESET_QUERY_COLLECTION_NAME
from retriever import ChromaCollectionRetriever, CollectionManager
emb_func = Text2VecEmbeddingFunction()
collection_manager = CollectionManager(chroma_client=client, embedding_func=emb_func
,collection_meta={"hnsw:space": "cosine"})
solved_query_collection = collection_manager.get_or_create_collection(collection_name=SOLVED_QUERY_COLLECTION_NAME)
preset_query_collection = collection_manager.get_or_create_collection(collection_name=PRESET_QUERY_COLLECTION_NAME)
solved_query_retriever = ChromaCollectionRetriever(solved_query_collection)
preset_query_retriever = ChromaCollectionRetriever(preset_query_collection)
logger.info("init_solved_query_collection_size: {}".format(solved_query_retriever.get_query_size()))
logger.info("init_preset_query_collection_size: {}".format(preset_query_retriever.get_query_size()))

View File

@@ -0,0 +1,167 @@
# -*- coding:utf-8 -*-
from typing import Any, List, Mapping, Optional, Union, Tuple
import os
import sys
sys.path.append(os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))))
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
sys.path.append(os.path.dirname(os.path.abspath(__file__)))
from instances.logging_instance import logger
from instances.text2vec_instance import emb_func
from sqlglot import parse_one, exp
import numpy as np
def sql2schema_linking(sql: str):
sql_ast = parse_one(sql)
fields_raw = []
table_alias_map = dict()
literals = []
fields = []
for literal in sql_ast.find_all(exp.Literal):
literals.append(literal.output_name)
for column in sql_ast.find_all(exp.Column):
fields_raw.append({
'column_table_alias': column.table,
'column_name': column.name,
})
for table in sql_ast.find_all(exp.Table):
if table.alias not in table_alias_map:
table_alias_map[table.alias] = table.name
logger.debug(f'literals: {literals}')
logger.debug(f'fields_raw: {fields_raw}')
logger.debug(f'table_alias_map: {table_alias_map}')
for field in fields_raw:
column_table_alias = field['column_table_alias']
column_name = field['column_name']
if column_table_alias.strip() == '':
column_table = ''
fields.append((column_table, column_name))
elif column_table_alias in table_alias_map:
column_table = table_alias_map[column_table_alias]
fields.append((column_table, column_name))
elif column_table_alias in table_alias_map.values():
column_table = column_table_alias
fields.append((column_table, column_name))
else:
logger.error(f'column_table_alias: {column_table_alias} not in table_alias_map: {table_alias_map}')
raise Exception(f'column_table_alias: {column_table_alias} not in table_alias_map: {table_alias_map}')
return {
'fields': list(set(fields)),
'literals': literals
}
def get_question_slices(question: str, min_window_size: int, max_window_size: int):
assert min_window_size <= max_window_size
assert min_window_size > 1
assert max_window_size < len(question)+1
question_slices = []
for i in range(len(question)):
for j in range(i+1, len(question)+1):
if j-i >= min_window_size and j-i <= max_window_size:
question_slices.append(question[i:j])
return question_slices
def schema_linking_match(fields: List[Tuple[str,str]], question: str, min_window_size: int, max_window_size: int):
question_slices = get_question_slices(question, min_window_size, max_window_size)
assert len(question_slices) > 0
logger.debug('question_slices_len:{}'.format(len(question_slices)))
logger.debug(f'question_slices: {question_slices}')
question_slices_embeddings = emb_func(question_slices)
fields_embeddings = emb_func([field[1] for field in fields])
fields_embeddings = np.array(fields_embeddings) # (n_fields, 768)
question_slices_embeddings = np.array(question_slices_embeddings) # (n_question_slices, 768)
question_slices_embeddings_norm = question_slices_embeddings / np.linalg.norm(question_slices_embeddings, axis=1, keepdims=True) # (n_question_slices, 768)
question_slices_embeddings_norm_transpose = question_slices_embeddings_norm.T # (768, n_question_slices)
if len(fields) > 0:
fields_embeddings_norm = fields_embeddings / np.linalg.norm(fields_embeddings, axis=1, keepdims=True) # (n_fields, 768)
fields_question_slices_similarity = np.matmul(fields_embeddings_norm, question_slices_embeddings_norm_transpose) # (n_fields, n_question_slices)
logger.debug('fields_question_slices_similarity_max:{}'.format(np.max(fields_question_slices_similarity, axis=1)))
fields_question_slices_argmax = np.argmax(fields_question_slices_similarity, axis=1) # (n_fields, )
logger.debug('fields_question_slices_argmax:{}'.format(fields_question_slices_argmax))
fields_question_slices_pair = []
for i in range(len(fields)):
if fields[i][0]!="":
fields_question_slices_pair.append((fields[i][0]+'.'+fields[i][1], question_slices[fields_question_slices_argmax[i]]))
else:
fields_question_slices_pair.append((fields[i][1], question_slices[fields_question_slices_argmax[i]]))
logger.debug(f'fields_question_slices_pair: {fields_question_slices_pair}')
else:
fields_question_slices_pair = []
return fields_question_slices_pair
def construct_schema_linking_cot(question:str, fields_question_slices_pair:List[Tuple[str,str]], literals_list:List[str]):
cot_intro= """Lets think step by step. In the question "{question}", we are asked:""".format(question=question)
schema_linkings_list = []
fields_cot_template = """"{question_slice}" so we need column = [{field}]"""
fields_cot_list = []
for field, question_slice in fields_question_slices_pair:
fields_cot_list.append(fields_cot_template.format(question_slice=question_slice, field=field))
schema_linkings_list.append(field)
fields_cot = '\n'.join(fields_cot_list)
literals_cot_template = """Based on the tables, columns, and Foreign_keys, The set of possible cell values are = [{literals}]. So the Schema_links are:"""
literals_cot = literals_cot_template.format(literals=",".join(literals_list))
schema_linkings_list += literals_list
schema_linking_str = '[' + ",".join(schema_linkings_list) + ']'
schema_linkings = 'Schema_links: '+ schema_linking_str
cot = """{cot_intro}""".format(cot_intro=cot_intro)
if len(fields_cot_list) > 0:
cot += '\n' + fields_cot
cot += '\n' + literals_cot
cot += '\n' + schema_linkings
return cot, schema_linking_str
def auto_cot_run(question, sql, min_window_size, max_window_size):
sql_entity = sql2schema_linking(sql)
logger.debug(f'sql_entity: {sql_entity}')
fields = sql_entity['fields']
literals = sql_entity['literals']
field_linked_pairs = schema_linking_match(fields, question, min_window_size, max_window_size)
logger.debug(f'field_linked_pairs: {field_linked_pairs}')
auto_schema_linking_cot, auto_schema_linkings = construct_schema_linking_cot(question, field_linked_pairs, literals)
logger.debug(f'auto_schema_linking_cot: {auto_schema_linking_cot}')
logger.debug(f'auto_schema_linkings: {auto_schema_linkings}')
return auto_schema_linking_cot, auto_schema_linkings
if __name__ == '__main__':
question = "没有获得过奖项的高校有哪几所?"
sql = "select 名称 from 高校 where 词条id not in ( select 高校id from 奖项 )"
min_window_size = 6
max_window_size = 10
generated_schema_linking_cot, generated_schema_linkings = auto_cot_run(question, sql, min_window_size, max_window_size)

View File

@@ -0,0 +1,83 @@
# -*- coding:utf-8 -*-
import os
import sys
from typing import Any, List, Union, Mapping
sys.path.append(os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))))
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
sys.path.append(os.path.dirname(os.path.abspath(__file__)))
from instances.logging_instance import logger
from auto_cot import auto_cot_run
def transform_sql_example(question:str, current_date:str, table_name:str, field_list: Union[str, List[str]], prior_linkings: Union[str, Mapping[str,str]], prior_exts:str, sql:str=None):
db_schema = f"Table: {table_name}, Columns = {field_list}\nForeign_keys: []"
prior_linkings_pairs = []
if isinstance(prior_linkings, str):
prior_linkings = prior_linkings.strip('[]')
if prior_linkings.strip() == '':
prior_linkings = []
else:
prior_linkings = prior_linkings.split(',')
logger.debug(f'prior_linkings: {prior_linkings}')
for prior_linking in prior_linkings:
logger.debug(f'prior_linking: {prior_linking}')
entity_value, entity_type = prior_linking.split('->')
entity_linking = """{}‘是一个’{}""".format(entity_value, entity_type)
prior_linkings_pairs.append(entity_linking)
elif isinstance(prior_linkings, Mapping):
for entity_value, entity_type in prior_linkings.items():
entity_linking = """{}‘是一个’{}""".format(entity_value, entity_type)
prior_linkings_pairs.append(entity_linking)
prior_linkings_str = ''.join(prior_linkings_pairs)
current_data_str = """当前的日期是{}""".format(current_date)
question_augmented = """{question} (补充信息:{prior_linking}{current_date}) (备注: {prior_exts})""".format(question=question, prior_linking=prior_linkings_str, prior_exts=prior_exts, current_date=current_data_str)
return question_augmented, db_schema, sql
def transform_sql_example_autoCoT_run(examplar_list, min_window_size, max_window_size):
transformed_sql_examplar_list = []
for examplar in examplar_list:
question = examplar['question']
current_date = examplar['currentDate']
table_name = examplar['tableName']
field_list = examplar['fieldsList']
prior_linkings = examplar['priorSchemaLinks']
sql = examplar['sql']
if 'priorExts' not in examplar:
prior_exts = ''
else:
prior_exts = examplar['priorExts']
question_augmented, db_schema, sql = transform_sql_example(question=question, current_date=current_date, table_name=table_name, field_list=field_list, prior_linkings=prior_linkings, prior_exts=prior_exts, sql=sql)
logger.debug(f'question_augmented: {question_augmented}')
logger.debug(f'db_schema: {db_schema}')
logger.debug(f'sql: {sql}')
generated_schema_linking_cot, generated_schema_linkings = auto_cot_run(question_augmented, sql, min_window_size, max_window_size)
transformed_sql_examplar = dict()
transformed_sql_examplar['question'] = question
transformed_sql_examplar['questionAugmented'] = question_augmented
transformed_sql_examplar['modelName'] = table_name
transformed_sql_examplar['dbSchema'] = db_schema
transformed_sql_examplar['sql'] = sql
transformed_sql_examplar['generatedSchemaLinkingCoT'] = generated_schema_linking_cot
transformed_sql_examplar['generatedSchemaLinkings'] = generated_schema_linkings
logger.debug(f'transformed_sql_examplar: {transformed_sql_examplar}')
transformed_sql_examplar_list.append(transformed_sql_examplar)
return transformed_sql_examplar_list

View File

@@ -0,0 +1,79 @@
# -*- coding:utf-8 -*-
import os
import sys
from typing import List, Mapping
from chromadb.api import Collection
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
sys.path.append(os.path.dirname(os.path.abspath(__file__)))
from instances.logging_instance import logger
from services.query_retrieval.retriever import ChromaCollectionRetriever
class FewShotPromptTemplate2(object):
def __init__(self, collection:Collection, retrieval_key:str, few_shot_seperator:str = "\n\n") -> None:
self.collection = collection
self.few_shot_retriever = ChromaCollectionRetriever(self.collection)
self.retrieval_key = retrieval_key
self.few_shot_seperator = few_shot_seperator
def add_few_shot_example(self, example_ids: List[str] , example_units: List[Mapping[str, str]])-> None:
query_text_list = []
for idx, example_unit in enumerate(example_units):
query_text_list.append(example_unit[self.retrieval_key])
self.few_shot_retriever.add_queries(query_text_list=query_text_list, query_id_list=example_ids, metadatas=example_units)
def update_few_shot_example(self, example_ids: List[str] , example_units: List[Mapping[str, str]])-> None:
query_text_list = []
for idx, example_unit in enumerate(example_units):
query_text_list.append(example_unit[self.retrieval_key])
self.few_shot_retriever.update_queries(query_text_list=query_text_list, query_id_list=example_ids, metadatas=example_units)
def delete_few_shot_example(self, example_ids: List[str])-> None:
self.few_shot_retriever.delete_queries_by_ids(query_ids=example_ids)
def get_few_shot_example(self, example_ids: List[str]):
return self.few_shot_retriever.get_query_by_ids(query_ids=example_ids)
def count_few_shot_example(self)-> int:
return self.few_shot_retriever.get_query_size()
def reload_few_shot_example(self, example_ids: List[str] , example_units: List[Mapping[str, str]])-> None:
logger.info(f"original {self.collection.name} size: {self.few_shot_retriever.get_query_size()}")
self.few_shot_retriever.empty_query_collection()
logger.info(f"emptied {self.collection.name} size: {self.few_shot_retriever.get_query_size()}")
self.add_few_shot_example(example_ids=example_ids, example_units=example_units)
logger.info(f"reloaded {self.collection.name} size: {self.few_shot_retriever.get_query_size()}")
def _sub_dict(self, d:Mapping[str, str], keys:List[str])-> Mapping[str, str]:
return {k:d[k] for k in keys if k in d}
def retrieve_few_shot_example(self, query_text: str, retrieval_num: int, filter_condition: Mapping[str,str] =None)-> List[Mapping[str, str]]:
query_text_list = [query_text]
retrieval_res_list = self.few_shot_retriever.retrieval_query_run(query_texts_list=query_text_list,
filter_condition=filter_condition, n_results=retrieval_num)
retrieval_res_unit_list = retrieval_res_list[0]['retrieval']
return retrieval_res_unit_list
def make_few_shot_example_prompt(self, few_shot_template: str, example_keys: List[str],
few_shot_example_meta_list: List[Mapping[str, str]])-> str:
few_shot_example_str_unit_list = []
retrieval_metas_list = [self._sub_dict(few_shot_example_meta['metadata'], example_keys) for few_shot_example_meta in few_shot_example_meta_list]
for meta in retrieval_metas_list:
few_shot_example_str_unit_list.append(few_shot_template.format(**meta))
few_shot_example_str = self.few_shot_seperator.join(few_shot_example_str_unit_list)
return few_shot_example_str

View File

@@ -0,0 +1,40 @@
# -*- coding:utf-8 -*-
import json
import os
import sys
from typing import List, Mapping
import requests
sys.path.append(os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))))
sys.path.append(os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))))
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
sys.path.append(os.path.dirname(os.path.abspath(__file__)))
from instances.logging_instance import logger
from config.config_parse import (
TEXT2DSL_EXAMPLE_NUM, TEXT2DSL_FEWSHOTS_NUM, TEXT2DSL_SELF_CONSISTENCY_NUM,
LLMPARSER_HOST, LLMPARSER_PORT,)
from few_shot_example.s2sql_exemplar import exemplars as sql_exemplars
def text2dsl_agent_wrapper_setting_update(llm_host:str, llm_port:str,
sql_examplars:List[Mapping[str, str]],
example_nums:int, fewshot_nums:int, self_consistency_nums:int):
sql_ids = [str(i) for i in range(0, len(sql_examplars))]
url = f"http://{llm_host}:{llm_port}/query2sql_setting_update"
payload = {
"sqlExamplars":sql_examplars, "sqlIds": sql_ids,
"exampleNums":example_nums, "fewshotNums":fewshot_nums, "selfConsistencyNums":self_consistency_nums
}
headers = {'content-type': 'application/json'}
response = requests.post(url, data=json.dumps(payload), headers=headers)
logger.info(response.text)
if __name__ == "__main__":
text2dsl_agent_wrapper_setting_update(LLMPARSER_HOST,LLMPARSER_PORT,
sql_exemplars, TEXT2DSL_EXAMPLE_NUM, TEXT2DSL_FEWSHOTS_NUM, TEXT2DSL_SELF_CONSISTENCY_NUM)

View File

@@ -0,0 +1,59 @@
# -*- coding:utf-8 -*-
import re
import os
import sys
sys.path.append(os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))))
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
sys.path.append(os.path.dirname(os.path.abspath(__file__)))
from instances.logging_instance import logger
def schema_link_parse(schema_link_output: str):
try:
schema_link_output = schema_link_output.strip()
pattern = r'Schema_links:(.*)'
schema_link_output = re.findall(pattern, schema_link_output, re.DOTALL)[0].strip()
except Exception as e:
logger.exception(e)
schema_link_output = None
return schema_link_output
def combo_schema_link_parse(schema_linking_sql_combo_output: str):
try:
schema_linking_sql_combo_output = schema_linking_sql_combo_output.strip()
pattern = r'Schema_links:(\[.*?\])|Schema_links: (\[.*?\])'
schema_links_match = re.search(pattern, schema_linking_sql_combo_output)
if schema_links_match.group(1):
schema_links = schema_links_match.group(1)
elif schema_links_match.group(2):
schema_links = schema_links_match.group(2)
else:
schema_links = None
except Exception as e:
logger.exception(e)
schema_links = None
return schema_links
def combo_sql_parse(schema_linking_sql_combo_output: str):
try:
schema_linking_sql_combo_output = schema_linking_sql_combo_output.strip()
pattern = r'SQL:(.*)'
sql_match = re.search(pattern, schema_linking_sql_combo_output)
if sql_match:
sql = sql_match.group(1)
else:
sql = None
except Exception as e:
logger.exception(e)
sql = None
return sql

View File

@@ -0,0 +1,63 @@
# -*- coding:utf-8 -*-
import asyncio
import os
import sys
sys.path.append(os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))))
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
sys.path.append(os.path.dirname(os.path.abspath(__file__)))
import json
from s2sql.constructor import FewShotPromptTemplate2
from s2sql.sql_agent import Text2DSLAgent, Text2DSLAgentAutoCoT, Text2DSLAgentWrapper
from instances.llm_instance import llm
from instances.chromadb_instance import client as chromadb_client
from instances.logging_instance import logger
from instances.text2vec_instance import emb_func
from few_shot_example.s2sql_exemplar import exemplars as sql_exemplars
from config.config_parse import (TEXT2DSLAGENT_COLLECTION_NAME, TEXT2DSLAGENTACT_COLLECTION_NAME,
TEXT2DSL_EXAMPLE_NUM, TEXT2DSL_FEWSHOTS_NUM, TEXT2DSL_SELF_CONSISTENCY_NUM,
ACT_MIN_WINDOWN_SIZE, ACT_MAX_WINDOWN_SIZE)
text2dsl_agent_collection = chromadb_client.get_or_create_collection(name=TEXT2DSLAGENT_COLLECTION_NAME,
embedding_function=emb_func,
metadata={"hnsw:space": "cosine"})
text2dsl_agent_act_collection = chromadb_client.get_or_create_collection(name=TEXT2DSLAGENTACT_COLLECTION_NAME,
embedding_function=emb_func,
metadata={"hnsw:space": "cosine"})
text2dsl_agent_example_prompter = FewShotPromptTemplate2(collection=text2dsl_agent_collection,
retrieval_key="question",
few_shot_seperator='\n\n')
text2dsl_agent_act_example_prompter = FewShotPromptTemplate2(collection=text2dsl_agent_act_collection,
retrieval_key="question",
few_shot_seperator='\n\n')
text2sql_agent = Text2DSLAgent(num_fewshots=TEXT2DSL_FEWSHOTS_NUM, num_examples=TEXT2DSL_EXAMPLE_NUM, num_self_consistency=TEXT2DSL_SELF_CONSISTENCY_NUM,
sql_example_prompter=text2dsl_agent_example_prompter, llm=llm)
text2sql_agent_autoCoT = Text2DSLAgentAutoCoT(num_fewshots=TEXT2DSL_FEWSHOTS_NUM, num_examples=TEXT2DSL_EXAMPLE_NUM, num_self_consistency=TEXT2DSL_SELF_CONSISTENCY_NUM,
sql_example_prompter=text2dsl_agent_act_example_prompter, llm=llm,
auto_cot_min_window_size=ACT_MIN_WINDOWN_SIZE, auto_cot_max_window_size=ACT_MAX_WINDOWN_SIZE)
sql_ids = [str(i) for i in range(0, len(sql_exemplars))]
text2sql_agent.reload_setting(sql_ids, sql_exemplars, TEXT2DSL_EXAMPLE_NUM, TEXT2DSL_FEWSHOTS_NUM, TEXT2DSL_SELF_CONSISTENCY_NUM)
if text2sql_agent_autoCoT.count_examples()==0:
source_dir_path = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
example_dir_path = os.path.join(source_dir_path, 'few_shot_example')
example_json_file = os.path.join(example_dir_path, 's2sql_exemplar3_transformed.json.json')
with open(example_json_file, 'r', encoding='utf-8') as f:
transformed_sql_examplar_list = json.load(f)
transformed_sql_examplar_ids = [str(i) for i in range(0, len(transformed_sql_examplar_list))]
text2sql_agent_autoCoT.reload_setting_autoCoT(transformed_sql_examplar_ids, transformed_sql_examplar_list, TEXT2DSL_EXAMPLE_NUM, TEXT2DSL_FEWSHOTS_NUM, TEXT2DSL_SELF_CONSISTENCY_NUM)
text2sql_agent_router = Text2DSLAgentWrapper(sql_agent_act=text2sql_agent_autoCoT)

View File

@@ -0,0 +1,812 @@
import os
import sys
from typing import List, Union, Mapping, Any
from collections import Counter
import random
import asyncio
from enum import Enum
from langchain.llms.base import BaseLLM
sys.path.append(os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))))
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
sys.path.append(os.path.dirname(os.path.abspath(__file__)))
from instances.logging_instance import logger
from s2sql.constructor import FewShotPromptTemplate2
from s2sql.output_parser import schema_link_parse, combo_schema_link_parse, combo_sql_parse
from s2sql.auto_cot_run import transform_sql_example, transform_sql_example_autoCoT_run
class Text2DSLAgentBase(object):
def __init__(self, num_fewshots:int, num_examples:int, num_self_consistency:int,
sql_example_prompter:FewShotPromptTemplate2, llm: BaseLLM) -> None:
self.num_fewshots = num_fewshots
self.num_examples = num_examples
assert self.num_fewshots <= self.num_examples
self.num_self_consistency = num_self_consistency
self.llm = llm
self.sql_example_prompter = sql_example_prompter
def get_examples_candidates(self, question: str, filter_condition: Mapping[str, str], num_examples: int)->List[Mapping[str, str]]:
few_shot_example_meta_list = self.sql_example_prompter.retrieve_few_shot_example(question, num_examples, filter_condition)
if len(few_shot_example_meta_list) == num_examples:
return few_shot_example_meta_list
elif len(few_shot_example_meta_list) < num_examples:
logger.info(f"few_shot_example_meta_list size: {len(few_shot_example_meta_list)} < num_examples: {num_examples}")
existed_id_set = set([item['id'] for item in few_shot_example_meta_list])
extra_few_shot_example_meta_list = self.sql_example_prompter.retrieve_few_shot_example(query_text=question, retrieval_num=num_examples, filter_condition=None)
for item in extra_few_shot_example_meta_list:
if item['id'] not in existed_id_set:
few_shot_example_meta_list.append(item)
existed_id_set.add(item['id'])
if len(few_shot_example_meta_list) == num_examples:
break
logger.info(f"few_shot_example_meta_list size: {len(few_shot_example_meta_list)} = num_examples: {num_examples}")
return few_shot_example_meta_list
else:
logger.info(f"few_shot_example_meta_list size: {len(few_shot_example_meta_list)} > num_examples: {num_examples}")
few_shot_example_meta_list = few_shot_example_meta_list[:num_examples]
return few_shot_example_meta_list
def get_fewshot_example_combos(self, example_meta_list:List[Mapping[str, str]], num_fewshots:int)-> List[List[Mapping[str, str]]]:
fewshot_example_list = []
for i in range(0, self.num_self_consistency):
random.shuffle(example_meta_list)
fewshot_example_list.append(example_meta_list[:num_fewshots])
return fewshot_example_list
def self_consistency_vote(self, output_res_pool:List[str]):
output_res_counts = Counter(output_res_pool)
output_res_max = output_res_counts.most_common(1)[0][0]
total_output_num = len(output_res_pool)
vote_percentage = {k: (v/total_output_num) for k,v in output_res_counts.items()}
return output_res_max, vote_percentage
def schema_linking_list_str_unify(self, schema_linking_list: List[str])-> List[str]:
schema_linking_list_unify = []
for schema_linking_str in schema_linking_list:
schema_linking_str_unify = ','.join(sorted([item.strip() for item in schema_linking_str.strip('[]').split(',')]))
schema_linking_str_unify = f'[{schema_linking_str_unify}]'
schema_linking_list_unify.append(schema_linking_str_unify)
return schema_linking_list_unify
class Text2DSLAgentAutoCoT(Text2DSLAgentBase):
def __init__(self, num_fewshots:int, num_examples:int, num_self_consistency:int,
sql_example_prompter:FewShotPromptTemplate2, llm: BaseLLM,
auto_cot_min_window_size: int, auto_cot_max_window_size: int):
super().__init__(num_fewshots, num_examples, num_self_consistency, sql_example_prompter, llm)
assert auto_cot_min_window_size <= auto_cot_max_window_size
self.auto_cot_min_window_size = auto_cot_min_window_size
self.auto_cot_max_window_size = auto_cot_max_window_size
def reload_setting(self, sql_example_ids: List[str], sql_example_units: List[Mapping[str,str]], num_examples:int, num_fewshots:int, num_self_consistency:int):
self.num_fewshots = num_fewshots
self.num_examples = num_examples
assert self.num_fewshots <= self.num_examples
self.num_self_consistency = num_self_consistency
assert self.num_self_consistency >= 1
new_sql_example_unit_list = transform_sql_example_autoCoT_run(sql_example_units, self.auto_cot_min_window_size, self.auto_cot_max_window_size)
self.sql_example_prompter.reload_few_shot_example(sql_example_ids, new_sql_example_unit_list)
def reload_setting_autoCoT(self, sql_example_ids: List[str], auto_cot_sql_example_units: List[Mapping[str,str]], num_examples:int, num_fewshots:int, num_self_consistency:int):
self.num_fewshots = num_fewshots
self.num_examples = num_examples
assert self.num_fewshots <= self.num_examples
self.num_self_consistency = num_self_consistency
assert self.num_self_consistency >= 1
self.sql_example_prompter.reload_few_shot_example(sql_example_ids, auto_cot_sql_example_units)
def add_examples(self, sql_example_ids: List[str], sql_example_units: List[Mapping[str,str]]):
new_sql_example_unit_list = transform_sql_example_autoCoT_run(sql_example_units, self.auto_cot_min_window_size, self.auto_cot_max_window_size)
self.sql_example_prompter.add_few_shot_example(sql_example_ids, new_sql_example_unit_list)
def update_examples(self, sql_example_ids: List[str], sql_example_units: List[Mapping[str,str]]):
new_sql_example_unit_list = transform_sql_example_autoCoT_run(sql_example_units, self.auto_cot_min_window_size, self.auto_cot_max_window_size)
self.sql_example_prompter.update_few_shot_example(sql_example_ids, new_sql_example_unit_list)
def delete_examples(self, sql_example_ids: List[str]):
self.sql_example_prompter.delete_few_shot_example(sql_example_ids)
def count_examples(self):
return self.sql_example_prompter.count_few_shot_example()
def get_examples(self, sql_example_ids: List[str]):
return self.sql_example_prompter.get_few_shot_example(sql_example_ids)
def generate_schema_linking_prompt(self, question: str, current_date:str, domain_name: str, fields_list: List[str],
prior_schema_links: Mapping[str,str], prior_exts:str, fewshot_example_list:List[Mapping[str, str]])-> str:
instruction = "# Find the schema_links for generating SQL queries for each question based on the database schema and Foreign keys."
schema_linking_example_keys = ["questionAugmented", "dbSchema", "generatedSchemaLinkingCoT"]
schema_linking_example_template = "{dbSchema}\nQ: {questionAugmented}\nA: {generatedSchemaLinkingCoT}"
schema_linking_fewshot_prompt = self.sql_example_prompter.make_few_shot_example_prompt(few_shot_template=schema_linking_example_template,
example_keys=schema_linking_example_keys,
few_shot_example_meta_list=fewshot_example_list)
question_augmented, db_schema, _ = transform_sql_example(question, current_date, domain_name, fields_list, prior_schema_links, prior_exts)
new_case_template = """{dbSchema}\nQ: {questionAugmented1}\nA: Lets think step by step. In the question "{questionAugmented2}", we are asked:"""
new_case_prompt = new_case_template.format(dbSchema=db_schema, questionAugmented1=question_augmented, questionAugmented2=question_augmented)
schema_linking_prompt = instruction + '\n\n' + schema_linking_fewshot_prompt + '\n\n' + new_case_prompt
logger.info(f'schema_linking_prompt: {schema_linking_prompt}')
return schema_linking_prompt
def generate_schema_linking_prompt_pool(self, question: str, current_date:str, domain_name: str, fields_list: List[str],
prior_schema_links: Mapping[str,str], prior_exts:str, fewshot_example_list_pool:List[List[Mapping[str, str]]])-> List[str]:
schema_linking_prompt_pool = []
for fewshot_example_list in fewshot_example_list_pool:
schema_linking_prompt = self.generate_schema_linking_prompt(question, current_date, domain_name, fields_list, prior_schema_links, prior_exts, fewshot_example_list)
schema_linking_prompt_pool.append(schema_linking_prompt)
return schema_linking_prompt_pool
def generate_sql_prompt(self, question: str, domain_name: str,fields_list: List[str],
schema_link_str: str, current_date: str, prior_schema_links: Mapping[str,str], prior_exts:str,
fewshot_example_list:List[Mapping[str, str]])-> str:
instruction = "# Use the the schema links to generate the SQL queries for each of the questions."
sql_example_keys = ["questionAugmented", "dbSchema", "generatedSchemaLinkings", "sql"]
sql_example_template = "{dbSchema}\nQ: {questionAugmented}\nSchema_links: {generatedSchemaLinkings}\nSQL: {sql}"
sql_example_fewshot_prompt = self.sql_example_prompter.make_few_shot_example_prompt(few_shot_template=sql_example_template,
example_keys=sql_example_keys,
few_shot_example_meta_list=fewshot_example_list)
question_augmented, db_schema, _ = transform_sql_example(question, current_date, domain_name, fields_list, prior_schema_links, prior_exts)
new_case_template = "{dbSchema}\nQ: {questionAugmented}\nSchema_links: {schemaLinkings}\nSQL: "
new_case_prompt = new_case_template.format(dbSchema=db_schema, questionAugmented=question_augmented, schemaLinkings=schema_link_str)
sql_example_prompt = instruction + '\n\n' + sql_example_fewshot_prompt + '\n\n' + new_case_prompt
logger.info(f'sql_example_prompt: {sql_example_prompt}')
return sql_example_prompt
def generate_sql_prompt_pool(self, question: str, domain_name: str,fields_list: List[str],
schema_link_str_pool: List[str], current_date: str, prior_schema_links: Mapping[str,str], prior_exts:str,
fewshot_example_list_pool:List[List[Mapping[str, str]]])-> List[str]:
sql_prompt_pool = []
for schema_link_str, fewshot_example_list in zip(schema_link_str_pool, fewshot_example_list_pool):
sql_prompt = self.generate_sql_prompt(question, domain_name, fields_list, schema_link_str, current_date, prior_schema_links, prior_exts, fewshot_example_list)
sql_prompt_pool.append(sql_prompt)
return sql_prompt_pool
def generate_schema_linking_sql_prompt(self, question: str, current_date:str, domain_name: str, fields_list: List[str],
prior_schema_links: Mapping[str,str], prior_exts:str, fewshot_example_list:List[Mapping[str, str]]):
instruction = "# Find the schema_links for generating SQL queries for each question based on the database schema and Foreign keys. Then use the the schema links to generate the SQL queries for each of the questions."
example_keys = ["questionAugmented", "dbSchema", "generatedSchemaLinkingCoT","sql"]
example_template = "{dbSchema}\nQ: {questionAugmented}\nA: {generatedSchemaLinkingCoT}\nSQL: {sql}"
fewshot_prompt = self.sql_example_prompter.make_few_shot_example_prompt(few_shot_template=example_template,
example_keys=example_keys,
few_shot_example_meta_list=fewshot_example_list)
question_augmented, db_schema, _ = transform_sql_example(question, current_date, domain_name, fields_list, prior_schema_links, prior_exts)
new_case_template = """{dbSchema}\nQ: {questionAugmented1}\nA: Lets think step by step. In the question "{questionAugmented2}", we are asked:"""
new_case_prompt = new_case_template.format(dbSchema=db_schema, questionAugmented1=question_augmented, questionAugmented2=question_augmented)
prompt = instruction + '\n\n' + fewshot_prompt + '\n\n' + new_case_prompt
logger.info(f'schema_linking_sql_prompt: {prompt}')
return prompt
def generate_schema_linking_sql_prompt_pool(self, question: str, current_date:str, domain_name: str, fields_list: List[str],
prior_schema_links: Mapping[str,str], prior_exts:str, fewshot_example_list_pool:List[List[Mapping[str, str]]])-> List[str]:
schema_linking_sql_prompt_pool = []
for fewshot_example_list in fewshot_example_list_pool:
schema_linking_sql_prompt = self.generate_schema_linking_sql_prompt(question, current_date, domain_name, fields_list, prior_schema_links, prior_exts, fewshot_example_list)
schema_linking_sql_prompt_pool.append(schema_linking_sql_prompt)
return schema_linking_sql_prompt_pool
async def async_query2sql(self, question: str, filter_condition: Mapping[str,str],
model_name: str, fields_list: List[str],
current_date: str, prior_schema_links: Mapping[str,str], prior_exts: str):
logger.info("question: {}".format(question))
logger.info("filter_condition: {}".format(filter_condition))
logger.info("model_name: {}".format(model_name))
logger.info("fields_list: {}".format(fields_list))
logger.info("current_date: {}".format(current_date))
logger.info("prior_schema_links: {}".format(prior_schema_links))
logger.info("prior_exts: {}".format(prior_exts))
fewshot_example_meta_list = self.get_examples_candidates(question, filter_condition, self.num_examples)
schema_linking_prompt = self.generate_schema_linking_prompt(question, current_date, model_name, fields_list, prior_schema_links, prior_exts, fewshot_example_meta_list)
logger.debug("schema_linking_prompt->{}".format(schema_linking_prompt))
schema_link_output = await self.llm._call_async(schema_linking_prompt)
logger.debug("schema_link_output->{}".format(schema_link_output))
schema_link_str = schema_link_parse(schema_link_output)
logger.debug("schema_link_str->{}".format(schema_link_str))
sql_prompt = self.generate_sql_prompt(question, model_name, fields_list, schema_link_str, current_date, prior_schema_links, prior_exts, fewshot_example_meta_list)
logger.debug("sql_prompt->{}".format(sql_prompt))
sql_output = await self.llm._call_async(sql_prompt)
resp = dict()
resp['question'] = question
resp['model'] = model_name
resp['fields'] = fields_list
resp['priorSchemaLinking'] = prior_schema_links
resp['priorExts'] = prior_exts
resp['currentDate'] = current_date
resp['prompt'] = [schema_linking_prompt+'\n\n'+sql_prompt]
resp['schemaLinkingOutput'] = schema_link_output
resp['schemaLinkStr'] = schema_link_str
resp['sqlOutput'] = sql_output
logger.info("resp: {}".format(resp))
return resp
async def async_query2sql_shortcut(self, question: str, filter_condition: Mapping[str,str],
model_name: str, fields_list: List[str],
current_date: str, prior_schema_links: Mapping[str,str], prior_exts: str):
logger.info("question: {}".format(question))
logger.info("filter_condition: {}".format(filter_condition))
logger.info("model_name: {}".format(model_name))
logger.info("fields_list: {}".format(fields_list))
logger.info("current_date: {}".format(current_date))
logger.info("prior_schema_links: {}".format(prior_schema_links))
logger.info("prior_exts: {}".format(prior_exts))
fewshot_example_meta_list = self.get_examples_candidates(question, filter_condition, self.num_examples)
schema_linking_sql_shortcut_prompt = self.generate_schema_linking_sql_prompt(question, current_date, model_name, fields_list, prior_schema_links, prior_exts, fewshot_example_meta_list)
logger.debug("schema_linking_sql_shortcut_prompt->{}".format(schema_linking_sql_shortcut_prompt))
schema_linking_sql_shortcut_output = await self.llm._call_async(schema_linking_sql_shortcut_prompt)
logger.debug("schema_linking_sql_shortcut_output->{}".format(schema_linking_sql_shortcut_output))
schema_linking_str = combo_schema_link_parse(schema_linking_sql_shortcut_output)
sql_str = combo_sql_parse(schema_linking_sql_shortcut_output)
resp = dict()
resp['question'] = question
resp['model'] = model_name
resp['fields'] = fields_list
resp['priorSchemaLinking'] = prior_schema_links
resp['priorExts'] = prior_exts
resp['currentDate'] = current_date
resp['prompt'] = [schema_linking_sql_shortcut_prompt]
resp['schemaLinkingComboOutput'] = schema_linking_sql_shortcut_output
resp['schemaLinkStr'] = schema_linking_str
resp['sqlOutput'] = sql_str
logger.info("resp: {}".format(resp))
return resp
async def generate_schema_linking_tasks(self, question: str, model_name: str, fields_list: List[str],
current_date: str, prior_schema_links: Mapping[str,str], prior_exts: str, fewshot_example_list_combo:List[List[Mapping[str, str]]]):
schema_linking_prompt_pool = self.generate_schema_linking_prompt_pool(question, current_date, model_name, fields_list, prior_schema_links, prior_exts, fewshot_example_list_combo)
logger.debug("schema_linking_prompt_pool->{}".format(schema_linking_prompt_pool))
schema_linking_output_pool = await asyncio.gather(*[self.llm._call_async(schema_linking_prompt) for schema_linking_prompt in schema_linking_prompt_pool])
logger.debug("schema_linking_output_pool->{}".format(schema_linking_output_pool))
schema_linking_str_pool = [schema_link_parse(schema_linking_output) for schema_linking_output in schema_linking_output_pool]
return schema_linking_str_pool, schema_linking_output_pool, schema_linking_prompt_pool
async def generate_sql_tasks(self, question: str, model_name: str, fields_list: List[str], schema_link_str_pool: List[str],
current_date: str, prior_schema_links: Mapping[str,str], prior_exts: str, fewshot_example_list_combo:List[List[Mapping[str, str]]]):
sql_prompt_pool = self.generate_sql_prompt_pool(question, model_name, fields_list, schema_link_str_pool, current_date, prior_schema_links, prior_exts, fewshot_example_list_combo)
logger.debug("sql_prompt_pool->{}".format(sql_prompt_pool))
sql_output_pool = await asyncio.gather(*[self.llm._call_async(sql_prompt) for sql_prompt in sql_prompt_pool])
logger.debug("sql_output_pool->{}".format(sql_output_pool))
return sql_output_pool, sql_prompt_pool
async def generate_schema_linking_sql_tasks(self, question: str, model_name: str, fields_list: List[str],
current_date: str, prior_schema_links: Mapping[str,str], prior_exts: str, fewshot_example_list_combo:List[List[Mapping[str, str]]]):
schema_linking_sql_prompt_pool = self.generate_schema_linking_sql_prompt_pool(question, current_date, model_name, fields_list, prior_schema_links, prior_exts, fewshot_example_list_combo)
schema_linking_sql_output_task_pool = [self.llm._call_async(schema_linking_sql_prompt) for schema_linking_sql_prompt in schema_linking_sql_prompt_pool]
schema_linking_sql_output_res_pool = await asyncio.gather(*schema_linking_sql_output_task_pool)
logger.debug("schema_linking_sql_output_res_pool->{}".format(schema_linking_sql_output_res_pool))
return schema_linking_sql_output_res_pool, schema_linking_sql_prompt_pool, schema_linking_sql_output_task_pool
async def tasks_run(self, question: str, filter_condition: Mapping[str,str],
model_name: str, fields_list: List[str],
current_date: str, prior_schema_links: Mapping[str,str], prior_exts: str):
logger.info("question: {}".format(question))
logger.info("filter_condition: {}".format(filter_condition))
logger.info("model_name: {}".format(model_name))
logger.info("fields_list: {}".format(fields_list))
logger.info("current_date: {}".format(current_date))
logger.info("prior_schema_links: {}".format(prior_schema_links))
logger.info("prior_exts: {}".format(prior_exts))
fewshot_example_meta_list = self.get_examples_candidates(question, filter_condition, self.num_examples)
fewshot_example_list_combo = self.get_fewshot_example_combos(fewshot_example_meta_list, self.num_fewshots)
schema_linking_candidate_list, _, schema_linking_prompt_list = await self.generate_schema_linking_tasks(question, model_name, fields_list, current_date, prior_schema_links, prior_exts, fewshot_example_list_combo)
logger.debug(f'schema_linking_candidate_list:{schema_linking_candidate_list}')
schema_linking_candidate_sorted_list = self.schema_linking_list_str_unify(schema_linking_candidate_list)
logger.debug(f'schema_linking_candidate_sorted_list:{schema_linking_candidate_sorted_list}')
schema_linking_output_max, schema_linking_output_vote_percentage = self.self_consistency_vote(schema_linking_candidate_sorted_list)
sql_output_candicates, sql_output_prompt_list = await self.generate_sql_tasks(question, model_name, fields_list, schema_linking_candidate_list, current_date, prior_schema_links, prior_exts, fewshot_example_list_combo)
logger.debug(f'sql_output_candicates:{sql_output_candicates}')
sql_output_max, sql_output_vote_percentage = self.self_consistency_vote(sql_output_candicates)
resp = dict()
resp['question'] = question
resp['model'] = model_name
resp['fields'] = fields_list
resp['priorSchemaLinking'] = prior_schema_links
resp['priorExts'] = prior_exts
resp['currentDate'] = current_date
resp['prompt'] = [schema_linking_prompt+'\n\n'+sql_prompt for schema_linking_prompt, sql_prompt in zip(schema_linking_prompt_list, sql_output_prompt_list)]
resp['schemaLinkStr'] = schema_linking_output_max
resp['schemaLinkingWeight'] = schema_linking_output_vote_percentage
resp['sqlOutput'] = sql_output_max
resp['sqlWeight'] = sql_output_vote_percentage
logger.info("resp: {}".format(resp))
return resp
async def tasks_run_shortcut(self, question: str, filter_condition: Mapping[str,str], model_name: str, fields_list: List[str],
current_date: str, prior_schema_links: Mapping[str,str], prior_exts: str):
logger.info("question: {}".format(question))
logger.info("filter_condition: {}".format(filter_condition))
logger.info("model_name: {}".format(model_name))
logger.info("fields_list: {}".format(fields_list))
logger.info("current_date: {}".format(current_date))
logger.info("prior_schema_links: {}".format(prior_schema_links))
logger.info("prior_exts: {}".format(prior_exts))
fewshot_example_meta_list = self.get_examples_candidates(question, filter_condition, self.num_examples)
fewshot_example_list_combo = self.get_fewshot_example_combos(fewshot_example_meta_list, self.num_fewshots)
schema_linking_sql_output_candidates, schema_linking_sql_prompt_list, _ = await self.generate_schema_linking_sql_tasks(question, model_name, fields_list, current_date, prior_schema_links, prior_exts, fewshot_example_list_combo)
logger.debug(f'schema_linking_sql_output_candidates:{schema_linking_sql_output_candidates}')
schema_linking_output_candidate_list = [combo_schema_link_parse(schema_linking_sql_output_candidate) for schema_linking_sql_output_candidate in schema_linking_sql_output_candidates]
logger.debug(f'schema_linking_sql_output_candidate_list:{schema_linking_output_candidate_list}')
schema_linking_output_candidate_sorted_list = self.schema_linking_list_str_unify(schema_linking_output_candidate_list)
schema_linking_output_max, schema_linking_output_vote_percentage = self.self_consistency_vote(schema_linking_output_candidate_sorted_list)
sql_output_candidate_list = [combo_sql_parse(schema_linking_sql_output_candidate) for schema_linking_sql_output_candidate in schema_linking_sql_output_candidates]
logger.debug(f'sql_output_candidate_list:{sql_output_candidate_list}')
sql_output_max, sql_output_vote_percentage = self.self_consistency_vote(sql_output_candidate_list)
resp = dict()
resp['question'] = question
resp['model'] = model_name
resp['fields'] = fields_list
resp['priorSchemaLinking'] = prior_schema_links
resp['priorExts'] = prior_exts
resp['currentDate'] = current_date
resp['prompt'] = schema_linking_sql_prompt_list
resp['schemaLinkStr'] = schema_linking_output_max
resp['schemaLinkingWeight'] = schema_linking_output_vote_percentage
resp['sqlOutput'] = sql_output_max
resp['sqlWeight'] = sql_output_vote_percentage
logger.info("resp: {}".format(resp))
return resp
class Text2DSLAgent(Text2DSLAgentBase):
def __init__(self, num_fewshots:int, num_examples:int, num_self_consistency:int,
sql_example_prompter:FewShotPromptTemplate2, llm: BaseLLM,) -> None:
super().__init__(num_fewshots, num_examples, num_self_consistency, sql_example_prompter, llm)
def reload_setting(self, sql_example_ids:List[str], sql_example_units: List[Mapping[str, str]], num_examples:int, num_fewshots:int, num_self_consistency:int):
self.num_fewshots = num_fewshots
self.num_examples = num_examples
assert self.num_fewshots <= self.num_examples
self.num_self_consistency = num_self_consistency
assert self.num_self_consistency >= 1
self.sql_example_prompter.reload_few_shot_example(sql_example_ids, sql_example_units)
def add_examples(self, sql_example_ids:List[str], sql_example_units: List[Mapping[str, str]]):
self.sql_example_prompter.add_few_shot_example(sql_example_ids, sql_example_units)
def update_examples(self, sql_example_ids:List[str], sql_example_units: List[Mapping[str, str]]):
self.sql_example_prompter.update_few_shot_example(sql_example_ids, sql_example_units)
def delete_examples(self, sql_example_ids:List[str]):
self.sql_example_prompter.delete_few_shot_example(sql_example_ids)
def get_examples(self, sql_example_ids: List[str]):
return self.sql_example_prompter.get_few_shot_example(sql_example_ids)
def count_examples(self):
return self.sql_example_prompter.count_few_shot_example()
def generate_schema_linking_prompt(self, question: str, domain_name: str, fields_list: List[str],
prior_schema_links: Mapping[str,str], fewshot_example_list:List[Mapping[str, str]])-> str:
prior_schema_links_str = '['+ ','.join(["""'{}'->{}""".format(k,v) for k,v in prior_schema_links.items()]) + ']'
instruction = "# 根据数据库的表结构,参考先验信息,找出为每个问题生成SQL查询语句的schema_links"
schema_linking_example_keys = ["tableName", "fieldsList", "priorSchemaLinks", "question", "analysis", "schemaLinks"]
schema_linking_example_template = "Table {tableName}, columns = {fieldsList}, prior_schema_links = {priorSchemaLinks}\n问题:{question}\n分析:{analysis} 所以Schema_links是:\nSchema_links:{schemaLinks}"
schema_linking_fewshot_prompt = self.sql_example_prompter.make_few_shot_example_prompt(few_shot_template=schema_linking_example_template,
example_keys=schema_linking_example_keys,
few_shot_example_meta_list=fewshot_example_list)
new_case_template = "Table {tableName}, columns = {fieldsList}, prior_schema_links = {priorSchemaLinks}\n问题:{question}\n分析: 让我们一步一步地思考。"
new_case_prompt = new_case_template.format(tableName=domain_name, fieldsList=fields_list, priorSchemaLinks=prior_schema_links_str, question=question)
schema_linking_prompt = instruction + '\n\n' + schema_linking_fewshot_prompt + '\n\n' + new_case_prompt
return schema_linking_prompt
def generate_schema_linking_prompt_pool(self, question: str, domain_name: str, fields_list: List[str],
prior_schema_links: Mapping[str,str], fewshot_example_list_pool:List[List[Mapping[str, str]]])-> List[str]:
schema_linking_prompt_pool = []
for fewshot_example_list in fewshot_example_list_pool:
schema_linking_prompt = self.generate_schema_linking_prompt(question, domain_name, fields_list, prior_schema_links, fewshot_example_list)
schema_linking_prompt_pool.append(schema_linking_prompt)
return schema_linking_prompt_pool
def generate_sql_prompt(self, question: str, domain_name: str,
schema_link_str: str, data_date: str,
fewshot_example_list:List[Mapping[str, str]])-> str:
instruction = "# 根据schema_links为每个问题生成SQL查询语句"
sql_example_keys = ["question", "currentDate", "tableName", "schemaLinks", "sql"]
sql_example_template = "问题:{question}\nCurrent_date:{currentDate}\nTable {tableName}\nSchema_links:{schemaLinks}\nSQL:{sql}"
sql_example_fewshot_prompt = self.sql_example_prompter.make_few_shot_example_prompt(few_shot_template=sql_example_template,
example_keys=sql_example_keys,
few_shot_example_meta_list=fewshot_example_list)
new_case_template = "问题:{question}\nCurrent_date:{currentDate}\nTable {tableName}\nSchema_links:{schemaLinks}\nSQL:"
new_case_prompt = new_case_template.format(question=question, currentDate=data_date, tableName=domain_name, schemaLinks=schema_link_str)
sql_example_prompt = instruction + '\n\n' + sql_example_fewshot_prompt + '\n\n' + new_case_prompt
return sql_example_prompt
def generate_sql_prompt_pool(self, question: str, domain_name: str, data_date: str,
schema_link_str_pool: List[str], fewshot_example_list_pool:List[List[Mapping[str, str]]])-> List[str]:
sql_prompt_pool = []
for schema_link_str, fewshot_example_list in zip(schema_link_str_pool, fewshot_example_list_pool):
sql_prompt = self.generate_sql_prompt(question, domain_name, schema_link_str, data_date, fewshot_example_list)
sql_prompt_pool.append(sql_prompt)
return sql_prompt_pool
def generate_schema_linking_sql_prompt(self, question: str,
domain_name: str,
data_date : str,
fields_list: List[str],
prior_schema_links: Mapping[str,str],
fewshot_example_list:List[Mapping[str, str]]):
prior_schema_links_str = '['+ ','.join(["""'{}'->{}""".format(k,v) for k,v in prior_schema_links.items()]) + ']'
instruction = "# 根据数据库的表结构,参考先验信息,找出为每个问题生成SQL查询语句的schema_links,再根据schema_links为每个问题生成SQL查询语句"
example_keys = ["tableName", "fieldsList", "priorSchemaLinks", "currentDate", "question", "analysis", "schemaLinks", "sql"]
example_template = "Table {tableName}, columns = {fieldsList}, prior_schema_links = {priorSchemaLinks}\nCurrent_date:{currentDate}\n问题:{question}\n分析:{analysis} 所以Schema_links是:\nSchema_links:{schemaLinks}\nSQL:{sql}"
fewshot_prompt = self.sql_example_prompter.make_few_shot_example_prompt(few_shot_template=example_template,
example_keys=example_keys,
few_shot_example_meta_list=fewshot_example_list)
new_case_template = "Table {tableName}, columns = {fieldsList}, prior_schema_links = {priorSchemaLinks}\nCurrent_date:{currentDate}\n问题:{question}\n分析: 让我们一步一步地思考。"
new_case_prompt = new_case_template.format(tableName=domain_name, fieldsList=fields_list, priorSchemaLinks=prior_schema_links_str, currentDate=data_date, question=question)
prompt = instruction + '\n\n' + fewshot_prompt + '\n\n' + new_case_prompt
return prompt
def generate_schema_linking_sql_prompt_pool(self, question: str, domain_name: str, fields_list: List[str], data_date: str,
prior_schema_links: Mapping[str,str], fewshot_example_list_pool:List[List[Mapping[str, str]]])-> List[str]:
schema_linking_sql_prompt_pool = []
for fewshot_example_list in fewshot_example_list_pool:
schema_linking_sql_prompt = self.generate_schema_linking_sql_prompt(question, domain_name, data_date, fields_list, prior_schema_links, fewshot_example_list)
schema_linking_sql_prompt_pool.append(schema_linking_sql_prompt)
return schema_linking_sql_prompt_pool
def self_consistency_vote(self, output_res_pool:List[str]):
output_res_counts = Counter(output_res_pool)
output_res_max = output_res_counts.most_common(1)[0][0]
total_output_num = len(output_res_pool)
vote_percentage = {k: (v/total_output_num) for k,v in output_res_counts.items()}
return output_res_max, vote_percentage
def schema_linking_list_str_unify(self, schema_linking_list: List[str])-> List[str]:
schema_linking_list_unify = []
for schema_linking_str in schema_linking_list:
schema_linking_str_unify = ','.join(sorted([item.strip() for item in schema_linking_str.strip('[]').split(',')]))
schema_linking_str_unify = f'[{schema_linking_str_unify}]'
schema_linking_list_unify.append(schema_linking_str_unify)
return schema_linking_list_unify
async def generate_schema_linking_tasks(self, question: str, domain_name: str,
fields_list: List[str], prior_schema_links: Mapping[str,str],
fewshot_example_list_combo:List[List[Mapping[str, str]]]):
schema_linking_prompt_pool = self.generate_schema_linking_prompt_pool(question, domain_name,
fields_list, prior_schema_links,
fewshot_example_list_combo)
schema_linking_output_task_pool = [self.llm._call_async(schema_linking_prompt) for schema_linking_prompt in schema_linking_prompt_pool]
schema_linking_output_pool = await asyncio.gather(*schema_linking_output_task_pool)
logger.debug(f'schema_linking_output_pool:{schema_linking_output_pool}')
schema_linking_str_pool = [schema_link_parse(schema_linking_output) for schema_linking_output in schema_linking_output_pool]
return schema_linking_str_pool
async def generate_sql_tasks(self, question: str, domain_name: str, data_date: str,
schema_link_str_pool: List[str], fewshot_example_list_combo:List[List[Mapping[str, str]]]):
sql_prompt_pool = self.generate_sql_prompt_pool(question, domain_name, schema_link_str_pool, data_date, fewshot_example_list_combo)
sql_output_task_pool = [self.llm._call_async(sql_prompt) for sql_prompt in sql_prompt_pool]
sql_output_res_pool = await asyncio.gather(*sql_output_task_pool)
logger.debug(f'sql_output_res_pool:{sql_output_res_pool}')
return sql_output_res_pool
async def generate_schema_linking_sql_tasks(self, question: str, domain_name: str, fields_list: List[str], data_date: str,
prior_schema_links: Mapping[str,str], fewshot_example_list_combo:List[List[Mapping[str, str]]]):
schema_linking_sql_prompt_pool = self.generate_schema_linking_sql_prompt_pool(question, domain_name, fields_list, data_date, prior_schema_links, fewshot_example_list_combo)
schema_linking_sql_output_task_pool = [self.llm._call_async(schema_linking_sql_prompt) for schema_linking_sql_prompt in schema_linking_sql_prompt_pool]
schema_linking_sql_output_res_pool = await asyncio.gather(*schema_linking_sql_output_task_pool)
logger.debug(f'schema_linking_sql_output_res_pool:{schema_linking_sql_output_res_pool}')
return schema_linking_sql_output_res_pool
async def tasks_run(self, question: str, filter_condition: Mapping[str, str], domain_name: str, fields_list: List[str], prior_schema_links: Mapping[str,str], data_date: str, prior_exts: str):
logger.info("question: {}".format(question))
logger.info("domain_name: {}".format(domain_name))
logger.info("fields_list: {}".format(fields_list))
logger.info("current_date: {}".format(data_date))
logger.info("prior_schema_links: {}".format(prior_schema_links))
logger.info("prior_exts: {}".format(prior_exts))
if prior_exts != '':
question = question + ' 备注:'+prior_exts
logger.info("question_prior_exts: {}".format(question))
fewshot_example_meta_list = self.get_examples_candidates(question, filter_condition, self.num_examples)
fewshot_example_list_combo = self.get_fewshot_example_combos(fewshot_example_meta_list, self.num_fewshots)
schema_linking_candidate_list = await self.generate_schema_linking_tasks(question, domain_name, fields_list, prior_schema_links, fewshot_example_list_combo)
logger.debug(f'schema_linking_candidate_list:{schema_linking_candidate_list}')
schema_linking_candidate_sorted_list = self.schema_linking_list_str_unify(schema_linking_candidate_list)
logger.debug(f'schema_linking_candidate_sorted_list:{schema_linking_candidate_sorted_list}')
schema_linking_output_max, schema_linking_output_vote_percentage = self.self_consistency_vote(schema_linking_candidate_sorted_list)
sql_output_candicates = await self.generate_sql_tasks(question, domain_name, data_date, schema_linking_candidate_list,fewshot_example_list_combo)
logger.debug(f'sql_output_candicates:{sql_output_candicates}')
sql_output_max, sql_output_vote_percentage = self.self_consistency_vote(sql_output_candicates)
resp = dict()
resp['question'] = question
resp['model'] = domain_name
resp['fields'] = fields_list
resp['priorSchemaLinking'] = prior_schema_links
resp['dataDate'] = data_date
resp['schemaLinkStr'] = schema_linking_output_max
resp['schemaLinkingWeight'] = schema_linking_output_vote_percentage
resp['sqlOutput'] = sql_output_max
resp['sqlWeight'] = sql_output_vote_percentage
logger.info("resp: {}".format(resp))
return resp
async def tasks_run_shortcut(self, question: str, filter_condition: Mapping[str, str], domain_name: str, fields_list: List[str], prior_schema_links: Mapping[str,str], data_date: str, prior_exts: str):
logger.info("question: {}".format(question))
logger.info("domain_name: {}".format(domain_name))
logger.info("fields_list: {}".format(fields_list))
logger.info("current_date: {}".format(data_date))
logger.info("prior_schema_links: {}".format(prior_schema_links))
logger.info("prior_exts: {}".format(prior_exts))
if prior_exts != '':
question = question + ' 备注:'+prior_exts
logger.info("question_prior_exts: {}".format(question))
fewshot_example_meta_list = self.get_examples_candidates(question, filter_condition, self.num_examples)
fewshot_example_list_combo = self.get_fewshot_example_combos(fewshot_example_meta_list, self.num_fewshots)
schema_linking_sql_output_candidates = await self.generate_schema_linking_sql_tasks(question, domain_name, fields_list, data_date, prior_schema_links, fewshot_example_list_combo)
logger.debug(f'schema_linking_sql_output_candidates:{schema_linking_sql_output_candidates}')
schema_linking_output_candidate_list = [combo_schema_link_parse(schema_linking_sql_output_candidate) for schema_linking_sql_output_candidate in schema_linking_sql_output_candidates]
logger.debug(f'schema_linking_sql_output_candidate_list:{schema_linking_output_candidate_list}')
schema_linking_output_candidate_sorted_list = self.schema_linking_list_str_unify(schema_linking_output_candidate_list)
schema_linking_output_max, schema_linking_output_vote_percentage = self.self_consistency_vote(schema_linking_output_candidate_sorted_list)
sql_output_candidate_list = [combo_sql_parse(schema_linking_sql_output_candidate) for schema_linking_sql_output_candidate in schema_linking_sql_output_candidates]
logger.debug(f'sql_output_candidate_list:{sql_output_candidate_list}')
sql_output_max, sql_output_vote_percentage = self.self_consistency_vote(sql_output_candidate_list)
resp = dict()
resp['question'] = question
resp['model'] = domain_name
resp['fields'] = fields_list
resp['priorSchemaLinking'] = prior_schema_links
resp['dataDate'] = data_date
resp['schemaLinkStr'] = schema_linking_output_max
resp['schemaLinkingWeight'] = schema_linking_output_vote_percentage
resp['sqlOutput'] = sql_output_max
resp['sqlWeight'] = sql_output_vote_percentage
logger.info("resp: {}".format(resp))
return resp
async def async_query2sql(self, question: str, filter_condition: Mapping[str,str],
model_name: str, fields_list: List[str],
data_date: str, prior_schema_links: Mapping[str,str], prior_exts: str):
logger.info("question: {}".format(question))
logger.info("model_name: {}".format(model_name))
logger.info("fields_list: {}".format(fields_list))
logger.info("data_date: {}".format(data_date))
logger.info("prior_schema_links: {}".format(prior_schema_links))
logger.info("prior_exts: {}".format(prior_exts))
if prior_exts != '':
question = question + ' 备注:'+prior_exts
logger.info("question_prior_exts: {}".format(question))
fewshot_example_meta_list = self.get_examples_candidates(question, filter_condition, self.num_examples)
schema_linking_prompt = self.generate_schema_linking_prompt(question, model_name, fields_list, prior_schema_links, fewshot_example_meta_list)
logger.debug("schema_linking_prompt->{}".format(schema_linking_prompt))
schema_link_output = await self.llm._call_async(schema_linking_prompt)
schema_link_str = schema_link_parse(schema_link_output)
sql_prompt = self.generate_sql_prompt(question, model_name, schema_link_str, data_date, fewshot_example_meta_list)
logger.debug("sql_prompt->{}".format(sql_prompt))
sql_output = await self.llm._call_async(sql_prompt)
resp = dict()
resp['question'] = question
resp['model'] = model_name
resp['fields'] = fields_list
resp['priorSchemaLinking'] = prior_schema_links
resp['dataDate'] = data_date
resp['schemaLinkingOutput'] = schema_link_output
resp['schemaLinkStr'] = schema_link_str
resp['sqlOutput'] = sql_output
logger.info("resp: {}".format(resp))
return resp
async def async_query2sql_shortcut(self, question: str, filter_condition: Mapping[str,str],
model_name: str, fields_list: List[str],
data_date: str, prior_schema_links: Mapping[str,str], prior_exts: str):
logger.info("question: {}".format(question))
logger.info("model_name: {}".format(model_name))
logger.info("fields_list: {}".format(fields_list))
logger.info("data_date: {}".format(data_date))
logger.info("prior_schema_links: {}".format(prior_schema_links))
logger.info("prior_exts: {}".format(prior_exts))
if prior_exts != '':
question = question + ' 备注:'+prior_exts
logger.info("question_prior_exts: {}".format(question))
fewshot_example_meta_list = self.get_examples_candidates(question, filter_condition, self.num_examples)
schema_linking_sql_shortcut_prompt = self.generate_schema_linking_sql_prompt(question, model_name, data_date, fields_list, prior_schema_links, fewshot_example_meta_list)
logger.debug("schema_linking_sql_shortcut_prompt->{}".format(schema_linking_sql_shortcut_prompt))
schema_linking_sql_shortcut_output = await self.llm._call_async(schema_linking_sql_shortcut_prompt)
schema_linking_str = combo_schema_link_parse(schema_linking_sql_shortcut_output)
sql_str = combo_sql_parse(schema_linking_sql_shortcut_output)
resp = dict()
resp['question'] = question
resp['model'] = model_name
resp['fields'] = fields_list
resp['priorSchemaLinking'] = prior_schema_links
resp['dataDate'] = data_date
resp['schemaLinkingComboOutput'] = schema_linking_sql_shortcut_output
resp['schemaLinkStr'] = schema_linking_str
resp['sqlOutput'] = sql_str
logger.info("resp: {}".format(resp))
return resp
class SqlModeEnum(Enum):
VALUE5 = '1_pass_auto_cot'
VALUE6 = '1_pass_auto_cot_self_consistency'
VALUE7 = '2_pass_auto_cot'
VALUE8 = '2_pass_auto_cot_self_consistency'
class Text2DSLAgentWrapper(object):
def __init__(self, sql_agent_act:Text2DSLAgentAutoCoT):
self.sql_agent_act = sql_agent_act
async def async_query2sql(self, question: str, filter_condition: Mapping[str,str],
model_name: str, fields_list: List[str],
data_date: str, prior_schema_links: Mapping[str,str], prior_exts: str, sql_generation_mode: str):
if sql_generation_mode not in (sql_mode.value for sql_mode in SqlModeEnum):
raise ValueError(f"sql_generation_mode: {sql_generation_mode} is not in SqlModeEnum")
if sql_generation_mode == '1_pass_auto_cot':
logger.info(f"sql wrapper: {sql_generation_mode}")
resp = await self.sql_agent_act.async_query2sql_shortcut(question=question, filter_condition=filter_condition, model_name=model_name, fields_list=fields_list, current_date=data_date, prior_schema_links=prior_schema_links, prior_exts=prior_exts)
return resp
elif sql_generation_mode == '1_pass_auto_cot_self_consistency':
logger.info(f"sql wrapper: {sql_generation_mode}")
resp = await self.sql_agent_act.tasks_run_shortcut(question=question, filter_condition=filter_condition, model_name=model_name, fields_list=fields_list, current_date=data_date, prior_schema_links=prior_schema_links, prior_exts=prior_exts)
return resp
elif sql_generation_mode == '2_pass_auto_cot':
logger.info(f"sql wrapper: {sql_generation_mode}")
resp = await self.sql_agent_act.async_query2sql(question=question, filter_condition=filter_condition, model_name=model_name, fields_list=fields_list, current_date=data_date, prior_schema_links=prior_schema_links, prior_exts=prior_exts)
return resp
elif sql_generation_mode == '2_pass_auto_cot_self_consistency':
logger.info(f"sql wrapper: {sql_generation_mode}")
resp = await self.sql_agent_act.tasks_run(question=question, filter_condition=filter_condition, model_name=model_name, fields_list=fields_list, current_date=data_date, prior_schema_links=prior_schema_links, prior_exts=prior_exts)
return resp
else:
raise ValueError(f'sql_generation_mode:{sql_generation_mode} is not in SqlModeEnum')
def update_configs(self, sql_example_ids:List[str], sql_example_units: List[Mapping[str, str]],
num_examples: int, num_fewshots: int, num_self_consistency: int):
self.sql_agent_act.reload_setting(sql_example_ids=sql_example_ids, sql_example_units=sql_example_units, num_examples=num_examples, num_fewshots=num_fewshots, num_self_consistency=num_self_consistency)
def add_examples(self, sql_example_ids:List[str], sql_example_units: List[Mapping[str, str]]):
self.sql_agent_act.add_examples(sql_example_ids=sql_example_ids, sql_example_units=sql_example_units)
def update_examples(self, sql_example_ids:List[str], sql_example_units: List[Mapping[str, str]]):
self.sql_agent_act.update_examples(sql_example_ids=sql_example_ids, sql_example_units=sql_example_units)
def delete_examples(self, sql_example_ids:List[str]):
self.sql_agent_act.delete_examples(sql_example_ids=sql_example_ids)
def get_examples(self, sql_example_ids: List[str]):
sql_agent_act_examples = self.sql_agent_act.get_examples(sql_example_ids=sql_example_ids)
return sql_agent_act_examples
def count_examples(self):
sql_agent_examples_act_cnt = self.sql_agent_act.count_examples()
return sql_agent_examples_act_cnt