Allow user to config examples and number of examples used by text2sql in middle of run (#85)

This commit is contained in:
codescracker
2023-09-13 17:24:12 +08:00
committed by GitHub
parent c38507d50c
commit 545fb139ee
13 changed files with 303 additions and 61 deletions

View File

@@ -0,0 +1,12 @@
#!/usr/bin/env bash
llm_host=$1
llm_port=$2
baseDir=$(cd "$binDir/.." && pwd -P)
cd $baseDir/llm/sql
${python_path} examples_reload_run.py ${llm_port} ${llm_host}

View File

@@ -10,7 +10,7 @@ from typing import Any, List, Mapping, Optional, Union
from fastapi import FastAPI, HTTPException
from sql.run import query2sql
from sql.run import text2sql_agent
from preset_retrieval.run import preset_query_retrieval_run, collection as preset_query_collection
from preset_retrieval.preset_query_db import (add2preset_query_collection, update_preset_query_collection,
@@ -46,12 +46,30 @@ async def din_query2sql(query_body: Mapping[str, Any]):
else:
linking = query_body['linking']
resp = query2sql(query_text=query_text,
resp = text2sql_agent.query2sql(query_text=query_text,
schema=schema, current_date=current_date, linking=linking)
return resp
@app.post("/query2sql_setting_update/")
async def query2sql_setting_update(query_body: Mapping[str, Any]):
if 'sqlExamplars' not in query_body:
raise HTTPException(status_code=400,
detail="sqlExamplars is not in query_body")
else:
sql_examplars = query_body['sqlExamplars']
if 'exampleNums' not in query_body:
raise HTTPException(status_code=400, detail="exampleNums is not in query_body")
else:
example_nums = query_body['exampleNums']
text2sql_agent.update_examples(sql_examplars=sql_examplars, example_nums=example_nums)
return "success"
@app.post("/preset_query_retrival/")
async def preset_query_retrival(query_text_list: List[str], n_results: int = 5):
parsed_retrieval_res_format = preset_query_retrieval_run(preset_query_collection, query_text_list, n_results)

View File

@@ -292,5 +292,57 @@ examplars= [
基于table和columns可能的cell values 是 = ['刘锝桦', 1992, 4, 2, 2020, 5, 2, 200000]。""",
"schema_links":"""["结算播放量", "发布时间", "歌手名", "刘锝桦", 1992, 4, 2, 2020, 5, 2, 200000]""",
"sql":"""select 歌曲名 from 歌曲库 where YEAR(发布时间) >= 1992 and MONTH(发布时间) >= 4 and DAY(发布时间) >= 2 and YEAR(发布时间) <= 2020 and MONTH(发布时间) <= 5 and DAY(发布时间) <= 2 and 歌手名 = '刘锝桦' and 结算播放量 > 200000 and 数据日期 = '2023-08-16'"""
},
{
"current_date":"2023-09-04",
"table_name":"内容库产品",
"fields_list":"""["用户名", "部门", "模块", "访问时长", "访问次数", "访问人数", "数据日期"]""",
"question":"内容库近30天访问次数的平均数",
"prior_schema_links":"""[]""",
"analysis": """让我们一步一步地思考。在问题“内容库近30天访问次数的平均数“中我们被问
“访问次数的平均数”所以我们需要column=[访问次数]
”内容库近30天“所以我们需要column=[数据日期]
基于table和columns可能的cell values 是 = [30]。""",
"schema_links":"""["访问次数", "数据日期", 30]""",
"sql":"""select avg(访问次数) from 内容库产品 where datediff('day', 数据日期, '2023-09-04') <= 30 """
},
{
"current_date":"2023-09-04",
"table_name":"内容库产品",
"fields_list":"""["用户名", "部门", "模块", "访问时长", "访问次数", "访问人数", "数据日期"]""",
"question":"内容库近半年哪个月的访问次数汇总最高",
"prior_schema_links":"""[]""",
"analysis": """让我们一步一步地思考。在问题“内容库近半年哪个月的访问次数汇总最高“中,我们被问:
“访问次数汇总最高”所以我们需要column=[访问次数]
”内容库近半年“所以我们需要column=[数据日期]
基于table和columns可能的cell values 是 = [0.5]。""",
"schema_links":"""["访问次数", "数据日期", 0.5]""",
"sql":"""select MONTH(数据日期), sum(访问次数) from 内容库产品 where datediff('year', 数据日期, '2023-09-04') <= 0.5 group by MONTH(数据日期) order by sum(访问次数) desc limit 1 """
},
{
"current_date":"2023-09-04",
"table_name":"内容库产品",
"fields_list":"""["用户名", "部门", "模块", "访问时长", "访问次数", "访问人数", "数据日期"]""",
"question":"内容库近半年每个月的平均访问次数",
"prior_schema_links":"""[]""",
"analysis": """让我们一步一步地思考。在问题“内容库近半年每个月的平均访问次数“中,我们被问:
“每个月的平均访问次数”所以我们需要column=[访问次数]
”内容库近半年“所以我们需要column=[数据日期]
基于table和columns可能的cell values 是 = [0.5]。""",
"schema_links":"""["访问次数", "数据日期", 0.5]""",
"sql":"""select MONTH(数据日期), avg(访问次数) from 内容库产品 where datediff('year', 数据日期, '2023-09-04') <= 0.5 group by MONTH(数据日期) """
},
{
"current_date":"2023-09-10",
"table_name":"内容库产品",
"fields_list":"""["用户名", "部门", "模块", "访问时长", "访问次数", "访问人数", "数据日期"]""",
"question":"内容库 按部门统计访问次数 top10 的部门",
"prior_schema_links":"""[]""",
"analysis": """让我们一步一步地思考。在问题“内容库 按部门统计访问次数 top10 的部门“中,我们被问:
“访问次数 top10 的部门”所以我们需要column=[访问次数]
”内容库 按部门统计“所以我们需要column=[部门]
基于table和columns可能的cell values 是 = [10]。""",
"schema_links":"""["访问次数", "部门", 10]""",
"sql":"""select 部门, sum(访问次数) from 内容库产品 group by 部门 order by sum(访问次数) desc limit 10 """
}
]

View File

@@ -5,11 +5,13 @@ PROJECT_DIR_PATH = os.path.dirname(os.path.abspath(__file__))
MODEL_NAME = "gpt-3.5-turbo-16k"
OPENAI_API_KEY = "YOUR_API_KEY"
TEMPERATURE = 0.0
CHROMA_DB_PERSIST_DIR = 'chm_db'
PRESET_QUERY_COLLECTION_NAME = "preset_query_collection"
TEXT2DSL_COLLECTION_NAME = "text2dsl_collection"
TEXT2DSL_FEW_SHOTS_EXAMPLE_NUM = 15
CHROMA_DB_PERSIST_PATH = os.path.join(PROJECT_DIR_PATH, CHROMA_DB_PERSIST_DIR)

View File

@@ -1,6 +1,8 @@
# -*- coding:utf-8 -*-
from typing import Any, List, Mapping, Optional, Union
import os
import sys
import time
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
sys.path.append(os.path.dirname(os.path.abspath(__file__)))
@@ -13,41 +15,64 @@ from langchain.prompts.example_selector import SemanticSimilarityExampleSelector
import chromadb
from chromadb.config import Settings
from chromadb.api import Collection, Documents, Embeddings
from few_shot_example.sql_exampler import examplars as din_sql_examplars
from few_shot_example.sql_exampler import examplars as sql_examplars
from util.text2vec import Text2VecEmbeddingFunction, hg_embedding
from util.chromadb_instance import client as chromadb_client
from util.chromadb_instance import client as chromadb_client, empty_chroma_collection_2
from run_config import TEXT2DSL_COLLECTION_NAME, TEXT2DSL_FEW_SHOTS_EXAMPLE_NUM
from run_config import TEXT2DSL_COLLECTION_NAME
def reload_sql_example_collection(vectorstore:Chroma,
sql_examplars:List[Mapping[str, str]],
schema_linking_example_selector:SemanticSimilarityExampleSelector,
sql_example_selector:SemanticSimilarityExampleSelector,
example_nums:int
):
print("original sql_examples_collection size:", vectorstore._collection.count())
new_collection = empty_chroma_collection_2(collection=vectorstore._collection)
vectorstore._collection = new_collection
print("emptied sql_examples_collection size:", vectorstore._collection.count())
vectorstore = Chroma(collection_name=TEXT2DSL_COLLECTION_NAME,
embedding_function=hg_embedding,
client=chromadb_client)
example_nums = 15
schema_linking_example_selector = SemanticSimilarityExampleSelector(vectorstore=vectorstore, k=example_nums,
schema_linking_example_selector = SemanticSimilarityExampleSelector(vectorstore=sql_examples_vectorstore, k=example_nums,
input_keys=["question"],
example_keys=["table_name", "fields_list", "prior_schema_links", "question", "analysis", "schema_links"])
sql_example_selector = SemanticSimilarityExampleSelector(vectorstore=vectorstore, k=example_nums,
sql_example_selector = SemanticSimilarityExampleSelector(vectorstore=sql_examples_vectorstore, k=example_nums,
input_keys=["question"],
example_keys=["question", "current_date", "table_name", "schema_links", "sql"])
if vectorstore._collection.count() > 0:
print("examples already in din_sql_vectorstore")
print("init din_sql_vectorstore size:", vectorstore._collection.count())
if vectorstore._collection.count() < len(din_sql_examplars):
print("din_sql_examplars size:", len(din_sql_examplars))
vectorstore._collection.delete()
print("empty din_sql_vectorstore")
for example in din_sql_examplars:
schema_linking_example_selector.add_example(example)
print("added din_sql_vectorstore size:", vectorstore._collection.count())
else:
for example in din_sql_examplars:
for example in sql_examplars:
schema_linking_example_selector.add_example(example)
print("added din_sql_vectorstore size:", vectorstore._collection.count())
print("reloaded sql_examples_collection size:", vectorstore._collection.count())
return vectorstore, schema_linking_example_selector, sql_example_selector
sql_examples_vectorstore = Chroma(collection_name=TEXT2DSL_COLLECTION_NAME,
embedding_function=hg_embedding,
client=chromadb_client)
example_nums = TEXT2DSL_FEW_SHOTS_EXAMPLE_NUM
schema_linking_example_selector = SemanticSimilarityExampleSelector(vectorstore=sql_examples_vectorstore, k=example_nums,
input_keys=["question"],
example_keys=["table_name", "fields_list", "prior_schema_links", "question", "analysis", "schema_links"])
sql_example_selector = SemanticSimilarityExampleSelector(vectorstore=sql_examples_vectorstore, k=example_nums,
input_keys=["question"],
example_keys=["question", "current_date", "table_name", "schema_links", "sql"])
if sql_examples_vectorstore._collection.count() > 0:
print("examples already in sql_vectorstore")
print("init sql_vectorstore size:", sql_examples_vectorstore._collection.count())
if sql_examples_vectorstore._collection.count() < len(sql_examplars):
print("sql_examplars size:", len(sql_examplars))
sql_examples_vectorstore, schema_linking_example_selector, sql_example_selector = reload_sql_example_collection(sql_examples_vectorstore, sql_examplars, schema_linking_example_selector, sql_example_selector, example_nums)
print("added sql_vectorstore size:", sql_examples_vectorstore._collection.count())
else:
sql_examples_vectorstore, schema_linking_example_selector, sql_example_selector = reload_sql_example_collection(sql_examples_vectorstore, sql_examplars, schema_linking_example_selector, sql_example_selector, example_nums)
print("added sql_vectorstore size:", sql_examples_vectorstore._collection.count())

View File

@@ -0,0 +1,31 @@
# -*- coding:utf-8 -*-
from typing import Any, List, Mapping, Optional, Union
import os
import sys
import requests
import json
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
sys.path.append(os.path.dirname(os.path.abspath(__file__)))
from run_config import TEXT2DSL_FEW_SHOTS_EXAMPLE_NUM
from few_shot_example.sql_exampler import examplars as sql_examplars
def text2dsl_setting_update(llm_host:str, llm_port:str,
sql_examplars:List[Mapping[str, str]], example_nums:int):
url = f"http://{llm_host}:{llm_port}/query2sql_setting_update/"
payload = {"sqlExamplars":sql_examplars, "exampleNums":example_nums}
headers = {'content-type': 'application/json'}
response = requests.post(url, data=json.dumps(payload), headers=headers)
print(response.text)
if __name__ == "__main__":
arguments = sys.argv
llm_host = str(arguments[1])
llm_port = str(arguments[2])
text2dsl_setting_update(llm_host, llm_port,
sql_examplars, TEXT2DSL_FEW_SHOTS_EXAMPLE_NUM)

View File

@@ -8,53 +8,75 @@ sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
sys.path.append(os.path.dirname(os.path.abspath(__file__)))
from sql.prompt_maker import schema_linking_exampler, sql_exampler
from sql.constructor import schema_linking_example_selector, sql_example_selector
from sql.constructor import schema_linking_example_selector, sql_example_selector,sql_examples_vectorstore, reload_sql_example_collection
from sql.output_parser import schema_link_parse
from util.llm_instance import llm
def query2sql(query_text: str,
class Text2DSLAgent(object):
def __init__(self):
self.schema_linking_exampler = schema_linking_exampler
self.sql_exampler = sql_exampler
self.sql_examples_vectorstore = sql_examples_vectorstore
self.schema_linking_example_selector = schema_linking_example_selector
self.sql_example_selector = sql_example_selector
self.schema_link_parse = schema_link_parse
self.llm = llm
def update_examples(self, sql_examplars, example_nums):
self.sql_examples_vectorstore, self.schema_linking_example_selector, self.sql_example_selector = reload_sql_example_collection(self.sql_examples_vectorstore,
sql_examplars,
self.schema_linking_example_selector,
self.sql_example_selector,
example_nums)
def query2sql(self, query_text: str,
schema : Union[dict, None] = None,
current_date: str = None,
linking: Union[List[Mapping[str, str]], None] = None
):
print("query_text: ", query_text)
print("schema: ", schema)
print("current_date: ", current_date)
print("prior_schema_links: ", linking)
if linking is not None:
prior_schema_links = {item['fieldValue']:item['fieldName'] for item in linking}
else:
prior_schema_links = {}
print("query_text: ", query_text)
print("schema: ", schema)
print("current_date: ", current_date)
print("prior_schema_links: ", linking)
model_name = schema['modelName']
fields_list = schema['fieldNameList']
if linking is not None:
prior_schema_links = {item['fieldValue']:item['fieldName'] for item in linking}
else:
prior_schema_links = {}
schema_linking_prompt = schema_linking_exampler(query_text, model_name, fields_list, prior_schema_links, schema_linking_example_selector)
print("schema_linking_prompt->", schema_linking_prompt)
schema_link_output = llm(schema_linking_prompt)
schema_link_str = schema_link_parse(schema_link_output)
sql_prompt = sql_exampler(query_text, model_name, schema_link_str, current_date, sql_example_selector)
print("sql_prompt->", sql_prompt)
sql_output = llm(sql_prompt)
model_name = schema['modelName']
fields_list = schema['fieldNameList']
resp = dict()
resp['query'] = query_text
resp['model'] = model_name
resp['fields'] = fields_list
resp['priorSchemaLinking'] = linking
resp['dataDate'] = current_date
schema_linking_prompt = self.schema_linking_exampler(query_text, model_name, fields_list, prior_schema_links, self.schema_linking_example_selector)
print("schema_linking_prompt->", schema_linking_prompt)
schema_link_output = self.llm(schema_linking_prompt)
schema_link_str = self.schema_link_parse(schema_link_output)
sql_prompt = self.sql_exampler(query_text, model_name, schema_link_str, current_date, self.sql_example_selector)
print("sql_prompt->", sql_prompt)
sql_output = llm(sql_prompt)
resp['schemaLinkingOutput'] = schema_link_output
resp['schemaLinkStr'] = schema_link_str
resp['sqlOutput'] = sql_output
resp = dict()
resp['query'] = query_text
resp['model'] = model_name
resp['fields'] = fields_list
resp['priorSchemaLinking'] = linking
resp['dataDate'] = current_date
print("resp: ", resp)
resp['schemaLinkingOutput'] = schema_link_output
resp['schemaLinkStr'] = schema_link_str
resp['sqlOutput'] = sql_output
return resp
print("resp: ", resp)
return resp
text2sql_agent = Text2DSLAgent()

View File

@@ -1,5 +1,8 @@
# -*- coding:utf-8 -*-
from typing import Any, List, Mapping, Optional, Union
import chromadb
from chromadb.api import Collection, Documents, Embeddings
from chromadb.config import Settings
from run_config import CHROMA_DB_PERSIST_PATH
@@ -7,4 +10,28 @@ from run_config import CHROMA_DB_PERSIST_PATH
client = chromadb.Client(Settings(
chroma_db_impl="duckdb+parquet",
persist_directory=CHROMA_DB_PERSIST_PATH # Optional, defaults to .chromadb/ in the current directory
))
))
def empty_chroma_collection_2(collection:Collection):
collection_name = collection.name
client = collection._client
metadata = collection.metadata
embedding_function = collection._embedding_function
client.delete_collection(collection_name)
new_collection = client.get_or_create_collection(name=collection_name,
metadata=metadata,
embedding_function=embedding_function)
size_of_new_collection = new_collection.count()
print(f'Collection {collection_name} emptied. Size of new collection: {size_of_new_collection}')
return new_collection
def empty_chroma_collection(collection:Collection):
collection.delete()

View File

@@ -4,4 +4,5 @@ fastapi==0.95.1
chromadb==0.3.21
tiktoken==0.3.3
uvicorn[standard]==0.21.1
pandas==1.5.3