mirror of
https://github.com/tencentmusic/supersonic.git
synced 2025-12-12 12:37:55 +00:00
[improvement][chat] Move python code out of chat-core module
This commit is contained in:
70
chat/python/config/config_parse.py
Normal file
70
chat/python/config/config_parse.py
Normal file
@@ -0,0 +1,70 @@
|
||||
# -*- coding:utf-8 -*-
|
||||
import os
|
||||
import configparser
|
||||
|
||||
import os
|
||||
import sys
|
||||
sys.path.append(os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))))
|
||||
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||||
sys.path.append(os.path.dirname(os.path.abspath(__file__)))
|
||||
|
||||
from instances.logging_instance import logger
|
||||
|
||||
|
||||
def type_convert(input_str: str):
|
||||
try:
|
||||
return eval(input_str)
|
||||
except:
|
||||
return input_str
|
||||
|
||||
|
||||
PROJECT_DIR_PATH = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
|
||||
config_dir = "config"
|
||||
CONFIG_DIR_PATH = os.path.join(PROJECT_DIR_PATH, config_dir)
|
||||
config_file = "run_config.ini"
|
||||
config_path = os.path.join(CONFIG_DIR_PATH, config_file)
|
||||
|
||||
config = configparser.ConfigParser()
|
||||
config.read(config_path)
|
||||
|
||||
llm_parser_section_name = "LLMParser"
|
||||
LLMPARSER_HOST = config.get(llm_parser_section_name, 'LLMPARSER_HOST')
|
||||
LLMPARSER_PORT = int(config.get(llm_parser_section_name, 'LLMPARSER_PORT'))
|
||||
|
||||
chroma_db_section_name = "ChromaDB"
|
||||
CHROMA_DB_PERSIST_DIR = config.get(chroma_db_section_name, 'CHROMA_DB_PERSIST_DIR')
|
||||
PRESET_QUERY_COLLECTION_NAME = config.get(chroma_db_section_name, 'PRESET_QUERY_COLLECTION_NAME')
|
||||
SOLVED_QUERY_COLLECTION_NAME = config.get(chroma_db_section_name, 'SOLVED_QUERY_COLLECTION_NAME')
|
||||
TEXT2DSLAGENT_COLLECTION_NAME = config.get(chroma_db_section_name, 'TEXT2DSLAGENT_COLLECTION_NAME')
|
||||
TEXT2DSLAGENTCS_COLLECTION_NAME = config.get(chroma_db_section_name, 'TEXT2DSLAGENTCS_COLLECTION_NAME')
|
||||
TEXT2DSL_EXAMPLE_NUM = int(config.get(chroma_db_section_name, 'TEXT2DSL_EXAMPLE_NUM'))
|
||||
TEXT2DSL_FEWSHOTS_NUM = int(config.get(chroma_db_section_name, 'TEXT2DSL_FEWSHOTS_NUM'))
|
||||
TEXT2DSL_SELF_CONSISTENCY_NUM = int(config.get(chroma_db_section_name, 'TEXT2DSL_SELF_CONSISTENCY_NUM'))
|
||||
TEXT2DSL_IS_SHORTCUT = eval(config.get(chroma_db_section_name, 'TEXT2DSL_IS_SHORTCUT'))
|
||||
TEXT2DSL_IS_SELF_CONSISTENCY = eval(config.get(chroma_db_section_name, 'TEXT2DSL_IS_SELF_CONSISTENCY'))
|
||||
CHROMA_DB_PERSIST_PATH = os.path.join(PROJECT_DIR_PATH, CHROMA_DB_PERSIST_DIR)
|
||||
|
||||
text2vec_section_name = "Text2Vec"
|
||||
HF_TEXT2VEC_MODEL_NAME = config.get(text2vec_section_name, 'HF_TEXT2VEC_MODEL_NAME')
|
||||
|
||||
llm_provider_section_name = "LLMProvider"
|
||||
LLM_PROVIDER_NAME = config.get(llm_provider_section_name, 'LLM_PROVIDER_NAME')
|
||||
|
||||
llm_model_section_name = "LLMModel"
|
||||
llm_config_dict = {}
|
||||
for option in config.options(llm_model_section_name):
|
||||
llm_config_dict[option] = type_convert(config.get(llm_model_section_name, option))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
logger.info(f"PROJECT_DIR_PATH: {PROJECT_DIR_PATH}")
|
||||
logger.info(f"EMB_MODEL_PATH: {HF_TEXT2VEC_MODEL_NAME}")
|
||||
logger.info(f"CHROMA_DB_PERSIST_PATH: {CHROMA_DB_PERSIST_PATH}")
|
||||
logger.info(f"LLMPARSER_HOST: {LLMPARSER_HOST}")
|
||||
logger.info(f"LLMPARSER_PORT: {LLMPARSER_PORT}")
|
||||
logger.info(f"llm_config_dict: {llm_config_dict}")
|
||||
logger.info(f"TEXT2DSL_EXAMPLE_NUM: {TEXT2DSL_EXAMPLE_NUM}")
|
||||
logger.info(f"TEXT2DSL_FEWSHOTS_NUM: {TEXT2DSL_FEWSHOTS_NUM}")
|
||||
logger.info(f"TEXT2DSL_SELF_CONSISTENCY_NUM: {TEXT2DSL_SELF_CONSISTENCY_NUM}")
|
||||
logger.info(f"TEXT2DSL_IS_SHORTCUT: {TEXT2DSL_IS_SHORTCUT}")
|
||||
logger.info(f"TEXT2DSL_IS_SELF_CONSISTENCY: {TEXT2DSL_IS_SELF_CONSISTENCY}")
|
||||
28
chat/python/config/run_config.ini
Normal file
28
chat/python/config/run_config.ini
Normal file
@@ -0,0 +1,28 @@
|
||||
[LLMParser]
|
||||
LLMPARSER_HOST = 127.0.0.1
|
||||
LLMPARSER_PORT = 9092
|
||||
|
||||
[ChromaDB]
|
||||
CHROMA_DB_PERSIST_DIR = chm_db
|
||||
PRESET_QUERY_COLLECTION_NAME = preset_query_collection
|
||||
SOLVED_QUERY_COLLECTION_NAME = solved_query_collection
|
||||
TEXT2DSLAGENT_COLLECTION_NAME = text2dsl_agent_collection
|
||||
TEXT2DSLAGENTCS_COLLECTION_NAME = text2dsl_agent_cs_collection
|
||||
TEXT2DSL_EXAMPLE_NUM = 15
|
||||
TEXT2DSL_FEWSHOTS_NUM = 10
|
||||
TEXT2DSL_SELF_CONSISTENCY_NUM = 5
|
||||
TEXT2DSL_IS_SHORTCUT = False
|
||||
TEXT2DSL_IS_SELF_CONSISTENCY = False
|
||||
|
||||
[Text2Vec]
|
||||
HF_TEXT2VEC_MODEL_NAME = GanymedeNil/text2vec-large-chinese
|
||||
|
||||
|
||||
[LLMProvider]
|
||||
LLM_PROVIDER_NAME = openai
|
||||
|
||||
|
||||
[LLMModel]
|
||||
MODEL_NAME = gpt-3.5-turbo-16k
|
||||
OPENAI_API_KEY = YOUR_API_KEY
|
||||
TEMPERATURE = 0.0
|
||||
374
chat/python/few_shot_example/sql_examplar.py
Normal file
374
chat/python/few_shot_example/sql_examplar.py
Normal file
@@ -0,0 +1,374 @@
|
||||
examplars= [
|
||||
{ "currentDate":"2020-12-01",
|
||||
"tableName":"内容库产品",
|
||||
"fieldsList":"""["部门", "模块", "用户名", "访问次数", "访问人数", "访问时长", "数据日期"]""",
|
||||
"question":"比较jackjchen和robinlee在内容库的访问次数",
|
||||
"priorSchemaLinks":"""['jackjchen'->用户名, 'robinlee'->用户名]""",
|
||||
"analysis": """让我们一步一步地思考。在问题“比较jackjchen和robinlee在内容库的访问次数“中,我们被问:
|
||||
“比较jackjchen和robinlee”,所以我们需要column=[用户名],cell values = ['jackjchen', 'robinlee'],所以有[用户名:('jackjchen', 'robinlee')]
|
||||
”内容库的访问次数“,所以我们需要column=[访问次数]""",
|
||||
"schemaLinks":"""["用户名":("'jackjchen'", "'robinlee'"), "访问次数"]""",
|
||||
"sql":"""select 用户名, 访问次数 from 内容库产品 where 用户名 in ('jackjchen', 'robinlee')"""
|
||||
},
|
||||
{ "currentDate":"2022-11-06",
|
||||
"tableName":"内容库产品",
|
||||
"fieldsList":"""["部门", "模块", "用户名", "访问次数", "访问人数", "访问时长", "数据日期"]""",
|
||||
"question":"内容库近12个月访问人数 按部门",
|
||||
"priorSchemaLinks":"""[]""",
|
||||
"analysis": """让我们一步一步地思考。在问题“内容库近12个月访问人数 按部门“中,我们被问:
|
||||
”内容库近12个月“,所以我们需要column=[数据日期],cell values = [12],所以有[数据日期:(12)]
|
||||
“访问人数”,所以我们需要column=[访问人数]
|
||||
”按部门“,所以我们需要column=[部门]""",
|
||||
"schemaLinks":"""["数据日期":(12), "访问人数", "部门"]""",
|
||||
"sql":"""select 部门, 数据日期, 访问人数 from 内容库产品 where datediff('month', 数据日期, '2022-11-06') <= 12 """
|
||||
},
|
||||
{ "currentDate":"2023-04-21",
|
||||
"tableName":"内容库产品",
|
||||
"fieldsList":"""["部门", "模块", "用户名", "访问次数", "访问人数", "访问时长", "数据日期"]""",
|
||||
"question":"内容库美术部、技术研发部的访问时长",
|
||||
"priorSchemaLinks":"""['美术部'->部门, '技术研发部'->部门]""",
|
||||
"analysis": """让我们一步一步地思考。在问题“内容库美术部、技术研发部的访问时长“中,我们被问:
|
||||
“访问时长”,所以我们需要column=[访问时长]
|
||||
”内容库美术部、技术研发部“,所以我们需要column=[部门], cell values = ['美术部', '技术研发部'],所以有[部门:('美术部', '技术研发部')]""",
|
||||
"schemaLinks":"""["访问时长", "部门":("'美术部'", "'技术研发部'")]""",
|
||||
"sql":"""select 部门, 访问时长 from 内容库产品 where 部门 in ('美术部', '技术研发部')"""
|
||||
},
|
||||
{ "currentDate":"2023-08-21",
|
||||
"tableName":"严选",
|
||||
"fieldsList":"""["严选版权归属系", "付费模式", "结算播放份额", "付费用户结算播放份额", "数据日期"]""",
|
||||
"question":"近3天海田飞系MPPM结算播放份额",
|
||||
"priorSchemaLinks":"""['海田飞系'->严选版权归属系]""",
|
||||
"analysis": """让我们一步一步地思考。在问题“近3天海田飞系MPPM结算播放份额“中,我们被问:
|
||||
“MPPM结算播放份额”,所以我们需要column=[结算播放份额],
|
||||
”海田飞系“,所以我们需要column=[严选版权归属系], cell values = ['海田飞系'],所以有[严选版权归属系:('海田飞系')],
|
||||
”近3天“,所以我们需要column=[数据日期], cell values = [3],所以有[数据日期:(3)]""",
|
||||
"schemaLinks":"""["结算播放份额", "严选版权归属系":("'海田飞系'"), "数据日期":(3)]""",
|
||||
"sql":"""select 严选版权归属系, 结算播放份额 from 严选 where 严选版权归属系 = '海田飞系' and datediff('day', 数据日期, '2023-08-21') <= 3 """
|
||||
},
|
||||
{ "currentDate":"2023-05-22",
|
||||
"tableName":"歌曲库",
|
||||
"fieldsList":"""["是否潮流人歌曲", "C音歌曲ID", "C音歌曲MID", "歌曲名", "歌曲版本", "语种", "歌曲类型", "翻唱类型", "MPPM歌曲ID", "是否严选窄口径歌曲", "是否严选宽口径歌曲", "结算播放量", "运营播放量", "付费用户结算播放量", "历史累计结算播放量", "运营搜播量", "结算搜播量", "运营完播量", "运营推播量", "近7日复播率", "日均搜播量", "数据日期"]""",
|
||||
"question":"对比近7天翻唱版和纯音乐的歌曲播放量",
|
||||
"priorSchemaLinks":"""['纯音乐'->语种, '翻唱版'->歌曲版本]""",
|
||||
"analysis": """让我们一步一步地思考。在问题“对比近3天翻唱版和纯音乐的歌曲播放量“中,我们被问:
|
||||
“歌曲播放量”,所以我们需要column=[结算播放量]
|
||||
”翻唱版“,所以我们需要column=[歌曲版本], cell values = ['翻唱版'],所以有[歌曲版本:('翻唱版')]
|
||||
”和纯音乐的歌曲“,所以我们需要column=[语种], cell values = ['纯音乐'],所以有[语种:('纯音乐')]
|
||||
”近7天“,所以我们需要column=[数据日期], cell values = [7],所以有[数据日期:(7)]""",
|
||||
"schemaLinks":"""["结算播放量", "歌曲版本":("'翻唱版'"), "语种":("'纯音乐'"), "数据日期":(7)]""",
|
||||
"sql":"""select 歌曲版本, 语种, 结算播放量 from 歌曲库 where 歌曲版本 = '翻唱版' and 语种 = '纯音乐' and datediff('day', 数据日期, '2023-05-22') <= 7 """
|
||||
},
|
||||
{ "currentDate":"2023-05-31",
|
||||
"tableName":"艺人库",
|
||||
"fieldsList":"""["上下架状态", "歌手名", "歌手等级", "歌手类型", "歌手来源", "MPPM潮流人等级", "活跃区域", "年龄", "歌手才能", "歌手风格", "粉丝数", "潮音粉丝数", "超声波粉丝数", "推博粉丝数", "超声波歌曲数", "在架歌曲数", "超声波分享数", "独占歌曲数", "超声波在架歌曲评论数", "有播放量歌曲数", "数据日期"]""",
|
||||
"question":"对比一下陈拙悬、孟梅琦、赖媚韵的粉丝数",
|
||||
"priorSchemaLinks":"""['1527896'->MPPM歌手ID, '1565463'->MPPM歌手ID, '2141459'->MPPM歌手ID]""",
|
||||
"analysis": """让我们一步一步地思考。在问题“对比一下陈拙悬、孟梅琦、赖媚韵的粉丝数“中,我们被问:
|
||||
“粉丝数”,所以我们需要column=[粉丝数]
|
||||
”陈拙悬、孟梅琦、赖媚韵“,所以我们需要column=[歌手名], cell values = ['陈拙悬', '孟梅琦', '赖媚韵'],所以有[歌手名:('陈拙悬', '孟梅琦', '赖媚韵')]""",
|
||||
"schemaLinks":"""["粉丝数", "歌手名":("'陈拙悬'", "'孟梅琦'", "'赖媚韵'")]""",
|
||||
"sql":"""select 歌手名, 粉丝数 from 艺人库 where 歌手名 in ('陈拙悬', '孟梅琦', '赖媚韵')"""
|
||||
},
|
||||
{ "currentDate":"2023-07-31",
|
||||
"tableName":"歌曲库",
|
||||
"fieldsList":"""["歌曲名", "歌曲版本", "歌曲类型", "MPPM歌曲ID", "是否严选窄口径歌曲", "是否严选宽口径歌曲", "是否潮流人歌曲", "超声波歌曲ID", "C音歌曲ID", "C音歌曲MID", "结算播放量", "运营播放量", "分享量", "收藏量", "运营搜播量", "结算搜播量", "拉新用户数", "拉活用户数", "分享率", "结算播放份额", "数据日期"]""",
|
||||
"question":"播放量大于1万的歌曲有多少",
|
||||
"priorSchemaLinks":"""[]""",
|
||||
"analysis": """让我们一步一步地思考。在问题“播放量大于1万的歌曲有多少“中,我们被问:
|
||||
“歌曲有多少”,所以我们需要column=[歌曲名]
|
||||
”播放量大于1万的“,所以我们需要column=[结算播放量], cell values = [10000],所以有[结算播放量:(10000)]""",
|
||||
"schemaLinks":"""["歌曲名", "结算播放量":(10000)]""",
|
||||
"sql":"""select 歌曲名 from 歌曲库 where 结算播放量 > 10000"""
|
||||
},
|
||||
{ "currentDate":"2023-07-31",
|
||||
"tableName":"内容库产品",
|
||||
"fieldsList":"""["用户名", "部门", "模块", "访问时长", "访问次数", "访问人数", "数据日期"]""",
|
||||
"question":"内容库访问时长小于1小时,且来自美术部的用户是哪些",
|
||||
"priorSchemaLinks":"""['美术部'->部门]""",
|
||||
"analysis": """让我们一步一步地思考。在问题“内容库访问时长小于1小时,且来自美术部的用户是哪些“中,我们被问:
|
||||
“用户是哪些”,所以我们需要column=[用户名]
|
||||
”美术部的“,所以我们需要column=[部门], cell values = ['美术部'],所以有[部门:('美术部')]
|
||||
”访问时长小于1小时“,所以我们需要column=[访问时长], cell values = [1],所以有[访问时长:(1)]""",
|
||||
"schemaLinks":"""["用户名", "部门":("'美术部'"), "访问时长":(1)]""",
|
||||
"sql":"""select 用户名 from 内容库产品 where 部门 = '美术部' and 访问时长 < 1"""
|
||||
},
|
||||
{ "currentDate":"2023-08-31",
|
||||
"tableName":"内容库产品",
|
||||
"fieldsList":"""["用户名", "部门", "模块", "访问时长", "访问次数", "访问人数", "数据日期"]""",
|
||||
"question":"内容库pv最高的用户有哪些",
|
||||
"priorSchemaLinks":"""[]""",
|
||||
"analysis": """让我们一步一步地思考。在问题“内容库pv最高的用户有哪些“中,我们被问:
|
||||
“用户有哪些”,所以我们需要column=[用户名]
|
||||
”pv最高的“,所以我们需要column=[访问次数], cell values = [1],所以有[访问次数:(1)]""",
|
||||
"schemaLinks":"""["用户名", "访问次数":(1)]""",
|
||||
"sql":"""select 用户名 from 内容库产品 order by 访问次数 desc limit 1"""
|
||||
},
|
||||
{ "currentDate":"2023-08-31",
|
||||
"tableName":"艺人库",
|
||||
"fieldsList":"""["播放量层级", "播放量单调性", "播放量方差", "播放量突增类型", "播放量集中度", "歌手名", "歌手等级", "歌手类型", "歌手来源", "MPPM潮流人等级", "结算播放量", "运营播放量", "历史累计结算播放量", "有播放量歌曲数", "历史累计运营播放量", "付费用户结算播放量", "结算播放量占比", "运营播放份额", "免费用户结算播放占比", "完播量", "数据日期"]""",
|
||||
"question":"近90天袁亚伟播放量平均值是多少",
|
||||
"priorSchemaLinks":"""['152789226'->MPPM歌手ID]""",
|
||||
"analysis": """让我们一步一步地思考。在问题“近90天袁亚伟播放量平均值是多少“中,我们被问:
|
||||
“播放量平均值是多少”,所以我们需要column=[结算播放量]
|
||||
”袁亚伟“,所以我们需要column=[歌手名], cell values = ['袁亚伟'],所以有[歌手名:('袁亚伟')]
|
||||
”近90天“,所以我们需要column=[数据日期], cell values = [90],所以有[数据日期:(90)]""",
|
||||
"schemaLinks":"""["结算播放量", "歌手名":("'袁亚伟'"), "数据日期":(90)]""",
|
||||
"sql":"""select avg(结算播放量) from 艺人库 where 歌手名 = '袁亚伟' and datediff('day', 数据日期, '2023-08-31') <= 90 """
|
||||
},
|
||||
{ "currentDate":"2023-08-31",
|
||||
"tableName":"艺人库",
|
||||
"fieldsList":"""["播放量层级", "播放量单调性", "播放量方差", "播放量突增类型", "播放量集中度", "歌手名", "歌手等级", "歌手类型", "歌手来源", "MPPM潮流人等级", "结算播放量", "运营播放量", "历史累计结算播放量", "有播放量歌曲数", "历史累计运营播放量", "付费用户结算播放量", "结算播放量占比", "运营播放份额", "免费用户结算播放占比", "完播量", "数据日期"]""",
|
||||
"question":"周倩倩近7天结算播放量总和是多少",
|
||||
"priorSchemaLinks":"""['199509'->MPPM歌手ID]""",
|
||||
"analysis": """让我们一步一步地思考。在问题“周倩倩近7天结算播放量总和是多少“中,我们被问:
|
||||
“结算播放量总和是多少”,所以我们需要column=[结算播放量]
|
||||
”周倩倩“,所以我们需要column=[歌手名], cell values = ['周倩倩'],所以有[歌手名:('周倩倩')]
|
||||
”近7天“,所以我们需要column=[数据日期], cell values = [7],所以有[数据日期:(7)]""",
|
||||
"schemaLinks":"""["结算播放量", "歌手名":("'周倩倩'"), "数据日期":(7)]""",
|
||||
"sql":"""select sum(结算播放量) from 艺人库 where 歌手名 = '周倩倩' and datediff('day', 数据日期, '2023-08-31') <= 7 """
|
||||
},
|
||||
{ "currentDate":"2023-09-14",
|
||||
"tableName":"内容库产品",
|
||||
"fieldsList":"""["部门", "模块", "用户名", "访问次数", "访问人数", "访问时长", "数据日期"]""",
|
||||
"question":"内容库访问次数大于1k的部门是哪些",
|
||||
"priorSchemaLinks":"""[]""",
|
||||
"analysis": """让我们一步一步地思考。在问题“内容库访问次数大于1k的部门是哪些“中,我们被问:
|
||||
“部门是哪些”,所以我们需要column=[部门]
|
||||
”访问次数大于1k的“,所以我们需要column=[访问次数], cell values = [1000],所以有[访问次数:(1000)]""",
|
||||
"schemaLinks":"""["部门", "访问次数":(1000)]""",
|
||||
"sql":"""select 部门 from 内容库产品 where 访问次数 > 1000"""
|
||||
},
|
||||
{ "currentDate":"2023-09-18",
|
||||
"tableName":"歌曲库",
|
||||
"fieldsList":"""["歌曲名", "MPPM歌手ID", "歌曲版本", "歌曲类型", "MPPM歌曲ID", "是否严选窄口径歌曲", "是否严选宽口径歌曲", "是否潮流人歌曲", "超声波歌曲ID", "C音歌曲ID", "C音歌曲MID", "结算播放量", "运营播放量", "分享量", "收藏量", "运营搜播量", "结算搜播量", "拉新用户数", "拉活用户数", "分享率", "结算播放份额", "数据日期"]""",
|
||||
"question":"陈亿训唱的所有的播放量大于20k的孤勇者有哪些",
|
||||
"priorSchemaLinks":"""['199509'->MPPM歌手ID, '1527123'->MPPM歌曲ID]""",
|
||||
"analysis": """让我们一步一步地思考。在问题“陈亿训唱的所有的播放量大于20k的孤勇者有哪些“中,我们被问:
|
||||
“孤勇者有哪些”,所以我们需要column=[歌曲名], cell values = ['孤勇者'],所以有[歌曲名:('孤勇者')]
|
||||
”播放量大于20k的“,所以我们需要column=[结算播放量], cell values = [20000],所以有[结算播放量:(20000)]
|
||||
”陈亿训唱的“,所以我们需要column=[歌手名], cell values = ['陈亿训'],所以有[歌手名:('陈亿训')]""",
|
||||
"schemaLinks":"""["歌曲名":("'孤勇者'"), "结算播放量":(20000), "歌手名":("'陈亿训'")]""",
|
||||
"sql":"""select 歌曲名 from 歌曲库 where 结算播放量 > 20000 and 歌手名 = '陈亿训' and 歌曲名 = '孤勇者'"""
|
||||
},
|
||||
{ "currentDate":"2023-09-18",
|
||||
"tableName":"歌曲库",
|
||||
"fieldsList":"""["歌曲名", "歌曲版本", "歌手名", "歌曲类型", "发布时间", "MPPM歌曲ID", "是否严选窄口径歌曲", "是否严选宽口径歌曲", "是否潮流人歌曲", "超声波歌曲ID", "C音歌曲ID", "C音歌曲MID", "结算播放量", "运营播放量", "分享量", "收藏量", "运营搜播量", "结算搜播量", "拉新用户数", "拉活用户数", "分享率", "结算播放份额", "数据日期"]""",
|
||||
"question":"周洁轮去年发布的歌曲有哪些",
|
||||
"priorSchemaLinks":"""['23109'->MPPM歌手ID]""",
|
||||
"analysis": """让我们一步一步地思考。在问题“周洁轮去年发布的歌曲有哪些“中,我们被问:
|
||||
“歌曲有哪些”,所以我们需要column=[歌曲名]
|
||||
”去年发布的“,所以我们需要column=[发布时间], cell values = [1],所以有[发布时间:(1)]
|
||||
”周洁轮“,所以我们需要column=[歌手名], cell values = ['周洁轮'],所以有[歌手名:('周洁轮')]""",
|
||||
"schemaLinks":"""["歌曲名", "发布时间":(1), "歌手名":("'周洁轮'")]""",
|
||||
"sql":"""select 歌曲名 from 歌曲库 where datediff('year', 发布时间, '2023-09-18') <= 1 and 歌手名 = '周洁轮'"""
|
||||
},
|
||||
{ "currentDate":"2023-09-11",
|
||||
"tableName":"艺人库",
|
||||
"fieldsList":"""["播放量层级", "播放量单调性", "播放量方差", "播放量突增类型", "播放量集中度", "歌手名", "歌手等级", "歌手类型", "歌手来源", "签约日期", "MPPM潮流人等级", "结算播放量", "运营播放量", "历史累计结算播放量", "有播放量歌曲数", "历史累计运营播放量", "付费用户结算播放量", "结算播放量占比", "运营播放份额", "免费用户结算播放占比", "完播量", "数据日期"]""",
|
||||
"question":"我想要近半年签约的播放量前十的歌手有哪些",
|
||||
"priorSchemaLinks":"""[]""",
|
||||
"analysis": """让我们一步一步地思考。在问题“我想要近半年签约的播放量前十的歌手“中,我们被问:
|
||||
“歌手有哪些”,所以我们需要column=[歌手名]
|
||||
”播放量前十的“,所以我们需要column=[结算播放量], cell values = [10],所以有[结算播放量:(10)]
|
||||
”近半年签约的“,所以我们需要column=[签约日期], cell values = [0.5],所以有[签约日期:(0.5)]""",
|
||||
"schemaLinks":"""["歌手名", "结算播放量":(10), "签约日期":(0.5)]""",
|
||||
"sql":"""select 歌手名 from 艺人库 where datediff('year', 签约日期, '2023-09-11') <= 0.5 order by 结算播放量 desc limit 10"""
|
||||
},
|
||||
{ "currentDate":"2023-08-12",
|
||||
"tableName":"歌曲库",
|
||||
"fieldsList": """["发行日期", "歌曲语言", "歌曲来源", "歌曲流派", "歌曲名", "歌曲版本", "歌曲类型", "发行时间", "数据日期"]""",
|
||||
"question":"最近一年发行的歌曲中,有哪些在近7天播放超过一千万的",
|
||||
"priorSchemaLinks":"""[]""",
|
||||
"analysis": """让我们一步一步地思考。在问题“最近一年发行的歌曲中,有哪些在近7天播放超过一千万的“中,我们被问:
|
||||
“发行的歌曲中,有哪些”,所以我们需要column=[歌曲名]
|
||||
”最近一年发行的“,所以我们需要column=[发行日期], cell values = [1],所以有[发行日期:(1)]
|
||||
”在近7天播放超过一千万的“,所以我们需要column=[数据日期, 结算播放量], cell values = [7, 10000000],所以有[数据日期:(7), 结算播放量:(10000000)]""",
|
||||
"schemaLinks":"""["歌曲名", "发行日期":(1), "数据日期":(7), "结算播放量":(10000000)]""",
|
||||
"sql":"""select 歌曲名 from 歌曲库 where datediff('year', 发行日期, '2023-08-12') <= 1 and datediff('day', 数据日期, '2023-08-12') <= 7 and 结算播放量 > 10000000"""
|
||||
},
|
||||
{ "currentDate":"2023-08-12",
|
||||
"tableName":"歌曲库",
|
||||
"fieldsList": """["发行日期", "歌曲语言", "歌曲来源", "歌曲流派", "歌曲名", "歌曲版本", "歌曲类型", "发行时间", "数据日期"]""",
|
||||
"question":"今年以来发行的歌曲中,有哪些在近7天播放超过一千万的",
|
||||
"priorSchemaLinks":"""[]""",
|
||||
"analysis": """让我们一步一步地思考。在问题“今年以来发行的歌曲中,有哪些在近7天播放超过一千万的“中,我们被问:
|
||||
“发行的歌曲中,有哪些”,所以我们需要column=[歌曲名]
|
||||
”今年以来发行的“,所以我们需要column=[发行日期], cell values = [0],所以有[发行日期:(0)]
|
||||
”在近7天播放超过一千万的“,所以我们需要column=[数据日期, 结算播放量], cell values = [7, 10000000],所以有[数据日期:(7), 结算播放量:(10000000)]""",
|
||||
"schemaLinks":"""["歌曲名", "发行日期":(0), "数据日期":(7), "结算播放量":(10000000)]""",
|
||||
"sql":"""select 歌曲名 from 歌曲库 where datediff('year', 发行日期, '2023-08-12') <= 0 and datediff('day', 数据日期, '2023-08-12') <= 7 and 结算播放量 > 10000000"""
|
||||
},
|
||||
{ "currentDate":"2023-08-12",
|
||||
"tableName":"歌曲库",
|
||||
"fieldsList": """["发行日期", "歌曲语言", "歌曲来源", "歌曲流派", "歌曲名", "歌曲版本", "歌曲类型", "发行时间", "数据日期"]""",
|
||||
"question":"2023年以来发行的歌曲中,有哪些在近7天播放超过一千万的",
|
||||
"priorSchemaLinks":"""['514129144'->MPPM歌曲ID]""",
|
||||
"analysis": """让我们一步一步地思考。在问题“2023年以来发行的歌曲中,有哪些在近7天播放超过一千万的“中,我们被问:
|
||||
“发行的歌曲中,有哪些”,所以我们需要column=[歌曲名]
|
||||
”2023年以来发行的“,所以我们需要column=[发行日期], cell values = ['2023-01-01'],所以有[发行日期:('2023-01-01')]
|
||||
”在近7天播放超过一千万的“,所以我们需要column=[数据日期, 结算播放量], cell values = [7, 10000000],所以有[数据日期:(7), 结算播放量:(10000000)]""",
|
||||
"schemaLinks":"""["歌曲名", "发行日期":("'2023-01-01'"), "数据日期":(7), "结算播放量":(10000000)]""",
|
||||
"sql":"""select 歌曲名 from 歌曲库 where 发行日期 >= '2023-01-01' and datediff('day', 数据日期, '2023-08-12') <= 7 and 结算播放量 > 10000000"""
|
||||
},
|
||||
{ "currentDate":"2023-08-01",
|
||||
"tableName":"歌曲库",
|
||||
"fieldsList":"""["歌曲名", "歌曲版本", "歌手名", "歌曲类型", "发布时间", "MPPM歌曲ID", "是否严选窄口径歌曲", "是否严选宽口径歌曲", "是否潮流人歌曲", "超声波歌曲ID", "C音歌曲ID", "C音歌曲MID", "结算播放量", "运营播放量", "分享量", "收藏量", "运营搜播量", "结算搜播量", "拉新用户数", "拉活用户数", "分享率", "结算播放份额", "数据日期"]""",
|
||||
"question":"周洁轮2023年6月之后发布的歌曲有哪些",
|
||||
"priorSchemaLinks":"""['23109'->MPPM歌手ID]""",
|
||||
"analysis": """让我们一步一步地思考。在问题“周洁轮2023年6月之后发布的歌曲有哪些“中,我们被问:
|
||||
“歌曲有哪些”,所以我们需要column=[歌曲名]
|
||||
”2023年6月之后发布的“,所以我们需要column=[发布时间], cell values = ['2023-06-01'],所以有[发布时间:('2023-06-01')]
|
||||
”周洁轮“,所以我们需要column=[歌手名], cell values = ['周洁轮'],所以有[歌手名:('周洁轮')]""",
|
||||
"schemaLinks":"""["歌曲名", "发布时间":("'2023-06-01'"), "歌手名":("'周洁轮'")]""",
|
||||
"sql":"""select 歌曲名 from 歌曲库 where 发布时间 >= '2023-06-01' and 歌手名 = '周洁轮'"""
|
||||
},
|
||||
{ "currentDate":"2023-08-01",
|
||||
"tableName":"歌曲库",
|
||||
"fieldsList":"""["歌曲名", "歌曲版本", "歌手名", "歌曲类型", "发布时间", "MPPM歌曲ID", "是否严选窄口径歌曲", "是否严选宽口径歌曲", "是否潮流人歌曲", "超声波歌曲ID", "C音歌曲ID", "C音歌曲MID", "结算播放量", "运营播放量", "分享量", "收藏量", "运营搜播量", "结算搜播量", "拉新用户数", "拉活用户数", "分享率", "结算播放份额", "数据日期"]""",
|
||||
"question":"邓梓琦在2023年1月5日之后发布的歌曲中,有哪些播放量大于500W的?",
|
||||
"priorSchemaLinks":"""['2312311'->MPPM歌手ID]""",
|
||||
"analysis": """让我们一步一步地思考。在问题“邓梓琦在2023年1月5日之后发布的歌曲中,有哪些播放量大于500W的?“中,我们被问:
|
||||
“歌曲中,有哪些”,所以我们需要column=[歌曲名]
|
||||
“播放量大于500W的”,所以我们需要column=[结算播放量], cell values = [5000000],所以有[结算播放量:(5000000)]
|
||||
”邓梓琦在2023年1月5日之后发布的“,所以我们需要column=[发布时间], cell values = ['2023-01-05'],所以有[发布时间:('2023-01-05')]
|
||||
”邓梓琦“,所以我们需要column=[歌手名], cell values = ['邓梓琦'],所以有[歌手名:('邓梓琦')]""",
|
||||
"schemaLinks":"""["歌曲名", "结算播放量":(5000000), "发布时间":("'2023-01-05'"), "歌手名":("'邓梓琦'")]""",
|
||||
"sql":"""select 歌曲名 from 歌曲库 where 发布时间 >= '2023-01-05' and 歌手名 = '邓梓琦' and 结算播放量 > 5000000"""
|
||||
},
|
||||
{ "currentDate":"2023-09-17",
|
||||
"tableName":"歌曲库",
|
||||
"fieldsList":"""["歌曲名", "歌曲版本", "歌手名", "歌曲类型", "发布时间", "MPPM歌曲ID", "是否严选窄口径歌曲", "是否严选宽口径歌曲", "是否潮流人歌曲", "超声波歌曲ID", "C音歌曲ID", "C音歌曲MID", "结算播放量", "运营播放量", "分享量", "收藏量", "运营搜播量", "结算搜播量", "拉新用户数", "拉活用户数", "分享率", "结算播放份额", "数据日期"]""",
|
||||
"question":"2023年6月以后,张亮英播放量大于200万的歌曲有哪些?",
|
||||
"priorSchemaLinks":"""['45453'->MPPM歌手ID]""",
|
||||
"analysis": """让我们一步一步地思考。在问题“2023年6月以后,张亮英播放量大于200万的歌曲有哪些?“中,我们被问:
|
||||
“播放量大于200万的”,所以我们需要column=[结算播放量], cell values = [2000000],所以有[结算播放量:(2000000)]
|
||||
”2023年6月以后,张亮英“,所以我们需要column=[数据日期, 歌手名], cell values = ['2023-06-01', '张亮英'],所以有[数据日期:('2023-06-01'), 歌手名:('张亮英')],
|
||||
”歌曲有哪些“,所以我们需要column=[歌曲名]""",
|
||||
"schemaLinks":"""["结算播放量":(2000000), "数据日期":("'2023-06-01'"), "歌手名":("'张亮英'"), "歌曲名"]""",
|
||||
"sql":"""select 歌曲名 from 歌曲库 where 数据日期 >= '2023-06-01' and 歌手名 = '张亮英' and 结算播放量 > 2000000"""
|
||||
},
|
||||
{ "currentDate":"2023-08-16",
|
||||
"tableName":"歌曲库",
|
||||
"fieldsList":"""["歌曲名", "歌曲版本", "歌手名", "歌曲类型", "发布时间", "MPPM歌曲ID", "是否严选窄口径歌曲", "是否严选宽口径歌曲", "是否潮流人歌曲", "超声波歌曲ID", "C音歌曲ID", "C音歌曲MID", "结算播放量", "运营播放量", "分享量", "收藏量", "运营搜播量", "结算搜播量", "拉新用户数", "拉活用户数", "分享率", "结算播放份额", "数据日期"]""",
|
||||
"question":"2021年6月以后发布的李雨纯的播放量大于20万的歌曲有哪些",
|
||||
"priorSchemaLinks":"""['23109'->MPPM歌手ID]""",
|
||||
"analysis": """让我们一步一步地思考。在问题“2021年6月以后发布的李雨纯的播放量大于20万的歌曲有哪些“中,我们被问:
|
||||
“播放量大于20万的”,所以我们需要column=[结算播放量], cell values = [200000],所以有[结算播放量:(200000)]
|
||||
”2021年6月以后发布的“,所以我们需要column=[发布时间], cell values = ['2021-06-01'],所以有[发布时间:('2021-06-01')]
|
||||
”李雨纯“,所以我们需要column=[歌手名], cell values = ['李雨纯'],所以有[歌手名:('李雨纯')]""",
|
||||
"schemaLinks":"""["结算播放量":(200000), "发布时间":("'2021-06-01'"), "歌手名":("'李雨纯'")]""",
|
||||
"sql":"""select 歌曲名 from 歌曲库 where 发布时间 >= '2021-06-01' and 歌手名 = '李雨纯' and 结算播放量 > 200000"""
|
||||
},
|
||||
{ "currentDate":"2023-08-16",
|
||||
"tableName":"歌曲库",
|
||||
"fieldsList":"""["歌曲名", "歌曲版本", "歌手名", "歌曲类型", "发布时间", "MPPM歌曲ID", "是否严选窄口径歌曲", "是否严选宽口径歌曲", "是否潮流人歌曲", "超声波歌曲ID", "C音歌曲ID", "C音歌曲MID", "结算播放量", "运营播放量", "分享量", "收藏量", "运营搜播量", "结算搜播量", "拉新用户数", "拉活用户数", "分享率", "结算播放份额", "数据日期"]""",
|
||||
"question":"刘锝桦在1992年4月2日到2020年5月2日之间发布的播放量大于20万的歌曲有哪些",
|
||||
"priorSchemaLinks":"""['4234234'->MPPM歌手ID]""",
|
||||
"analysis": """让我们一步一步地思考。在问题“刘锝桦在1992年4月2日到2020年5月2日之间发布的播放量大于20万的歌曲有哪些“中,我们被问:
|
||||
“播放量大于20万的”,所以我们需要column=[结算播放量], cell values = [200000],所以有[结算播放量:(200000)]
|
||||
”1992年4月2日到2020年5月2日之间发布的“, 所以我们需要column=[发布时间], cell values = ['1992-04-02', '2020-05-02'],所以有[发布时间:('1992-04-02', '2020-05-02')]
|
||||
”刘锝桦“,所以我们需要column=[歌手名], cell values = ['刘锝桦'],所以有[歌手名:('刘锝桦')]""",
|
||||
"schemaLinks":"""["结算播放量":(200000), "发布时间":("'1992-04-02'", "'2020-05-02'"), "歌手名":("'刘锝桦'")]""",
|
||||
"sql":"""select 歌曲名 from 歌曲库 where 发布时间 >= '1992-04-02' and 发布时间 <= '2020-05-02' and 歌手名 = '刘锝桦' and 结算播放量 > 200000"""
|
||||
},
|
||||
{
|
||||
"currentDate":"2023-09-04",
|
||||
"tableName":"内容库产品",
|
||||
"fieldsList":"""["用户名", "部门", "模块", "访问时长", "访问次数", "访问人数", "数据日期"]""",
|
||||
"question":"内容库近30天访问次数的平均数",
|
||||
"priorSchemaLinks":"""[]""",
|
||||
"analysis": """让我们一步一步地思考。在问题“内容库近30天访问次数的平均数“中,我们被问:
|
||||
“访问次数的平均数”,所以我们需要column=[访问次数]
|
||||
”内容库近30天“,所以我们需要column=[数据日期], cell values = [30],所以有[数据日期:(30)]""",
|
||||
"schemaLinks":"""["访问次数", "数据日期":(30)]""",
|
||||
"sql":"""select avg(访问次数) from 内容库产品 where datediff('day', 数据日期, '2023-09-04') <= 30 """
|
||||
},
|
||||
{
|
||||
"currentDate":"2023-09-04",
|
||||
"tableName":"内容库产品",
|
||||
"fieldsList":"""["用户名", "部门", "模块", "访问时长", "访问次数", "访问人数", "数据日期"]""",
|
||||
"question":"内容库近半年哪个月的访问次数汇总最高",
|
||||
"priorSchemaLinks":"""[]""",
|
||||
"analysis": """让我们一步一步地思考。在问题“内容库近半年哪个月的访问次数汇总最高“中,我们被问:
|
||||
“访问次数汇总最高”,所以我们需要column=[访问次数], cell values = [1],所以有[访问次数:(1)]
|
||||
”内容库近半年“,所以我们需要column=[数据日期], cell values = [0.5],所以有[数据日期:(0.5)]""",
|
||||
"schemaLinks":"""["访问次数":(1), "数据日期":(0.5)]""",
|
||||
"sql":"""select MONTH(数据日期), sum(访问次数) from 内容库产品 where datediff('year', 数据日期, '2023-09-04') <= 0.5 group by MONTH(数据日期) order by sum(访问次数) desc limit 1"""
|
||||
},
|
||||
{
|
||||
"currentDate":"2023-09-04",
|
||||
"tableName":"内容库产品",
|
||||
"fieldsList":"""["用户名", "部门", "模块", "访问时长", "访问次数", "访问人数", "数据日期"]""",
|
||||
"question":"内容库近半年每个月的平均访问次数",
|
||||
"priorSchemaLinks":"""[]""",
|
||||
"analysis": """让我们一步一步地思考。在问题“内容库近半年每个月的平均访问次数“中,我们被问:
|
||||
“每个月的平均访问次数”,所以我们需要column=[访问次数]
|
||||
”内容库近半年“,所以我们需要column=[数据日期], cell values = [0.5],所以有[数据日期:(0.5)]""",
|
||||
"schemaLinks":"""["访问次数", "数据日期":(0.5)]""",
|
||||
"sql":"""select MONTH(数据日期), avg(访问次数) from 内容库产品 where datediff('year', 数据日期, '2023-09-04') <= 0.5 group by MONTH(数据日期)"""
|
||||
},
|
||||
{
|
||||
"currentDate":"2023-09-10",
|
||||
"tableName":"内容库产品",
|
||||
"fieldsList":"""["用户名", "部门", "模块", "访问时长", "访问次数", "访问人数", "数据日期"]""",
|
||||
"question":"内容库 按部门统计访问次数 top10 的部门",
|
||||
"priorSchemaLinks":"""[]""",
|
||||
"analysis": """让我们一步一步地思考。在问题“内容库 按部门统计访问次数 top10 的部门“中,我们被问:
|
||||
“访问次数 top10 的部门”,所以我们需要column=[访问次数], cell values = [10],所以有[访问次数:(10)]
|
||||
”内容库 按部门统计“,所以我们需要column=[部门]""",
|
||||
"schemaLinks":"""["访问次数":(10), "部门"]""",
|
||||
"sql":"""select 部门, sum(访问次数) from 内容库产品 group by 部门 order by sum(访问次数) desc limit 10"""
|
||||
},
|
||||
{
|
||||
"currentDate":"2023-09-10",
|
||||
"tableName":"内容库产品",
|
||||
"fieldsList":"""["用户名", "部门", "模块", "访问时长", "访问次数", "访问人数", "数据日期"]""",
|
||||
"question":"超音速 近7个月,月度总访问量超过 2万的月份",
|
||||
"priorSchemaLinks":"""[]""",
|
||||
"analysis": """让我们一步一步地思考。在问题“超音速 近7个月,月度总访问量超过 2万的月份“中,我们被问:
|
||||
“月度总访问量超过 2万的月份”,所以我们需要column=[访问次数], cell values = [20000],所以有[访问次数:(20000)]
|
||||
”超音速 近7个月“,所以我们需要column=[数据日期], cell values = [7],所以有[数据日期:(7)]""",
|
||||
"schemaLinks":"""["访问次数":(20000), "数据日期":(7)]""",
|
||||
"sql":"""select MONTH(数据日期) from 内容库产品 where datediff('day', 数据日期, '2023-09-10') <= 7 group by MONTH(数据日期) having sum(访问次数) > 20000"""
|
||||
},
|
||||
{
|
||||
"currentDate":"2023-09-10",
|
||||
"tableName":"歌曲库",
|
||||
"fieldsList":"""["歌曲语言", "歌曲来源", "运营播放量", "播放量", "歌曲名", "结算播放量", "专辑名", "发布日期", "歌曲版本", "歌曲类型", "数据日期"]""",
|
||||
"question":"2022年7月到2023年7月之间发布到歌曲,按播放量取top 100,再按月粒度来统计近1年的运营播放量",
|
||||
"priorSchemaLinks":"""[]""",
|
||||
"analysis": """让我们一步一步地思考。在问题“2022年7月到2023年7月之间发布到歌曲,按播放量取top 100,再按月粒度来统计近1年的运营播放量“中,我们被问:
|
||||
“按月粒度来统计近1年的运营播放量”,所以我们需要column=[运营播放量, 数据日期], cell values = [1],所以有[运营播放量, 数据日期:(1)]
|
||||
”按播放量取top 100“,所以我们需要column=[播放量], cell values = [100],所以有[播放量:(100)]
|
||||
“2022年7月到2023年7月之间发布到歌曲”,所以我们需要column=[发布日期], cell values = ['2022-07-01', '2023-07-01'],所以有[发布日期:('2022-07-01', '2023-07-01')]""",
|
||||
"schemaLinks":"""["运营播放量", "数据日期":(1), "播放量":(100), "发布日期":("'2022-07-01'", "'2023-07-01'")]""",
|
||||
"sql":"""select MONTH(数据日期), sum(运营播放量) from (select 数据日期, 运营播放量 from 歌曲库 where 发布日期 >= '2022-07-01' and 发布日期 <= '2023-07-01' order by 播放量 desc limit 100) t where datediff('year', 数据日期, '2023-09-10') <= 1 group by MONTH(数据日期)"""
|
||||
},
|
||||
{
|
||||
"currentDate":"2023-09-10",
|
||||
"tableName":"歌曲库",
|
||||
"fieldsList":"""["歌曲语言", "歌曲来源", "运营播放量", "播放量", "歌曲名", "结算播放量", "专辑名", "发布日期", "歌曲版本", "歌曲类型", "数据日期"]""",
|
||||
"question":"2022年7月到2023年7月之间发布到歌曲,按播放量取top100,再按月粒度来统计近1年的运营播放量之和,筛选出其中运营播放量之和大于2k的月份",
|
||||
"priorSchemaLinks":"""[]""",
|
||||
"analysis": """让我们一步一步地思考。在问题“2022年7月到2023年7月之间发布到歌曲,按播放量取top100,再按月粒度来统计近1年的运营播放量之和,筛选出其中运营播放量之和大于2k的月份“中,我们被问:
|
||||
“筛选出其中运营播放量之和大于2k的月份”,所以我们需要column=[运营播放量], cell values = [2000],所以有[运营播放量:(2000)]
|
||||
”按月粒度来统计近1年的运营播放量之和“,所以我们需要column=[数据日期], cell values = [1],所以有[数据日期:(1)]
|
||||
”按播放量取top100“,所以我们需要column=[播放量], cell values = [100],所以有[播放量:(100)]
|
||||
”2022年7月到2023年7月之间发布到歌曲“,所以我们需要column=[发布日期], cell values = ['2022-07-01', '2023-07-01'],所以有[发布日期:('2022-07-01', '2023-07-01')]""",
|
||||
"schemaLinks":"""["运营播放量":(2000), "数据日期":(1), "播放量":(100), "发布日期":("'2022-07-01'", "'2023-07-01'")]""",
|
||||
"sql":"""select MONTH(数据日期), sum(运营播放量) from (select 数据日期, 运营播放量 from 歌曲库 where 发布日期 >= '2022-07-01' and 发布日期 <= '2023-07-01' order by 播放量 desc limit 100) t where datediff('year', 数据日期, '2023-09-10') <= 1 group by MONTH(数据日期) having sum(运营播放量) > 2000"""
|
||||
},
|
||||
{
|
||||
"currentDate":"2023-11-01",
|
||||
"tableName":"营销月模型",
|
||||
"fieldsList":"""["国家中文名", "机型类别", "销量", "数据日期"]""",
|
||||
"question":"今年智能机在哪个国家的销量之和最高",
|
||||
"priorSchemaLinks":"""['智能机'->机型类别]""",
|
||||
"analysis": """让我们一步一步地思考。在问题“今年智能机在哪个国家的销量之和最高“中,我们被问:
|
||||
“销量最高”,所以我们需要column=[销量], cell values = [1],所以有[销量:(1)]
|
||||
”今年“,所以我们需要column=[数据日期], cell values = ['2023-01-01', '2023-11-01'],所以有[数据日期:('2023-01-01', '2023-11-01')]
|
||||
”智能机“,所以我们需要column=[机型类别], cell values = ['智能机'],所以有[机型类别:('智能机')]""",
|
||||
"schemaLinks":"""["销量":(1), "数据日期":("'2023-01-01'", "'2023-11-01'"), "机型类别":("'智能机'")]""",
|
||||
"sql":"""select 国家中文名, sum(销量) from 营销月模型 where 机型类别 = '智能机' and 数据日期 >= '2023-01-01' and 数据日期 <= '2023-11-01' group by 国家中文名 order by sum(销量) desc limit 1"""
|
||||
}
|
||||
]
|
||||
21
chat/python/instances/chromadb_instance.py
Normal file
21
chat/python/instances/chromadb_instance.py
Normal file
@@ -0,0 +1,21 @@
|
||||
# -*- coding:utf-8 -*-
|
||||
from typing import Any, List, Mapping, Optional, Union
|
||||
|
||||
import chromadb
|
||||
from chromadb.api import Collection
|
||||
from chromadb.config import Settings
|
||||
|
||||
import os
|
||||
import sys
|
||||
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||||
sys.path.append(os.path.dirname(os.path.abspath(__file__)))
|
||||
|
||||
from config.config_parse import CHROMA_DB_PERSIST_PATH
|
||||
|
||||
|
||||
client = chromadb.Client(
|
||||
Settings(
|
||||
chroma_db_impl="duckdb+parquet",
|
||||
persist_directory=CHROMA_DB_PERSIST_PATH, # Optional, defaults to .chromadb/ in the current directory
|
||||
)
|
||||
)
|
||||
21
chat/python/instances/llm_instance.py
Normal file
21
chat/python/instances/llm_instance.py
Normal file
@@ -0,0 +1,21 @@
|
||||
# -*- coding:utf-8 -*-
|
||||
from langchain import llms
|
||||
|
||||
import os
|
||||
import sys
|
||||
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||||
sys.path.append(os.path.dirname(os.path.abspath(__file__)))
|
||||
|
||||
from config.config_parse import LLM_PROVIDER_NAME, llm_config_dict
|
||||
|
||||
|
||||
def get_llm_provider(llm_provider_name: str, llm_config_dict: dict):
|
||||
if llm_provider_name in llms.type_to_cls_dict:
|
||||
llm_provider = llms.type_to_cls_dict[llm_provider_name]
|
||||
llm = llm_provider(**llm_config_dict)
|
||||
return llm
|
||||
else:
|
||||
raise Exception("llm_provider_name is not supported: {}".format(llm_provider_name))
|
||||
|
||||
|
||||
llm = get_llm_provider(LLM_PROVIDER_NAME, llm_config_dict)
|
||||
6
chat/python/instances/logging_instance.py
Normal file
6
chat/python/instances/logging_instance.py
Normal file
@@ -0,0 +1,6 @@
|
||||
from loguru import logger
|
||||
import sys
|
||||
|
||||
logger.remove() #remove the old handler. Else, the old one will work along with the new one you've added below'
|
||||
logger.add(sys.stdout, format="{time:YYYY-MM-DD at HH:mm:ss} | {level} | {message}", level="INFO")
|
||||
|
||||
23
chat/python/instances/text2vec.py
Normal file
23
chat/python/instances/text2vec.py
Normal file
@@ -0,0 +1,23 @@
|
||||
# -*- coding:utf-8 -*-
|
||||
from typing import List
|
||||
|
||||
from chromadb.api.types import Documents, EmbeddingFunction, Embeddings
|
||||
from langchain.embeddings import HuggingFaceEmbeddings
|
||||
|
||||
from config.config_parse import HF_TEXT2VEC_MODEL_NAME
|
||||
|
||||
hg_embedding = HuggingFaceEmbeddings(model_name=HF_TEXT2VEC_MODEL_NAME)
|
||||
|
||||
|
||||
class Text2VecEmbeddingFunction(EmbeddingFunction):
|
||||
def __call__(self, texts: Documents) -> Embeddings:
|
||||
|
||||
embeddings = hg_embedding.embed_documents(texts)
|
||||
|
||||
return embeddings
|
||||
|
||||
|
||||
def get_embeddings(documents: List[str]) -> List[List[float]]:
|
||||
embeddings = hg_embedding.embed_documents(documents)
|
||||
|
||||
return embeddings
|
||||
8
chat/python/requirements.txt
Normal file
8
chat/python/requirements.txt
Normal file
@@ -0,0 +1,8 @@
|
||||
langchain==0.0.207
|
||||
openai==0.27.4
|
||||
fastapi==0.95.1
|
||||
chromadb==0.3.26
|
||||
tiktoken==0.3.3
|
||||
uvicorn[standard]==0.21.1
|
||||
pandas==1.5.3
|
||||
loguru==0.7.2
|
||||
99
chat/python/services/plugin_call/prompt_construct.py
Normal file
99
chat/python/services/plugin_call/prompt_construct.py
Normal file
@@ -0,0 +1,99 @@
|
||||
# -*- coding:utf-8 -*-
|
||||
import json
|
||||
import os
|
||||
import re
|
||||
import sys
|
||||
from typing import Any, List, Mapping, Union
|
||||
|
||||
from instances.logging_instance import logger
|
||||
|
||||
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||||
sys.path.append(os.path.dirname(os.path.abspath(__file__)))
|
||||
|
||||
|
||||
def construct_plugin_prompt(tool_config):
|
||||
tool_name = tool_config["name"]
|
||||
tool_description = tool_config["description"]
|
||||
tool_examples = tool_config["examples"]
|
||||
|
||||
prompt = "【工具名称】\n" + tool_name + "\n"
|
||||
prompt += "【工具描述】\n" + tool_description + "\n"
|
||||
|
||||
prompt += "【工具适用问题示例】\n"
|
||||
for example in tool_examples:
|
||||
prompt += example + "\n"
|
||||
return prompt
|
||||
|
||||
|
||||
def construct_plugin_pool_prompt(tool_config_list):
|
||||
tool_explain_list = []
|
||||
for tool_config in tool_config_list:
|
||||
tool_explain = construct_plugin_prompt(tool_config)
|
||||
tool_explain_list.append(tool_explain)
|
||||
|
||||
tool_explain_list_str = "\n\n".join(tool_explain_list)
|
||||
|
||||
return tool_explain_list_str
|
||||
|
||||
|
||||
def construct_task_prompt(query_text, tool_explain_list_str):
|
||||
instruction = """问题为:{query_text}\n请根据问题和工具的描述,选择对应的工具,完成任务。请注意,只能选择1个工具。请一步一步地分析选择工具的原因(每个工具的【工具适用问题示例】是选择的重要参考依据),并给出最终选择,输出格式为json,key为’分析过程‘, ’选择工具‘""".format(
|
||||
query_text=query_text
|
||||
)
|
||||
|
||||
prompt = "工具选择如下:\n\n{tool_explain_list_str}\n\n【任务说明】\n{instruction}".format(
|
||||
instruction=instruction, tool_explain_list_str=tool_explain_list_str
|
||||
)
|
||||
|
||||
return prompt
|
||||
|
||||
|
||||
def plugin_selection_output_parse(llm_output: str) -> Union[Mapping[str, str], None]:
|
||||
try:
|
||||
pattern = r"\{[^{}]+\}"
|
||||
find_result = re.findall(pattern, llm_output)
|
||||
result = find_result[0].strip()
|
||||
|
||||
logger.info("result: {}", result)
|
||||
|
||||
result_dict = json.loads(result)
|
||||
logger.info("result_dict: {}", result_dict)
|
||||
|
||||
key_mapping = {"分析过程": "analysis", "选择工具": "toolSelection"}
|
||||
|
||||
converted_result_dict = {
|
||||
key_mapping[key]: value
|
||||
for key, value in result_dict.items()
|
||||
if key in key_mapping
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.exception(e)
|
||||
converted_result_dict = None
|
||||
|
||||
return converted_result_dict
|
||||
|
||||
|
||||
def plugins_config_format_convert(
|
||||
plugin_config_list: List[Mapping[str, Any]]
|
||||
) -> List[Mapping[str, Any]]:
|
||||
plugin_config_list_new = []
|
||||
for plugin_config in plugin_config_list:
|
||||
plugin_config_new = dict()
|
||||
name = plugin_config["name"]
|
||||
description = plugin_config["description"]
|
||||
examples = plugin_config["examples"]
|
||||
parameters = plugin_config["parameters"]
|
||||
|
||||
examples_str = "\n".join(examples)
|
||||
description_new = """{plugin_desc}\n\n例如能够处理如下问题:\n{examples_str}""".format(
|
||||
plugin_desc=description, examples_str=examples_str
|
||||
)
|
||||
|
||||
plugin_config_new["name"] = name
|
||||
plugin_config_new["description"] = description_new
|
||||
plugin_config_new["parameters"] = parameters
|
||||
|
||||
plugin_config_list_new.append(plugin_config_new)
|
||||
|
||||
return plugin_config_list_new
|
||||
28
chat/python/services/plugin_call/run.py
Normal file
28
chat/python/services/plugin_call/run.py
Normal file
@@ -0,0 +1,28 @@
|
||||
# -*- coding:utf-8 -*-
|
||||
|
||||
import os
|
||||
import sys
|
||||
from typing import Any, List, Mapping, Union
|
||||
|
||||
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||||
sys.path.append(os.path.dirname(os.path.abspath(__file__)))
|
||||
|
||||
from plugin_call.prompt_construct import (
|
||||
construct_plugin_pool_prompt,
|
||||
construct_task_prompt,
|
||||
plugin_selection_output_parse,
|
||||
)
|
||||
from instances.llm_instance import llm
|
||||
|
||||
|
||||
def plugin_selection_run(
|
||||
query_text: str, plugin_configs: List[Mapping[str, Any]]
|
||||
) -> Union[Mapping[str, str], None]:
|
||||
|
||||
tools_prompt = construct_plugin_pool_prompt(plugin_configs)
|
||||
|
||||
task_prompt = construct_task_prompt(query_text, tools_prompt)
|
||||
llm_output = llm(task_prompt)
|
||||
parsed_output = plugin_selection_output_parse(llm_output)
|
||||
|
||||
return parsed_output
|
||||
98
chat/python/services/query_retrieval/retriever.py
Normal file
98
chat/python/services/query_retrieval/retriever.py
Normal file
@@ -0,0 +1,98 @@
|
||||
# -*- coding:utf-8 -*-
|
||||
|
||||
import os
|
||||
import sys
|
||||
import uuid
|
||||
from typing import Any, List, Mapping, Optional, Union
|
||||
|
||||
import chromadb
|
||||
from chromadb import Client
|
||||
from chromadb.config import Settings
|
||||
from chromadb.api import Collection, Documents, Embeddings
|
||||
from chromadb.api.types import CollectionMetadata
|
||||
|
||||
sys.path.append(os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))))
|
||||
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||||
sys.path.append(os.path.dirname(os.path.abspath(__file__)))
|
||||
|
||||
from instances.logging_instance import logger
|
||||
from utils.chromadb_utils import (get_chroma_collection_size, query_chroma_collection,
|
||||
parse_retrieval_chroma_collection_query, chroma_collection_query_retrieval_format,
|
||||
get_chroma_collection_by_ids, get_chroma_collection_size,
|
||||
add_chroma_collection, update_chroma_collection, delete_chroma_collection_by_ids,
|
||||
empty_chroma_collection_2)
|
||||
|
||||
from instances.text2vec import Text2VecEmbeddingFunction
|
||||
|
||||
class ChromaCollectionRetriever(object):
|
||||
def __init__(self, collection:Collection):
|
||||
self.collection = collection
|
||||
|
||||
def retrieval_query_run(self, query_texts_list:List[str]=None, query_embeddings:Embeddings=None,
|
||||
filter_condition:Mapping[str,str]=None, n_results:int=5):
|
||||
|
||||
retrieval_res = query_chroma_collection(self.collection, query_texts_list, query_embeddings,
|
||||
filter_condition, n_results)
|
||||
|
||||
parsed_retrieval_res = parse_retrieval_chroma_collection_query(retrieval_res)
|
||||
logger.debug('parsed_retrieval_res: {}', parsed_retrieval_res)
|
||||
parsed_retrieval_res_format = chroma_collection_query_retrieval_format(query_texts_list, query_embeddings, parsed_retrieval_res)
|
||||
logger.debug('parsed_retrieval_res_format: {}', parsed_retrieval_res_format)
|
||||
|
||||
return parsed_retrieval_res_format
|
||||
|
||||
def get_query_by_ids(self, query_ids:List[str]):
|
||||
queries = get_chroma_collection_by_ids(self.collection, query_ids)
|
||||
return queries
|
||||
|
||||
def get_query_size(self):
|
||||
return get_chroma_collection_size(self.collection)
|
||||
|
||||
def add_queries(self, query_text_list:List[str],
|
||||
query_id_list:List[str],
|
||||
metadatas:List[Mapping[str, str]]=None,
|
||||
embeddings:Embeddings=None):
|
||||
add_chroma_collection(self.collection, query_text_list, query_id_list, metadatas, embeddings)
|
||||
return True
|
||||
|
||||
def update_queries(self, query_text_list:List[str],
|
||||
query_id_list:List[str],
|
||||
metadatas:List[Mapping[str, str]]=None,
|
||||
embeddings:Embeddings=None):
|
||||
update_chroma_collection(self.collection, query_text_list, query_id_list, metadatas, embeddings)
|
||||
return True
|
||||
|
||||
def delete_queries_by_ids(self, query_ids:List[str]):
|
||||
delete_chroma_collection_by_ids(self.collection, query_ids)
|
||||
return True
|
||||
|
||||
def empty_query_collection(self):
|
||||
self.collection = empty_chroma_collection_2(self.collection)
|
||||
|
||||
return True
|
||||
|
||||
class CollectionManager(object):
|
||||
def __init__(self, chroma_client:Client, embedding_func: Text2VecEmbeddingFunction, collection_meta: Optional[CollectionMetadata] = None):
|
||||
self.chroma_client = chroma_client
|
||||
self.embedding_func = embedding_func
|
||||
self.collection_meta = collection_meta
|
||||
|
||||
def list_collections(self):
|
||||
collection_list = self.chroma_client.list_collections()
|
||||
return collection_list
|
||||
|
||||
def get_collection(self, collection_name:str):
|
||||
collection = self.chroma_client.get_collection(name=collection_name, embedding_function=self.embedding_func)
|
||||
return collection
|
||||
|
||||
def create_collection(self, collection_name:str):
|
||||
collection = self.chroma_client.create_collection(name=collection_name, embedding_function=self.embedding_func, metadata=self.collection_meta)
|
||||
return collection
|
||||
|
||||
def get_or_create_collection(self, collection_name:str):
|
||||
collection = self.chroma_client.get_or_create_collection(name=collection_name, embedding_function=self.embedding_func, metadata=self.collection_meta)
|
||||
return collection
|
||||
|
||||
def delete_collection(self, collection_name:str):
|
||||
self.chroma_client.delete_collection(collection_name)
|
||||
return True
|
||||
37
chat/python/services/query_retrieval/run.py
Normal file
37
chat/python/services/query_retrieval/run.py
Normal file
@@ -0,0 +1,37 @@
|
||||
# -*- coding:utf-8 -*-
|
||||
|
||||
import os
|
||||
import sys
|
||||
import uuid
|
||||
from typing import Any, List, Mapping, Optional, Union
|
||||
|
||||
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||||
sys.path.append(os.path.dirname(os.path.abspath(__file__)))
|
||||
|
||||
from instances.logging_instance import logger
|
||||
|
||||
import chromadb
|
||||
from chromadb.config import Settings
|
||||
from chromadb.api import Collection, Documents, Embeddings
|
||||
|
||||
from instances.text2vec import Text2VecEmbeddingFunction
|
||||
from instances.chromadb_instance import client
|
||||
|
||||
from config.config_parse import SOLVED_QUERY_COLLECTION_NAME, PRESET_QUERY_COLLECTION_NAME
|
||||
from retriever import ChromaCollectionRetriever, CollectionManager
|
||||
|
||||
|
||||
emb_func = Text2VecEmbeddingFunction()
|
||||
|
||||
collection_manager = CollectionManager(chroma_client=client, embedding_func=emb_func
|
||||
,collection_meta={"hnsw:space": "cosine"})
|
||||
|
||||
solved_query_collection = collection_manager.get_or_create_collection(collection_name=SOLVED_QUERY_COLLECTION_NAME)
|
||||
preset_query_collection = collection_manager.get_or_create_collection(collection_name=PRESET_QUERY_COLLECTION_NAME)
|
||||
|
||||
|
||||
solved_query_retriever = ChromaCollectionRetriever(solved_query_collection)
|
||||
preset_query_retriever = ChromaCollectionRetriever(preset_query_collection)
|
||||
|
||||
logger.info("init_solved_query_collection_size: {}".format(solved_query_retriever.get_query_size()))
|
||||
logger.info("init_preset_query_collection_size: {}".format(preset_query_retriever.get_query_size()))
|
||||
75
chat/python/services/sql/constructor.py
Normal file
75
chat/python/services/sql/constructor.py
Normal file
@@ -0,0 +1,75 @@
|
||||
# -*- coding:utf-8 -*-
|
||||
import os
|
||||
import sys
|
||||
from typing import List, Mapping
|
||||
from chromadb.api import Collection
|
||||
|
||||
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||||
sys.path.append(os.path.dirname(os.path.abspath(__file__)))
|
||||
|
||||
from instances.logging_instance import logger
|
||||
from services.query_retrieval.retriever import ChromaCollectionRetriever
|
||||
|
||||
class FewShotPromptTemplate2(object):
|
||||
def __init__(self, collection:Collection, retrieval_key:str, few_shot_seperator:str = "\n\n") -> None:
|
||||
self.collection = collection
|
||||
self.few_shot_retriever = ChromaCollectionRetriever(self.collection)
|
||||
|
||||
self.retrieval_key = retrieval_key
|
||||
|
||||
self.few_shot_seperator = few_shot_seperator
|
||||
|
||||
def add_few_shot_example(self, example_ids: List[str] , example_units: List[Mapping[str, str]])-> None:
|
||||
query_text_list = []
|
||||
|
||||
for idx, example_unit in enumerate(example_units):
|
||||
query_text_list.append(example_unit[self.retrieval_key])
|
||||
|
||||
self.few_shot_retriever.add_queries(query_text_list=query_text_list, query_id_list=example_ids, metadatas=example_units)
|
||||
|
||||
def update_few_shot_example(self, example_ids: List[str] , example_units: List[Mapping[str, str]])-> None:
|
||||
query_text_list = []
|
||||
|
||||
for idx, example_unit in enumerate(example_units):
|
||||
query_text_list.append(example_unit[self.retrieval_key])
|
||||
|
||||
self.few_shot_retriever.update_queries(query_text_list=query_text_list, query_id_list=example_ids, metadatas=example_units)
|
||||
|
||||
def delete_few_shot_example(self, example_ids: List[str])-> None:
|
||||
self.few_shot_retriever.delete_queries_by_ids(query_ids=example_ids)
|
||||
|
||||
def count_few_shot_example(self)-> int:
|
||||
return self.few_shot_retriever.get_query_size()
|
||||
|
||||
def reload_few_shot_example(self, example_ids: List[str] , example_units: List[Mapping[str, str]])-> None:
|
||||
logger.info(f"original {self.collection.name} size: {self.few_shot_retriever.get_query_size()}")
|
||||
|
||||
self.few_shot_retriever.empty_query_collection()
|
||||
logger.info(f"emptied {self.collection.name} size: {self.few_shot_retriever.get_query_size()}")
|
||||
|
||||
self.add_few_shot_example(example_ids=example_ids, example_units=example_units)
|
||||
logger.info(f"reloaded {self.collection.name} size: {self.few_shot_retriever.get_query_size()}")
|
||||
|
||||
def _sub_dict(self, d:Mapping[str, str], keys:List[str])-> Mapping[str, str]:
|
||||
return {k:d[k] for k in keys if k in d}
|
||||
|
||||
def retrieve_few_shot_example(self, query_text: str, retrieval_num: int, filter_condition: Mapping[str,str] =None)-> List[Mapping[str, str]]:
|
||||
query_text_list = [query_text]
|
||||
retrieval_res_list = self.few_shot_retriever.retrieval_query_run(query_texts_list=query_text_list,
|
||||
filter_condition=filter_condition, n_results=retrieval_num)
|
||||
retrieval_res_unit_list = retrieval_res_list[0]['retrieval']
|
||||
|
||||
return retrieval_res_unit_list
|
||||
|
||||
def make_few_shot_example_prompt(self, few_shot_template: str, example_keys: List[str],
|
||||
few_shot_example_meta_list: List[Mapping[str, str]])-> str:
|
||||
few_shot_example_str_unit_list = []
|
||||
|
||||
retrieval_metas_list = [self._sub_dict(few_shot_example_meta['metadata'], example_keys) for few_shot_example_meta in few_shot_example_meta_list]
|
||||
|
||||
for meta in retrieval_metas_list:
|
||||
few_shot_example_str_unit_list.append(few_shot_template.format(**meta))
|
||||
|
||||
few_shot_example_str = self.few_shot_seperator.join(few_shot_example_str_unit_list)
|
||||
|
||||
return few_shot_example_str
|
||||
61
chat/python/services/sql/examples_reload_run.py
Normal file
61
chat/python/services/sql/examples_reload_run.py
Normal file
@@ -0,0 +1,61 @@
|
||||
# -*- coding:utf-8 -*-
|
||||
import json
|
||||
import os
|
||||
import sys
|
||||
from typing import List, Mapping
|
||||
|
||||
import requests
|
||||
|
||||
sys.path.append(os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))))
|
||||
sys.path.append(os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))))
|
||||
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||||
sys.path.append(os.path.dirname(os.path.abspath(__file__)))
|
||||
|
||||
from instances.logging_instance import logger
|
||||
|
||||
from config.config_parse import (TEXT2DSL_EXAMPLE_NUM, TEXT2DSL_FEWSHOTS_NUM, TEXT2DSL_SELF_CONSISTENCY_NUM,
|
||||
LLMPARSER_HOST, LLMPARSER_PORT, TEXT2DSL_IS_SHORTCUT, TEXT2DSL_IS_SELF_CONSISTENCY)
|
||||
from few_shot_example.sql_examplar import examplars as sql_examplars
|
||||
|
||||
|
||||
def text2sql_agent_setting_update(llm_host:str, llm_port:str,
|
||||
sql_examplars:List[Mapping[str, str]], example_nums:int):
|
||||
|
||||
url = f"http://{llm_host}:{llm_port}/text2sql_agent_setting_update/"
|
||||
payload = {"sqlExamplars":sql_examplars, "exampleNums":example_nums}
|
||||
headers = {'content-type': 'application/json'}
|
||||
response = requests.post(url, data=json.dumps(payload), headers=headers)
|
||||
logger.info(response.text)
|
||||
|
||||
|
||||
def text2dsl_agent_cs_setting_update(llm_host:str, llm_port:str,
|
||||
sql_examplars:List[Mapping[str, str]], example_nums:int, fewshot_nums:int, self_consistency_nums:int):
|
||||
|
||||
url = f"http://{llm_host}:{llm_port}/texg2sqt_cs_agent_setting_update/"
|
||||
payload = {"sqlExamplars":sql_examplars,
|
||||
"exampleNums":example_nums, "fewshotNums":fewshot_nums, "selfConsistencyNums":self_consistency_nums}
|
||||
headers = {'content-type': 'application/json'}
|
||||
response = requests.post(url, data=json.dumps(payload), headers=headers)
|
||||
logger.info(response.text)
|
||||
|
||||
|
||||
def text2dsl_agent_wrapper_setting_update(llm_host:str, llm_port:str,
|
||||
is_shortcut:bool, is_self_consistency:bool,
|
||||
sql_examplars:List[Mapping[str, str]], example_nums:int, fewshot_nums:int, self_consistency_nums:int):
|
||||
|
||||
sql_ids = [str(i) for i in range(0, len(sql_examplars))]
|
||||
|
||||
url = f"http://{llm_host}:{llm_port}/query2sql_setting_update/"
|
||||
payload = {"isShortcut":is_shortcut, "isSelfConsistency":is_self_consistency,
|
||||
"sqlExamplars":sql_examplars, "sqlIds": sql_ids,
|
||||
"exampleNums":example_nums, "fewshotNums":fewshot_nums, "selfConsistencyNums":self_consistency_nums}
|
||||
headers = {'content-type': 'application/json'}
|
||||
response = requests.post(url, data=json.dumps(payload), headers=headers)
|
||||
logger.info(response.text)
|
||||
|
||||
if __name__ == "__main__":
|
||||
text2dsl_agent_wrapper_setting_update(LLMPARSER_HOST,LLMPARSER_PORT,
|
||||
TEXT2DSL_IS_SHORTCUT, TEXT2DSL_IS_SELF_CONSISTENCY,
|
||||
sql_examplars, TEXT2DSL_EXAMPLE_NUM, TEXT2DSL_FEWSHOTS_NUM, TEXT2DSL_SELF_CONSISTENCY_NUM)
|
||||
|
||||
|
||||
57
chat/python/services/sql/output_parser.py
Normal file
57
chat/python/services/sql/output_parser.py
Normal file
@@ -0,0 +1,57 @@
|
||||
# -*- coding:utf-8 -*-
|
||||
import re
|
||||
|
||||
import os
|
||||
import sys
|
||||
|
||||
sys.path.append(os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))))
|
||||
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||||
sys.path.append(os.path.dirname(os.path.abspath(__file__)))
|
||||
|
||||
from instances.logging_instance import logger
|
||||
|
||||
|
||||
def schema_link_parse(schema_link_output):
|
||||
try:
|
||||
schema_link_output = schema_link_output.strip()
|
||||
pattern = r"Schema_links:(.*)"
|
||||
schema_link_output = re.findall(pattern, schema_link_output, re.DOTALL)[0].strip()
|
||||
except Exception as e:
|
||||
logger.exception(e)
|
||||
schema_link_output = None
|
||||
|
||||
return schema_link_output
|
||||
|
||||
|
||||
def combo_schema_link_parse(schema_linking_sql_combo_output: str):
|
||||
try:
|
||||
schema_linking_sql_combo_output = schema_linking_sql_combo_output.strip()
|
||||
pattern = r"Schema_links:(\[.*?\])"
|
||||
schema_links_match = re.search(pattern, schema_linking_sql_combo_output)
|
||||
|
||||
if schema_links_match:
|
||||
schema_links = schema_links_match.group(1)
|
||||
else:
|
||||
schema_links = None
|
||||
except Exception as e:
|
||||
logger.info(e)
|
||||
schema_links = None
|
||||
|
||||
return schema_links
|
||||
|
||||
|
||||
def combo_sql_parse(schema_linking_sql_combo_output: str):
|
||||
try:
|
||||
schema_linking_sql_combo_output = schema_linking_sql_combo_output.strip()
|
||||
pattern = r"SQL:(.*)"
|
||||
sql_match = re.search(pattern, schema_linking_sql_combo_output)
|
||||
|
||||
if sql_match:
|
||||
sql = sql_match.group(1)
|
||||
else:
|
||||
sql = None
|
||||
except Exception as e:
|
||||
logger.exception(e)
|
||||
sql = None
|
||||
|
||||
return sql
|
||||
54
chat/python/services/sql/run.py
Normal file
54
chat/python/services/sql/run.py
Normal file
@@ -0,0 +1,54 @@
|
||||
# -*- coding:utf-8 -*-
|
||||
|
||||
import asyncio
|
||||
|
||||
import os
|
||||
import sys
|
||||
|
||||
sys.path.append(os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))))
|
||||
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||||
sys.path.append(os.path.dirname(os.path.abspath(__file__)))
|
||||
|
||||
from sql.constructor import FewShotPromptTemplate2
|
||||
from sql.sql_agent import Text2DSLAgent, Text2DSLAgentConsistency, Text2DSLAgentWrapper
|
||||
|
||||
from instances.llm_instance import llm
|
||||
from instances.text2vec import Text2VecEmbeddingFunction
|
||||
from instances.chromadb_instance import client
|
||||
from instances.logging_instance import logger
|
||||
|
||||
from few_shot_example.sql_examplar import examplars as sql_examplars
|
||||
from config.config_parse import (TEXT2DSLAGENT_COLLECTION_NAME, TEXT2DSLAGENTCS_COLLECTION_NAME,
|
||||
TEXT2DSL_EXAMPLE_NUM, TEXT2DSL_FEWSHOTS_NUM, TEXT2DSL_SELF_CONSISTENCY_NUM,
|
||||
TEXT2DSL_IS_SHORTCUT, TEXT2DSL_IS_SELF_CONSISTENCY)
|
||||
|
||||
|
||||
emb_func = Text2VecEmbeddingFunction()
|
||||
text2dsl_agent_collection = client.get_or_create_collection(name=TEXT2DSLAGENT_COLLECTION_NAME,
|
||||
embedding_function=emb_func,
|
||||
metadata={"hnsw:space": "cosine"})
|
||||
text2dsl_agentcs_collection = client.get_or_create_collection(name=TEXT2DSLAGENTCS_COLLECTION_NAME,
|
||||
embedding_function=emb_func,
|
||||
metadata={"hnsw:space": "cosine"})
|
||||
|
||||
text2dsl_agent_example_prompter = FewShotPromptTemplate2(collection=text2dsl_agent_collection,
|
||||
retrieval_key="question",
|
||||
few_shot_seperator='\n\n')
|
||||
|
||||
text2dsl_agentcs_example_prompter = FewShotPromptTemplate2(collection=text2dsl_agentcs_collection,
|
||||
retrieval_key="question",
|
||||
few_shot_seperator='\n\n')
|
||||
|
||||
text2sql_agent = Text2DSLAgent(num_fewshots=TEXT2DSL_EXAMPLE_NUM,
|
||||
sql_example_prompter=text2dsl_agent_example_prompter, llm=llm)
|
||||
|
||||
text2sql_cs_agent = Text2DSLAgentConsistency(num_fewshots=TEXT2DSL_FEWSHOTS_NUM, num_examples=TEXT2DSL_EXAMPLE_NUM, num_self_consistency=TEXT2DSL_SELF_CONSISTENCY_NUM,
|
||||
sql_example_prompter=text2dsl_agentcs_example_prompter, llm=llm)
|
||||
|
||||
sql_ids = [str(i) for i in range(0, len(sql_examplars))]
|
||||
text2sql_agent.reload_setting(sql_ids, sql_examplars, TEXT2DSL_EXAMPLE_NUM)
|
||||
text2sql_cs_agent.reload_setting(sql_ids, sql_examplars, TEXT2DSL_EXAMPLE_NUM, TEXT2DSL_FEWSHOTS_NUM, TEXT2DSL_SELF_CONSISTENCY_NUM)
|
||||
|
||||
|
||||
text2sql_agent_router = Text2DSLAgentWrapper(sql_agent=text2sql_agent, sql_agent_cs=text2sql_cs_agent,
|
||||
is_shortcut=TEXT2DSL_IS_SHORTCUT, is_self_consistency=TEXT2DSL_IS_SELF_CONSISTENCY)
|
||||
405
chat/python/services/sql/sql_agent.py
Normal file
405
chat/python/services/sql/sql_agent.py
Normal file
@@ -0,0 +1,405 @@
|
||||
import os
|
||||
import sys
|
||||
from typing import List, Union, Mapping, Any
|
||||
from collections import Counter
|
||||
import random
|
||||
import asyncio
|
||||
from langchain.llms.base import BaseLLM
|
||||
|
||||
sys.path.append(os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))))
|
||||
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||||
sys.path.append(os.path.dirname(os.path.abspath(__file__)))
|
||||
|
||||
from instances.logging_instance import logger
|
||||
|
||||
from sql.constructor import FewShotPromptTemplate2
|
||||
from sql.output_parser import schema_link_parse, combo_schema_link_parse, combo_sql_parse
|
||||
|
||||
|
||||
class Text2DSLAgent(object):
|
||||
def __init__(self, num_fewshots:int,
|
||||
sql_example_prompter:FewShotPromptTemplate2,
|
||||
llm: BaseLLM):
|
||||
self.num_fewshots = num_fewshots
|
||||
self.sql_example_prompter = sql_example_prompter
|
||||
self.llm = llm
|
||||
|
||||
def reload_setting(self, sql_example_ids: List[str], sql_example_units: List[Mapping[str,str]], num_fewshots: int):
|
||||
self.num_fewshots = num_fewshots
|
||||
|
||||
self.sql_example_prompter.reload_few_shot_example(sql_example_ids, sql_example_units)
|
||||
|
||||
def add_examples(self, sql_example_ids: List[str], sql_example_units: List[Mapping[str,str]]):
|
||||
self.sql_example_prompter.add_few_shot_example(sql_example_ids, sql_example_units)
|
||||
|
||||
def update_examples(self, sql_example_ids: List[str], sql_example_units: List[Mapping[str,str]]):
|
||||
self.sql_example_prompter.update_few_shot_example(sql_example_ids, sql_example_units)
|
||||
|
||||
def delete_examples(self, sql_example_ids: List[str]):
|
||||
self.sql_example_prompter.delete_few_shot_example(sql_example_ids)
|
||||
|
||||
def count_examples(self):
|
||||
return self.sql_example_prompter.count_few_shot_example()
|
||||
|
||||
def get_fewshot_examples(self, query_text: str, filter_condition: Mapping[str,str])->List[Mapping[str, str]]:
|
||||
few_shot_example_meta_list = self.sql_example_prompter.retrieve_few_shot_example(query_text, self.num_fewshots, filter_condition)
|
||||
|
||||
return few_shot_example_meta_list
|
||||
|
||||
def generate_schema_linking_prompt(self, user_query: str, domain_name: str, fields_list: List[str],
|
||||
prior_schema_links: Mapping[str,str], fewshot_example_list:List[Mapping[str, str]])-> str:
|
||||
|
||||
prior_schema_links_str = '['+ ','.join(["""'{}'->{}""".format(k,v) for k,v in prior_schema_links.items()]) + ']'
|
||||
|
||||
instruction = "# 根据数据库的表结构,参考先验信息,找出为每个问题生成SQL查询语句的schema_links"
|
||||
|
||||
schema_linking_example_keys = ["tableName", "fieldsList", "priorSchemaLinks", "question", "analysis", "schemaLinks"]
|
||||
schema_linking_example_template = "Table {tableName}, columns = {fieldsList}, prior_schema_links = {priorSchemaLinks}\n问题:{question}\n分析:{analysis} 所以Schema_links是:\nSchema_links:{schemaLinks}"
|
||||
schema_linking_fewshot_prompt = self.sql_example_prompter.make_few_shot_example_prompt(few_shot_template=schema_linking_example_template,
|
||||
example_keys=schema_linking_example_keys,
|
||||
few_shot_example_meta_list=fewshot_example_list)
|
||||
|
||||
new_case_template = "Table {tableName}, columns = {fieldsList}, prior_schema_links = {priorSchemaLinks}\n问题:{question}\n分析: 让我们一步一步地思考。"
|
||||
new_case_prompt = new_case_template.format(tableName=domain_name, fieldsList=fields_list, priorSchemaLinks=prior_schema_links_str, question=user_query)
|
||||
|
||||
schema_linking_prompt = instruction + '\n\n' + schema_linking_fewshot_prompt + '\n\n' + new_case_prompt
|
||||
return schema_linking_prompt
|
||||
|
||||
def generate_sql_prompt(self, user_query: str, domain_name: str,
|
||||
schema_link_str: str, data_date: str,
|
||||
fewshot_example_list:List[Mapping[str, str]])-> str:
|
||||
instruction = "# 根据schema_links为每个问题生成SQL查询语句"
|
||||
sql_example_keys = ["question", "currentDate", "tableName", "schemaLinks", "sql"]
|
||||
sql_example_template = "问题:{question}\nCurrent_date:{currentDate}\nTable {tableName}\nSchema_links:{schemaLinks}\nSQL:{sql}"
|
||||
|
||||
sql_example_fewshot_prompt = self.sql_example_prompter.make_few_shot_example_prompt(few_shot_template=sql_example_template,
|
||||
example_keys=sql_example_keys,
|
||||
few_shot_example_meta_list=fewshot_example_list)
|
||||
|
||||
new_case_template = "问题:{question}\nCurrent_date:{currentDate}\nTable {tableName}\nSchema_links:{schemaLinks}\nSQL:"
|
||||
new_case_prompt = new_case_template.format(question=user_query, currentDate=data_date, tableName=domain_name, schemaLinks=schema_link_str)
|
||||
|
||||
sql_example_prompt = instruction + '\n\n' + sql_example_fewshot_prompt + '\n\n' + new_case_prompt
|
||||
|
||||
return sql_example_prompt
|
||||
|
||||
def generate_schema_linking_sql_prompt(self, user_query: str,
|
||||
domain_name: str,
|
||||
data_date : str,
|
||||
fields_list: List[str],
|
||||
prior_schema_links: Mapping[str,str],
|
||||
fewshot_example_list:List[Mapping[str, str]]):
|
||||
|
||||
prior_schema_links_str = '['+ ','.join(["""'{}'->{}""".format(k,v) for k,v in prior_schema_links.items()]) + ']'
|
||||
|
||||
instruction = "# 根据数据库的表结构,参考先验信息,找出为每个问题生成SQL查询语句的schema_links,再根据schema_links为每个问题生成SQL查询语句"
|
||||
|
||||
example_keys = ["tableName", "fieldsList", "priorSchemaLinks", "currentDate", "question", "analysis", "schemaLinks", "sql"]
|
||||
example_template = "Table {tableName}, columns = {fieldsList}, prior_schema_links = {priorSchemaLinks}\nCurrent_date:{currentDate}\n问题:{question}\n分析:{analysis} 所以Schema_links是:\nSchema_links:{schemaLinks}\nSQL:{sql}"
|
||||
fewshot_prompt = self.sql_example_prompter.make_few_shot_example_prompt(few_shot_template=example_template,
|
||||
example_keys=example_keys,
|
||||
few_shot_example_meta_list=fewshot_example_list)
|
||||
|
||||
new_case_template = "Table {tableName}, columns = {fieldsList}, prior_schema_links = {priorSchemaLinks}\nCurrent_date:{currentDate}\n问题:{question}\n分析: 让我们一步一步地思考。"
|
||||
new_case_prompt = new_case_template.format(tableName=domain_name, fieldsList=fields_list, priorSchemaLinks=prior_schema_links_str, currentDate=data_date, question=user_query)
|
||||
|
||||
prompt = instruction + '\n\n' + fewshot_prompt + '\n\n' + new_case_prompt
|
||||
|
||||
return prompt
|
||||
|
||||
async def async_query2sql(self, query_text: str, filter_condition: Mapping[str,str],
|
||||
model_name: str, fields_list: List[str],
|
||||
data_date: str, prior_schema_links: Mapping[str,str], prior_exts: str):
|
||||
logger.info("query_text: {}".format(query_text))
|
||||
logger.info("model_name: {}".format(model_name))
|
||||
logger.info("fields_list: {}".format(fields_list))
|
||||
logger.info("data_date: {}".format(data_date))
|
||||
logger.info("prior_schema_links: {}".format(prior_schema_links))
|
||||
logger.info("prior_exts: {}".format(prior_exts))
|
||||
|
||||
if prior_exts != '':
|
||||
query_text = query_text + ' 备注:'+prior_exts
|
||||
logger.info("query_text_prior_exts: {}".format(query_text))
|
||||
|
||||
fewshot_example_meta_list = self.get_fewshot_examples(query_text, filter_condition)
|
||||
schema_linking_prompt = self.generate_schema_linking_prompt(query_text, model_name, fields_list, prior_schema_links, fewshot_example_meta_list)
|
||||
logger.debug("schema_linking_prompt->{}".format(schema_linking_prompt))
|
||||
schema_link_output = await self.llm._call_async(schema_linking_prompt)
|
||||
|
||||
schema_link_str = schema_link_parse(schema_link_output)
|
||||
|
||||
sql_prompt = self.generate_sql_prompt(query_text, model_name, schema_link_str, data_date, fewshot_example_meta_list)
|
||||
logger.debug("sql_prompt->{}".format(sql_prompt))
|
||||
sql_output = await self.llm._call_async(sql_prompt)
|
||||
|
||||
resp = dict()
|
||||
resp['query'] = query_text
|
||||
resp['model'] = model_name
|
||||
resp['fields'] = fields_list
|
||||
resp['priorSchemaLinking'] = prior_schema_links
|
||||
resp['dataDate'] = data_date
|
||||
|
||||
resp['schemaLinkingOutput'] = schema_link_output
|
||||
resp['schemaLinkStr'] = schema_link_str
|
||||
|
||||
resp['sqlOutput'] = sql_output
|
||||
|
||||
logger.info("resp: {}".format(resp))
|
||||
|
||||
return resp
|
||||
|
||||
async def async_query2sql_shortcut(self, query_text: str, filter_condition: Mapping[str,str],
|
||||
model_name: str, fields_list: List[str],
|
||||
data_date: str, prior_schema_links: Mapping[str,str], prior_exts: str):
|
||||
logger.info("query_text: {}".format(query_text))
|
||||
logger.info("model_name: {}".format(model_name))
|
||||
logger.info("fields_list: {}".format(fields_list))
|
||||
logger.info("data_date: {}".format(data_date))
|
||||
logger.info("prior_schema_links: {}".format(prior_schema_links))
|
||||
logger.info("prior_exts: {}".format(prior_exts))
|
||||
|
||||
if prior_exts != '':
|
||||
query_text = query_text + ' 备注:'+prior_exts
|
||||
logger.info("query_text_prior_exts: {}".format(query_text))
|
||||
|
||||
fewshot_example_meta_list = self.get_fewshot_examples(query_text, filter_condition)
|
||||
schema_linking_sql_shortcut_prompt = self.generate_schema_linking_sql_prompt(query_text, model_name, data_date, fields_list, prior_schema_links, fewshot_example_meta_list)
|
||||
logger.debug("schema_linking_sql_shortcut_prompt->{}".format(schema_linking_sql_shortcut_prompt))
|
||||
schema_linking_sql_shortcut_output = await self.llm._call_async(schema_linking_sql_shortcut_prompt)
|
||||
|
||||
schema_linking_str = combo_schema_link_parse(schema_linking_sql_shortcut_output)
|
||||
sql_str = combo_sql_parse(schema_linking_sql_shortcut_output)
|
||||
|
||||
resp = dict()
|
||||
resp['query'] = query_text
|
||||
resp['model'] = model_name
|
||||
resp['fields'] = fields_list
|
||||
resp['priorSchemaLinking'] = prior_schema_links
|
||||
resp['dataDate'] = data_date
|
||||
|
||||
resp['schemaLinkingComboOutput'] = schema_linking_sql_shortcut_output
|
||||
resp['schemaLinkStr'] = schema_linking_str
|
||||
resp['sqlOutput'] = sql_str
|
||||
|
||||
logger.info("resp: {}".format(resp))
|
||||
|
||||
return resp
|
||||
|
||||
class Text2DSLAgentConsistency(object):
|
||||
def __init__(self, num_fewshots:int, num_examples:int, num_self_consistency:int,
|
||||
sql_example_prompter:FewShotPromptTemplate2, llm: BaseLLM) -> None:
|
||||
self.num_fewshots = num_fewshots
|
||||
self.num_examples = num_examples
|
||||
assert self.num_fewshots <= self.num_examples
|
||||
self.num_self_consistency = num_self_consistency
|
||||
|
||||
self.llm = llm
|
||||
self.sql_example_prompter = sql_example_prompter
|
||||
|
||||
def reload_setting(self, sql_example_ids:List[str], sql_example_units: List[Mapping[str, str]], num_examples:int, num_fewshots:int, num_self_consistency:int):
|
||||
self.num_fewshots = num_fewshots
|
||||
self.num_examples = num_examples
|
||||
assert self.num_fewshots <= self.num_examples
|
||||
self.num_self_consistency = num_self_consistency
|
||||
assert self.num_self_consistency >= 1
|
||||
self.sql_example_prompter.reload_few_shot_example(sql_example_ids, sql_example_units)
|
||||
|
||||
def add_examples(self, sql_example_ids:List[str], sql_example_units: List[Mapping[str, str]]):
|
||||
self.sql_example_prompter.add_few_shot_example(sql_example_ids, sql_example_units)
|
||||
|
||||
def update_examples(self, sql_example_ids:List[str], sql_example_units: List[Mapping[str, str]]):
|
||||
self.sql_example_prompter.update_few_shot_example(sql_example_ids, sql_example_units)
|
||||
|
||||
def delete_examples(self, sql_example_ids:List[str]):
|
||||
self.sql_example_prompter.delete_few_shot_example(sql_example_ids)
|
||||
|
||||
def count_examples(self):
|
||||
return self.sql_example_prompter.count_few_shot_example()
|
||||
|
||||
def get_examples_candidates(self, query_text: str, filter_condition: Mapping[str, str])->List[Mapping[str, str]]:
|
||||
few_shot_example_meta_list = self.sql_example_prompter.retrieve_few_shot_example(query_text, self.num_examples, filter_condition)
|
||||
|
||||
return few_shot_example_meta_list
|
||||
|
||||
def get_fewshot_example_combos(self, example_meta_list:List[Mapping[str, str]])-> List[List[Mapping[str, str]]]:
|
||||
fewshot_example_list = []
|
||||
for i in range(0, self.num_self_consistency):
|
||||
random.shuffle(example_meta_list)
|
||||
fewshot_example_list.append(example_meta_list[:self.num_fewshots])
|
||||
|
||||
return fewshot_example_list
|
||||
|
||||
def generate_schema_linking_prompt(self, user_query: str, domain_name: str, fields_list: List[str],
|
||||
prior_schema_links: Mapping[str,str], fewshot_example_list:List[Mapping[str, str]])-> str:
|
||||
|
||||
prior_schema_links_str = '['+ ','.join(["""'{}'->{}""".format(k,v) for k,v in prior_schema_links.items()]) + ']'
|
||||
|
||||
instruction = "# 根据数据库的表结构,参考先验信息,找出为每个问题生成SQL查询语句的schema_links"
|
||||
|
||||
schema_linking_example_keys = ["tableName", "fieldsList", "priorSchemaLinks", "question", "analysis", "schemaLinks"]
|
||||
schema_linking_example_template = "Table {tableName}, columns = {fieldsList}, prior_schema_links = {priorSchemaLinks}\n问题:{question}\n分析:{analysis} 所以Schema_links是:\nSchema_links:{schemaLinks}"
|
||||
schema_linking_fewshot_prompt = self.sql_example_prompter.make_few_shot_example_prompt(few_shot_template=schema_linking_example_template,
|
||||
example_keys=schema_linking_example_keys,
|
||||
few_shot_example_meta_list=fewshot_example_list)
|
||||
|
||||
new_case_template = "Table {tableName}, columns = {fieldsList}, prior_schema_links = {priorSchemaLinks}\n问题:{question}\n分析: 让我们一步一步地思考。"
|
||||
new_case_prompt = new_case_template.format(tableName=domain_name, fieldsList=fields_list, priorSchemaLinks=prior_schema_links_str, question=user_query)
|
||||
|
||||
schema_linking_prompt = instruction + '\n\n' + schema_linking_fewshot_prompt + '\n\n' + new_case_prompt
|
||||
return schema_linking_prompt
|
||||
|
||||
def generate_schema_linking_prompt_pool(self, user_query: str, domain_name: str, fields_list: List[str],
|
||||
prior_schema_links: Mapping[str,str], fewshot_example_list_pool:List[List[Mapping[str, str]]])-> List[str]:
|
||||
schema_linking_prompt_pool = []
|
||||
for fewshot_example_list in fewshot_example_list_pool:
|
||||
schema_linking_prompt = self.generate_schema_linking_prompt(user_query, domain_name, fields_list, prior_schema_links, fewshot_example_list)
|
||||
schema_linking_prompt_pool.append(schema_linking_prompt)
|
||||
|
||||
return schema_linking_prompt_pool
|
||||
|
||||
def generate_sql_prompt(self, user_query: str, domain_name: str,
|
||||
schema_link_str: str, data_date: str,
|
||||
fewshot_example_list:List[Mapping[str, str]])-> str:
|
||||
instruction = "# 根据schema_links为每个问题生成SQL查询语句"
|
||||
sql_example_keys = ["question", "currentDate", "tableName", "schemaLinks", "sql"]
|
||||
sql_example_template = "问题:{question}\nCurrent_date:{currentDate}\nTable {tableName}\nSchema_links:{schemaLinks}\nSQL:{sql}"
|
||||
|
||||
sql_example_fewshot_prompt = self.sql_example_prompter.make_few_shot_example_prompt(few_shot_template=sql_example_template,
|
||||
example_keys=sql_example_keys,
|
||||
few_shot_example_meta_list=fewshot_example_list)
|
||||
|
||||
new_case_template = "问题:{question}\nCurrent_date:{currentDate}\nTable {tableName}\nSchema_links:{schemaLinks}\nSQL:"
|
||||
new_case_prompt = new_case_template.format(question=user_query, currentDate=data_date, tableName=domain_name, schemaLinks=schema_link_str)
|
||||
|
||||
sql_example_prompt = instruction + '\n\n' + sql_example_fewshot_prompt + '\n\n' + new_case_prompt
|
||||
|
||||
return sql_example_prompt
|
||||
|
||||
def generate_sql_prompt_pool(self, user_query: str, domain_name: str, data_date: str,
|
||||
schema_link_str_pool: List[str], fewshot_example_list_pool:List[List[Mapping[str, str]]])-> List[str]:
|
||||
sql_prompt_pool = []
|
||||
for schema_link_str, fewshot_example_list in zip(schema_link_str_pool, fewshot_example_list_pool):
|
||||
sql_prompt = self.generate_sql_prompt(user_query, domain_name, schema_link_str, data_date, fewshot_example_list)
|
||||
sql_prompt_pool.append(sql_prompt)
|
||||
|
||||
return sql_prompt_pool
|
||||
|
||||
def self_consistency_vote(self, output_res_pool:List[str]):
|
||||
output_res_counts = Counter(output_res_pool)
|
||||
output_res_max = output_res_counts.most_common(1)[0][0]
|
||||
total_output_num = len(output_res_pool)
|
||||
|
||||
vote_percentage = {k: (v/total_output_num) for k,v in output_res_counts.items()}
|
||||
|
||||
return output_res_max, vote_percentage
|
||||
|
||||
def schema_linking_list_str_unify(self, schema_linking_list: List[str])-> List[str]:
|
||||
schema_linking_list_unify = []
|
||||
for schema_linking_str in schema_linking_list:
|
||||
schema_linking_str_unify = ','.join(sorted([item.strip() for item in schema_linking_str.strip('[]').split(',')]))
|
||||
schema_linking_str_unify = f'[{schema_linking_str_unify}]'
|
||||
schema_linking_list_unify.append(schema_linking_str_unify)
|
||||
|
||||
return schema_linking_list_unify
|
||||
|
||||
|
||||
async def generate_schema_linking_tasks(self, user_query: str, domain_name: str,
|
||||
fields_list: List[str], prior_schema_links: Mapping[str,str],
|
||||
fewshot_example_list_combo:List[List[Mapping[str, str]]]):
|
||||
|
||||
schema_linking_prompt_pool = self.generate_schema_linking_prompt_pool(user_query, domain_name,
|
||||
fields_list, prior_schema_links,
|
||||
fewshot_example_list_combo)
|
||||
schema_linking_output_task_pool = [self.llm._call_async(schema_linking_prompt) for schema_linking_prompt in schema_linking_prompt_pool]
|
||||
schema_linking_output_res_pool = await asyncio.gather(*schema_linking_output_task_pool)
|
||||
logger.debug(f'schema_linking_output_res_pool:{schema_linking_output_res_pool}')
|
||||
|
||||
return schema_linking_output_res_pool
|
||||
|
||||
async def generate_sql_tasks(self, user_query: str, domain_name: str, data_date: str,
|
||||
schema_link_str_pool: List[str], fewshot_example_list_combo:List[List[Mapping[str, str]]]):
|
||||
|
||||
sql_prompt_pool = self.generate_sql_prompt_pool(user_query, domain_name, schema_link_str_pool, data_date, fewshot_example_list_combo)
|
||||
sql_output_task_pool = [self.llm._call_async(sql_prompt) for sql_prompt in sql_prompt_pool]
|
||||
sql_output_res_pool = await asyncio.gather(*sql_output_task_pool)
|
||||
logger.debug(f'sql_output_res_pool:{sql_output_res_pool}')
|
||||
|
||||
return sql_output_res_pool
|
||||
|
||||
async def tasks_run(self, user_query: str, filter_condition: Mapping[str, str], domain_name: str, fields_list: List[str], prior_schema_links: Mapping[str,str], data_date: str, prior_exts: str):
|
||||
logger.info("user_query: {}".format(user_query))
|
||||
logger.info("domain_name: {}".format(domain_name))
|
||||
logger.info("fields_list: {}".format(fields_list))
|
||||
logger.info("current_date: {}".format(data_date))
|
||||
logger.info("prior_schema_links: {}".format(prior_schema_links))
|
||||
logger.info("prior_exts: {}".format(prior_exts))
|
||||
|
||||
if prior_exts != '':
|
||||
user_query = user_query + ' 备注:'+prior_exts
|
||||
logger.info("user_query_prior_exts: {}".format(user_query))
|
||||
|
||||
fewshot_example_meta_list = self.get_examples_candidates(user_query, filter_condition)
|
||||
fewshot_example_list_combo = self.get_fewshot_example_combos(fewshot_example_meta_list)
|
||||
|
||||
schema_linking_output_candidates = await self.generate_schema_linking_tasks(user_query, domain_name, fields_list, prior_schema_links, fewshot_example_list_combo)
|
||||
schema_linking_candidate_list = [schema_link_parse(schema_linking_output_candidate) for schema_linking_output_candidate in schema_linking_output_candidates]
|
||||
logger.debug(f'schema_linking_candidate_list:{schema_linking_candidate_list}')
|
||||
schema_linking_candidate_sorted_list = self.schema_linking_list_str_unify(schema_linking_candidate_list)
|
||||
logger.debug(f'schema_linking_candidate_sorted_list:{schema_linking_candidate_sorted_list}')
|
||||
|
||||
schema_linking_output_max, schema_linking_output_vote_percentage = self.self_consistency_vote(schema_linking_candidate_sorted_list)
|
||||
|
||||
sql_output_candicates = await self.generate_sql_tasks(user_query, domain_name, data_date, schema_linking_candidate_list,fewshot_example_list_combo)
|
||||
logger.debug(f'sql_output_candicates:{sql_output_candicates}')
|
||||
sql_output_max, sql_output_vote_percentage = self.self_consistency_vote(sql_output_candicates)
|
||||
|
||||
resp = dict()
|
||||
resp['query'] = user_query
|
||||
resp['model'] = domain_name
|
||||
resp['fields'] = fields_list
|
||||
resp['priorSchemaLinking'] = prior_schema_links
|
||||
resp['dataDate'] = data_date
|
||||
|
||||
resp['schemaLinkStr'] = schema_linking_output_max
|
||||
resp['schemaLinkingWeight'] = schema_linking_output_vote_percentage
|
||||
|
||||
resp['sqlOutput'] = sql_output_max
|
||||
resp['sqlWeight'] = sql_output_vote_percentage
|
||||
|
||||
logger.info("resp: {}".format(resp))
|
||||
|
||||
return resp
|
||||
|
||||
class Text2DSLAgentWrapper(object):
|
||||
def __init__(self, sql_agent:Text2DSLAgent, sql_agent_cs:Text2DSLAgentConsistency,
|
||||
is_shortcut:bool, is_self_consistency:bool):
|
||||
self.sql_agent = sql_agent
|
||||
self.sql_agent_cs = sql_agent_cs
|
||||
|
||||
self.is_shortcut = is_shortcut
|
||||
self.is_self_consistency = is_self_consistency
|
||||
|
||||
async def async_query2sql(self, query_text: str, filter_condition: Mapping[str,str],
|
||||
model_name: str, fields_list: List[str],
|
||||
data_date: str, prior_schema_links: Mapping[str,str], prior_exts: str):
|
||||
if self.is_self_consistency:
|
||||
logger.info("sql wrapper: self_consistency")
|
||||
resp = await self.sql_agent_cs.tasks_run(user_query=query_text, filter_condition=filter_condition, domain_name=model_name, fields_list=fields_list, prior_schema_links=prior_schema_links, data_date=data_date, prior_exts=prior_exts)
|
||||
return resp
|
||||
elif self.is_shortcut:
|
||||
logger.info("sql wrapper: shortcut")
|
||||
resp = await self.sql_agent.async_query2sql_shortcut(query_text=query_text, filter_condition=filter_condition, model_name=model_name, fields_list=fields_list, data_date=data_date, prior_schema_links=prior_schema_links, prior_exts=prior_exts)
|
||||
return resp
|
||||
else:
|
||||
logger.info("sql wrapper: normal")
|
||||
resp = await self.sql_agent.async_query2sql(query_text=query_text, filter_condition=filter_condition, model_name=model_name, fields_list=fields_list, data_date=data_date, prior_schema_links=prior_schema_links, prior_exts=prior_exts)
|
||||
return resp
|
||||
|
||||
def update_configs(self, is_shortcut, is_self_consistency,
|
||||
sql_examplars, num_examples, num_fewshots, num_self_consistency):
|
||||
self.is_shortcut = is_shortcut
|
||||
self.is_self_consistency = is_self_consistency
|
||||
|
||||
self.sql_agent.update_examples(sql_examplars=sql_examplars, num_fewshots=num_examples)
|
||||
self.sql_agent_cs.update_examples(sql_examplars=sql_examplars, num_examples=num_examples, num_fewshots=num_fewshots, num_self_consistency=num_self_consistency)
|
||||
|
||||
33
chat/python/services_router/plugin_call_service.py
Normal file
33
chat/python/services_router/plugin_call_service.py
Normal file
@@ -0,0 +1,33 @@
|
||||
# -*- coding:utf-8 -*-
|
||||
import os
|
||||
import sys
|
||||
from typing import Any, List, Mapping, Optional, Union
|
||||
|
||||
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||||
sys.path.append(os.path.dirname(os.path.abspath(__file__)))
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException
|
||||
|
||||
from services.plugin_call.run import plugin_selection_run
|
||||
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
@router.post("/plugin_selection/")
|
||||
async def tool_selection(query_body: Mapping[str, Any]):
|
||||
if "queryText" not in query_body:
|
||||
raise HTTPException(status_code=400, detail="query_text is not in query_body")
|
||||
else:
|
||||
query_text = query_body["queryText"]
|
||||
|
||||
if "pluginConfigs" not in query_body:
|
||||
raise HTTPException(
|
||||
status_code=400, detail="pluginConfigs is not in query_body"
|
||||
)
|
||||
else:
|
||||
plugin_configs = query_body["pluginConfigs"]
|
||||
|
||||
resp = plugin_selection_run(query_text=query_text, plugin_configs=plugin_configs)
|
||||
|
||||
return resp
|
||||
|
||||
71
chat/python/services_router/preset_query_service.py
Normal file
71
chat/python/services_router/preset_query_service.py
Normal file
@@ -0,0 +1,71 @@
|
||||
# -*- coding:utf-8 -*-
|
||||
import os
|
||||
import sys
|
||||
from typing import Any, List, Mapping, Optional, Union
|
||||
|
||||
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||||
sys.path.append(os.path.dirname(os.path.abspath(__file__)))
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException
|
||||
|
||||
from services.query_retrieval.run import preset_query_retriever
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
@router.post("/preset_query_retrival")
|
||||
def preset_query_retrival(query_text_list: List[str], n_results: int = 5):
|
||||
parsed_retrieval_res_format = preset_query_retriever.retrieval_query_run(query_texts_list=query_text_list, filter_condition=None, n_results=n_results)
|
||||
|
||||
return parsed_retrieval_res_format
|
||||
|
||||
|
||||
@router.post("/preset_query_add")
|
||||
def preset_query_add(preset_info_list: List[Mapping[str, str]]):
|
||||
preset_queries = []
|
||||
preset_query_ids = []
|
||||
|
||||
for preset_info in preset_info_list:
|
||||
preset_queries.append(preset_info['preset_query'])
|
||||
preset_query_ids.append(preset_info['preset_query_id'])
|
||||
|
||||
preset_query_retriever.add_queries(query_text_list=preset_queries, query_id_list=preset_query_ids, metadatas=None)
|
||||
|
||||
return "success"
|
||||
|
||||
@router.post("/preset_query_update")
|
||||
def preset_query_update(preset_info_list: List[Mapping[str, str]]):
|
||||
preset_queries = []
|
||||
preset_query_ids = []
|
||||
|
||||
for preset_info in preset_info_list:
|
||||
preset_queries.append(preset_info['preset_query'])
|
||||
preset_query_ids.append(preset_info['preset_query_id'])
|
||||
|
||||
preset_query_retriever.update_queries(query_text_list=preset_queries, query_id_list=preset_query_ids, metadatas=None)
|
||||
|
||||
return "success"
|
||||
|
||||
|
||||
@router.get("/preset_query_empty")
|
||||
def preset_query_empty():
|
||||
preset_query_retriever.empty_query_collection()
|
||||
|
||||
return "success"
|
||||
|
||||
@router.post("/preset_delete_by_ids")
|
||||
def preset_delete_by_ids(preset_query_ids: List[str]):
|
||||
preset_query_retriever.delete_queries_by_ids(preset_query_ids)
|
||||
|
||||
return "success"
|
||||
|
||||
@router.post("/preset_get_by_ids")
|
||||
def preset_get_by_ids(preset_query_ids: List[str]):
|
||||
preset_queries = preset_query_retriever.get_query_by_ids(preset_query_ids)
|
||||
|
||||
return preset_queries
|
||||
|
||||
@router.get("/preset_query_size")
|
||||
def preset_query_size():
|
||||
size = preset_query_retriever.get_query_size()
|
||||
|
||||
return size
|
||||
161
chat/python/services_router/query2sql_service.py
Normal file
161
chat/python/services_router/query2sql_service.py
Normal file
@@ -0,0 +1,161 @@
|
||||
# -*- coding:utf-8 -*-
|
||||
import os
|
||||
import sys
|
||||
from typing import Any, List, Mapping, Optional, Union
|
||||
|
||||
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||||
sys.path.append(os.path.dirname(os.path.abspath(__file__)))
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException
|
||||
|
||||
from services.sql.run import text2sql_agent_router
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
@router.post("/query2sql")
|
||||
async def query2sql(query_body: Mapping[str, Any]):
|
||||
if 'queryText' not in query_body:
|
||||
raise HTTPException(status_code=400, detail="query_text is not in query_body")
|
||||
else:
|
||||
query_text = query_body['queryText']
|
||||
|
||||
if 'schema' not in query_body:
|
||||
raise HTTPException(status_code=400, detail="schema is not in query_body")
|
||||
else:
|
||||
schema = query_body['schema']
|
||||
|
||||
if 'currentDate' not in query_body:
|
||||
raise HTTPException(status_code=400, detail="currentDate is not in query_body")
|
||||
else:
|
||||
current_date = query_body['currentDate']
|
||||
|
||||
if 'linking' not in query_body:
|
||||
raise HTTPException(status_code=400, detail="linking is not in query_body")
|
||||
else:
|
||||
linking = query_body['linking']
|
||||
|
||||
if 'priorExts' not in query_body:
|
||||
raise HTTPException(status_code=400, detail="prior_exts is not in query_body")
|
||||
else:
|
||||
prior_exts = query_body['priorExts']
|
||||
|
||||
if 'filterCondition' not in query_body:
|
||||
raise HTTPException(status_code=400, detail="filterCondition is not in query_body")
|
||||
else:
|
||||
filter_condition = query_body['filterCondition']
|
||||
|
||||
model_name = schema['modelName']
|
||||
fields_list = schema['fieldNameList']
|
||||
prior_schema_links = {item['fieldValue']:item['fieldName'] for item in linking}
|
||||
|
||||
resp = await text2sql_agent_router.async_query2sql(query_text=query_text, filter_condition=filter_condition, model_name=model_name, fields_list=fields_list,
|
||||
data_date=current_date, prior_schema_links=prior_schema_links, prior_exts=prior_exts)
|
||||
|
||||
return resp
|
||||
|
||||
|
||||
@router.post("/query2sql_setting_update")
|
||||
def query2sql_setting_update(query_body: Mapping[str, Any]):
|
||||
if 'sqlExamplars' not in query_body:
|
||||
raise HTTPException(status_code=400, detail="sqlExamplars is not in query_body")
|
||||
else:
|
||||
sql_examplars = query_body['sqlExamplars']
|
||||
|
||||
if 'sqlIds' not in query_body:
|
||||
raise HTTPException(status_code=400, detail="sqlIds is not in query_body")
|
||||
else:
|
||||
sql_ids = query_body['sqlIds']
|
||||
|
||||
if 'exampleNums' not in query_body:
|
||||
raise HTTPException(status_code=400, detail="exampleNums is not in query_body")
|
||||
else:
|
||||
example_nums = query_body['exampleNums']
|
||||
|
||||
if 'fewshotNums' not in query_body:
|
||||
raise HTTPException(status_code=400, detail="fewshotNums is not in query_body")
|
||||
else:
|
||||
fewshot_nums = query_body['fewshotNums']
|
||||
|
||||
if 'selfConsistencyNums' not in query_body:
|
||||
raise HTTPException(status_code=400, detail="selfConsistencyNums is not in query_body")
|
||||
else:
|
||||
self_consistency_nums = query_body['selfConsistencyNums']
|
||||
|
||||
if 'isShortcut' not in query_body:
|
||||
raise HTTPException(status_code=400, detail="isShortcut is not in query_body")
|
||||
else:
|
||||
is_shortcut = query_body['isShortcut']
|
||||
|
||||
if 'isSelfConsistency' not in query_body:
|
||||
raise HTTPException(status_code=400, detail="isSelfConsistency is not in query_body")
|
||||
else:
|
||||
is_self_consistency = query_body['isSelfConsistency']
|
||||
|
||||
text2sql_agent_router.update_configs(is_shortcut=is_shortcut, is_self_consistency=is_self_consistency,
|
||||
sql_example_ids=sql_ids, sql_example_units=sql_examplars,
|
||||
num_examples=example_nums, num_fewshots=fewshot_nums, num_self_consistency=self_consistency_nums)
|
||||
|
||||
return "success"
|
||||
|
||||
|
||||
@router.post("/query2sql_add_examples")
|
||||
def query2sql_add_examples(query_body: Mapping[str, Any]):
|
||||
if 'sqlIds' not in query_body:
|
||||
raise HTTPException(status_code=400, detail="sqlIds is not in query_body")
|
||||
else:
|
||||
sql_ids = query_body['sqlIds']
|
||||
|
||||
if 'sqlExamplars' not in query_body:
|
||||
raise HTTPException(status_code=400,
|
||||
detail="sqlExamplars is not in query_body")
|
||||
else:
|
||||
sql_examplars = query_body['sqlExamplars']
|
||||
|
||||
text2sql_agent_router.sql_agent.add_examples(sql_example_ids=sql_ids, sql_example_units=sql_examplars)
|
||||
text2sql_agent_router.sql_agent_cs.add_examples(sql_example_ids=sql_ids, sql_example_units=sql_examplars)
|
||||
|
||||
return "success"
|
||||
|
||||
|
||||
@router.post("/query2sql_update_examples")
|
||||
def query2sql_update_examples(query_body: Mapping[str, Any]):
|
||||
if 'sqlIds' not in query_body:
|
||||
raise HTTPException(status_code=400, detail="sqlIds is not in query_body")
|
||||
else:
|
||||
sql_ids = query_body['sqlIds']
|
||||
|
||||
if 'sqlExamplars' not in query_body:
|
||||
raise HTTPException(status_code=400,
|
||||
detail="sqlExamplars is not in query_body")
|
||||
else:
|
||||
sql_examplars = query_body['sqlExamplars']
|
||||
|
||||
text2sql_agent_router.sql_agent.update_examples(sql_example_ids=sql_ids, sql_example_units=sql_examplars)
|
||||
text2sql_agent_router.sql_agent_cs.update_examples(sql_example_ids=sql_ids, sql_example_units=sql_examplars)
|
||||
|
||||
return "success"
|
||||
|
||||
|
||||
@router.post("/query2sql_delete_examples")
|
||||
def query2sql_delete_examples(query_body: Mapping[str, Any]):
|
||||
if 'sqlIds' not in query_body:
|
||||
raise HTTPException(status_code=400, detail="sqlIds is not in query_body")
|
||||
else:
|
||||
sql_ids = query_body['sqlIds']
|
||||
|
||||
text2sql_agent_router.sql_agent.delete_examples(sql_example_ids=sql_ids)
|
||||
text2sql_agent_router.sql_agent_cs.delete_examples(sql_example_ids=sql_ids)
|
||||
|
||||
return "success"
|
||||
|
||||
|
||||
@router.get("/query2sql_count_examples")
|
||||
def query2sql_count_examples():
|
||||
sql_agent_examples_cnt = text2sql_agent_router.sql_agent.count_examples()
|
||||
sql_agent_cs_examples_cnt = text2sql_agent_router.sql_agent_cs.count_examples()
|
||||
|
||||
assert sql_agent_examples_cnt == sql_agent_cs_examples_cnt
|
||||
|
||||
return sql_agent_examples_cnt
|
||||
|
||||
156
chat/python/services_router/retriever_service.py
Normal file
156
chat/python/services_router/retriever_service.py
Normal file
@@ -0,0 +1,156 @@
|
||||
# -*- coding:utf-8 -*-
|
||||
import os
|
||||
import sys
|
||||
from typing import Any, List, Mapping, Optional, Union
|
||||
|
||||
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||||
sys.path.append(os.path.dirname(os.path.abspath(__file__)))
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException
|
||||
|
||||
from services.query_retrieval.run import collection_manager
|
||||
from services.query_retrieval.retriever import ChromaCollectionRetriever
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
@router.get("/list_collections")
|
||||
def list_collections():
|
||||
collections = collection_manager.list_collections()
|
||||
|
||||
return collections
|
||||
|
||||
@router.get("/create_collection")
|
||||
def create_collection(collection_name: str):
|
||||
collection_manager.create_collection(collection_name)
|
||||
|
||||
return "success"
|
||||
|
||||
@router.get("/delete_collection")
|
||||
def delete_collection(collection_name: str):
|
||||
collection_manager.delete_collection(collection_name)
|
||||
|
||||
return "success"
|
||||
|
||||
@router.get("/get_collection")
|
||||
def get_collection(collection_name: str):
|
||||
collection = collection_manager.get_collection(collection_name)
|
||||
|
||||
return collection
|
||||
|
||||
@router.get("/get_or_create_collection")
|
||||
def get_or_create_collection(collection_name: str):
|
||||
collection = collection_manager.get_or_create_collection(collection_name)
|
||||
|
||||
return collection
|
||||
|
||||
@router.post("/add_query")
|
||||
def query_add(collection_name:str, query_info_list: List[Mapping[str, Any]]):
|
||||
queries = []
|
||||
query_ids = []
|
||||
metadatas = []
|
||||
embeddings = []
|
||||
|
||||
for query_info in query_info_list:
|
||||
queries.append(query_info['query'])
|
||||
query_ids.append(query_info['queryId'])
|
||||
metadatas.append(query_info['metadata'])
|
||||
embeddings.append(query_info['queryEmbedding'])
|
||||
|
||||
if None in embeddings:
|
||||
embeddings = None
|
||||
if None in queries:
|
||||
queries = None
|
||||
|
||||
if embeddings is None and queries is None:
|
||||
raise HTTPException(status_code=400, detail="query and queryEmbedding are None")
|
||||
if embeddings is not None and queries is not None:
|
||||
raise HTTPException(status_code=400, detail="query and queryEmbedding are not None")
|
||||
|
||||
query_collection = collection_manager.get_or_create_collection(collection_name=collection_name)
|
||||
query_retriever = ChromaCollectionRetriever(collection=query_collection)
|
||||
query_retriever.add_queries(query_text_list=queries, query_id_list=query_ids, metadatas=metadatas, embeddings=embeddings)
|
||||
|
||||
return "success"
|
||||
|
||||
@router.post("/update_query")
|
||||
def update_query(collection_name:str, query_info_list: List[Mapping[str, Any]]):
|
||||
queries = []
|
||||
query_ids = []
|
||||
metadatas = []
|
||||
embeddings = []
|
||||
|
||||
for query_info in query_info_list:
|
||||
queries.append(query_info['query'])
|
||||
query_ids.append(query_info['queryId'])
|
||||
metadatas.append(query_info['metadata'])
|
||||
embeddings.append(query_info['queryEmbedding'])
|
||||
|
||||
if None in embeddings:
|
||||
embeddings = None
|
||||
if None in queries:
|
||||
queries = None
|
||||
|
||||
if embeddings is None and queries is None:
|
||||
raise HTTPException(status_code=400, detail="query and queryEmbedding are None")
|
||||
if embeddings is not None and queries is not None:
|
||||
raise HTTPException(status_code=400, detail="query and queryEmbedding are not None")
|
||||
|
||||
query_collection = collection_manager.get_or_create_collection(collection_name=collection_name)
|
||||
query_retriever = ChromaCollectionRetriever(collection=query_collection)
|
||||
query_retriever.update_queries(query_text_list=queries, query_id_list=query_ids, metadatas=metadatas, embeddings=embeddings)
|
||||
|
||||
return "success"
|
||||
|
||||
@router.get("/empty_query")
|
||||
def empty_query(collection_name:str):
|
||||
query_collection = collection_manager.get_or_create_collection(collection_name=collection_name)
|
||||
query_retriever = ChromaCollectionRetriever(collection=query_collection)
|
||||
query_retriever.empty_query_collection()
|
||||
|
||||
return "success"
|
||||
|
||||
|
||||
@router.post("/delete_query_by_ids")
|
||||
def delete_query_by_ids(collection_name:str, query_ids: List[str]):
|
||||
query_collection = collection_manager.get_or_create_collection(collection_name=collection_name)
|
||||
query_retriever = ChromaCollectionRetriever(collection=query_collection)
|
||||
query_retriever.delete_queries_by_ids(query_ids=query_ids)
|
||||
|
||||
return "success"
|
||||
|
||||
@router.post("/get_query_by_ids")
|
||||
def get_query_by_ids(collection_name:str, query_ids: List[str]):
|
||||
query_collection = collection_manager.get_or_create_collection(collection_name=collection_name)
|
||||
query_retriever = ChromaCollectionRetriever(collection=query_collection)
|
||||
queries = query_retriever.get_query_by_ids(query_ids=query_ids)
|
||||
|
||||
return queries
|
||||
|
||||
@router.get("/query_size")
|
||||
def query_size(collection_name:str):
|
||||
query_collection = collection_manager.get_or_create_collection(collection_name=collection_name)
|
||||
query_retriever = ChromaCollectionRetriever(collection=query_collection)
|
||||
size = query_retriever.get_query_size()
|
||||
|
||||
return size
|
||||
|
||||
@router.post("/retrieve_query")
|
||||
def retrieve_query(collection_name:str, query_info: Mapping[str, Any], n_results:int=10):
|
||||
query_collection = collection_manager.get_or_create_collection(collection_name=collection_name)
|
||||
query_retriever = ChromaCollectionRetriever(collection=query_collection)
|
||||
|
||||
query_texts_list = query_info['queryTextsList']
|
||||
qeuery_embeddings = query_info['queryEmbeddings']
|
||||
filter_condition = query_info['filterCondition']
|
||||
|
||||
if query_texts_list is None and qeuery_embeddings is None:
|
||||
raise HTTPException(status_code=400, detail="query and queryEmbedding are None")
|
||||
if query_texts_list is not None and qeuery_embeddings is not None:
|
||||
raise HTTPException(status_code=400, detail="query and queryEmbedding are not None")
|
||||
|
||||
parsed_retrieval_res_format = query_retriever.retrieval_query_run(query_texts_list=query_texts_list,
|
||||
query_embeddings=qeuery_embeddings,
|
||||
filter_condition=filter_condition,
|
||||
n_results=n_results)
|
||||
|
||||
return parsed_retrieval_res_format
|
||||
80
chat/python/services_router/solved_query_service.py
Normal file
80
chat/python/services_router/solved_query_service.py
Normal file
@@ -0,0 +1,80 @@
|
||||
# -*- coding:utf-8 -*-
|
||||
import os
|
||||
import sys
|
||||
from typing import Any, List, Mapping, Optional, Union
|
||||
|
||||
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||||
sys.path.append(os.path.dirname(os.path.abspath(__file__)))
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException
|
||||
|
||||
from services.query_retrieval.run import solved_query_retriever
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
@router.post("/solved_query_retrival")
|
||||
def solved_query_retrival(query_info: Mapping[str, Any], n_results: int = 5):
|
||||
query_texts_list = query_info['queryTextsList']
|
||||
filter_condition = query_info['filterCondition']
|
||||
|
||||
parsed_retrieval_res_format = solved_query_retriever.retrieval_query_run(query_texts_list=query_texts_list,
|
||||
filter_condition=filter_condition,
|
||||
n_results=n_results)
|
||||
|
||||
return parsed_retrieval_res_format
|
||||
|
||||
|
||||
@router.post("/solved_query_add")
|
||||
def add_solved_queries(sovled_query_info_list: List[Mapping[str, Any]]):
|
||||
queries = []
|
||||
query_ids = []
|
||||
metadatas = []
|
||||
|
||||
for sovled_query_info in sovled_query_info_list:
|
||||
queries.append(sovled_query_info['query'])
|
||||
query_ids.append(sovled_query_info['query_id'])
|
||||
metadatas.append(sovled_query_info['metadata'])
|
||||
|
||||
solved_query_retriever.add_queries(query_text_list=queries, query_id_list=query_ids, metadatas=metadatas)
|
||||
|
||||
return "success"
|
||||
|
||||
@router.post("/solved_query_update")
|
||||
def solved_query_update(sovled_query_info_list: List[Mapping[str, Any]]):
|
||||
queries = []
|
||||
query_ids = []
|
||||
metadatas = []
|
||||
|
||||
for sovled_query_info in sovled_query_info_list:
|
||||
queries.append(sovled_query_info['query'])
|
||||
query_ids.append(sovled_query_info['query_id'])
|
||||
metadatas.append(sovled_query_info['metadata'])
|
||||
|
||||
solved_query_retriever.update_queries(query_text_list=queries, query_id_list=query_ids, metadatas=metadatas)
|
||||
|
||||
return "success"
|
||||
|
||||
|
||||
@router.get("/solved_query_empty")
|
||||
def solved_query_empty():
|
||||
solved_query_retriever.empty_query_collection()
|
||||
|
||||
return "success"
|
||||
|
||||
@router.post("/solved_query_delete_by_ids")
|
||||
def solved_query_delete_by_ids(query_ids: List[str]):
|
||||
solved_query_retriever.delete_queries_by_ids(query_ids=query_ids)
|
||||
|
||||
return "success"
|
||||
|
||||
@router.post("/solved_query_get_by_ids")
|
||||
def solved_query_get_by_ids(query_ids: List[str]):
|
||||
queries = solved_query_retriever.get_query_by_ids(query_ids=query_ids)
|
||||
|
||||
return queries
|
||||
|
||||
@router.get("/solved_query_size")
|
||||
def solved_query_size():
|
||||
size = solved_query_retriever.get_query_size()
|
||||
|
||||
return size
|
||||
33
chat/python/supersonic_llmparser.py
Normal file
33
chat/python/supersonic_llmparser.py
Normal file
@@ -0,0 +1,33 @@
|
||||
# -*- coding:utf-8 -*-
|
||||
import os
|
||||
import sys
|
||||
|
||||
import uvicorn
|
||||
|
||||
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||||
sys.path.append(os.path.dirname(os.path.abspath(__file__)))
|
||||
|
||||
from typing import Any, List, Mapping
|
||||
|
||||
from fastapi import FastAPI, HTTPException
|
||||
|
||||
from config.config_parse import LLMPARSER_HOST, LLMPARSER_PORT
|
||||
|
||||
from services_router import (query2sql_service, preset_query_service,
|
||||
solved_query_service, plugin_call_service, retriever_service)
|
||||
|
||||
|
||||
app = FastAPI()
|
||||
|
||||
@app.get("/health")
|
||||
def read_health():
|
||||
return {"status": "Healthy"}
|
||||
|
||||
app.include_router(preset_query_service.router)
|
||||
app.include_router(solved_query_service.router)
|
||||
app.include_router(query2sql_service.router)
|
||||
app.include_router(plugin_call_service.router)
|
||||
app.include_router(retriever_service.router)
|
||||
|
||||
if __name__ == "__main__":
|
||||
uvicorn.run(app, host=LLMPARSER_HOST, port=LLMPARSER_PORT)
|
||||
156
chat/python/utils/chromadb_utils.py
Normal file
156
chat/python/utils/chromadb_utils.py
Normal file
@@ -0,0 +1,156 @@
|
||||
# -*- coding:utf-8 -*-
|
||||
from typing import Any, List, Mapping, Optional, Union
|
||||
|
||||
import chromadb
|
||||
from chromadb.api import Collection
|
||||
from chromadb.config import Settings
|
||||
from chromadb.api import Collection, Documents, Embeddings
|
||||
|
||||
import os
|
||||
import sys
|
||||
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||||
sys.path.append(os.path.dirname(os.path.abspath(__file__)))
|
||||
|
||||
from instances.logging_instance import logger
|
||||
|
||||
|
||||
def empty_chroma_collection_2(collection:Collection):
|
||||
collection_name = collection.name
|
||||
client = collection._client
|
||||
metadata = collection.metadata
|
||||
embedding_function = collection._embedding_function
|
||||
|
||||
client.delete_collection(collection_name)
|
||||
|
||||
new_collection = client.get_or_create_collection(name=collection_name,
|
||||
metadata=metadata,
|
||||
embedding_function=embedding_function)
|
||||
|
||||
size_of_new_collection = new_collection.count()
|
||||
|
||||
logger.info(f'Collection {collection_name} emptied. Size of new collection: {size_of_new_collection}')
|
||||
|
||||
return new_collection
|
||||
|
||||
|
||||
def empty_chroma_collection(collection:Collection) -> None:
|
||||
collection.delete()
|
||||
|
||||
|
||||
def add_chroma_collection(collection:Collection,
|
||||
queries:List[str],
|
||||
query_ids:List[str],
|
||||
metadatas:List[Mapping[str, str]]=None,
|
||||
embeddings:Embeddings=None
|
||||
) -> None:
|
||||
|
||||
collection.add(documents=queries,
|
||||
ids=query_ids,
|
||||
metadatas=metadatas,
|
||||
embeddings=embeddings)
|
||||
|
||||
|
||||
def update_chroma_collection(collection:Collection,
|
||||
queries:List[str],
|
||||
query_ids:List[str],
|
||||
metadatas:List[Mapping[str, str]]=None,
|
||||
embeddings:Embeddings=None
|
||||
) -> None:
|
||||
|
||||
collection.update(documents=queries,
|
||||
ids=query_ids,
|
||||
metadatas=metadatas,
|
||||
embeddings=embeddings)
|
||||
|
||||
|
||||
def query_chroma_collection(collection:Collection, query_texts:List[str]=None, query_embeddings:Embeddings=None,
|
||||
filter_condition:Mapping[str,str]=None, n_results:int=10):
|
||||
outer_opt = '$and'
|
||||
inner_opt = '$eq'
|
||||
|
||||
if filter_condition is not None:
|
||||
if len(filter_condition)==1:
|
||||
outer_filter = filter_condition
|
||||
else:
|
||||
inner_filter = [{_k: {inner_opt:_v}} for _k, _v in filter_condition.items()]
|
||||
outer_filter = {outer_opt: inner_filter}
|
||||
else:
|
||||
outer_filter = None
|
||||
|
||||
logger.info('outer_filter: {}'.format(outer_filter))
|
||||
|
||||
res = collection.query(query_texts=query_texts, query_embeddings=query_embeddings,
|
||||
n_results=n_results, where=outer_filter)
|
||||
return res
|
||||
|
||||
|
||||
def parse_retrieval_chroma_collection_query(res:List[Mapping[str, Any]]):
|
||||
parsed_res = [[] for _ in range(0, len(res['ids']))]
|
||||
|
||||
retrieval_ids = res['ids']
|
||||
retrieval_distances = res['distances']
|
||||
retrieval_sentences = res['documents']
|
||||
retrieval_metadatas = res['metadatas']
|
||||
|
||||
for query_idx in range(0, len(retrieval_ids)):
|
||||
id_ls = retrieval_ids[query_idx]
|
||||
distance_ls = retrieval_distances[query_idx]
|
||||
sentence_ls = retrieval_sentences[query_idx]
|
||||
metadata_ls = retrieval_metadatas[query_idx]
|
||||
|
||||
for idx in range(0, len(id_ls)):
|
||||
id = id_ls[idx]
|
||||
distance = distance_ls[idx]
|
||||
sentence = sentence_ls[idx]
|
||||
metadata = metadata_ls[idx]
|
||||
|
||||
parsed_res[query_idx].append({
|
||||
'id': id,
|
||||
'distance': distance,
|
||||
'query': sentence,
|
||||
'metadata': metadata
|
||||
})
|
||||
|
||||
return parsed_res
|
||||
|
||||
def chroma_collection_query_retrieval_format(query_list:List[str], query_embeddings:Embeddings ,retrieval_list:List[Mapping[str, Any]]):
|
||||
res = []
|
||||
|
||||
if query_list is not None and query_embeddings is not None:
|
||||
raise Exception("query_list and query_embeddings are not None")
|
||||
if query_list is None and query_embeddings is None:
|
||||
raise Exception("query_list and query_embeddings are None")
|
||||
|
||||
if query_list is not None:
|
||||
for query_idx in range(0, len(query_list)):
|
||||
query = query_list[query_idx]
|
||||
retrieval = retrieval_list[query_idx]
|
||||
|
||||
res.append({
|
||||
'query': query,
|
||||
'retrieval': retrieval
|
||||
})
|
||||
else:
|
||||
for query_idx in range(0, len(query_embeddings)):
|
||||
query_embedding = query_embeddings[query_idx]
|
||||
retrieval = retrieval_list[query_idx]
|
||||
|
||||
res.append({
|
||||
'query_embedding': query_embedding,
|
||||
'retrieval': retrieval
|
||||
})
|
||||
|
||||
return res
|
||||
|
||||
|
||||
def delete_chroma_collection_by_ids(collection:Collection, query_ids:List[str]) -> None:
|
||||
collection.delete(ids=query_ids)
|
||||
|
||||
def get_chroma_collection_by_ids(collection:Collection, query_ids:List[str]):
|
||||
res = collection.get(ids=query_ids)
|
||||
|
||||
return res
|
||||
|
||||
def get_chroma_collection_size(collection:Collection) -> int:
|
||||
return collection.count()
|
||||
|
||||
Reference in New Issue
Block a user