(improvement)(text2sql) add text2sql feature that only call LLM once, and add correponding configs and docs. (#102)

Co-authored-by: shaoweigong <shaoweigong@tencent.com>
This commit is contained in:
codescracker
2023-09-20 10:22:20 +08:00
committed by GitHub
parent 6a5a95e543
commit c8ff37e304
9 changed files with 173 additions and 55 deletions

View File

@@ -15,6 +15,7 @@ CHROMA_DB_PERSIST_DIR = 'chm_db'
PRESET_QUERY_COLLECTION_NAME = "preset_query_collection" PRESET_QUERY_COLLECTION_NAME = "preset_query_collection"
TEXT2DSL_COLLECTION_NAME = "text2dsl_collection" TEXT2DSL_COLLECTION_NAME = "text2dsl_collection"
TEXT2DSL_FEW_SHOTS_EXAMPLE_NUM = 15 TEXT2DSL_FEW_SHOTS_EXAMPLE_NUM = 15
TEXT2DSL_IS_SHORTCUT = False
CHROMA_DB_PERSIST_PATH = os.path.join(PROJECT_DIR_PATH, CHROMA_DB_PERSIST_DIR) CHROMA_DB_PERSIST_PATH = os.path.join(PROJECT_DIR_PATH, CHROMA_DB_PERSIST_DIR)

View File

@@ -22,10 +22,8 @@ from util.text2vec import Text2VecEmbeddingFunction, hg_embedding
from util.chromadb_instance import client as chromadb_client, empty_chroma_collection_2 from util.chromadb_instance import client as chromadb_client, empty_chroma_collection_2
from run_config import TEXT2DSL_COLLECTION_NAME, TEXT2DSL_FEW_SHOTS_EXAMPLE_NUM from run_config import TEXT2DSL_COLLECTION_NAME, TEXT2DSL_FEW_SHOTS_EXAMPLE_NUM
def reload_sql_example_collection(vectorstore:Chroma, def reload_sql_example_collection(vectorstore:Chroma,
sql_examplars:List[Mapping[str, str]], sql_examplars:List[Mapping[str, str]],
schema_linking_example_selector:SemanticSimilarityExampleSelector,
sql_example_selector:SemanticSimilarityExampleSelector, sql_example_selector:SemanticSimilarityExampleSelector,
example_nums:int example_nums:int
): ):
@@ -35,20 +33,16 @@ def reload_sql_example_collection(vectorstore:Chroma,
print("emptied sql_examples_collection size:", vectorstore._collection.count()) print("emptied sql_examples_collection size:", vectorstore._collection.count())
schema_linking_example_selector = SemanticSimilarityExampleSelector(vectorstore=sql_examples_vectorstore, k=example_nums,
input_keys=["question"],
example_keys=["table_name", "fields_list", "prior_schema_links", "question", "analysis", "schema_links"])
sql_example_selector = SemanticSimilarityExampleSelector(vectorstore=sql_examples_vectorstore, k=example_nums, sql_example_selector = SemanticSimilarityExampleSelector(vectorstore=sql_examples_vectorstore, k=example_nums,
input_keys=["question"], input_keys=["question"],
example_keys=["question", "current_date", "table_name", "schema_links", "sql"]) example_keys=["table_name", "fields_list", "prior_schema_links", "question", "analysis", "schema_links", "current_date", "sql"])
for example in sql_examplars: for example in sql_examplars:
schema_linking_example_selector.add_example(example) sql_example_selector.add_example(example)
print("reloaded sql_examples_collection size:", vectorstore._collection.count()) print("reloaded sql_examples_collection size:", vectorstore._collection.count())
return vectorstore, schema_linking_example_selector, sql_example_selector return vectorstore, sql_example_selector
sql_examples_vectorstore = Chroma(collection_name=TEXT2DSL_COLLECTION_NAME, sql_examples_vectorstore = Chroma(collection_name=TEXT2DSL_COLLECTION_NAME,
@@ -57,22 +51,14 @@ sql_examples_vectorstore = Chroma(collection_name=TEXT2DSL_COLLECTION_NAME,
example_nums = TEXT2DSL_FEW_SHOTS_EXAMPLE_NUM example_nums = TEXT2DSL_FEW_SHOTS_EXAMPLE_NUM
schema_linking_example_selector = SemanticSimilarityExampleSelector(vectorstore=sql_examples_vectorstore, k=example_nums,
input_keys=["question"],
example_keys=["table_name", "fields_list", "prior_schema_links", "question", "analysis", "schema_links"])
sql_example_selector = SemanticSimilarityExampleSelector(vectorstore=sql_examples_vectorstore, k=example_nums, sql_example_selector = SemanticSimilarityExampleSelector(vectorstore=sql_examples_vectorstore, k=example_nums,
input_keys=["question"], input_keys=["question"],
example_keys=["question", "current_date", "table_name", "schema_links", "sql"]) example_keys=["table_name", "fields_list", "prior_schema_links", "question", "analysis", "schema_links", "current_date", "sql"])
if sql_examples_vectorstore._collection.count() > 0: if sql_examples_vectorstore._collection.count() > 0:
print("examples already in sql_vectorstore") print("examples already in sql_vectorstore")
print("init sql_vectorstore size:", sql_examples_vectorstore._collection.count()) print("init sql_vectorstore size:", sql_examples_vectorstore._collection.count())
if sql_examples_vectorstore._collection.count() < len(sql_examplars):
print("sql_examplars size:", len(sql_examplars))
sql_examples_vectorstore, schema_linking_example_selector, sql_example_selector = reload_sql_example_collection(sql_examples_vectorstore, sql_examplars, schema_linking_example_selector, sql_example_selector, example_nums)
print("added sql_vectorstore size:", sql_examples_vectorstore._collection.count())
else:
sql_examples_vectorstore, schema_linking_example_selector, sql_example_selector = reload_sql_example_collection(sql_examples_vectorstore, sql_examplars, schema_linking_example_selector, sql_example_selector, example_nums)
print("added sql_vectorstore size:", sql_examples_vectorstore._collection.count())
print("sql_examplars size:", len(sql_examplars))
sql_examples_vectorstore, sql_example_selector = reload_sql_example_collection(sql_examples_vectorstore, sql_examplars, sql_example_selector, example_nums)
print("added sql_vectorstore size:", sql_examples_vectorstore._collection.count())

View File

@@ -8,24 +8,22 @@ import json
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
sys.path.append(os.path.dirname(os.path.abspath(__file__))) sys.path.append(os.path.dirname(os.path.abspath(__file__)))
from run_config import TEXT2DSL_FEW_SHOTS_EXAMPLE_NUM from run_config import TEXT2DSL_FEW_SHOTS_EXAMPLE_NUM, TEXT2DSL_IS_SHORTCUT
from few_shot_example.sql_exampler import examplars as sql_examplars from few_shot_example.sql_exampler import examplars as sql_examplars
from run_config import LLMPARSER_HOST from run_config import LLMPARSER_HOST, LLMPARSER_PORT
from run_config import LLMPARSER_PORT
def text2dsl_setting_update(llm_parser_host:str, llm_parser_port:str, def text2dsl_setting_update(llm_parser_host:str, llm_parser_port:str,
sql_examplars:List[Mapping[str, str]], example_nums:int): sql_examplars:List[Mapping[str, str]], example_nums:int, is_shortcut:bool):
url = f"http://{llm_parser_host}:{llm_parser_port}/query2sql_setting_update/" url = f"http://{llm_parser_host}:{llm_parser_port}/query2sql_setting_update/"
print("url: ", url) print("url: ", url)
payload = {"sqlExamplars":sql_examplars, "exampleNums":example_nums} payload = {"sqlExamplars":sql_examplars, "exampleNums":example_nums, "isShortcut":is_shortcut}
headers = {'content-type': 'application/json'} headers = {'content-type': 'application/json'}
response = requests.post(url, data=json.dumps(payload), headers=headers) response = requests.post(url, data=json.dumps(payload), headers=headers)
print(response.text) print(response.text)
if __name__ == "__main__": if __name__ == "__main__":
arguments = sys.argv
text2dsl_setting_update(LLMPARSER_HOST, LLMPARSER_PORT, text2dsl_setting_update(LLMPARSER_HOST, LLMPARSER_PORT,
sql_examplars, TEXT2DSL_FEW_SHOTS_EXAMPLE_NUM) sql_examplars, TEXT2DSL_FEW_SHOTS_EXAMPLE_NUM, TEXT2DSL_IS_SHORTCUT)

View File

@@ -10,4 +10,36 @@ def schema_link_parse(schema_link_output):
print(e) print(e)
schema_link_output = None schema_link_output = None
return schema_link_output return schema_link_output
def combo_schema_link_parse(schema_linking_sql_combo_output: str):
try:
schema_linking_sql_combo_output = schema_linking_sql_combo_output.strip()
pattern = r'Schema_links:(\[.*?\])'
schema_links_match = re.search(pattern, schema_linking_sql_combo_output)
if schema_links_match:
schema_links = schema_links_match.group(1)
else:
schema_links = None
except Exception as e:
print(e)
schema_links = None
return schema_links
def combo_sql_parse(schema_linking_sql_combo_output: str):
try:
schema_linking_sql_combo_output = schema_linking_sql_combo_output.strip()
pattern = r'SQL:(.*)'
sql_match = re.search(pattern, schema_linking_sql_combo_output)
if sql_match:
sql = sql_match.group(1)
else:
sql = None
except Exception as e:
print(e)
sql = None
return sql

View File

@@ -73,3 +73,38 @@ def sql_exampler(user_query: str,
schema_links=schema_link_str) schema_links=schema_link_str)
return sql_example_prompt return sql_example_prompt
def schema_linking_sql_combo_examplar(user_query: str,
domain_name: str,
data_date : str,
fields_list: List[str],
prior_schema_links: Mapping[str,str],
example_selector: SemanticSimilarityExampleSelector) -> str:
prior_schema_links_str = '['+ ','.join(["""'{}'->{}""".format(k,v) for k,v in prior_schema_links.items()]) + ']'
example_prompt_template = PromptTemplate(input_variables=["table_name", "fields_list", "prior_schema_links", "current_date", "question", "analysis", "schema_links", "sql"],
template="Table {table_name}, columns = {fields_list}, prior_schema_links = {prior_schema_links}\nCurrent_date:{current_date}\n问题:{question}\n分析:{analysis} 所以Schema_links是:\nSchema_links:{schema_links}\nSQL:{sql}")
instruction = "# 根据数据库的表结构,参考先验信息,找出为每个问题生成SQL查询语句的schema_links,再根据schema_links为每个问题生成SQL查询语句"
schema_linking_sql_combo_prompt = "Table {table_name}, columns = {fields_list}, prior_schema_links = {prior_schema_links}\nCurrent_date:{current_date}\n问题:{question}\n分析: 让我们一步一步地思考。"
schema_linking_sql_combo_example_prompt_template = FewShotPromptTemplate(
example_selector=example_selector,
example_prompt=example_prompt_template,
example_separator="\n\n",
prefix=instruction,
input_variables=["table_name", "fields_list", "prior_schema_links", "current_date", "question"],
suffix=schema_linking_sql_combo_prompt
)
schema_linking_sql_combo_example_prompt = schema_linking_sql_combo_example_prompt_template.format(table_name=domain_name,
fields_list=fields_list,
prior_schema_links=prior_schema_links_str,
current_date=data_date,
question=user_query)
return schema_linking_sql_combo_example_prompt

View File

@@ -7,32 +7,37 @@ import sys
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
sys.path.append(os.path.dirname(os.path.abspath(__file__))) sys.path.append(os.path.dirname(os.path.abspath(__file__)))
from sql.prompt_maker import schema_linking_exampler, sql_exampler from sql.prompt_maker import schema_linking_exampler, sql_exampler, schema_linking_sql_combo_examplar
from sql.constructor import schema_linking_example_selector, sql_example_selector,sql_examples_vectorstore, reload_sql_example_collection from sql.constructor import sql_examples_vectorstore, sql_example_selector, reload_sql_example_collection
from sql.output_parser import schema_link_parse from sql.output_parser import schema_link_parse, combo_schema_link_parse, combo_sql_parse
from util.llm_instance import llm from util.llm_instance import llm
from run_config import TEXT2DSL_IS_SHORTCUT
class Text2DSLAgent(object): class Text2DSLAgent(object):
def __init__(self): def __init__(self):
self.schema_linking_exampler = schema_linking_exampler self.schema_linking_exampler = schema_linking_exampler
self.sql_exampler = sql_exampler self.sql_exampler = sql_exampler
self.schema_linking_sql_combo_exampler = schema_linking_sql_combo_examplar
self.sql_examples_vectorstore = sql_examples_vectorstore self.sql_examples_vectorstore = sql_examples_vectorstore
self.schema_linking_example_selector = schema_linking_example_selector
self.sql_example_selector = sql_example_selector self.sql_example_selector = sql_example_selector
self.schema_link_parse = schema_link_parse self.schema_link_parse = schema_link_parse
self.combo_schema_link_parse = combo_schema_link_parse
self.combo_sql_parse = combo_sql_parse
self.llm = llm self.llm = llm
def update_examples(self, sql_examplars, example_nums): self.is_shortcut = TEXT2DSL_IS_SHORTCUT
self.sql_examples_vectorstore, self.schema_linking_example_selector, self.sql_example_selector = reload_sql_example_collection(self.sql_examples_vectorstore,
sql_examplars, def update_examples(self, sql_examples, example_nums, is_shortcut):
self.schema_linking_example_selector, self.sql_examples_vectorstore, self.sql_example_selector = reload_sql_example_collection(self.sql_examples_vectorstore,
self.sql_example_selector, sql_examples,
example_nums) self.sql_example_selector,
example_nums)
self.is_shortcut = is_shortcut
def query2sql(self, query_text: str, def query2sql(self, query_text: str,
schema : Union[dict, None] = None, schema : Union[dict, None] = None,
@@ -53,14 +58,14 @@ class Text2DSLAgent(object):
model_name = schema['modelName'] model_name = schema['modelName']
fields_list = schema['fieldNameList'] fields_list = schema['fieldNameList']
schema_linking_prompt = self.schema_linking_exampler(query_text, model_name, fields_list, prior_schema_links, self.schema_linking_example_selector) schema_linking_prompt = self.schema_linking_exampler(query_text, model_name, fields_list, prior_schema_links, self.sql_example_selector)
print("schema_linking_prompt->", schema_linking_prompt) print("schema_linking_prompt->", schema_linking_prompt)
schema_link_output = self.llm(schema_linking_prompt) schema_link_output = self.llm(schema_linking_prompt)
schema_link_str = self.schema_link_parse(schema_link_output) schema_link_str = self.schema_link_parse(schema_link_output)
sql_prompt = self.sql_exampler(query_text, model_name, schema_link_str, current_date, self.sql_example_selector) sql_prompt = self.sql_exampler(query_text, model_name, schema_link_str, current_date, self.sql_example_selector)
print("sql_prompt->", sql_prompt) print("sql_prompt->", sql_prompt)
sql_output = llm(sql_prompt) sql_output = self.llm(sql_prompt)
resp = dict() resp = dict()
resp['query'] = query_text resp['query'] = query_text
@@ -69,7 +74,7 @@ class Text2DSLAgent(object):
resp['priorSchemaLinking'] = linking resp['priorSchemaLinking'] = linking
resp['dataDate'] = current_date resp['dataDate'] = current_date
resp['schemaLinkingOutput'] = schema_link_output resp['analysisOutput'] = schema_link_output
resp['schemaLinkStr'] = schema_link_str resp['schemaLinkStr'] = schema_link_str
resp['sqlOutput'] = sql_output resp['sqlOutput'] = sql_output
@@ -78,5 +83,57 @@ class Text2DSLAgent(object):
return resp return resp
def query2sqlcombo(self, query_text: str,
schema : Union[dict, None] = None,
current_date: str = None,
linking: Union[List[Mapping[str, str]], None] = None
):
print("query_text: ", query_text)
print("schema: ", schema)
print("current_date: ", current_date)
print("prior_schema_links: ", linking)
if linking is not None:
prior_schema_links = {item['fieldValue']:item['fieldName'] for item in linking}
else:
prior_schema_links = {}
model_name = schema['modelName']
fields_list = schema['fieldNameList']
schema_linking_sql_combo_prompt = self.schema_linking_sql_combo_exampler(query_text, model_name, current_date, fields_list,
prior_schema_links, self.sql_example_selector)
print("schema_linking_sql_combo_prompt->", schema_linking_sql_combo_prompt)
schema_linking_sql_combo_output = self.llm(schema_linking_sql_combo_prompt)
schema_linking_str = self.combo_schema_link_parse(schema_linking_sql_combo_output)
sql_str = self.combo_sql_parse(schema_linking_sql_combo_output)
resp = dict()
resp['query'] = query_text
resp['model'] = model_name
resp['fields'] = fields_list
resp['priorSchemaLinking'] = prior_schema_links
resp['dataDate'] = current_date
resp['analysisOutput'] = schema_linking_sql_combo_output
resp['schemaLinkStr'] = schema_linking_str
resp['sqlOutput'] = sql_str
print("resp: ", resp)
return resp
def query2sql_run(self, query_text: str,
schema : Union[dict, None] = None,
current_date: str = None,
linking: Union[List[Mapping[str, str]], None] = None):
if self.is_shortcut:
return self.query2sqlcombo(query_text, schema, current_date, linking)
else:
return self.query2sql(query_text, schema, current_date, linking)
text2sql_agent = Text2DSLAgent() text2sql_agent = Text2DSLAgent()

View File

@@ -51,7 +51,7 @@ async def din_query2sql(query_body: Mapping[str, Any]):
else: else:
linking = query_body['linking'] linking = query_body['linking']
resp = text2sql_agent.query2sql(query_text=query_text, resp = text2sql_agent.query2sql_run(query_text=query_text,
schema=schema, current_date=current_date, linking=linking) schema=schema, current_date=current_date, linking=linking)
return resp return resp
@@ -70,7 +70,12 @@ async def query2sql_setting_update(query_body: Mapping[str, Any]):
else: else:
example_nums = query_body['exampleNums'] example_nums = query_body['exampleNums']
text2sql_agent.update_examples(sql_examplars=sql_examplars, example_nums=example_nums) if 'isShortcut' not in query_body:
raise HTTPException(status_code=400, detail="isShortcut is not in query_body")
else:
is_shortcut = query_body['isShortcut']
text2sql_agent.update_examples(sql_examples=sql_examplars, example_nums=example_nums, is_shortcut=is_shortcut)
return "success" return "success"

Binary file not shown.

Before

Width:  |  Height:  |  Size: 123 KiB

After

Width:  |  Height:  |  Size: 107 KiB

View File

@@ -5,21 +5,25 @@ text2sql的功能实现高度依赖对LLM的应用。通过LLM生成SQL的过
### **配置方式** ### **配置方式**
1. 样本池的配置。 1. 样本池的配置。
- supersonic/chat/core/src/main/python/llm/few_shot_example/sql_exampler.py为样本池配置文件。用户可以以已有的样本作为参考配置更贴近自身业务需求的样本用于更好的引导LLM生成SQL。 - supersonic/chat/core/src/main/python/few_shot_example/sql_exampler.py 为样本池配置文件。用户可以以已有的样本作为参考配置更贴近自身业务需求的样本用于更好的引导LLM生成SQL。
2. 样本数量的配置。 2. 样本数量的配置。
- 在supersonic/chat/core/src/main/python/llm/run_config.py 中通过 TEXT2DSL_FEW_SHOTS_EXAMPLE_NUM 变量进行配置。 - supersonic/chat/core/src/main/python/run_config.py 中通过 TEXT2DSL_FEW_SHOTS_EXAMPLE_NUM 变量进行配置。
- 默认值为15为项目在内部实践后较优的经验值。样本少太少对导致LLM在生成SQL的过程中缺少引导和示范生成的SQL会更不稳定样本太多会增加生成SQL需要的时间和LLM的token消耗或超过LLM的token上限 - 默认值为15为项目在内部实践后较优的经验值。样本少太少对导致LLM在生成SQL的过程中缺少引导和示范生成的SQL会更不稳定样本太多会增加生成SQL需要的时间和LLM的token消耗或超过LLM的token上限
- <div align="left" > 3. SQL生成方式的配置
- 在 supersonic/chat/core/src/main/python/run_config.py 中通过 TEXT2DSL_IS_SHORTCUT 变量进行配置。
- 默认值为False当为False时会调用2次LLM生成SQL当为True时会只调用1次LLM生成SQL。相较于2次LLM调用生成的SQL耗时会减少30-40%token的消耗量会减少30%左右但生成的SQL正确率会有所下降。
<div align="left" >
<img src=../images/text2sql_config.png width="70%"/> <img src=../images/text2sql_config.png width="70%"/>
<p>图1-1 样本数量的配置文件</p> <p>图1-1 配置文件</p>
</div> </div>
3. 运行中更新配置的脚本 ### **运行中更新配置的脚本**
- 如果在启动项目后用户需要对text2sql功能的相关配置进行调试可以在修改相关配置文件后通过脚本 supersonic/chat/core/src/main/python/bin/text2sql_resetting.sh 在项目运行中让配置生效。 1. 如果在启动项目后用户需要对text2sql功能的相关配置进行调试可以在修改相关配置文件后通过以下2种方式让配置在项目运行中让配置生效。
- 执行 supersonic-daemon.sh reload llmparser
- 执行 python examples_reload_run.py
### **FAQ** ### **FAQ**
1. 生成一个SQL需要消耗的的LLM token数量太多了按照openAI对token的收费标准生成一个SQL太贵了可以少用一些token吗 1. 生成一个SQL需要消耗的的LLM token数量太多了按照openAI对token的收费标准生成一个SQL太贵了可以少用一些token吗
- 可以。 用户可以根据自身需求如配置方式1.中所示修改样本池中的样本选用一些更加简短的样本。如配置方式2.中所示,减少使用的样本数量。 - 可以。 用户可以根据自身需求如配置方式1.中所示修改样本池中的样本选用一些更加简短的样本。如配置方式2.中所示,减少使用的样本数量。配置方式3.中所示只调用1次LLM生成SQL。
- 需要注意样本和样本数量的选择对生成SQL的质量有很大的影响。过于激进的降低输入的token数量可能会降低生成SQL的质量。需要用户根据自身业务特点实测后进行平衡。 - 需要注意样本和样本数量的选择对生成SQL的质量有很大的影响。过于激进的降低输入的token数量可能会降低生成SQL的质量。需要用户根据自身业务特点实测后进行平衡。