mirror of
https://github.com/tencentmusic/supersonic.git
synced 2025-12-10 11:07:06 +00:00
Allow user to config examples and number of examples used by text2sql in middle of run (#85)
This commit is contained in:
1
.gitignore
vendored
1
.gitignore
vendored
@@ -15,4 +15,5 @@ assembly/runtime/*
|
|||||||
/assembly/deploy
|
/assembly/deploy
|
||||||
/runtime
|
/runtime
|
||||||
**/.flattened-pom.xml
|
**/.flattened-pom.xml
|
||||||
|
chm_db/
|
||||||
__pycache__/
|
__pycache__/
|
||||||
12
chat/core/src/main/python/bin/text2sql_resetting.sh
Normal file
12
chat/core/src/main/python/bin/text2sql_resetting.sh
Normal file
@@ -0,0 +1,12 @@
|
|||||||
|
#!/usr/bin/env bash
|
||||||
|
llm_host=$1
|
||||||
|
llm_port=$2
|
||||||
|
|
||||||
|
baseDir=$(cd "$binDir/.." && pwd -P)
|
||||||
|
|
||||||
|
cd $baseDir/llm/sql
|
||||||
|
|
||||||
|
${python_path} examples_reload_run.py ${llm_port} ${llm_host}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
@@ -10,7 +10,7 @@ from typing import Any, List, Mapping, Optional, Union
|
|||||||
|
|
||||||
from fastapi import FastAPI, HTTPException
|
from fastapi import FastAPI, HTTPException
|
||||||
|
|
||||||
from sql.run import query2sql
|
from sql.run import text2sql_agent
|
||||||
|
|
||||||
from preset_retrieval.run import preset_query_retrieval_run, collection as preset_query_collection
|
from preset_retrieval.run import preset_query_retrieval_run, collection as preset_query_collection
|
||||||
from preset_retrieval.preset_query_db import (add2preset_query_collection, update_preset_query_collection,
|
from preset_retrieval.preset_query_db import (add2preset_query_collection, update_preset_query_collection,
|
||||||
@@ -46,12 +46,30 @@ async def din_query2sql(query_body: Mapping[str, Any]):
|
|||||||
else:
|
else:
|
||||||
linking = query_body['linking']
|
linking = query_body['linking']
|
||||||
|
|
||||||
resp = query2sql(query_text=query_text,
|
resp = text2sql_agent.query2sql(query_text=query_text,
|
||||||
schema=schema, current_date=current_date, linking=linking)
|
schema=schema, current_date=current_date, linking=linking)
|
||||||
|
|
||||||
return resp
|
return resp
|
||||||
|
|
||||||
|
|
||||||
|
@app.post("/query2sql_setting_update/")
|
||||||
|
async def query2sql_setting_update(query_body: Mapping[str, Any]):
|
||||||
|
if 'sqlExamplars' not in query_body:
|
||||||
|
raise HTTPException(status_code=400,
|
||||||
|
detail="sqlExamplars is not in query_body")
|
||||||
|
else:
|
||||||
|
sql_examplars = query_body['sqlExamplars']
|
||||||
|
|
||||||
|
if 'exampleNums' not in query_body:
|
||||||
|
raise HTTPException(status_code=400, detail="exampleNums is not in query_body")
|
||||||
|
else:
|
||||||
|
example_nums = query_body['exampleNums']
|
||||||
|
|
||||||
|
text2sql_agent.update_examples(sql_examplars=sql_examplars, example_nums=example_nums)
|
||||||
|
|
||||||
|
return "success"
|
||||||
|
|
||||||
|
|
||||||
@app.post("/preset_query_retrival/")
|
@app.post("/preset_query_retrival/")
|
||||||
async def preset_query_retrival(query_text_list: List[str], n_results: int = 5):
|
async def preset_query_retrival(query_text_list: List[str], n_results: int = 5):
|
||||||
parsed_retrieval_res_format = preset_query_retrieval_run(preset_query_collection, query_text_list, n_results)
|
parsed_retrieval_res_format = preset_query_retrieval_run(preset_query_collection, query_text_list, n_results)
|
||||||
|
|||||||
@@ -292,5 +292,57 @@ examplars= [
|
|||||||
基于table和columns,可能的cell values 是 = ['刘锝桦', 1992, 4, 2, 2020, 5, 2, 200000]。""",
|
基于table和columns,可能的cell values 是 = ['刘锝桦', 1992, 4, 2, 2020, 5, 2, 200000]。""",
|
||||||
"schema_links":"""["结算播放量", "发布时间", "歌手名", "刘锝桦", 1992, 4, 2, 2020, 5, 2, 200000]""",
|
"schema_links":"""["结算播放量", "发布时间", "歌手名", "刘锝桦", 1992, 4, 2, 2020, 5, 2, 200000]""",
|
||||||
"sql":"""select 歌曲名 from 歌曲库 where YEAR(发布时间) >= 1992 and MONTH(发布时间) >= 4 and DAY(发布时间) >= 2 and YEAR(发布时间) <= 2020 and MONTH(发布时间) <= 5 and DAY(发布时间) <= 2 and 歌手名 = '刘锝桦' and 结算播放量 > 200000 and 数据日期 = '2023-08-16'"""
|
"sql":"""select 歌曲名 from 歌曲库 where YEAR(发布时间) >= 1992 and MONTH(发布时间) >= 4 and DAY(发布时间) >= 2 and YEAR(发布时间) <= 2020 and MONTH(发布时间) <= 5 and DAY(发布时间) <= 2 and 歌手名 = '刘锝桦' and 结算播放量 > 200000 and 数据日期 = '2023-08-16'"""
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"current_date":"2023-09-04",
|
||||||
|
"table_name":"内容库产品",
|
||||||
|
"fields_list":"""["用户名", "部门", "模块", "访问时长", "访问次数", "访问人数", "数据日期"]""",
|
||||||
|
"question":"内容库近30天访问次数的平均数",
|
||||||
|
"prior_schema_links":"""[]""",
|
||||||
|
"analysis": """让我们一步一步地思考。在问题“内容库近30天访问次数的平均数“中,我们被问:
|
||||||
|
“访问次数的平均数”,所以我们需要column=[访问次数]
|
||||||
|
”内容库近30天“,所以我们需要column=[数据日期]
|
||||||
|
基于table和columns,可能的cell values 是 = [30]。""",
|
||||||
|
"schema_links":"""["访问次数", "数据日期", 30]""",
|
||||||
|
"sql":"""select avg(访问次数) from 内容库产品 where datediff('day', 数据日期, '2023-09-04') <= 30 """
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"current_date":"2023-09-04",
|
||||||
|
"table_name":"内容库产品",
|
||||||
|
"fields_list":"""["用户名", "部门", "模块", "访问时长", "访问次数", "访问人数", "数据日期"]""",
|
||||||
|
"question":"内容库近半年哪个月的访问次数汇总最高",
|
||||||
|
"prior_schema_links":"""[]""",
|
||||||
|
"analysis": """让我们一步一步地思考。在问题“内容库近半年哪个月的访问次数汇总最高“中,我们被问:
|
||||||
|
“访问次数汇总最高”,所以我们需要column=[访问次数]
|
||||||
|
”内容库近半年“,所以我们需要column=[数据日期]
|
||||||
|
基于table和columns,可能的cell values 是 = [0.5]。""",
|
||||||
|
"schema_links":"""["访问次数", "数据日期", 0.5]""",
|
||||||
|
"sql":"""select MONTH(数据日期), sum(访问次数) from 内容库产品 where datediff('year', 数据日期, '2023-09-04') <= 0.5 group by MONTH(数据日期) order by sum(访问次数) desc limit 1 """
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"current_date":"2023-09-04",
|
||||||
|
"table_name":"内容库产品",
|
||||||
|
"fields_list":"""["用户名", "部门", "模块", "访问时长", "访问次数", "访问人数", "数据日期"]""",
|
||||||
|
"question":"内容库近半年每个月的平均访问次数",
|
||||||
|
"prior_schema_links":"""[]""",
|
||||||
|
"analysis": """让我们一步一步地思考。在问题“内容库近半年每个月的平均访问次数“中,我们被问:
|
||||||
|
“每个月的平均访问次数”,所以我们需要column=[访问次数]
|
||||||
|
”内容库近半年“,所以我们需要column=[数据日期]
|
||||||
|
基于table和columns,可能的cell values 是 = [0.5]。""",
|
||||||
|
"schema_links":"""["访问次数", "数据日期", 0.5]""",
|
||||||
|
"sql":"""select MONTH(数据日期), avg(访问次数) from 内容库产品 where datediff('year', 数据日期, '2023-09-04') <= 0.5 group by MONTH(数据日期) """
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"current_date":"2023-09-10",
|
||||||
|
"table_name":"内容库产品",
|
||||||
|
"fields_list":"""["用户名", "部门", "模块", "访问时长", "访问次数", "访问人数", "数据日期"]""",
|
||||||
|
"question":"内容库 按部门统计访问次数 top10 的部门",
|
||||||
|
"prior_schema_links":"""[]""",
|
||||||
|
"analysis": """让我们一步一步地思考。在问题“内容库 按部门统计访问次数 top10 的部门“中,我们被问:
|
||||||
|
“访问次数 top10 的部门”,所以我们需要column=[访问次数]
|
||||||
|
”内容库 按部门统计“,所以我们需要column=[部门]
|
||||||
|
基于table和columns,可能的cell values 是 = [10]。""",
|
||||||
|
"schema_links":"""["访问次数", "部门", 10]""",
|
||||||
|
"sql":"""select 部门, sum(访问次数) from 内容库产品 group by 部门 order by sum(访问次数) desc limit 10 """
|
||||||
}
|
}
|
||||||
]
|
]
|
||||||
@@ -5,11 +5,13 @@ PROJECT_DIR_PATH = os.path.dirname(os.path.abspath(__file__))
|
|||||||
|
|
||||||
MODEL_NAME = "gpt-3.5-turbo-16k"
|
MODEL_NAME = "gpt-3.5-turbo-16k"
|
||||||
OPENAI_API_KEY = "YOUR_API_KEY"
|
OPENAI_API_KEY = "YOUR_API_KEY"
|
||||||
|
|
||||||
TEMPERATURE = 0.0
|
TEMPERATURE = 0.0
|
||||||
|
|
||||||
CHROMA_DB_PERSIST_DIR = 'chm_db'
|
CHROMA_DB_PERSIST_DIR = 'chm_db'
|
||||||
PRESET_QUERY_COLLECTION_NAME = "preset_query_collection"
|
PRESET_QUERY_COLLECTION_NAME = "preset_query_collection"
|
||||||
TEXT2DSL_COLLECTION_NAME = "text2dsl_collection"
|
TEXT2DSL_COLLECTION_NAME = "text2dsl_collection"
|
||||||
|
TEXT2DSL_FEW_SHOTS_EXAMPLE_NUM = 15
|
||||||
|
|
||||||
CHROMA_DB_PERSIST_PATH = os.path.join(PROJECT_DIR_PATH, CHROMA_DB_PERSIST_DIR)
|
CHROMA_DB_PERSIST_PATH = os.path.join(PROJECT_DIR_PATH, CHROMA_DB_PERSIST_DIR)
|
||||||
|
|
||||||
|
|||||||
@@ -1,6 +1,8 @@
|
|||||||
# -*- coding:utf-8 -*-
|
# -*- coding:utf-8 -*-
|
||||||
|
from typing import Any, List, Mapping, Optional, Union
|
||||||
import os
|
import os
|
||||||
import sys
|
import sys
|
||||||
|
import time
|
||||||
|
|
||||||
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||||||
sys.path.append(os.path.dirname(os.path.abspath(__file__)))
|
sys.path.append(os.path.dirname(os.path.abspath(__file__)))
|
||||||
@@ -13,41 +15,64 @@ from langchain.prompts.example_selector import SemanticSimilarityExampleSelector
|
|||||||
|
|
||||||
import chromadb
|
import chromadb
|
||||||
from chromadb.config import Settings
|
from chromadb.config import Settings
|
||||||
|
from chromadb.api import Collection, Documents, Embeddings
|
||||||
|
|
||||||
from few_shot_example.sql_exampler import examplars as din_sql_examplars
|
from few_shot_example.sql_exampler import examplars as sql_examplars
|
||||||
from util.text2vec import Text2VecEmbeddingFunction, hg_embedding
|
from util.text2vec import Text2VecEmbeddingFunction, hg_embedding
|
||||||
from util.chromadb_instance import client as chromadb_client
|
from util.chromadb_instance import client as chromadb_client, empty_chroma_collection_2
|
||||||
|
from run_config import TEXT2DSL_COLLECTION_NAME, TEXT2DSL_FEW_SHOTS_EXAMPLE_NUM
|
||||||
|
|
||||||
|
|
||||||
from run_config import TEXT2DSL_COLLECTION_NAME
|
def reload_sql_example_collection(vectorstore:Chroma,
|
||||||
|
sql_examplars:List[Mapping[str, str]],
|
||||||
|
schema_linking_example_selector:SemanticSimilarityExampleSelector,
|
||||||
|
sql_example_selector:SemanticSimilarityExampleSelector,
|
||||||
|
example_nums:int
|
||||||
|
):
|
||||||
|
print("original sql_examples_collection size:", vectorstore._collection.count())
|
||||||
|
new_collection = empty_chroma_collection_2(collection=vectorstore._collection)
|
||||||
|
vectorstore._collection = new_collection
|
||||||
|
|
||||||
|
print("emptied sql_examples_collection size:", vectorstore._collection.count())
|
||||||
|
|
||||||
vectorstore = Chroma(collection_name=TEXT2DSL_COLLECTION_NAME,
|
schema_linking_example_selector = SemanticSimilarityExampleSelector(vectorstore=sql_examples_vectorstore, k=example_nums,
|
||||||
embedding_function=hg_embedding,
|
|
||||||
client=chromadb_client)
|
|
||||||
|
|
||||||
example_nums = 15
|
|
||||||
|
|
||||||
schema_linking_example_selector = SemanticSimilarityExampleSelector(vectorstore=vectorstore, k=example_nums,
|
|
||||||
input_keys=["question"],
|
input_keys=["question"],
|
||||||
example_keys=["table_name", "fields_list", "prior_schema_links", "question", "analysis", "schema_links"])
|
example_keys=["table_name", "fields_list", "prior_schema_links", "question", "analysis", "schema_links"])
|
||||||
|
|
||||||
sql_example_selector = SemanticSimilarityExampleSelector(vectorstore=vectorstore, k=example_nums,
|
sql_example_selector = SemanticSimilarityExampleSelector(vectorstore=sql_examples_vectorstore, k=example_nums,
|
||||||
input_keys=["question"],
|
input_keys=["question"],
|
||||||
example_keys=["question", "current_date", "table_name", "schema_links", "sql"])
|
example_keys=["question", "current_date", "table_name", "schema_links", "sql"])
|
||||||
|
|
||||||
if vectorstore._collection.count() > 0:
|
for example in sql_examplars:
|
||||||
print("examples already in din_sql_vectorstore")
|
|
||||||
print("init din_sql_vectorstore size:", vectorstore._collection.count())
|
|
||||||
if vectorstore._collection.count() < len(din_sql_examplars):
|
|
||||||
print("din_sql_examplars size:", len(din_sql_examplars))
|
|
||||||
vectorstore._collection.delete()
|
|
||||||
print("empty din_sql_vectorstore")
|
|
||||||
for example in din_sql_examplars:
|
|
||||||
schema_linking_example_selector.add_example(example)
|
|
||||||
print("added din_sql_vectorstore size:", vectorstore._collection.count())
|
|
||||||
else:
|
|
||||||
for example in din_sql_examplars:
|
|
||||||
schema_linking_example_selector.add_example(example)
|
schema_linking_example_selector.add_example(example)
|
||||||
|
|
||||||
print("added din_sql_vectorstore size:", vectorstore._collection.count())
|
print("reloaded sql_examples_collection size:", vectorstore._collection.count())
|
||||||
|
|
||||||
|
return vectorstore, schema_linking_example_selector, sql_example_selector
|
||||||
|
|
||||||
|
|
||||||
|
sql_examples_vectorstore = Chroma(collection_name=TEXT2DSL_COLLECTION_NAME,
|
||||||
|
embedding_function=hg_embedding,
|
||||||
|
client=chromadb_client)
|
||||||
|
|
||||||
|
example_nums = TEXT2DSL_FEW_SHOTS_EXAMPLE_NUM
|
||||||
|
|
||||||
|
schema_linking_example_selector = SemanticSimilarityExampleSelector(vectorstore=sql_examples_vectorstore, k=example_nums,
|
||||||
|
input_keys=["question"],
|
||||||
|
example_keys=["table_name", "fields_list", "prior_schema_links", "question", "analysis", "schema_links"])
|
||||||
|
|
||||||
|
sql_example_selector = SemanticSimilarityExampleSelector(vectorstore=sql_examples_vectorstore, k=example_nums,
|
||||||
|
input_keys=["question"],
|
||||||
|
example_keys=["question", "current_date", "table_name", "schema_links", "sql"])
|
||||||
|
|
||||||
|
if sql_examples_vectorstore._collection.count() > 0:
|
||||||
|
print("examples already in sql_vectorstore")
|
||||||
|
print("init sql_vectorstore size:", sql_examples_vectorstore._collection.count())
|
||||||
|
if sql_examples_vectorstore._collection.count() < len(sql_examplars):
|
||||||
|
print("sql_examplars size:", len(sql_examplars))
|
||||||
|
sql_examples_vectorstore, schema_linking_example_selector, sql_example_selector = reload_sql_example_collection(sql_examples_vectorstore, sql_examplars, schema_linking_example_selector, sql_example_selector, example_nums)
|
||||||
|
print("added sql_vectorstore size:", sql_examples_vectorstore._collection.count())
|
||||||
|
else:
|
||||||
|
sql_examples_vectorstore, schema_linking_example_selector, sql_example_selector = reload_sql_example_collection(sql_examples_vectorstore, sql_examplars, schema_linking_example_selector, sql_example_selector, example_nums)
|
||||||
|
print("added sql_vectorstore size:", sql_examples_vectorstore._collection.count())
|
||||||
|
|
||||||
|
|||||||
31
chat/core/src/main/python/llm/sql/examples_reload_run.py
Normal file
31
chat/core/src/main/python/llm/sql/examples_reload_run.py
Normal file
@@ -0,0 +1,31 @@
|
|||||||
|
# -*- coding:utf-8 -*-
|
||||||
|
from typing import Any, List, Mapping, Optional, Union
|
||||||
|
import os
|
||||||
|
import sys
|
||||||
|
import requests
|
||||||
|
import json
|
||||||
|
|
||||||
|
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||||||
|
sys.path.append(os.path.dirname(os.path.abspath(__file__)))
|
||||||
|
|
||||||
|
from run_config import TEXT2DSL_FEW_SHOTS_EXAMPLE_NUM
|
||||||
|
from few_shot_example.sql_exampler import examplars as sql_examplars
|
||||||
|
|
||||||
|
def text2dsl_setting_update(llm_host:str, llm_port:str,
|
||||||
|
sql_examplars:List[Mapping[str, str]], example_nums:int):
|
||||||
|
|
||||||
|
url = f"http://{llm_host}:{llm_port}/query2sql_setting_update/"
|
||||||
|
payload = {"sqlExamplars":sql_examplars, "exampleNums":example_nums}
|
||||||
|
headers = {'content-type': 'application/json'}
|
||||||
|
response = requests.post(url, data=json.dumps(payload), headers=headers)
|
||||||
|
print(response.text)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
arguments = sys.argv
|
||||||
|
|
||||||
|
llm_host = str(arguments[1])
|
||||||
|
llm_port = str(arguments[2])
|
||||||
|
|
||||||
|
text2dsl_setting_update(llm_host, llm_port,
|
||||||
|
sql_examplars, TEXT2DSL_FEW_SHOTS_EXAMPLE_NUM)
|
||||||
@@ -8,53 +8,75 @@ sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
|||||||
sys.path.append(os.path.dirname(os.path.abspath(__file__)))
|
sys.path.append(os.path.dirname(os.path.abspath(__file__)))
|
||||||
|
|
||||||
from sql.prompt_maker import schema_linking_exampler, sql_exampler
|
from sql.prompt_maker import schema_linking_exampler, sql_exampler
|
||||||
from sql.constructor import schema_linking_example_selector, sql_example_selector
|
from sql.constructor import schema_linking_example_selector, sql_example_selector,sql_examples_vectorstore, reload_sql_example_collection
|
||||||
from sql.output_parser import schema_link_parse
|
from sql.output_parser import schema_link_parse
|
||||||
|
|
||||||
from util.llm_instance import llm
|
from util.llm_instance import llm
|
||||||
|
|
||||||
|
|
||||||
def query2sql(query_text: str,
|
class Text2DSLAgent(object):
|
||||||
|
def __init__(self):
|
||||||
|
self.schema_linking_exampler = schema_linking_exampler
|
||||||
|
self.sql_exampler = sql_exampler
|
||||||
|
|
||||||
|
self.sql_examples_vectorstore = sql_examples_vectorstore
|
||||||
|
self.schema_linking_example_selector = schema_linking_example_selector
|
||||||
|
self.sql_example_selector = sql_example_selector
|
||||||
|
|
||||||
|
self.schema_link_parse = schema_link_parse
|
||||||
|
|
||||||
|
self.llm = llm
|
||||||
|
|
||||||
|
def update_examples(self, sql_examplars, example_nums):
|
||||||
|
self.sql_examples_vectorstore, self.schema_linking_example_selector, self.sql_example_selector = reload_sql_example_collection(self.sql_examples_vectorstore,
|
||||||
|
sql_examplars,
|
||||||
|
self.schema_linking_example_selector,
|
||||||
|
self.sql_example_selector,
|
||||||
|
example_nums)
|
||||||
|
|
||||||
|
def query2sql(self, query_text: str,
|
||||||
schema : Union[dict, None] = None,
|
schema : Union[dict, None] = None,
|
||||||
current_date: str = None,
|
current_date: str = None,
|
||||||
linking: Union[List[Mapping[str, str]], None] = None
|
linking: Union[List[Mapping[str, str]], None] = None
|
||||||
):
|
):
|
||||||
|
|
||||||
print("query_text: ", query_text)
|
|
||||||
print("schema: ", schema)
|
|
||||||
print("current_date: ", current_date)
|
|
||||||
print("prior_schema_links: ", linking)
|
|
||||||
|
|
||||||
if linking is not None:
|
print("query_text: ", query_text)
|
||||||
prior_schema_links = {item['fieldValue']:item['fieldName'] for item in linking}
|
print("schema: ", schema)
|
||||||
else:
|
print("current_date: ", current_date)
|
||||||
prior_schema_links = {}
|
print("prior_schema_links: ", linking)
|
||||||
|
|
||||||
model_name = schema['modelName']
|
if linking is not None:
|
||||||
fields_list = schema['fieldNameList']
|
prior_schema_links = {item['fieldValue']:item['fieldName'] for item in linking}
|
||||||
|
else:
|
||||||
|
prior_schema_links = {}
|
||||||
|
|
||||||
schema_linking_prompt = schema_linking_exampler(query_text, model_name, fields_list, prior_schema_links, schema_linking_example_selector)
|
model_name = schema['modelName']
|
||||||
print("schema_linking_prompt->", schema_linking_prompt)
|
fields_list = schema['fieldNameList']
|
||||||
schema_link_output = llm(schema_linking_prompt)
|
|
||||||
schema_link_str = schema_link_parse(schema_link_output)
|
|
||||||
|
|
||||||
sql_prompt = sql_exampler(query_text, model_name, schema_link_str, current_date, sql_example_selector)
|
|
||||||
print("sql_prompt->", sql_prompt)
|
|
||||||
sql_output = llm(sql_prompt)
|
|
||||||
|
|
||||||
resp = dict()
|
schema_linking_prompt = self.schema_linking_exampler(query_text, model_name, fields_list, prior_schema_links, self.schema_linking_example_selector)
|
||||||
resp['query'] = query_text
|
print("schema_linking_prompt->", schema_linking_prompt)
|
||||||
resp['model'] = model_name
|
schema_link_output = self.llm(schema_linking_prompt)
|
||||||
resp['fields'] = fields_list
|
schema_link_str = self.schema_link_parse(schema_link_output)
|
||||||
resp['priorSchemaLinking'] = linking
|
|
||||||
resp['dataDate'] = current_date
|
sql_prompt = self.sql_exampler(query_text, model_name, schema_link_str, current_date, self.sql_example_selector)
|
||||||
|
print("sql_prompt->", sql_prompt)
|
||||||
|
sql_output = llm(sql_prompt)
|
||||||
|
|
||||||
resp['schemaLinkingOutput'] = schema_link_output
|
resp = dict()
|
||||||
resp['schemaLinkStr'] = schema_link_str
|
resp['query'] = query_text
|
||||||
|
resp['model'] = model_name
|
||||||
resp['sqlOutput'] = sql_output
|
resp['fields'] = fields_list
|
||||||
|
resp['priorSchemaLinking'] = linking
|
||||||
|
resp['dataDate'] = current_date
|
||||||
|
|
||||||
print("resp: ", resp)
|
resp['schemaLinkingOutput'] = schema_link_output
|
||||||
|
resp['schemaLinkStr'] = schema_link_str
|
||||||
|
|
||||||
|
resp['sqlOutput'] = sql_output
|
||||||
|
|
||||||
return resp
|
print("resp: ", resp)
|
||||||
|
|
||||||
|
return resp
|
||||||
|
|
||||||
|
text2sql_agent = Text2DSLAgent()
|
||||||
|
|
||||||
|
|||||||
@@ -1,5 +1,8 @@
|
|||||||
# -*- coding:utf-8 -*-
|
# -*- coding:utf-8 -*-
|
||||||
|
from typing import Any, List, Mapping, Optional, Union
|
||||||
|
|
||||||
import chromadb
|
import chromadb
|
||||||
|
from chromadb.api import Collection, Documents, Embeddings
|
||||||
from chromadb.config import Settings
|
from chromadb.config import Settings
|
||||||
|
|
||||||
from run_config import CHROMA_DB_PERSIST_PATH
|
from run_config import CHROMA_DB_PERSIST_PATH
|
||||||
@@ -7,4 +10,28 @@ from run_config import CHROMA_DB_PERSIST_PATH
|
|||||||
client = chromadb.Client(Settings(
|
client = chromadb.Client(Settings(
|
||||||
chroma_db_impl="duckdb+parquet",
|
chroma_db_impl="duckdb+parquet",
|
||||||
persist_directory=CHROMA_DB_PERSIST_PATH # Optional, defaults to .chromadb/ in the current directory
|
persist_directory=CHROMA_DB_PERSIST_PATH # Optional, defaults to .chromadb/ in the current directory
|
||||||
))
|
))
|
||||||
|
|
||||||
|
|
||||||
|
def empty_chroma_collection_2(collection:Collection):
|
||||||
|
collection_name = collection.name
|
||||||
|
client = collection._client
|
||||||
|
metadata = collection.metadata
|
||||||
|
embedding_function = collection._embedding_function
|
||||||
|
|
||||||
|
client.delete_collection(collection_name)
|
||||||
|
|
||||||
|
new_collection = client.get_or_create_collection(name=collection_name,
|
||||||
|
metadata=metadata,
|
||||||
|
embedding_function=embedding_function)
|
||||||
|
|
||||||
|
size_of_new_collection = new_collection.count()
|
||||||
|
|
||||||
|
print(f'Collection {collection_name} emptied. Size of new collection: {size_of_new_collection}')
|
||||||
|
|
||||||
|
return new_collection
|
||||||
|
|
||||||
|
|
||||||
|
def empty_chroma_collection(collection:Collection):
|
||||||
|
collection.delete()
|
||||||
|
|
||||||
|
|||||||
@@ -4,4 +4,5 @@ fastapi==0.95.1
|
|||||||
chromadb==0.3.21
|
chromadb==0.3.21
|
||||||
tiktoken==0.3.3
|
tiktoken==0.3.3
|
||||||
uvicorn[standard]==0.21.1
|
uvicorn[standard]==0.21.1
|
||||||
pandas==1.5.3
|
|
||||||
|
|
||||||
|
|||||||
BIN
docs/images/text2sql_config.png
Normal file
BIN
docs/images/text2sql_config.png
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 123 KiB |
26
docs/userguides/llm_config_cn.md
Normal file
26
docs/userguides/llm_config_cn.md
Normal file
@@ -0,0 +1,26 @@
|
|||||||
|
# LLM模型配置
|
||||||
|
|
||||||
|
### **简介**
|
||||||
|
|
||||||
|
语言模型的使用是超音数的重要一环。能显著增强对用户的问题的理解能力,是通过对话形式与用户交互的基石之一。在本项目中对语言模型能力的应用主要在 LLM 和 Embedding 两方面;默认使用的模型中,LLM选用闭源模型 gpt-3.5-turbo-16k,Embedding模型选用开源模型 GanymedeNil/text2vec-large-chinese。用户可以根据自己实际需求进行配置更改。
|
||||||
|
|
||||||
|
|
||||||
|
### **配置方式**
|
||||||
|
<div align="left" >
|
||||||
|
<img src=../images/nlp_config.png width="70%"/>
|
||||||
|
<p>图1-1 LLM配置文件</p>
|
||||||
|
</div>
|
||||||
|
|
||||||
|
1. LLM模型相关的配置,在 supersonic/chat/core/src/main/python/llm/run_config.py 进行配置。
|
||||||
|
2. LLM采用OpenAI的闭源模型 gpt-3.5-turbo-16k,在使用时需要提供OpenAI的API-Key才能调用LLM模型,通过 OPENAI_API_KEY 变量进行配置。
|
||||||
|
3. Embedding模型采用开源模型 GanymedeNil/text2vec-large-chinese,通过 HF_TEXT2VEC_MODEL_NAME 变量进行位置,为了使用方便采用托管在HuggingFace的源,初次启动时自动下载模型文件。
|
||||||
|
|
||||||
|
### **FAQ**
|
||||||
|
1. 可以用开源的LLM模型替代OpenAI的GPT模型吗?
|
||||||
|
- 暂时不能。我们测试过大部分主流的开源LLM,在实际使用中,在本项目需要LLM提供的逻辑推理和代码生成场景上,开源模型还不能满足需求。
|
||||||
|
- 我们会持续跟进开源LLM的最新进展,在有满足要求的开源LLM后,在项目中集成私有化部署开源LLM的能力。
|
||||||
|
2. GPT4、GPT3.5、GPT3.5-16k 这几个模型用哪个比较好?
|
||||||
|
- GPT3.5、GPT3.5-16k 均能基本满足要求,但会有输出结果不稳定的情况;GPT3.5的token长度限制为4k,在现有CoT策略下,容易出现超过长度限制的情况。
|
||||||
|
- GPT4的输出更稳定,但费用成本远超GPT3.5,可以根据实际使用场景进行选择。
|
||||||
|
3. Embedding模型用其他的可以吗?
|
||||||
|
- 可以。可以以该项目[text2vec]([URL](https://github.com/shibing624/text2vec))的榜单作为参考,然后在HuggingFace找到对应模型的model card,修改HF_TEXT2VEC_MODEL_NAME变量的取值。
|
||||||
25
docs/userguides/text2sql_cn.md
Normal file
25
docs/userguides/text2sql_cn.md
Normal file
@@ -0,0 +1,25 @@
|
|||||||
|
# text2sql功能相关配置
|
||||||
|
|
||||||
|
### **简介**
|
||||||
|
text2sql的功能实现,高度依赖对LLM的应用。通过LLM生成SQL的过程中,利用小样本(few-shots-examples)通过思维链(chain-of-thoughts)的方式对LLM in-context-learning的能力进行引导,对于生成较为稳定且符合下游语法解析规则的SQL非常重要。用户可以根据自身需要,对样本池及样本的数量进行配置,使其更加符合自身业务特点。
|
||||||
|
|
||||||
|
### **配置方式**
|
||||||
|
1. 样本池的配置。
|
||||||
|
- supersonic/chat/core/src/main/python/llm/few_shot_example/sql_exampler.py为样本池配置文件。用户可以以已有的样本作为参考,配置更贴近自身业务需求的样本,用于更好的引导LLM生成SQL。
|
||||||
|
2. 样本数量的配置。
|
||||||
|
- 在supersonic/chat/core/src/main/python/llm/run_config.py 中通过 TEXT2DSL_FEW_SHOTS_EXAMPLE_NUM 变量进行配置。
|
||||||
|
- 默认值为15,为项目在内部实践后较优的经验值。样本少太少,对导致LLM在生成SQL的过程中缺少引导和示范,生成的SQL会更不稳定;样本太多,会增加生成SQL需要的时间和LLM的token消耗(或超过LLM的token上限)。
|
||||||
|
- <div align="left" >
|
||||||
|
<img src=../images/text2sql_config.png width="70%"/>
|
||||||
|
<p>图1-1 样本数量的配置文件</p>
|
||||||
|
</div>
|
||||||
|
|
||||||
|
3. 运行中更新配置的脚本。
|
||||||
|
- 如果在启动项目后,用户需要对text2sql功能的相关配置进行调试,可以在修改相关配置文件后,通过脚本 supersonic/chat/core/src/main/python/bin/text2sql_resetting.sh 在项目运行中让配置生效。
|
||||||
|
|
||||||
|
### **FAQ**
|
||||||
|
1. 生成一个SQL需要消耗的的LLM token数量太多了,按照openAI对token的收费标准,生成一个SQL太贵了,可以少用一些token吗?
|
||||||
|
- 可以。 用户可以根据自身需求,如配置方式1.中所示,修改样本池中的样本,选用一些更加简短的样本。如配置方式2.中所示,减少使用的样本数量。
|
||||||
|
- 需要注意,样本和样本数量的选择对生成SQL的质量有很大的影响。过于激进的降低输入的token数量可能会降低生成SQL的质量。需要用户根据自身业务特点实测后进行平衡。
|
||||||
|
|
||||||
|
|
||||||
Reference in New Issue
Block a user