(fix)(chat):fix typo in s2sql and add prompt to output. (#581)

* 1.refactor the retrieval module. 2.refactor the http service module. 3.upgrade text2sql output format the parse for absolute time related expression in query.

* fix bug.

* upgrade the config module, now support config llm suppoted by langchain.

* fix conflicts.

* update text2sql config reload to be compitable with new config format.

* modify default config.

* 1.add self-consistency feature for text2sql. 2.upgrade llm api call from sync to async. 3.refactor text2sql module. 4. refactor semantical retriever modules.

* merege with upstream master

* add general retrieve service.

* add api service for sql_agent for crud opereations of few-shots examples.

* modify requirements

* add auto-cot feature

* 1. output log to a fixed log file.  2.allow few-shots examples tied to data model, and add strategy that extend examples when retrieved examples tied to a data model is not enough. 3. fix misformat in s2ql args.

* add prior_ext to output.

* fix type in in s2sql

* add prompt to output.

---------

Co-authored-by: shaoweigong <shaoweigong@tencent.com>
This commit is contained in:
codescracker
2024-01-02 16:35:33 +08:00
committed by GitHub
parent 49f0a4dc1d
commit af1c560cc4
11 changed files with 27 additions and 19 deletions

View File

@@ -0,0 +1,167 @@
# -*- coding:utf-8 -*-
from typing import Any, List, Mapping, Optional, Union, Tuple
import os
import sys
sys.path.append(os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))))
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
sys.path.append(os.path.dirname(os.path.abspath(__file__)))
from instances.logging_instance import logger
from instances.text2vec_instance import emb_func
from sqlglot import parse_one, exp
import numpy as np
def sql2schema_linking(sql: str):
sql_ast = parse_one(sql)
fields_raw = []
table_alias_map = dict()
literals = []
fields = []
for literal in sql_ast.find_all(exp.Literal):
literals.append(literal.output_name)
for column in sql_ast.find_all(exp.Column):
fields_raw.append({
'column_table_alias': column.table,
'column_name': column.name,
})
for table in sql_ast.find_all(exp.Table):
if table.alias not in table_alias_map:
table_alias_map[table.alias] = table.name
logger.debug(f'literals: {literals}')
logger.debug(f'fields_raw: {fields_raw}')
logger.debug(f'table_alias_map: {table_alias_map}')
for field in fields_raw:
column_table_alias = field['column_table_alias']
column_name = field['column_name']
if column_table_alias.strip() == '':
column_table = ''
fields.append((column_table, column_name))
elif column_table_alias in table_alias_map:
column_table = table_alias_map[column_table_alias]
fields.append((column_table, column_name))
elif column_table_alias in table_alias_map.values():
column_table = column_table_alias
fields.append((column_table, column_name))
else:
logger.error(f'column_table_alias: {column_table_alias} not in table_alias_map: {table_alias_map}')
raise Exception(f'column_table_alias: {column_table_alias} not in table_alias_map: {table_alias_map}')
return {
'fields': list(set(fields)),
'literals': literals
}
def get_question_slices(question: str, min_window_size: int, max_window_size: int):
assert min_window_size <= max_window_size
assert min_window_size > 1
assert max_window_size < len(question)+1
question_slices = []
for i in range(len(question)):
for j in range(i+1, len(question)+1):
if j-i >= min_window_size and j-i <= max_window_size:
question_slices.append(question[i:j])
return question_slices
def schema_linking_match(fields: List[Tuple[str,str]], question: str, min_window_size: int, max_window_size: int):
question_slices = get_question_slices(question, min_window_size, max_window_size)
assert len(question_slices) > 0
logger.debug('question_slices_len:{}'.format(len(question_slices)))
logger.debug(f'question_slices: {question_slices}')
question_slices_embeddings = emb_func(question_slices)
fields_embeddings = emb_func([field[1] for field in fields])
fields_embeddings = np.array(fields_embeddings) # (n_fields, 768)
question_slices_embeddings = np.array(question_slices_embeddings) # (n_question_slices, 768)
question_slices_embeddings_norm = question_slices_embeddings / np.linalg.norm(question_slices_embeddings, axis=1, keepdims=True) # (n_question_slices, 768)
question_slices_embeddings_norm_transpose = question_slices_embeddings_norm.T # (768, n_question_slices)
if len(fields) > 0:
fields_embeddings_norm = fields_embeddings / np.linalg.norm(fields_embeddings, axis=1, keepdims=True) # (n_fields, 768)
fields_question_slices_similarity = np.matmul(fields_embeddings_norm, question_slices_embeddings_norm_transpose) # (n_fields, n_question_slices)
logger.debug('fields_question_slices_similarity_max:{}'.format(np.max(fields_question_slices_similarity, axis=1)))
fields_question_slices_argmax = np.argmax(fields_question_slices_similarity, axis=1) # (n_fields, )
logger.debug('fields_question_slices_argmax:{}'.format(fields_question_slices_argmax))
fields_question_slices_pair = []
for i in range(len(fields)):
if fields[i][0]!="":
fields_question_slices_pair.append((fields[i][0]+'.'+fields[i][1], question_slices[fields_question_slices_argmax[i]]))
else:
fields_question_slices_pair.append((fields[i][1], question_slices[fields_question_slices_argmax[i]]))
logger.debug(f'fields_question_slices_pair: {fields_question_slices_pair}')
else:
fields_question_slices_pair = []
return fields_question_slices_pair
def construct_schema_linking_cot(question:str, fields_question_slices_pair:List[Tuple[str,str]], literals_list:List[str]):
cot_intro= """Lets think step by step. In the question "{question}", we are asked:""".format(question=question)
schema_linkings_list = []
fields_cot_template = """"{question_slice}" so we need column = [{field}]"""
fields_cot_list = []
for field, question_slice in fields_question_slices_pair:
fields_cot_list.append(fields_cot_template.format(question_slice=question_slice, field=field))
schema_linkings_list.append(field)
fields_cot = '\n'.join(fields_cot_list)
literals_cot_template = """Based on the tables, columns, and Foreign_keys, The set of possible cell values are = [{literals}]. So the Schema_links are:"""
literals_cot = literals_cot_template.format(literals=",".join(literals_list))
schema_linkings_list += literals_list
schema_linking_str = '[' + ",".join(schema_linkings_list) + ']'
schema_linkings = 'Schema_links: '+ schema_linking_str
cot = """{cot_intro}""".format(cot_intro=cot_intro)
if len(fields_cot_list) > 0:
cot += '\n' + fields_cot
cot += '\n' + literals_cot
cot += '\n' + schema_linkings
return cot, schema_linking_str
def auto_cot_run(question, sql, min_window_size, max_window_size):
sql_entity = sql2schema_linking(sql)
logger.debug(f'sql_entity: {sql_entity}')
fields = sql_entity['fields']
literals = sql_entity['literals']
field_linked_pairs = schema_linking_match(fields, question, min_window_size, max_window_size)
logger.debug(f'field_linked_pairs: {field_linked_pairs}')
auto_schema_linking_cot, auto_schema_linkings = construct_schema_linking_cot(question, field_linked_pairs, literals)
logger.debug(f'auto_schema_linking_cot: {auto_schema_linking_cot}')
logger.debug(f'auto_schema_linkings: {auto_schema_linkings}')
return auto_schema_linking_cot, auto_schema_linkings
if __name__ == '__main__':
question = "没有获得过奖项的高校有哪几所?"
sql = "select 名称 from 高校 where 词条id not in ( select 高校id from 奖项 )"
min_window_size = 6
max_window_size = 10
generated_schema_linking_cot, generated_schema_linkings = auto_cot_run(question, sql, min_window_size, max_window_size)