mirror of
https://github.com/tencentmusic/supersonic.git
synced 2025-12-10 11:07:06 +00:00
Compare commits
104 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
f60c1675cd | ||
|
|
1d9b6d6877 | ||
|
|
d8930e8906 | ||
|
|
c68df24375 | ||
|
|
bb1001677d | ||
|
|
7a1cfbcef8 | ||
|
|
67b9c4bf79 | ||
|
|
7cb7697353 | ||
|
|
3e18655c69 | ||
|
|
e7d52f87f0 | ||
|
|
2cd8f8022b | ||
|
|
e08435902a | ||
|
|
b44fa2bf3c | ||
|
|
d7f1f06daf | ||
|
|
4c26e0c972 | ||
|
|
d7fafa361d | ||
|
|
0c69651ef3 | ||
|
|
b5fdbfbbf6 | ||
|
|
33a2688e77 | ||
|
|
6bd97cd8af | ||
|
|
64615cbef9 | ||
|
|
dfb3b59984 | ||
|
|
61641ecb00 | ||
|
|
5016881ce3 | ||
|
|
fe75b3e393 | ||
|
|
3db443f9b1 | ||
|
|
59c21ea19a | ||
|
|
95334441b1 | ||
|
|
276b224c13 | ||
|
|
f03da53d6f | ||
|
|
9201550027 | ||
|
|
c86cd9f901 | ||
|
|
ef8caea9d2 | ||
|
|
6daaff8c30 | ||
|
|
4b00c16eb7 | ||
|
|
4dae84034e | ||
|
|
e6eac03ec6 | ||
|
|
e9a479e2df | ||
|
|
7db1cc270e | ||
|
|
3bf5b86535 | ||
|
|
3d30632b41 | ||
|
|
287a6561ff | ||
|
|
169262cc62 | ||
|
|
fda5a577d6 | ||
|
|
f89be48e98 | ||
|
|
2c7afd0d55 | ||
|
|
2ad0553f6c | ||
|
|
340cb2c835 | ||
|
|
caefa501f2 | ||
|
|
5c96d75d39 | ||
|
|
86c2f96942 | ||
|
|
73899e3174 | ||
|
|
49bb2c6d8b | ||
|
|
9223a4f856 | ||
|
|
f3f60af231 | ||
|
|
3cdfcae01c | ||
|
|
0c6efada43 | ||
|
|
d79f73eab6 | ||
|
|
3ae720ef30 | ||
|
|
221e88de0f | ||
|
|
23d926f195 | ||
|
|
97b11ec244 | ||
|
|
899047dbd1 | ||
|
|
cb4b91878f | ||
|
|
6af661459c | ||
|
|
0e0ba51750 | ||
|
|
a5c32ac064 | ||
|
|
abbe8c84a1 | ||
|
|
6c0f88d8b5 | ||
|
|
68ada561ac | ||
|
|
18b52ec742 | ||
|
|
ca8d7d89c1 | ||
|
|
e6ab7cb5ff | ||
|
|
9679169e6f | ||
|
|
ed0f856438 | ||
|
|
9aa5c93d9d | ||
|
|
b45592c009 | ||
|
|
6e0fa95e6f | ||
|
|
94f310d17f | ||
|
|
2bc29d64a4 | ||
|
|
c220ca69c2 | ||
|
|
4280aad0a7 | ||
|
|
c98d15059b | ||
|
|
a862a83272 | ||
|
|
c6d59701db | ||
|
|
39a85dc4ed | ||
|
|
507c02a8fd | ||
|
|
380597f0c3 | ||
|
|
e469c449b4 | ||
|
|
f8bdb8a4b4 | ||
|
|
d76216a2ec | ||
|
|
82cfb3050d | ||
|
|
57f7d0c67d | ||
|
|
c11a242f34 | ||
|
|
576fad5fb1 | ||
|
|
8171d754e0 | ||
|
|
6be0f02c75 | ||
|
|
95e3138ab2 | ||
|
|
3a30a1a317 | ||
|
|
46733d1728 | ||
|
|
b6734d99e1 | ||
|
|
9cb01149f8 | ||
|
|
db88127da9 | ||
|
|
0e492ef402 |
@@ -4,6 +4,15 @@
|
||||
- "Breaking Changes" describes any changes that may break existing functionality or cause
|
||||
compatibility issues with previous versions.
|
||||
|
||||
## SuperSonic [0.8.2] - 2023-12-18
|
||||
|
||||
### Added
|
||||
- rewrite Python service with Java project, default to Java implementation.
|
||||
- support setting the SQL generation method for large models in the interface.
|
||||
- optimization of metric market experience.
|
||||
- optimization of semantic modeling canvas experience.
|
||||
- code structure adjustment and abstraction optimization for chat.
|
||||
|
||||
## SuperSonic [0.7.5] - 2023-10-13
|
||||
|
||||
### Added
|
||||
|
||||
16
README.md
16
README.md
@@ -2,25 +2,25 @@
|
||||
|
||||
# SuperSonic (超音数)
|
||||
|
||||
**SuperSonic is an out-of-the-box yet highly extensible framework for building ChatBI**. SuperSonic provides a chat interface that empowers users to query data using natural language and visualize the results with suitable charts. To enable such experience, the only thing necessary is to build logical semantic models (definition of metrics/dimensions/entities, along with their meaning, context and relationships) on top of physical data models, and no data modification or copying is required. Meanwhile, SuperSonic is designed to be pluggable, allowing new functionalities to be added through plugins and core components to be integrated with other systems.
|
||||
**SuperSonic is the next-generation LLM-powered data analytics platform that integrates ChatBI and HeadlessBI**. SuperSonic provides a chat interface that empowers users to query data using natural language and visualize the results with suitable charts. To enable such experience, the only thing necessary is to build logical semantic models (definition of entities/metrics/dimensions/tags, along with their meaning, context and relationships) on top of physical data models, and **no data modification or copying** is required. Meanwhile, SuperSonic is designed to be **highly extensible**, allowing custom functionalities to be added and configured with Java SPI.
|
||||
|
||||
<img src="./docs/images/supersonic_demo.gif" height="100%" width="100%" align="center"/>
|
||||
|
||||
## Motivation
|
||||
|
||||
The emergence of Large Language Model (LLM) like ChatGPT is reshaping the way information is retrieved. In the field of data analytics, both academia and industry are primarily focused on leveraging LLM to convert natural language into SQL (so called text2sql or nl2sql). While some works exhibit promising results, their **reliability** is inadequate for real-world applications.
|
||||
The emergence of Large Language Model (LLM) like ChatGPT is reshaping the way information is retrieved. In the field of data analytics, both academia and industry are primarily focused on leveraging LLM to convert natural language into SQL (so called Text2SQL or NL2SQL). While some approaches exhibit promising results, their **reliability** and **efficiency** are insufficient for real-world applications.
|
||||
|
||||
From our perspective, the key to filling the real-world gap lies in three aspects:
|
||||
1. Introduce a semantic layer encapsulating underlying data context(joins, formulas, etc) to reduce **complexity**.
|
||||
2. Augment the LLM with schema mappers(as a kind of preprocessor) and semantic correctors(as a kind of postprocessor) to mitigate **hallucination**.
|
||||
3. Utilize heuristic rules when necessary to improve **efficiency**(in terms of latency and cost).
|
||||
1. Integrate ChatBI with HeadlessBI encapsulating underlying data context (joins, keys, formulas, etc) to **reduce complexity**.
|
||||
2. Augment the LLM with schema mappers(as a kind of preprocessor) and semantic correctors(as a kind of postprocessor) to **mitigate hallucination**.
|
||||
3. Utilize rule-based schema parsers when necessary to **improve efficiency**(in terms of latency and cost).
|
||||
|
||||
With these ideas in mind, we develop SuperSonic as a practical reference implementation and use it to power our real-world products. Additionally, to facilitate further development of ChatBI, we decide to open source SuperSonic as an extensible framework.
|
||||
|
||||
## Out-of-the-box Features
|
||||
|
||||
- Built-in CUI(Chat User Interface) for *business users* to enter data queries
|
||||
- Built-in GUI(Graphical User Interface) for *analytics engineers* to build semantic models
|
||||
- Built-in ChatBI interface for *business users* to enter natural language queries
|
||||
- Built-in HeadlessBI interface for *analytics engineers* to build semantic models
|
||||
- Built-in GUI for *system administrators* to manage chat agents and third-party plugins
|
||||
- Support input auto-completion as well as query recommendation
|
||||
- Support multi-turn conversation and history context management
|
||||
@@ -49,7 +49,7 @@ The high-level architecture and main process flow is as follows:
|
||||
SuperSonic comes with sample semantic models as well as chat conversations that can be used as a starting point. Please follow the steps:
|
||||
|
||||
- Download the latest prebuilt binary from the [release page](https://github.com/tencentmusic/supersonic/releases)
|
||||
- Run script "bin/supersonic-daemon.sh" to start services (one java process and one python process)
|
||||
- Run script "assembly/bin/supersonic-daemon.sh start" to start a standalone Java service
|
||||
- Visit http://localhost:9080 in the browser to start exploration
|
||||
|
||||
## Build and Development
|
||||
|
||||
20
README_CN.md
20
README_CN.md
@@ -1,6 +1,6 @@
|
||||
# 超音数(SuperSonic)
|
||||
# SuperSonic (超音数)
|
||||
|
||||
**超音数是一个开箱即用且易于扩展的数据问答对话框架**。通过超音数的问答对话界面,用户能够使用自然语言查询数据,系统会选择合适的可视化图表呈现结果。超音数不需要修改或复制数据,只需要在物理数据模型之上构建逻辑语义模型(指标/维度/实体的定义,以及他们的业务含义、相互间关系等),即可开启数据问答体验。与此同时,超音数被设计为可插拔式的框架,允许以插件形式来扩展新功能,或者将核心组件与其他系统集成。
|
||||
**SuperSonic融合ChatBI和HeadlessBI打造新一代的数据分析平台**。通过SuperSonic的问答对话界面,用户能够使用自然语言查询数据,系统会选择合适的可视化图表呈现结果。SuperSonic不需要修改或复制数据,只需要在物理数据模型之上构建逻辑语义模型(指标/维度/实体的定义,以及他们的业务含义、相互间关系等),即可开启数据问答体验。与此同时,SuperSonic被设计为可插拔的框架,采用Java SPI机制来扩展定制功能。
|
||||
|
||||
<img src="./docs/images/supersonic_demo.gif" height="100%" width="100%" align="center"/>
|
||||
|
||||
@@ -9,24 +9,24 @@
|
||||
大型语言模型(LLMs)如ChatGPT的出现正在重塑信息检索的方式。在数据分析领域,学术界和工业界主要关注利用深度学习模型将自然语言查询转换为SQL查询。虽然一些工作显示出有前景的结果,但它们的可靠性还达不到生产可用的要求。
|
||||
|
||||
在我们看来,为了在实际场景发挥价值,有三个关键点:
|
||||
1. 引入语义模型层,封装底层数据的上下文(关联、公式等),降低SQL生成的**复杂度**。
|
||||
1. 融合HeadlessBI,通过统一语义层封装底层数据细节(关联、键值、公式等),降低SQL生成的**复杂度**。
|
||||
2. 通过一前一后的模式映射器和语义修正器,来缓解LLM常见的**幻觉**现象。
|
||||
3. 设计启发式的规则,在一些特定场景提升语义解析的**效率**。
|
||||
|
||||
为了验证上述想法,我们开发了超音数项目,并将其应用在实际的内部产品中。与此同时,我们将超音数作为一个可扩展的框架开源,希望能够促进数据问答对话领域的进一步发展。
|
||||
为了验证上述想法,我们开发了SuperSonic项目,并将其应用在实际的内部产品中。与此同时,我们将SuperSonic作为一个可扩展的框架开源,希望能够促进数据问答对话领域的进一步发展。
|
||||
|
||||
## 开箱即用的特性
|
||||
|
||||
- 内置对话界面以便*业务用户*输入数据查询。
|
||||
- 内置图形界面以便*分析工程师*构建语义模型。
|
||||
- 内置图形界面以便*系统管理员*管理第三方插件和对话助理。
|
||||
- 内置ChatBI界面以便*业务用户*输入数据查询。
|
||||
- 内置HeadlessBI界面以便*分析工程师*构建语义模型。
|
||||
- 内置图形用户界面以便*系统管理员*管理第三方插件和对话助理。
|
||||
- 支持文本输入的联想和查询问题的推荐。
|
||||
- 支持多轮对话,根据语境自动切换上下文。
|
||||
- 支持四级权限控制:主题域级、模型级、列级、行级。
|
||||
|
||||
## 易于扩展的组件
|
||||
|
||||
超音数的整体架构和主流程如下图所示:
|
||||
SuperSonic的整体架构和主流程如下图所示:
|
||||
|
||||
<img src="./docs/images/supersonic_components.png" height="65%" width="65%" align="center"/>
|
||||
|
||||
@@ -44,10 +44,10 @@
|
||||
|
||||
## 快速体验
|
||||
|
||||
超音数自带样例的语义模型和问答对话,只需以下三步即可快速体验:
|
||||
SuperSonic自带样例的语义模型和问答对话,只需以下三步即可快速体验:
|
||||
|
||||
- 从[release page](https://github.com/tencentmusic/supersonic/releases)下载预先构建好的发行包
|
||||
- 运行 "bin/supersonic-daemon.sh"启动服务(一个Java进程和一个Python进程)
|
||||
- 运行 "assembly/bin/supersonic-daemon.sh start"启动standalone模式的Java服务
|
||||
- 在浏览器访问http://localhost:9080 开启探索
|
||||
|
||||
## 如何构建和部署
|
||||
|
||||
@@ -6,11 +6,18 @@ set "baseDir=%~dp0.."
|
||||
set "buildDir=%baseDir%\build"
|
||||
set "runtimeDir=%baseDir%\..\runtime"
|
||||
set "pip_path=pip3"
|
||||
set "service=%~1"
|
||||
|
||||
|
||||
rem 1. build backend java modules
|
||||
del /q "%buildDir%\*.tar.gz" 2>NUL
|
||||
call mvn -f "%baseDir%\..\pom.xml" clean package -DskipTests
|
||||
|
||||
IF ERRORLEVEL 1 (
|
||||
ECHO Failed to build backend Java modules.
|
||||
EXIT /B 1
|
||||
)
|
||||
|
||||
rem 2. move package to build
|
||||
echo f|xcopy "%baseDir%\..\launchers\standalone\target\*.tar.gz" "%buildDir%\supersonic-standalone.tar.gz"
|
||||
|
||||
@@ -19,6 +26,11 @@ cd "%baseDir%\..\webapp"
|
||||
call start-fe-prod.bat
|
||||
copy /y "%baseDir%\..\webapp\supersonic-webapp.tar.gz" "%buildDir%\"
|
||||
|
||||
IF ERRORLEVEL 1 (
|
||||
ECHO Failed to build frontend webapp.
|
||||
EXIT /B 1
|
||||
)
|
||||
|
||||
rem 4. copy webapp to java classpath
|
||||
cd "%buildDir%"
|
||||
tar -zxvf supersonic-webapp.tar.gz
|
||||
@@ -26,16 +38,23 @@ move supersonic-webapp webapp
|
||||
move webapp ..\..\launchers\standalone\target\classes
|
||||
|
||||
rem 5. build backend python modules
|
||||
echo "start installing python modules with pip: ${pip_path}"
|
||||
set requirementPath="%baseDir%/../chat/python/requirements.txt"
|
||||
%pip_path% install -r %requirementPath%
|
||||
echo "install python modules success"
|
||||
if "%service%"=="pyllm" (
|
||||
echo "start installing python modules with pip: ${pip_path}"
|
||||
set requirementPath="%baseDir%/../chat/python/requirements.txt"
|
||||
%pip_path% install -r %requirementPath%
|
||||
echo "install python modules success"
|
||||
)
|
||||
|
||||
call :BUILD_RUNTIME
|
||||
|
||||
:BUILD_RUNTIME
|
||||
rem 6. reset runtime
|
||||
rd /s /q "%runtimeDir%"
|
||||
IF EXIST "%runtimeDir%" (
|
||||
echo begin to delete dir : %runtimeDir%
|
||||
rd /s /q "%runtimeDir%"
|
||||
) ELSE (
|
||||
echo %runtimeDir% does not exist, create directly
|
||||
)
|
||||
mkdir "%runtimeDir%"
|
||||
tar -zxvf "%buildDir%\supersonic-standalone.tar.gz" -C "%runtimeDir%"
|
||||
for /d %%f in ("%runtimeDir%\launchers-standalone-*") do (
|
||||
|
||||
@@ -4,16 +4,19 @@ set -x
|
||||
sbinDir=$(cd "$(dirname "$0")"; pwd)
|
||||
chmod +x $sbinDir/supersonic-common.sh
|
||||
source $sbinDir/supersonic-common.sh
|
||||
|
||||
cd $baseDir
|
||||
|
||||
service=$1
|
||||
#1. build backend java modules
|
||||
rm -fr ${buildDir}/*.tar.gz
|
||||
rm -fr dist
|
||||
|
||||
set +x
|
||||
|
||||
mvn -f $baseDir/../ clean package -DskipTests
|
||||
# check build result
|
||||
if [ $? -ne 0 ]; then
|
||||
echo "Failed to build backend Java modules."
|
||||
exit 1
|
||||
fi
|
||||
|
||||
#2. move package to build
|
||||
cp $baseDir/../launchers/semantic/target/*.tar.gz ${buildDir}/supersonic-semantic.tar.gz
|
||||
@@ -26,6 +29,11 @@ cd ../webapp
|
||||
sh ./start-fe-prod.sh
|
||||
cp -fr ./supersonic-webapp.tar.gz ${buildDir}/
|
||||
|
||||
# check build result
|
||||
if [ $? -ne 0 ]; then
|
||||
echo "Failed to build frontend webapp."
|
||||
exit 1
|
||||
fi
|
||||
#4. copy webapp to java classpath
|
||||
cd $buildDir
|
||||
tar xvf supersonic-webapp.tar.gz
|
||||
@@ -36,13 +44,15 @@ cp -fr webapp ../../launchers/standalone/target/classes
|
||||
rm -fr ${buildDir}/webapp
|
||||
|
||||
#5. build backend python modules
|
||||
echo "start installing python modules with pip: ${pip_path}"
|
||||
requirementPath=$baseDir/../chat/python/requirements.txt
|
||||
${pip_path} install -r ${requirementPath}
|
||||
echo "install python modules success"
|
||||
if [ "$service" == "pyllm" ]; then
|
||||
echo "start installing python modules with pip: ${pip_path}"
|
||||
requirementPath=$baseDir/../chat/python/requirements.txt
|
||||
${pip_path} install -r ${requirementPath}
|
||||
echo "install python modules success"
|
||||
fi
|
||||
|
||||
#6. reset runtime
|
||||
rm -fr $runtimeDir/*
|
||||
rm -fr $runtimeDir/supersonic*
|
||||
moveAllToRuntime
|
||||
setEnvToWeb chat
|
||||
setEnvToWeb semantic
|
||||
|
||||
@@ -11,14 +11,14 @@ buildDir=$baseDir/build
|
||||
|
||||
readonly CHAT_APP_NAME="supersonic_chat"
|
||||
readonly SEMANTIC_APP_NAME="supersonic_semantic"
|
||||
readonly LLMPARSER_APP_NAME="supersonic_llmparser"
|
||||
readonly PYLLM_APP_NAME="supersonic_pyllm"
|
||||
readonly STANDALONE_APP_NAME="supersonic_standalone"
|
||||
readonly CHAT_SERVICE="chat"
|
||||
readonly SEMANTIC_SERVICE="semantic"
|
||||
readonly LLMPARSER_SERVICE="llmparser"
|
||||
readonly PYLLM_SERVICE="pyllm"
|
||||
readonly STANDALONE_SERVICE="standalone"
|
||||
readonly LLMPARSER_HOST="127.0.0.1"
|
||||
readonly LLMPARSER_PORT="9092"
|
||||
readonly PYLLM_HOST="127.0.0.1"
|
||||
readonly PYLLM_PORT="9092"
|
||||
|
||||
function setEnvToWeb {
|
||||
model_name=$1
|
||||
@@ -29,11 +29,15 @@ function setEnvToWeb {
|
||||
|
||||
function moveToRuntime {
|
||||
model_name=$1
|
||||
tar -zxvf ${buildDir}/supersonic-${model_name}.tar.gz -C ${runtimeDir}
|
||||
mv ${runtimeDir}/launchers-${model_name}-* ${runtimeDir}/supersonic-${model_name}
|
||||
|
||||
mkdir -p ${runtimeDir}/supersonic-${model_name}/webapp
|
||||
cp -fr ${buildDir}/webapp/* ${runtimeDir}/supersonic-${model_name}/webapp
|
||||
file="${buildDir}/supersonic-${model_name}.tar.gz"
|
||||
if [ -f "$file" ]; then
|
||||
tar -zxvf "$file" -C ${runtimeDir}
|
||||
mv ${runtimeDir}/launchers-${model_name}-* ${runtimeDir}/supersonic-${model_name}
|
||||
mkdir -p ${runtimeDir}/supersonic-${model_name}/webapp
|
||||
cp -fr ${buildDir}/webapp/* ${runtimeDir}/supersonic-${model_name}/webapp
|
||||
else
|
||||
echo "File $file does not exist. Skipping the move to runtime."
|
||||
fi
|
||||
}
|
||||
|
||||
function moveAllToRuntime {
|
||||
@@ -81,23 +85,23 @@ function runJavaService {
|
||||
|
||||
# run python service
|
||||
function runPythonService {
|
||||
pythonRunDir=${runtimeDir}/supersonic-${model_name}/llmparser
|
||||
pythonRunDir=${runtimeDir}/supersonic-${model_name}/pyllm
|
||||
cd $pythonRunDir
|
||||
nohup ${python_path} supersonic_llmparser.py > $pythonRunDir/llmparser.log 2>&1 &
|
||||
nohup ${python_path} supersonic_pyllm.py > $pythonRunDir/pyllm.log 2>&1 &
|
||||
# add health check
|
||||
for i in {1..10}
|
||||
do
|
||||
echo "llmparser health check attempt $i..."
|
||||
response=$(curl -s http://${LLMPARSER_HOST}:${LLMPARSER_PORT}/health)
|
||||
echo "llmparser health check response: $response"
|
||||
echo "pyllm health check attempt $i..."
|
||||
response=$(curl -s http://${PYLLM_HOST}:${PYLLM_PORT}/health)
|
||||
echo "pyllm health check response: $response"
|
||||
status_ok="Healthy"
|
||||
if [[ $response == *$status_ok* ]] ; then
|
||||
echo "llmparser Health check passed."
|
||||
echo "pyllm Health check passed."
|
||||
break
|
||||
else
|
||||
if [ "$i" -eq 10 ]; then
|
||||
echo "llmparser Health check failed after 10 attempts."
|
||||
echo "May still downloading model files. Please check llmparser.log in runtime directory."
|
||||
echo "pyllm Health check failed after 10 attempts."
|
||||
echo "May still downloading model files. Please check pyllm.log in runtime directory."
|
||||
fi
|
||||
echo "Retrying after 5 seconds..."
|
||||
sleep 5
|
||||
|
||||
@@ -9,10 +9,10 @@ set "main_class=com.tencent.supersonic.StandaloneLauncher"
|
||||
set "python_path=python"
|
||||
set "pip_path=pip3"
|
||||
set "standalone_service=standalone"
|
||||
set "llmparser_service=llmparser"
|
||||
set "pyllm_service=pyllm"
|
||||
|
||||
set "javaRunDir=%runtimeDir%\supersonic-standalone"
|
||||
set "pythonRunDir=%runtimeDir%\supersonic-standalone\llmparser"
|
||||
set "pythonRunDir=%runtimeDir%\supersonic-standalone\pyllm"
|
||||
|
||||
set "command=%~1"
|
||||
set "service=%~2"
|
||||
@@ -21,6 +21,10 @@ if "%service%"=="" (
|
||||
set "service=%standalone_service%"
|
||||
)
|
||||
|
||||
IF "%service%"=="pyllm" (
|
||||
SET "llmProxy=PythonLLMProxy"
|
||||
)
|
||||
|
||||
call :BUILD_RUNTIME
|
||||
|
||||
if "%command%"=="restart" (
|
||||
@@ -42,27 +46,23 @@ if "%command%"=="restart" (
|
||||
)
|
||||
|
||||
:START
|
||||
if "%service%"=="%llmparser_service%" (
|
||||
if "%service%"=="%pyllm_service%" (
|
||||
call :START_PYTHON
|
||||
call :START_JAVA
|
||||
goto :EOF
|
||||
)
|
||||
call :START_PYTHON
|
||||
call :START_JAVA
|
||||
goto :EOF
|
||||
|
||||
:STOP
|
||||
if "%service%"=="%llmparser_service%" (
|
||||
call :STOP_PYTHON
|
||||
goto :EOF
|
||||
)
|
||||
call :STOP_PYTHON
|
||||
call :STOP_JAVA
|
||||
goto :EOF
|
||||
|
||||
:START_PYTHON
|
||||
echo 'python service starting, see logs in llmparser/llmparser.log'
|
||||
echo 'python service starting, see logs in pyllm/pyllm.log'
|
||||
cd "%pythonRunDir%"
|
||||
start /B %python_path% supersonic_llmparser.py > %pythonRunDir%\llmparser.log 2>&1
|
||||
start /B %python_path% supersonic_pyllm.py > %pythonRunDir%\pyllm.log 2>&1
|
||||
timeout /t 10 >nul
|
||||
echo 'python service started'
|
||||
goto :EOF
|
||||
@@ -71,9 +71,9 @@ if "%command%"=="restart" (
|
||||
echo 'java service starting, see logs in logs/'
|
||||
cd "%javaRunDir%"
|
||||
if not exist "%runtimeDir%\supersonic-standalone\logs" mkdir "%runtimeDir%\supersonic-standalone\logs"
|
||||
set "libDir=%runtimeDir%\supersonic-%service%\lib"
|
||||
set "confDir=%runtimeDir%\supersonic-%service%\conf"
|
||||
set "webDir=%runtimeDir%\supersonic-%service%\webapp"
|
||||
set "libDir=%runtimeDir%\supersonic-standalone\lib"
|
||||
set "confDir=%runtimeDir%\supersonic-standalone\conf"
|
||||
set "webDir=%runtimeDir%\supersonic-standalone\webapp"
|
||||
set "classpath=%confDir%;%webDir%;%libDir%\*"
|
||||
set "java-command=-Dfile.encoding=UTF-8 -Duser.language=Zh -Duser.region=CN -Duser.timezone=GMT+08 -Xms1024m -Xmx2048m -cp %CLASSPATH% %MAIN_CLASS%"
|
||||
start /B java %java-command% >nul 2>&1
|
||||
@@ -96,7 +96,7 @@ if "%command%"=="restart" (
|
||||
goto :EOF
|
||||
|
||||
:RELOAD_EXAMPLE
|
||||
cd "%runtimeDir%\supersonic-standalone\llmparser\sql"
|
||||
cd "%runtimeDir%\supersonic-standalone\pyllm\sql"
|
||||
start %python_path% examples_reload_run.py
|
||||
goto :EOF
|
||||
|
||||
|
||||
@@ -22,8 +22,9 @@ app_name=$STANDALONE_APP_NAME
|
||||
main_class="com.tencent.supersonic.StandaloneLauncher"
|
||||
model_name=$service
|
||||
|
||||
if [ "$service" == "llmparser" ]; then
|
||||
if [ "$service" == "pyllm" ]; then
|
||||
model_name=${STANDALONE_SERVICE}
|
||||
export llmProxy=PythonLLMProxy
|
||||
fi
|
||||
|
||||
cd $baseDir
|
||||
@@ -43,14 +44,14 @@ function setAppName {
|
||||
app_name=$CHAT_APP_NAME
|
||||
elif [ "$service" == $SEMANTIC_SERVICE ]; then
|
||||
app_name=$SEMANTIC_APP_NAME
|
||||
elif [ "$service" == $LLMPARSER_SERVICE ]; then
|
||||
app_name=$LLMPARSER_APP_NAME
|
||||
elif [ "$service" == $PYLLM_SERVICE ]; then
|
||||
app_name=$PYLLM_APP_NAME
|
||||
fi
|
||||
}
|
||||
setAppName
|
||||
|
||||
function reloadExamples {
|
||||
pythonRunDir=${runtimeDir}/supersonic-${model_name}/llmparser
|
||||
pythonRunDir=${runtimeDir}/supersonic-${model_name}/pyllm
|
||||
cd $pythonRunDir/sql
|
||||
${python_path} examples_reload_run.py
|
||||
}
|
||||
@@ -61,7 +62,7 @@ function start()
|
||||
local_app_name=$1
|
||||
pid=$(ps aux |grep ${local_app_name} | grep -v grep | awk '{print $2}')
|
||||
if [[ "$pid" == "" ]]; then
|
||||
if [[ ${local_app_name} == $LLMPARSER_APP_NAME ]]; then
|
||||
if [[ ${local_app_name} == $PYLLM_APP_NAME ]]; then
|
||||
runPythonService ${local_app_name}
|
||||
else
|
||||
runJavaService ${local_app_name}
|
||||
@@ -87,7 +88,7 @@ function stop()
|
||||
|
||||
function reload()
|
||||
{
|
||||
if [[ $1 == $LLMPARSER_APP_NAME ]]; then
|
||||
if [[ $1 == $PYLLM_APP_NAME ]]; then
|
||||
reloadExamples
|
||||
fi
|
||||
}
|
||||
@@ -95,11 +96,11 @@ function reload()
|
||||
# 4. execute command operation
|
||||
case "$command" in
|
||||
start)
|
||||
if [ "$service" == $STANDALONE_SERVICE ]; then
|
||||
echo "Starting $LLMPARSER_APP_NAME"
|
||||
start $LLMPARSER_APP_NAME
|
||||
if [ "$service" == $PYLLM_SERVICE ]; then
|
||||
echo "Starting $app_name"
|
||||
start $app_name
|
||||
echo "Starting $STANDALONE_APP_NAME"
|
||||
start $STANDALONE_APP_NAME
|
||||
else
|
||||
echo "Starting $app_name"
|
||||
start $app_name
|
||||
@@ -107,15 +108,10 @@ case "$command" in
|
||||
echo "Start success"
|
||||
;;
|
||||
stop)
|
||||
if [ "$service" == $STANDALONE_SERVICE ]; then
|
||||
echo "Stopping $LLMPARSER_APP_NAME"
|
||||
stop $LLMPARSER_APP_NAME
|
||||
echo "Stopping $app_name"
|
||||
stop $app_name
|
||||
else
|
||||
echo "Stopping $app_name"
|
||||
stop ${app_name}
|
||||
fi
|
||||
echo "Stopping $app_name"
|
||||
stop $app_name
|
||||
echo "Stopping $PYLLM_APP_NAME"
|
||||
stop $PYLLM_APP_NAME
|
||||
echo "Stop success"
|
||||
;;
|
||||
reload)
|
||||
@@ -124,15 +120,15 @@ case "$command" in
|
||||
echo "Reload success"
|
||||
;;
|
||||
restart)
|
||||
if [ "$service" == $STANDALONE_SERVICE ]; then
|
||||
if [ "$service" == $PYLLM_SERVICE ]; then
|
||||
echo "Stopping ${app_name}"
|
||||
stop ${app_name}
|
||||
echo "Stopping ${LLMPARSER_APP_NAME}"
|
||||
stop $LLMPARSER_APP_NAME
|
||||
echo "Starting ${LLMPARSER_APP_NAME}"
|
||||
start $LLMPARSER_APP_NAME
|
||||
echo "Stopping ${STANDALONE_APP_NAME}"
|
||||
stop $STANDALONE_APP_NAME
|
||||
echo "Starting ${app_name}"
|
||||
start ${app_name}
|
||||
echo "Starting ${STANDALONE_APP_NAME}"
|
||||
start $STANDALONE_APP_NAME
|
||||
else
|
||||
echo "Stopping ${app_name}"
|
||||
stop ${app_name}
|
||||
|
||||
@@ -22,7 +22,7 @@
|
||||
</fileSet>
|
||||
<fileSet>
|
||||
<directory>${project.basedir}/../../chat/python</directory>
|
||||
<outputDirectory>llmparser</outputDirectory>
|
||||
<outputDirectory>pyllm</outputDirectory>
|
||||
<fileMode>0777</fileMode>
|
||||
<directoryMode>0755</directoryMode>
|
||||
</fileSet>
|
||||
|
||||
@@ -7,6 +7,9 @@ import com.tencent.supersonic.auth.api.authentication.request.UserReq;
|
||||
import java.util.List;
|
||||
import java.util.Set;
|
||||
|
||||
/**
|
||||
* UserAdaptor defines some interfaces for obtaining user and organization information
|
||||
*/
|
||||
public interface UserAdaptor {
|
||||
|
||||
List<String> getUserNames();
|
||||
|
||||
@@ -16,6 +16,9 @@ import java.util.List;
|
||||
import java.util.Set;
|
||||
import java.util.stream.Collectors;
|
||||
|
||||
/**
|
||||
* DefaultUserAdaptor provides a default method to obtain user and organization information
|
||||
*/
|
||||
public class DefaultUserAdaptor implements UserAdaptor {
|
||||
|
||||
private List<UserDO> getUserDOList() {
|
||||
|
||||
@@ -29,7 +29,6 @@ public abstract class AuthenticationInterceptor implements HandlerInterceptor {
|
||||
|
||||
protected S2ThreadContext s2ThreadContext;
|
||||
|
||||
|
||||
protected boolean isExcludedUri(String uri) {
|
||||
String excludePathStr = authenticationConfig.getExcludePath();
|
||||
if (Strings.isEmpty(excludePathStr)) {
|
||||
@@ -59,7 +58,6 @@ public abstract class AuthenticationInterceptor implements HandlerInterceptor {
|
||||
return "true".equalsIgnoreCase(internal);
|
||||
}
|
||||
|
||||
|
||||
protected void reflectSetparam(HttpServletRequest request, String key, String value) {
|
||||
try {
|
||||
if (request instanceof StandardMultipartHttpServletRequest) {
|
||||
|
||||
@@ -76,5 +76,4 @@ public class DefaultAuthenticationInterceptor extends AuthenticationInterceptor
|
||||
s2ThreadContext.set(threadContext);
|
||||
}
|
||||
|
||||
|
||||
}
|
||||
|
||||
@@ -20,7 +20,6 @@ public class UserRepositoryImpl implements UserRepository {
|
||||
this.userDOMapper = userDOMapper;
|
||||
}
|
||||
|
||||
|
||||
@Override
|
||||
public List<UserDO> getUserList() {
|
||||
return userDOMapper.selectByExample(new UserDOExample());
|
||||
@@ -40,5 +39,4 @@ public class UserRepositoryImpl implements UserRepository {
|
||||
return userDOOptional.orElse(null);
|
||||
}
|
||||
|
||||
|
||||
}
|
||||
|
||||
@@ -29,7 +29,6 @@ public class UserController {
|
||||
this.userService = userService;
|
||||
}
|
||||
|
||||
|
||||
@GetMapping("/getCurrentUser")
|
||||
public User getCurrentUser(HttpServletRequest httpServletRequest, HttpServletResponse httpServletResponse) {
|
||||
return UserHolder.findUser(httpServletRequest, httpServletResponse);
|
||||
@@ -70,5 +69,4 @@ public class UserController {
|
||||
return userService.login(userCmd);
|
||||
}
|
||||
|
||||
|
||||
}
|
||||
|
||||
@@ -13,7 +13,6 @@ import org.springframework.stereotype.Service;
|
||||
@Service
|
||||
public class UserServiceImpl implements UserService {
|
||||
|
||||
|
||||
@Override
|
||||
public List<String> getUserNames() {
|
||||
return ComponentFactory.getUserAdaptor().getUserNames();
|
||||
|
||||
@@ -20,5 +20,4 @@ public class FakeUserStrategy implements UserStrategy {
|
||||
return User.getFakeUser();
|
||||
}
|
||||
|
||||
|
||||
}
|
||||
|
||||
@@ -58,7 +58,6 @@ public class UserTokenUtils {
|
||||
return generate(claims);
|
||||
}
|
||||
|
||||
|
||||
public User getUser(HttpServletRequest request) {
|
||||
String token = request.getHeader(authenticationConfig.getTokenHttpHeaderKey());
|
||||
final Claims claims = getClaims(token);
|
||||
@@ -120,5 +119,4 @@ public class UserTokenUtils {
|
||||
.compact();
|
||||
}
|
||||
|
||||
|
||||
}
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
package com.tencent.supersonic.auth.authorization.application;
|
||||
package com.tencent.supersonic.auth.authorization.service;
|
||||
|
||||
import com.google.common.base.Strings;
|
||||
import com.google.common.collect.Lists;
|
||||
import com.google.gson.Gson;
|
||||
import com.tencent.supersonic.auth.api.authentication.pojo.User;
|
||||
import com.tencent.supersonic.auth.api.authentication.service.UserService;
|
||||
@@ -75,7 +76,6 @@ public class AuthServiceImpl implements AuthService {
|
||||
jdbcTemplate.update("delete from s2_auth_groups where group_id = ?", group.getGroupId());
|
||||
}
|
||||
|
||||
|
||||
@Override
|
||||
public AuthorizedResourceResp queryAuthorizedResources(QueryAuthResReq req, User user) {
|
||||
Set<String> userOrgIds = userService.getUserAllOrgId(user.getName());
|
||||
@@ -109,8 +109,11 @@ public class AuthServiceImpl implements AuthService {
|
||||
}
|
||||
}
|
||||
|
||||
if (req.getModelId() != null) {
|
||||
List<AuthGroup> authGroups = authGroupsByModelId.get(req.getModelId());
|
||||
if (!CollectionUtils.isEmpty(req.getModelIds())) {
|
||||
List<AuthGroup> authGroups = Lists.newArrayList();
|
||||
for (Long modelId : authGroupsByModelId.keySet()) {
|
||||
authGroups.addAll(authGroupsByModelId.getOrDefault(modelId, Lists.newArrayList()));
|
||||
}
|
||||
if (!CollectionUtils.isEmpty(authGroups)) {
|
||||
for (AuthGroup group : authGroups) {
|
||||
if (group.getDimensionFilters() != null
|
||||
@@ -7,7 +7,7 @@ import com.tencent.supersonic.chat.api.pojo.response.SqlInfo;
|
||||
import com.tencent.supersonic.common.pojo.DateConf;
|
||||
import com.tencent.supersonic.common.pojo.ModelCluster;
|
||||
import com.tencent.supersonic.common.pojo.Order;
|
||||
import com.tencent.supersonic.common.pojo.QueryType;
|
||||
import com.tencent.supersonic.common.pojo.enums.QueryType;
|
||||
import com.tencent.supersonic.common.pojo.enums.AggregateTypeEnum;
|
||||
import com.tencent.supersonic.common.pojo.enums.FilterType;
|
||||
import lombok.Data;
|
||||
@@ -42,7 +42,7 @@ public class SemanticParseInfo {
|
||||
private Map<String, Object> properties = new HashMap<>();
|
||||
private EntityInfo entityInfo;
|
||||
private SqlInfo sqlInfo = new SqlInfo();
|
||||
private QueryType queryType = QueryType.OTHER;
|
||||
private QueryType queryType = QueryType.ID;
|
||||
|
||||
public String getModelClusterKey() {
|
||||
if (model == null) {
|
||||
@@ -88,10 +88,11 @@ public class SemanticParseInfo {
|
||||
|
||||
private Map<Long, Integer> getModelElementCountMap() {
|
||||
Map<Long, Integer> elementCountMap = new HashMap<>();
|
||||
elementMatches.forEach(element -> {
|
||||
int count = elementCountMap.getOrDefault(element.getElement().getModel(), 0);
|
||||
elementCountMap.put(element.getElement().getModel(), count + 1);
|
||||
});
|
||||
elementMatches.stream().filter(element -> element.getElement().getModel() != null)
|
||||
.forEach(element -> {
|
||||
int count = elementCountMap.getOrDefault(element.getElement().getModel(), 0);
|
||||
elementCountMap.put(element.getElement().getModel(), count + 1);
|
||||
});
|
||||
return elementCountMap;
|
||||
}
|
||||
|
||||
|
||||
@@ -57,19 +57,19 @@ public class SemanticSchema implements Serializable {
|
||||
|
||||
switch (elementType) {
|
||||
case ENTITY:
|
||||
element = getElementsByName(name, getEntities());
|
||||
element = getElementsByNameOrAlias(name, getEntities());
|
||||
break;
|
||||
case MODEL:
|
||||
element = getElementsByName(name, getModels());
|
||||
element = getElementsByNameOrAlias(name, getModels());
|
||||
break;
|
||||
case METRIC:
|
||||
element = getElementsByName(name, getMetrics());
|
||||
element = getElementsByNameOrAlias(name, getMetrics());
|
||||
break;
|
||||
case DIMENSION:
|
||||
element = getElementsByName(name, getDimensions());
|
||||
element = getElementsByNameOrAlias(name, getDimensions());
|
||||
break;
|
||||
case VALUE:
|
||||
element = getElementsByName(name, getDimensionValues());
|
||||
element = getElementsByNameOrAlias(name, getDimensionValues());
|
||||
break;
|
||||
default:
|
||||
}
|
||||
@@ -151,10 +151,11 @@ public class SemanticSchema implements Serializable {
|
||||
.findFirst();
|
||||
}
|
||||
|
||||
private Optional<SchemaElement> getElementsByName(String name, List<SchemaElement> elements) {
|
||||
private Optional<SchemaElement> getElementsByNameOrAlias(String name, List<SchemaElement> elements) {
|
||||
return elements.stream()
|
||||
.filter(schemaElement -> name.equals(schemaElement.getName()))
|
||||
.findFirst();
|
||||
.filter(schemaElement ->
|
||||
name.equals(schemaElement.getName()) || schemaElement.getAlias().contains(name)
|
||||
).findFirst();
|
||||
}
|
||||
|
||||
public List<SchemaElement> getModels() {
|
||||
|
||||
@@ -10,7 +10,7 @@ import lombok.NoArgsConstructor;
|
||||
@NoArgsConstructor
|
||||
@AllArgsConstructor
|
||||
@Builder
|
||||
public class SolvedQueryReq {
|
||||
public class SimilarQueryReq {
|
||||
|
||||
private Long queryId;
|
||||
|
||||
@@ -12,7 +12,6 @@ public class ParseResp {
|
||||
private Long queryId;
|
||||
private ParseState state;
|
||||
private List<SemanticParseInfo> selectedParses = Lists.newArrayList();
|
||||
private List<SemanticParseInfo> candidateParses = Lists.newArrayList();
|
||||
private ParseTimeCostDO parseTimeCost = new ParseTimeCostDO();
|
||||
|
||||
public enum ParseState {
|
||||
|
||||
@@ -6,6 +6,6 @@ import java.util.List;
|
||||
|
||||
@Data
|
||||
public class QueryRecallResp {
|
||||
private List<SolvedQueryRecallResp> solvedQueryRecallRespList;
|
||||
private List<SimilarQueryRecallResp> solvedQueryRecallRespList;
|
||||
private Long queryTimeCost;
|
||||
}
|
||||
|
||||
@@ -1,9 +1,9 @@
|
||||
package com.tencent.supersonic.chat.api.pojo.response;
|
||||
|
||||
import java.util.Date;
|
||||
import java.util.List;
|
||||
import com.tencent.supersonic.chat.api.pojo.SemanticParseInfo;
|
||||
import lombok.Data;
|
||||
import java.util.Date;
|
||||
import java.util.List;
|
||||
|
||||
@Data
|
||||
public class QueryResp {
|
||||
@@ -16,4 +16,5 @@ public class QueryResp {
|
||||
private String queryText;
|
||||
private QueryResult queryResult;
|
||||
private List<SemanticParseInfo> parseInfos;
|
||||
private List<SimilarQueryRecallResp> similarQueries;
|
||||
}
|
||||
@@ -1,11 +1,12 @@
|
||||
package com.tencent.supersonic.chat.api.pojo.response;
|
||||
|
||||
import com.tencent.supersonic.chat.api.pojo.SchemaElement;
|
||||
import com.tencent.supersonic.chat.api.pojo.SemanticParseInfo;
|
||||
import com.tencent.supersonic.common.pojo.QueryAuthorization;
|
||||
import com.tencent.supersonic.common.pojo.QueryColumn;
|
||||
import lombok.Data;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
import lombok.Data;
|
||||
|
||||
@Data
|
||||
public class QueryResult {
|
||||
@@ -22,4 +23,5 @@ public class QueryResult {
|
||||
private Object response;
|
||||
private List<Map<String, Object>> queryResults;
|
||||
private Long queryTimeCost;
|
||||
private List<SchemaElement> recommendedDimensions;
|
||||
}
|
||||
|
||||
@@ -6,7 +6,7 @@ import lombok.Data;
|
||||
|
||||
@Data
|
||||
@Builder
|
||||
public class SolvedQueryRecallResp {
|
||||
public class SimilarQueryRecallResp {
|
||||
|
||||
private Long queryId;
|
||||
|
||||
@@ -40,11 +40,6 @@
|
||||
<scope>compile</scope>
|
||||
</dependency>
|
||||
|
||||
<dependency>
|
||||
<groupId>org.junit.jupiter</groupId>
|
||||
<artifactId>junit-jupiter</artifactId>
|
||||
<scope>test</scope>
|
||||
</dependency>
|
||||
<dependency>
|
||||
<groupId>org.springframework.boot</groupId>
|
||||
<artifactId>spring-boot-starter-test</artifactId>
|
||||
@@ -89,7 +84,6 @@
|
||||
<groupId>com.tencent.supersonic</groupId>
|
||||
<artifactId>semantic-query</artifactId>
|
||||
<version>${project.version}</version>
|
||||
<scope>test</scope>
|
||||
</dependency>
|
||||
<dependency>
|
||||
<groupId>com.tencent.supersonic</groupId>
|
||||
@@ -97,12 +91,6 @@
|
||||
<version>${project.version}</version>
|
||||
<scope>compile</scope>
|
||||
</dependency>
|
||||
<dependency>
|
||||
<groupId>com.tencent.supersonic</groupId>
|
||||
<artifactId>semantic-query</artifactId>
|
||||
<version>${project.version}</version>
|
||||
<scope>compile</scope>
|
||||
</dependency>
|
||||
|
||||
<dependency>
|
||||
<groupId>com.github.xkzhangsan</groupId>
|
||||
|
||||
@@ -3,7 +3,6 @@ package com.tencent.supersonic.chat.agent;
|
||||
|
||||
import com.alibaba.fastjson.JSONObject;
|
||||
import com.google.common.collect.Lists;
|
||||
import com.tencent.supersonic.chat.agent.tool.AgentToolType;
|
||||
import com.tencent.supersonic.common.pojo.RecordInfo;
|
||||
import java.util.Objects;
|
||||
import lombok.Data;
|
||||
|
||||
@@ -1,7 +1,6 @@
|
||||
package com.tencent.supersonic.chat.agent;
|
||||
|
||||
import com.google.common.collect.Lists;
|
||||
import com.tencent.supersonic.chat.agent.tool.AgentTool;
|
||||
import lombok.AllArgsConstructor;
|
||||
import lombok.Data;
|
||||
import lombok.NoArgsConstructor;
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
package com.tencent.supersonic.chat.agent.tool;
|
||||
package com.tencent.supersonic.chat.agent;
|
||||
|
||||
import lombok.AllArgsConstructor;
|
||||
import lombok.Data;
|
||||
@@ -0,0 +1,8 @@
|
||||
package com.tencent.supersonic.chat.agent;
|
||||
|
||||
public enum AgentToolType {
|
||||
NL2SQL_RULE,
|
||||
NL2SQL_LLM,
|
||||
PLUGIN,
|
||||
ANALYTICS
|
||||
}
|
||||
@@ -0,0 +1,11 @@
|
||||
package com.tencent.supersonic.chat.agent;
|
||||
|
||||
import lombok.Data;
|
||||
|
||||
|
||||
@Data
|
||||
public class DataAnalyticsTool extends AgentTool {
|
||||
|
||||
private Long modelId;
|
||||
|
||||
}
|
||||
@@ -0,0 +1,12 @@
|
||||
package com.tencent.supersonic.chat.agent;
|
||||
|
||||
import lombok.Data;
|
||||
|
||||
import java.util.List;
|
||||
|
||||
@Data
|
||||
public class LLMParserTool extends NL2SQLTool {
|
||||
|
||||
private List<String> exampleQuestions;
|
||||
|
||||
}
|
||||
@@ -1,4 +1,4 @@
|
||||
package com.tencent.supersonic.chat.agent.tool;
|
||||
package com.tencent.supersonic.chat.agent;
|
||||
|
||||
|
||||
import java.util.List;
|
||||
@@ -9,7 +9,7 @@ import lombok.NoArgsConstructor;
|
||||
@Data
|
||||
@NoArgsConstructor
|
||||
@AllArgsConstructor
|
||||
public class CommonAgentTool extends AgentTool {
|
||||
public class NL2SQLTool extends AgentTool {
|
||||
|
||||
protected List<Long> modelIds;
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
package com.tencent.supersonic.chat.agent.tool;
|
||||
package com.tencent.supersonic.chat.agent;
|
||||
|
||||
|
||||
import lombok.Data;
|
||||
@@ -1,4 +1,4 @@
|
||||
package com.tencent.supersonic.chat.agent.tool;
|
||||
package com.tencent.supersonic.chat.agent;
|
||||
|
||||
|
||||
import lombok.Data;
|
||||
@@ -7,7 +7,7 @@ import org.apache.commons.collections.CollectionUtils;
|
||||
import java.util.List;
|
||||
|
||||
@Data
|
||||
public class RuleQueryTool extends CommonAgentTool {
|
||||
public class RuleParserTool extends NL2SQLTool {
|
||||
|
||||
|
||||
private List<String> queryModes;
|
||||
@@ -1,8 +0,0 @@
|
||||
package com.tencent.supersonic.chat.agent.tool;
|
||||
|
||||
public enum AgentToolType {
|
||||
RULE,
|
||||
LLM_S2SQL,
|
||||
PLUGIN,
|
||||
INTERPRET
|
||||
}
|
||||
@@ -1,12 +0,0 @@
|
||||
package com.tencent.supersonic.chat.agent.tool;
|
||||
|
||||
import lombok.Data;
|
||||
|
||||
import java.util.List;
|
||||
|
||||
@Data
|
||||
public class LLMParserTool extends CommonAgentTool {
|
||||
|
||||
private List<String> exampleQuestions;
|
||||
|
||||
}
|
||||
@@ -1,16 +0,0 @@
|
||||
package com.tencent.supersonic.chat.agent.tool;
|
||||
|
||||
import com.tencent.supersonic.chat.parser.llm.interpret.MetricOption;
|
||||
import lombok.Data;
|
||||
|
||||
import java.util.List;
|
||||
|
||||
|
||||
@Data
|
||||
public class MetricInterpretTool extends AgentTool {
|
||||
|
||||
private Long modelId;
|
||||
|
||||
private List<MetricOption> metricOptions;
|
||||
|
||||
}
|
||||
@@ -1,5 +1,6 @@
|
||||
package com.tencent.supersonic.chat.config;
|
||||
|
||||
import com.tencent.supersonic.chat.query.llm.s2sql.LLMReq.SqlGenerationMode;
|
||||
import com.tencent.supersonic.common.service.SysParameterService;
|
||||
import lombok.Data;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
@@ -57,16 +58,19 @@ public class OptimizationConfig {
|
||||
@Value("${s2SQL.linking.value.switch:true}")
|
||||
private boolean useLinkingValueSwitch;
|
||||
|
||||
@Value("${s2SQL.generation:TWO_PASS_AUTO_COT}")
|
||||
private SqlGenerationMode sqlGenerationMode;
|
||||
|
||||
@Value("${s2SQL.use.switch:true}")
|
||||
private boolean useS2SqlSwitch;
|
||||
|
||||
@Value("${text2sql.example.num:10}")
|
||||
private int text2sqlExampleNum;
|
||||
|
||||
@Value("${text2sql.fewShots.num:10}")
|
||||
@Value("${text2sql.fewShots.num:5}")
|
||||
private int text2sqlFewShotsNum;
|
||||
|
||||
@Value("${text2sql.self.consistency.num:5}")
|
||||
@Value("${text2sql.self.consistency.num:2}")
|
||||
private int text2sqlSelfConsistencyNum;
|
||||
|
||||
@Value("${text2sql.collection.name:text2dsl_agent_collection}")
|
||||
@@ -139,6 +143,10 @@ public class OptimizationConfig {
|
||||
return convertValue("s2SQL.linking.value.switch", Boolean.class, useLinkingValueSwitch);
|
||||
}
|
||||
|
||||
public SqlGenerationMode getSqlGenerationMode() {
|
||||
return convertValue("s2SQL.generation", SqlGenerationMode.class, sqlGenerationMode);
|
||||
}
|
||||
|
||||
public <T> T convertValue(String paramName, Class<T> targetType, T defaultValue) {
|
||||
try {
|
||||
String value = sysParameterService.getSysParameter().getParameterByName(paramName);
|
||||
@@ -151,6 +159,8 @@ public class OptimizationConfig {
|
||||
return targetType.cast(Integer.parseInt(value));
|
||||
} else if (targetType == Boolean.class) {
|
||||
return targetType.cast(Boolean.parseBoolean(value));
|
||||
} else if (targetType == SqlGenerationMode.class) {
|
||||
return targetType.cast(SqlGenerationMode.valueOf(value));
|
||||
}
|
||||
} catch (Exception e) {
|
||||
log.error("convertValue", e);
|
||||
|
||||
@@ -14,6 +14,7 @@ import com.tencent.supersonic.common.util.jsqlparser.SqlParserSelectHelper;
|
||||
import com.tencent.supersonic.knowledge.service.SchemaService;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.apache.commons.lang3.StringUtils;
|
||||
import org.apache.commons.lang3.tuple.Pair;
|
||||
import org.springframework.util.CollectionUtils;
|
||||
|
||||
import java.util.ArrayList;
|
||||
@@ -43,7 +44,6 @@ public abstract class BaseSemanticCorrector implements SemanticCorrector {
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
public abstract void doCorrect(QueryReq queryReq, SemanticParseInfo semanticParseInfo);
|
||||
|
||||
protected Map<String, String> getFieldNameMap(Set<Long> modelIds) {
|
||||
@@ -114,7 +114,15 @@ public abstract class BaseSemanticCorrector implements SemanticCorrector {
|
||||
schemaElement.setDefaultAgg(AggregateTypeEnum.SUM.name());
|
||||
}
|
||||
return schemaElement;
|
||||
}).collect(Collectors.toMap(a -> a.getName(), a -> a.getDefaultAgg(), (k1, k2) -> k1));
|
||||
}).flatMap(schemaElement -> {
|
||||
Set<String> elements = new HashSet<>();
|
||||
elements.add(schemaElement.getName());
|
||||
if (!CollectionUtils.isEmpty(schemaElement.getAlias())) {
|
||||
elements.addAll(schemaElement.getAlias());
|
||||
}
|
||||
return elements.stream().map(element -> Pair.of(element, schemaElement.getDefaultAgg())
|
||||
);
|
||||
}).collect(Collectors.toMap(Pair::getLeft, Pair::getRight, (k1, k2) -> k1));
|
||||
|
||||
if (CollectionUtils.isEmpty(metricToAggregate)) {
|
||||
return;
|
||||
|
||||
@@ -4,7 +4,9 @@ import com.tencent.supersonic.chat.api.pojo.SemanticParseInfo;
|
||||
import com.tencent.supersonic.chat.api.pojo.request.QueryReq;
|
||||
import com.tencent.supersonic.common.util.jsqlparser.SqlParserReplaceHelper;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
|
||||
/**
|
||||
* Perform SQL corrections on the "From" section in S2SQL.
|
||||
*/
|
||||
@Slf4j
|
||||
public class FromCorrector extends BaseSemanticCorrector {
|
||||
|
||||
|
||||
@@ -18,7 +18,7 @@ import java.util.Set;
|
||||
import java.util.stream.Collectors;
|
||||
|
||||
/**
|
||||
* Perform SQL corrections on the "group by" section in S2SQL.
|
||||
* Perform SQL corrections on the "Group by" section in S2SQL.
|
||||
*/
|
||||
@Slf4j
|
||||
public class GroupByCorrector extends BaseSemanticCorrector {
|
||||
|
||||
@@ -62,5 +62,4 @@ public class HavingCorrector extends BaseSemanticCorrector {
|
||||
return;
|
||||
}
|
||||
|
||||
|
||||
}
|
||||
|
||||
@@ -3,7 +3,7 @@ package com.tencent.supersonic.chat.corrector;
|
||||
import com.tencent.supersonic.chat.api.pojo.SemanticParseInfo;
|
||||
import com.tencent.supersonic.chat.api.pojo.request.QueryReq;
|
||||
import com.tencent.supersonic.chat.api.pojo.response.SqlInfo;
|
||||
import com.tencent.supersonic.chat.parser.llm.s2sql.ParseResult;
|
||||
import com.tencent.supersonic.chat.parser.sql.llm.ParseResult;
|
||||
import com.tencent.supersonic.chat.query.llm.s2sql.LLMReq.ElementValue;
|
||||
import com.tencent.supersonic.common.pojo.Constants;
|
||||
import com.tencent.supersonic.common.util.JsonUtil;
|
||||
@@ -86,7 +86,6 @@ public class SchemaCorrector extends BaseSemanticCorrector {
|
||||
return parseResult.getLinkingValues();
|
||||
}
|
||||
|
||||
|
||||
private void updateFieldValueByLinkingValue(SemanticParseInfo semanticParseInfo) {
|
||||
List<ElementValue> linking = getLinkingValues(semanticParseInfo);
|
||||
if (CollectionUtils.isEmpty(linking)) {
|
||||
|
||||
@@ -6,7 +6,7 @@ import com.tencent.supersonic.chat.api.pojo.SemanticParseInfo;
|
||||
import com.tencent.supersonic.chat.api.pojo.SemanticSchema;
|
||||
import com.tencent.supersonic.chat.api.pojo.request.QueryFilters;
|
||||
import com.tencent.supersonic.chat.api.pojo.request.QueryReq;
|
||||
import com.tencent.supersonic.chat.parser.llm.s2sql.S2SQLDateHelper;
|
||||
import com.tencent.supersonic.chat.parser.sql.llm.S2SqlDateHelper;
|
||||
import com.tencent.supersonic.common.pojo.Constants;
|
||||
import com.tencent.supersonic.common.pojo.enums.TimeDimensionEnum;
|
||||
import com.tencent.supersonic.common.util.ContextUtils;
|
||||
@@ -76,7 +76,7 @@ public class WhereCorrector extends BaseSemanticCorrector {
|
||||
String correctS2SQL = semanticParseInfo.getSqlInfo().getCorrectS2SQL();
|
||||
List<String> whereFields = SqlParserSelectHelper.getWhereFields(correctS2SQL);
|
||||
if (CollectionUtils.isEmpty(whereFields) || !TimeDimensionEnum.containsZhTimeDimension(whereFields)) {
|
||||
String currentDate = S2SQLDateHelper.getReferenceDate(semanticParseInfo.getModelId());
|
||||
String currentDate = S2SqlDateHelper.getReferenceDate(semanticParseInfo.getModelId());
|
||||
if (StringUtils.isNotBlank(currentDate)) {
|
||||
correctS2SQL = SqlParserAddHelper.addParenthesisToWhere(correctS2SQL);
|
||||
correctS2SQL = SqlParserAddHelper.addWhere(
|
||||
|
||||
@@ -19,9 +19,6 @@ import lombok.extern.slf4j.Slf4j;
|
||||
import org.apache.commons.lang3.StringUtils;
|
||||
import org.springframework.beans.BeanUtils;
|
||||
|
||||
/**
|
||||
* base Mapper
|
||||
*/
|
||||
@Slf4j
|
||||
public abstract class BaseMapper implements SchemaMapper {
|
||||
|
||||
@@ -44,7 +41,6 @@ public abstract class BaseMapper implements SchemaMapper {
|
||||
|
||||
public abstract void doMap(QueryContext queryContext);
|
||||
|
||||
|
||||
public void addToSchemaMap(SchemaMapInfo schemaMap, Long modelId, SchemaElementMatch newElementMatch) {
|
||||
Map<Long, List<SchemaElementMatch>> modelElementMatches = schemaMap.getModelElementMatches();
|
||||
List<SchemaElementMatch> schemaElementMatches = modelElementMatches.putIfAbsent(modelId, new ArrayList<>());
|
||||
|
||||
@@ -19,9 +19,6 @@ import org.apache.commons.lang3.StringUtils;
|
||||
import org.springframework.beans.factory.annotation.Autowired;
|
||||
import org.springframework.stereotype.Service;
|
||||
|
||||
/**
|
||||
* Base Match Strategy
|
||||
*/
|
||||
@Service
|
||||
@Slf4j
|
||||
public abstract class BaseMatchStrategy<T> implements MatchStrategy<T> {
|
||||
@@ -154,5 +151,4 @@ public abstract class BaseMatchStrategy<T> implements MatchStrategy<T> {
|
||||
public abstract void detectByStep(QueryContext queryContext, Set<T> results,
|
||||
Set<Long> detectModelIds, Integer startIndex, Integer index, int offset);
|
||||
|
||||
|
||||
}
|
||||
|
||||
@@ -6,7 +6,7 @@ import com.tencent.supersonic.chat.api.pojo.SchemaElement;
|
||||
import com.tencent.supersonic.chat.api.pojo.SchemaElementMatch;
|
||||
import com.tencent.supersonic.chat.config.OptimizationConfig;
|
||||
import com.tencent.supersonic.common.pojo.Constants;
|
||||
import com.tencent.supersonic.knowledge.dictionary.FuzzyResult;
|
||||
import com.tencent.supersonic.knowledge.dictionary.DatabaseMapResult;
|
||||
import com.tencent.supersonic.knowledge.service.SchemaService;
|
||||
import java.util.ArrayList;
|
||||
import java.util.HashSet;
|
||||
@@ -22,11 +22,12 @@ import org.springframework.stereotype.Service;
|
||||
import org.springframework.util.CollectionUtils;
|
||||
|
||||
/**
|
||||
* Fuzzy Name Match Strategy
|
||||
* DatabaseMatchStrategy uses SQL LIKE operator to match schema elements.
|
||||
* It currently supports fuzzy matching against names and aliases.
|
||||
*/
|
||||
@Service
|
||||
@Slf4j
|
||||
public class FuzzyNameMatchStrategy extends BaseMatchStrategy<FuzzyResult> {
|
||||
public class DatabaseMatchStrategy extends BaseMatchStrategy<DatabaseMapResult> {
|
||||
|
||||
@Autowired
|
||||
private OptimizationConfig optimizationConfig;
|
||||
@@ -36,27 +37,26 @@ public class FuzzyNameMatchStrategy extends BaseMatchStrategy<FuzzyResult> {
|
||||
private SchemaService schemaService;
|
||||
private List<SchemaElement> allElements;
|
||||
|
||||
|
||||
@Override
|
||||
public Map<MatchText, List<FuzzyResult>> match(QueryContext queryContext, List<Term> terms,
|
||||
public Map<MatchText, List<DatabaseMapResult>> match(QueryContext queryContext, List<Term> terms,
|
||||
Set<Long> detectModelIds) {
|
||||
this.allElements = getSchemaElements();
|
||||
return super.match(queryContext, terms, detectModelIds);
|
||||
}
|
||||
|
||||
@Override
|
||||
public boolean needDelete(FuzzyResult oneRoundResult, FuzzyResult existResult) {
|
||||
public boolean needDelete(DatabaseMapResult oneRoundResult, DatabaseMapResult existResult) {
|
||||
return getMapKey(oneRoundResult).equals(getMapKey(existResult))
|
||||
&& existResult.getDetectWord().length() < oneRoundResult.getDetectWord().length();
|
||||
}
|
||||
|
||||
@Override
|
||||
public String getMapKey(FuzzyResult a) {
|
||||
public String getMapKey(DatabaseMapResult a) {
|
||||
return a.getName() + Constants.UNDERLINE + a.getSchemaElement().getId()
|
||||
+ Constants.UNDERLINE + a.getSchemaElement().getName();
|
||||
}
|
||||
|
||||
public void detectByStep(QueryContext queryContext, Set<FuzzyResult> existResults, Set<Long> detectModelIds,
|
||||
public void detectByStep(QueryContext queryContext, Set<DatabaseMapResult> existResults, Set<Long> detectModelIds,
|
||||
Integer startIndex, Integer index, int offset) {
|
||||
String detectSegment = queryContext.getRequest().getQueryText().substring(startIndex, index);
|
||||
if (StringUtils.isBlank(detectSegment)) {
|
||||
@@ -81,11 +81,11 @@ public class FuzzyNameMatchStrategy extends BaseMatchStrategy<FuzzyResult> {
|
||||
.collect(Collectors.toSet());
|
||||
}
|
||||
for (SchemaElement schemaElement : schemaElements) {
|
||||
FuzzyResult fuzzyResult = new FuzzyResult();
|
||||
fuzzyResult.setDetectWord(detectSegment);
|
||||
fuzzyResult.setName(schemaElement.getName());
|
||||
fuzzyResult.setSchemaElement(schemaElement);
|
||||
existResults.add(fuzzyResult);
|
||||
DatabaseMapResult databaseMapResult = new DatabaseMapResult();
|
||||
databaseMapResult.setDetectWord(detectSegment);
|
||||
databaseMapResult.setName(schemaElement.getName());
|
||||
databaseMapResult.setSchemaElement(schemaElement);
|
||||
existResults.add(databaseMapResult);
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -97,7 +97,6 @@ public class FuzzyNameMatchStrategy extends BaseMatchStrategy<FuzzyResult> {
|
||||
return allElements;
|
||||
}
|
||||
|
||||
|
||||
private Double getThreshold(QueryContext queryContext) {
|
||||
Double metricDimensionThresholdConfig = optimizationConfig.getMetricDimensionThresholdConfig();
|
||||
Double metricDimensionMinThresholdConfig = optimizationConfig.getMetricDimensionMinThresholdConfig();
|
||||
@@ -15,7 +15,7 @@ import lombok.extern.slf4j.Slf4j;
|
||||
import org.apache.commons.lang3.StringUtils;
|
||||
|
||||
/***
|
||||
* A mapper that is capable of semantic understanding of text.
|
||||
* A mapper that recognizes schema elements with vector embedding.
|
||||
*/
|
||||
@Slf4j
|
||||
public class EmbeddingMapper extends BaseMapper {
|
||||
@@ -23,7 +23,6 @@ public class EmbeddingMapper extends BaseMapper {
|
||||
@Override
|
||||
public void doMap(QueryContext queryContext) {
|
||||
//1. query from embedding by queryText
|
||||
|
||||
String queryText = queryContext.getRequest().getQueryText();
|
||||
List<Term> terms = HanlpHelper.getTerms(queryText);
|
||||
|
||||
@@ -39,11 +38,11 @@ public class EmbeddingMapper extends BaseMapper {
|
||||
SchemaElement schemaElement = JSONObject.parseObject(JSONObject.toJSONString(matchResult.getMetadata()),
|
||||
SchemaElement.class);
|
||||
|
||||
if (StringUtils.isBlank(matchResult.getMetadata().get("modelId"))) {
|
||||
String modelIdStr = matchResult.getMetadata().get("modelId");
|
||||
if (StringUtils.isBlank(modelIdStr)) {
|
||||
continue;
|
||||
}
|
||||
long modelId = Long.parseLong(matchResult.getMetadata().get("modelId"));
|
||||
|
||||
long modelId = Long.parseLong(modelIdStr);
|
||||
schemaElement = getSchemaElement(modelId, schemaElement.getType(), elementId);
|
||||
if (schemaElement == null) {
|
||||
continue;
|
||||
|
||||
@@ -4,10 +4,11 @@ import com.google.common.collect.Lists;
|
||||
import com.tencent.supersonic.chat.api.pojo.QueryContext;
|
||||
import com.tencent.supersonic.chat.config.OptimizationConfig;
|
||||
import com.tencent.supersonic.common.pojo.Constants;
|
||||
import com.tencent.supersonic.common.util.embedding.EmbeddingUtils;
|
||||
import com.tencent.supersonic.common.util.ComponentFactory;
|
||||
import com.tencent.supersonic.common.util.embedding.Retrieval;
|
||||
import com.tencent.supersonic.common.util.embedding.RetrieveQuery;
|
||||
import com.tencent.supersonic.common.util.embedding.RetrieveQueryResult;
|
||||
import com.tencent.supersonic.common.util.embedding.S2EmbeddingStore;
|
||||
import com.tencent.supersonic.knowledge.dictionary.EmbeddingResult;
|
||||
import com.tencent.supersonic.semantic.model.domain.listener.MetaEmbeddingListener;
|
||||
import java.util.Comparator;
|
||||
@@ -24,7 +25,8 @@ import org.springframework.beans.factory.annotation.Autowired;
|
||||
import org.springframework.stereotype.Service;
|
||||
|
||||
/**
|
||||
* match strategy implement
|
||||
* EmbeddingMatchStrategy uses vector database to perform
|
||||
* similarity search against the embeddings of schema elements.
|
||||
*/
|
||||
@Service
|
||||
@Slf4j
|
||||
@@ -32,8 +34,8 @@ public class EmbeddingMatchStrategy extends BaseMatchStrategy<EmbeddingResult> {
|
||||
|
||||
@Autowired
|
||||
private OptimizationConfig optimizationConfig;
|
||||
@Autowired
|
||||
private EmbeddingUtils embeddingUtils;
|
||||
|
||||
private S2EmbeddingStore s2EmbeddingStore = ComponentFactory.getS2EmbeddingStore();
|
||||
|
||||
@Override
|
||||
public boolean needDelete(EmbeddingResult oneRoundResult, EmbeddingResult existResult) {
|
||||
@@ -46,7 +48,6 @@ public class EmbeddingMatchStrategy extends BaseMatchStrategy<EmbeddingResult> {
|
||||
return a.getName() + Constants.UNDERLINE + a.getId();
|
||||
}
|
||||
|
||||
|
||||
@Override
|
||||
protected void detectByBatch(QueryContext queryContext, Set<EmbeddingResult> results, Set<Long> detectModelIds,
|
||||
Set<String> detectSegments) {
|
||||
@@ -84,7 +85,7 @@ public class EmbeddingMatchStrategy extends BaseMatchStrategy<EmbeddingResult> {
|
||||
.queryEmbeddings(null)
|
||||
.build();
|
||||
// step2. retrieveQuery by detectSegment
|
||||
List<RetrieveQueryResult> retrieveQueryResults = embeddingUtils.retrieveQuery(
|
||||
List<RetrieveQueryResult> retrieveQueryResults = s2EmbeddingStore.retrieveQuery(
|
||||
MetaEmbeddingListener.COLLECTION_NAME, retrieveQuery, embeddingNumber);
|
||||
|
||||
if (CollectionUtils.isEmpty(retrieveQueryResults)) {
|
||||
@@ -98,7 +99,7 @@ public class EmbeddingMatchStrategy extends BaseMatchStrategy<EmbeddingResult> {
|
||||
retrievals.removeIf(retrieval -> retrieval.getDistance() > distance.doubleValue());
|
||||
if (CollectionUtils.isNotEmpty(detectModelIds)) {
|
||||
retrievals.removeIf(retrieval -> {
|
||||
String modelIdStr = retrieval.getMetadata().get("modelId");
|
||||
String modelIdStr = retrieval.getMetadata().get("modelId").toString();
|
||||
if (StringUtils.isBlank(modelIdStr)) {
|
||||
return true;
|
||||
}
|
||||
|
||||
@@ -1,67 +0,0 @@
|
||||
package com.tencent.supersonic.chat.mapper;
|
||||
|
||||
import com.hankcs.hanlp.seg.common.Term;
|
||||
import com.tencent.supersonic.chat.api.pojo.QueryContext;
|
||||
import com.tencent.supersonic.chat.api.pojo.SchemaElement;
|
||||
import com.tencent.supersonic.chat.api.pojo.SchemaElementMatch;
|
||||
import com.tencent.supersonic.chat.api.pojo.SchemaElementType;
|
||||
import com.tencent.supersonic.chat.api.pojo.SchemaMapInfo;
|
||||
import com.tencent.supersonic.common.util.ContextUtils;
|
||||
import com.tencent.supersonic.knowledge.dictionary.FuzzyResult;
|
||||
import com.tencent.supersonic.knowledge.utils.HanlpHelper;
|
||||
import java.util.HashSet;
|
||||
import java.util.List;
|
||||
import java.util.Set;
|
||||
import java.util.stream.Collectors;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.springframework.util.CollectionUtils;
|
||||
|
||||
/***
|
||||
* A mapper capable of fuzzy parsing of metric names and dimension names.
|
||||
*/
|
||||
@Slf4j
|
||||
public class FuzzyNameMapper extends BaseMapper {
|
||||
|
||||
@Override
|
||||
public void doMap(QueryContext queryContext) {
|
||||
|
||||
List<Term> terms = HanlpHelper.getTerms(queryContext.getRequest().getQueryText());
|
||||
|
||||
FuzzyNameMatchStrategy fuzzyNameMatchStrategy = ContextUtils.getBean(FuzzyNameMatchStrategy.class);
|
||||
|
||||
MapperHelper mapperHelper = ContextUtils.getBean(MapperHelper.class);
|
||||
|
||||
List<FuzzyResult> matches = fuzzyNameMatchStrategy.getMatches(queryContext, terms);
|
||||
|
||||
for (FuzzyResult match : matches) {
|
||||
SchemaElement schemaElement = match.getSchemaElement();
|
||||
Set<Long> regElementSet = getRegElementSet(queryContext.getMapInfo(), schemaElement);
|
||||
if (regElementSet.contains(schemaElement.getId())) {
|
||||
continue;
|
||||
}
|
||||
SchemaElementMatch schemaElementMatch = SchemaElementMatch.builder()
|
||||
.element(schemaElement)
|
||||
.word(schemaElement.getName())
|
||||
.detectWord(match.getDetectWord())
|
||||
.frequency(10000L)
|
||||
.similarity(mapperHelper.getSimilarity(match.getDetectWord(), schemaElement.getName()))
|
||||
.build();
|
||||
log.info("add to schema, elementMatch {}", schemaElementMatch);
|
||||
addToSchemaMap(queryContext.getMapInfo(), schemaElement.getModel(), schemaElementMatch);
|
||||
}
|
||||
}
|
||||
|
||||
private Set<Long> getRegElementSet(SchemaMapInfo schemaMap, SchemaElement schemaElement) {
|
||||
List<SchemaElementMatch> elements = schemaMap.getMatchedElements(schemaElement.getModel());
|
||||
if (CollectionUtils.isEmpty(elements)) {
|
||||
return new HashSet<>();
|
||||
}
|
||||
return elements.stream()
|
||||
.filter(elementMatch ->
|
||||
SchemaElementType.METRIC.equals(elementMatch.getElement().getType())
|
||||
|| SchemaElementType.DIMENSION.equals(elementMatch.getElement().getType()))
|
||||
.map(elementMatch -> elementMatch.getElement().getId())
|
||||
.collect(Collectors.toSet());
|
||||
}
|
||||
|
||||
}
|
||||
@@ -1,86 +0,0 @@
|
||||
package com.tencent.supersonic.chat.mapper;
|
||||
|
||||
import com.hankcs.hanlp.seg.common.Term;
|
||||
import com.tencent.supersonic.chat.api.pojo.QueryContext;
|
||||
import com.tencent.supersonic.chat.api.pojo.SchemaElement;
|
||||
import com.tencent.supersonic.chat.api.pojo.SchemaElementMatch;
|
||||
import com.tencent.supersonic.chat.api.pojo.SchemaElementType;
|
||||
import com.tencent.supersonic.chat.api.pojo.SchemaMapInfo;
|
||||
import com.tencent.supersonic.common.util.ContextUtils;
|
||||
import com.tencent.supersonic.knowledge.dictionary.HanlpMapResult;
|
||||
import com.tencent.supersonic.knowledge.utils.HanlpHelper;
|
||||
import com.tencent.supersonic.knowledge.utils.NatureHelper;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
import java.util.Objects;
|
||||
import java.util.stream.Collectors;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.apache.commons.collections.CollectionUtils;
|
||||
|
||||
/***
|
||||
* A mapper capable of prefix and suffix similarity parsing for
|
||||
* domain names, dimension values, metric names, and dimension names.
|
||||
*/
|
||||
@Slf4j
|
||||
public class HanlpDictMapper extends BaseMapper {
|
||||
|
||||
@Override
|
||||
public void doMap(QueryContext queryContext) {
|
||||
|
||||
String queryText = queryContext.getRequest().getQueryText();
|
||||
List<Term> terms = HanlpHelper.getTerms(queryText);
|
||||
|
||||
HanlpDictMatchStrategy matchStrategy = ContextUtils.getBean(HanlpDictMatchStrategy.class);
|
||||
|
||||
List<HanlpMapResult> matches = matchStrategy.getMatches(queryContext, terms);
|
||||
|
||||
HanlpHelper.transLetterOriginal(matches);
|
||||
|
||||
convertTermsToSchemaMapInfo(matches, queryContext.getMapInfo(), terms);
|
||||
}
|
||||
|
||||
|
||||
private void convertTermsToSchemaMapInfo(List<HanlpMapResult> hanlpMapResults, SchemaMapInfo schemaMap,
|
||||
List<Term> terms) {
|
||||
if (CollectionUtils.isEmpty(hanlpMapResults)) {
|
||||
return;
|
||||
}
|
||||
|
||||
Map<String, Long> wordNatureToFrequency = terms.stream().collect(
|
||||
Collectors.toMap(entry -> entry.getWord() + entry.getNature(),
|
||||
term -> Long.valueOf(term.getFrequency()), (value1, value2) -> value2));
|
||||
|
||||
for (HanlpMapResult hanlpMapResult : hanlpMapResults) {
|
||||
for (String nature : hanlpMapResult.getNatures()) {
|
||||
Long modelId = NatureHelper.getModelId(nature);
|
||||
if (Objects.isNull(modelId)) {
|
||||
continue;
|
||||
}
|
||||
SchemaElementType elementType = NatureHelper.convertToElementType(nature);
|
||||
if (Objects.isNull(elementType)) {
|
||||
continue;
|
||||
}
|
||||
Long elementID = NatureHelper.getElementID(nature);
|
||||
SchemaElement element = getSchemaElement(modelId, elementType, elementID);
|
||||
if (element == null) {
|
||||
continue;
|
||||
}
|
||||
if (element.getType().equals(SchemaElementType.VALUE)) {
|
||||
element.setName(hanlpMapResult.getName());
|
||||
}
|
||||
Long frequency = wordNatureToFrequency.get(hanlpMapResult.getName() + nature);
|
||||
SchemaElementMatch schemaElementMatch = SchemaElementMatch.builder()
|
||||
.element(element)
|
||||
.frequency(frequency)
|
||||
.word(hanlpMapResult.getName())
|
||||
.similarity(hanlpMapResult.getSimilarity())
|
||||
.detectWord(hanlpMapResult.getDetectWord())
|
||||
.build();
|
||||
|
||||
addToSchemaMap(schemaMap, modelId, schemaElementMatch);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
}
|
||||
@@ -21,7 +21,9 @@ import org.springframework.beans.factory.annotation.Autowired;
|
||||
import org.springframework.stereotype.Service;
|
||||
|
||||
/**
|
||||
* match strategy implement
|
||||
* HanlpDictMatchStrategy uses <a href="https://www.hanlp.com/">HanLP</a> to
|
||||
* match schema elements. It currently supports prefix and suffix matching
|
||||
* against names, values and aliases.
|
||||
*/
|
||||
@Service
|
||||
@Slf4j
|
||||
|
||||
@@ -0,0 +1,121 @@
|
||||
package com.tencent.supersonic.chat.mapper;
|
||||
|
||||
import com.hankcs.hanlp.seg.common.Term;
|
||||
import com.tencent.supersonic.chat.api.pojo.QueryContext;
|
||||
import com.tencent.supersonic.chat.api.pojo.SchemaElement;
|
||||
import com.tencent.supersonic.chat.api.pojo.SchemaElementMatch;
|
||||
import com.tencent.supersonic.chat.api.pojo.SchemaElementType;
|
||||
import com.tencent.supersonic.chat.api.pojo.SchemaMapInfo;
|
||||
import com.tencent.supersonic.common.util.ContextUtils;
|
||||
import com.tencent.supersonic.knowledge.dictionary.DatabaseMapResult;
|
||||
import com.tencent.supersonic.knowledge.dictionary.HanlpMapResult;
|
||||
import com.tencent.supersonic.knowledge.utils.HanlpHelper;
|
||||
import com.tencent.supersonic.knowledge.utils.NatureHelper;
|
||||
import java.util.HashSet;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
import java.util.Objects;
|
||||
import java.util.Set;
|
||||
import java.util.stream.Collectors;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.springframework.util.CollectionUtils;
|
||||
|
||||
/***
|
||||
* A mapper that recognizes schema elements with keyword.
|
||||
* It leverages two matching strategies: HanlpDictMatchStrategy and DatabaseMatchStrategy.
|
||||
*/
|
||||
@Slf4j
|
||||
public class KeywordMapper extends BaseMapper {
|
||||
|
||||
@Override
|
||||
public void doMap(QueryContext queryContext) {
|
||||
String queryText = queryContext.getRequest().getQueryText();
|
||||
//1.hanlpDict Match
|
||||
List<Term> terms = HanlpHelper.getTerms(queryText);
|
||||
HanlpDictMatchStrategy hanlpMatchStrategy = ContextUtils.getBean(HanlpDictMatchStrategy.class);
|
||||
|
||||
List<HanlpMapResult> hanlpMapResults = hanlpMatchStrategy.getMatches(queryContext, terms);
|
||||
convertHanlpMapResultToMapInfo(hanlpMapResults, queryContext.getMapInfo(), terms);
|
||||
|
||||
//2.database Match
|
||||
DatabaseMatchStrategy databaseMatchStrategy = ContextUtils.getBean(DatabaseMatchStrategy.class);
|
||||
|
||||
List<DatabaseMapResult> databaseResults = databaseMatchStrategy.getMatches(queryContext, terms);
|
||||
convertDatabaseMapResultToMapInfo(queryContext, databaseResults);
|
||||
}
|
||||
|
||||
private void convertHanlpMapResultToMapInfo(List<HanlpMapResult> mapResults, SchemaMapInfo schemaMap,
|
||||
List<Term> terms) {
|
||||
if (CollectionUtils.isEmpty(mapResults)) {
|
||||
return;
|
||||
}
|
||||
HanlpHelper.transLetterOriginal(mapResults);
|
||||
Map<String, Long> wordNatureToFrequency = terms.stream().collect(
|
||||
Collectors.toMap(entry -> entry.getWord() + entry.getNature(),
|
||||
term -> Long.valueOf(term.getFrequency()), (value1, value2) -> value2));
|
||||
|
||||
for (HanlpMapResult hanlpMapResult : mapResults) {
|
||||
for (String nature : hanlpMapResult.getNatures()) {
|
||||
Long modelId = NatureHelper.getModelId(nature);
|
||||
if (Objects.isNull(modelId)) {
|
||||
continue;
|
||||
}
|
||||
SchemaElementType elementType = NatureHelper.convertToElementType(nature);
|
||||
if (Objects.isNull(elementType)) {
|
||||
continue;
|
||||
}
|
||||
Long elementID = NatureHelper.getElementID(nature);
|
||||
SchemaElement element = getSchemaElement(modelId, elementType, elementID);
|
||||
if (element == null) {
|
||||
continue;
|
||||
}
|
||||
if (element.getType().equals(SchemaElementType.VALUE)) {
|
||||
element.setName(hanlpMapResult.getName());
|
||||
}
|
||||
Long frequency = wordNatureToFrequency.get(hanlpMapResult.getName() + nature);
|
||||
SchemaElementMatch schemaElementMatch = SchemaElementMatch.builder()
|
||||
.element(element)
|
||||
.frequency(frequency)
|
||||
.word(hanlpMapResult.getName())
|
||||
.similarity(hanlpMapResult.getSimilarity())
|
||||
.detectWord(hanlpMapResult.getDetectWord())
|
||||
.build();
|
||||
|
||||
addToSchemaMap(schemaMap, modelId, schemaElementMatch);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
private void convertDatabaseMapResultToMapInfo(QueryContext queryContext, List<DatabaseMapResult> mapResults) {
|
||||
MapperHelper mapperHelper = ContextUtils.getBean(MapperHelper.class);
|
||||
for (DatabaseMapResult match : mapResults) {
|
||||
SchemaElement schemaElement = match.getSchemaElement();
|
||||
Set<Long> regElementSet = getRegElementSet(queryContext.getMapInfo(), schemaElement);
|
||||
if (regElementSet.contains(schemaElement.getId())) {
|
||||
continue;
|
||||
}
|
||||
SchemaElementMatch schemaElementMatch = SchemaElementMatch.builder()
|
||||
.element(schemaElement)
|
||||
.word(schemaElement.getName())
|
||||
.detectWord(match.getDetectWord())
|
||||
.frequency(10000L)
|
||||
.similarity(mapperHelper.getSimilarity(match.getDetectWord(), schemaElement.getName()))
|
||||
.build();
|
||||
log.info("add to schema, elementMatch {}", schemaElementMatch);
|
||||
addToSchemaMap(queryContext.getMapInfo(), schemaElement.getModel(), schemaElementMatch);
|
||||
}
|
||||
}
|
||||
|
||||
private Set<Long> getRegElementSet(SchemaMapInfo schemaMap, SchemaElement schemaElement) {
|
||||
List<SchemaElementMatch> elements = schemaMap.getMatchedElements(schemaElement.getModel());
|
||||
if (CollectionUtils.isEmpty(elements)) {
|
||||
return new HashSet<>();
|
||||
}
|
||||
return elements.stream()
|
||||
.filter(elementMatch ->
|
||||
SchemaElementType.METRIC.equals(elementMatch.getElement().getType())
|
||||
|| SchemaElementType.DIMENSION.equals(elementMatch.getElement().getType()))
|
||||
.map(elementMatch -> elementMatch.getElement().getId())
|
||||
.collect(Collectors.toSet());
|
||||
}
|
||||
}
|
||||
@@ -19,10 +19,6 @@ import lombok.extern.slf4j.Slf4j;
|
||||
import org.springframework.beans.factory.annotation.Autowired;
|
||||
import org.springframework.stereotype.Service;
|
||||
|
||||
/**
|
||||
* Mapper helper
|
||||
*/
|
||||
|
||||
@Data
|
||||
@Service
|
||||
@Slf4j
|
||||
@@ -41,7 +37,6 @@ public class MapperHelper {
|
||||
return index;
|
||||
}
|
||||
|
||||
|
||||
public Integer getStepOffset(List<Term> termList, Integer index) {
|
||||
List<Integer> offsetList = termList.stream().sorted(Comparator.comparing(Term::getOffset))
|
||||
.map(term -> term.getOffset()).collect(Collectors.toList());
|
||||
|
||||
@@ -7,7 +7,8 @@ import java.util.Map;
|
||||
import java.util.Set;
|
||||
|
||||
/**
|
||||
* match strategy
|
||||
* MatchStrategy encapsulates a concrete matching algorithm
|
||||
* executed during query or search process.
|
||||
*/
|
||||
public interface MatchStrategy<T> {
|
||||
|
||||
|
||||
@@ -18,6 +18,11 @@ import java.util.Map;
|
||||
import java.util.Set;
|
||||
import java.util.stream.Collectors;
|
||||
|
||||
/***
|
||||
* ModelClusterMapper build a cluster from
|
||||
* connectable data models based on model-rela configuration
|
||||
* and generate SchemaModelClusterMapInfo
|
||||
*/
|
||||
public class ModelClusterMapper implements SchemaMapper {
|
||||
|
||||
@Override
|
||||
|
||||
@@ -11,6 +11,7 @@ import com.tencent.supersonic.chat.api.pojo.request.QueryFilter;
|
||||
import com.tencent.supersonic.chat.api.pojo.request.QueryFilters;
|
||||
import com.tencent.supersonic.chat.api.pojo.request.QueryReq;
|
||||
import com.tencent.supersonic.common.pojo.Constants;
|
||||
import com.tencent.supersonic.knowledge.dictionary.builder.BaseWordBuilder;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.springframework.util.CollectionUtils;
|
||||
import java.util.List;
|
||||
@@ -19,8 +20,6 @@ import java.util.stream.Collectors;
|
||||
|
||||
@Slf4j
|
||||
public class QueryFilterMapper implements SchemaMapper {
|
||||
|
||||
private Long frequency = 9999999L;
|
||||
private double similarity = 1.0;
|
||||
|
||||
@Override
|
||||
@@ -37,7 +36,7 @@ public class QueryFilterMapper implements SchemaMapper {
|
||||
schemaElementMatches = Lists.newArrayList();
|
||||
schemaMapInfo.setMatchedElements(modelId, schemaElementMatches);
|
||||
}
|
||||
addValueSchemaElementMatch(schemaElementMatches, queryReq.getQueryFilters());
|
||||
addValueSchemaElementMatch(queryContext, schemaElementMatches, queryReq.getQueryFilters());
|
||||
}
|
||||
|
||||
private void clearOtherSchemaElementMatch(Long modelId, SchemaMapInfo schemaMapInfo) {
|
||||
@@ -48,7 +47,8 @@ public class QueryFilterMapper implements SchemaMapper {
|
||||
}
|
||||
}
|
||||
|
||||
private List<SchemaElementMatch> addValueSchemaElementMatch(List<SchemaElementMatch> candidateElementMatches,
|
||||
private List<SchemaElementMatch> addValueSchemaElementMatch(QueryContext queryContext,
|
||||
List<SchemaElementMatch> candidateElementMatches,
|
||||
QueryFilters queryFilter) {
|
||||
if (queryFilter == null || CollectionUtils.isEmpty(queryFilter.getFilters())) {
|
||||
return candidateElementMatches;
|
||||
@@ -62,10 +62,11 @@ public class QueryFilterMapper implements SchemaMapper {
|
||||
.name(String.valueOf(filter.getValue()))
|
||||
.type(SchemaElementType.VALUE)
|
||||
.bizName(filter.getBizName())
|
||||
.model(queryContext.getRequest().getModelId())
|
||||
.build();
|
||||
SchemaElementMatch schemaElementMatch = SchemaElementMatch.builder()
|
||||
.element(element)
|
||||
.frequency(frequency)
|
||||
.frequency(BaseWordBuilder.DEFAULT_FREQUENCY)
|
||||
.word(String.valueOf(filter.getValue()))
|
||||
.similarity(similarity)
|
||||
.detectWord(Constants.EMPTY)
|
||||
|
||||
@@ -18,7 +18,8 @@ import org.apache.commons.lang3.StringUtils;
|
||||
import org.springframework.stereotype.Service;
|
||||
|
||||
/**
|
||||
* match strategy implement
|
||||
* SearchMatchStrategy encapsulates a concrete matching algorithm
|
||||
* executed during search process.
|
||||
*/
|
||||
@Service
|
||||
public class SearchMatchStrategy extends BaseMatchStrategy<HanlpMapResult> {
|
||||
|
||||
@@ -0,0 +1,66 @@
|
||||
package com.tencent.supersonic.chat.parser;
|
||||
|
||||
import com.tencent.supersonic.chat.api.pojo.QueryContext;
|
||||
import com.tencent.supersonic.chat.parser.plugin.function.FunctionPromptGenerator;
|
||||
import com.tencent.supersonic.chat.parser.plugin.function.FunctionReq;
|
||||
import com.tencent.supersonic.chat.parser.plugin.function.FunctionResp;
|
||||
import com.tencent.supersonic.chat.parser.sql.llm.OutputFormat;
|
||||
import com.tencent.supersonic.chat.parser.sql.llm.SqlGeneration;
|
||||
import com.tencent.supersonic.chat.parser.sql.llm.SqlGenerationFactory;
|
||||
import com.tencent.supersonic.chat.query.llm.s2sql.LLMReq;
|
||||
import com.tencent.supersonic.chat.query.llm.s2sql.LLMReq.SqlGenerationMode;
|
||||
import com.tencent.supersonic.chat.query.llm.s2sql.LLMResp;
|
||||
import com.tencent.supersonic.common.util.ContextUtils;
|
||||
import dev.langchain4j.model.chat.ChatLanguageModel;
|
||||
import java.util.Map;
|
||||
import java.util.Objects;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.springframework.stereotype.Component;
|
||||
|
||||
/**
|
||||
* LLMProxy based on langchain4j Java version.
|
||||
*/
|
||||
@Slf4j
|
||||
@Component
|
||||
public class JavaLLMProxy implements LLMProxy {
|
||||
|
||||
@Override
|
||||
public boolean isSkip(QueryContext queryContext) {
|
||||
ChatLanguageModel chatLanguageModel = ContextUtils.getBean(ChatLanguageModel.class);
|
||||
if (Objects.isNull(chatLanguageModel)) {
|
||||
log.warn("chatLanguageModel is null, skip :{}", JavaLLMProxy.class.getName());
|
||||
return true;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
public LLMResp query2sql(LLMReq llmReq, String modelClusterKey) {
|
||||
|
||||
SqlGeneration sqlGeneration = SqlGenerationFactory.get(
|
||||
SqlGenerationMode.getMode(llmReq.getSqlGenerationMode()));
|
||||
String modelName = llmReq.getSchema().getModelName();
|
||||
Map<String, Double> sqlWeight = sqlGeneration.generation(llmReq, modelClusterKey);
|
||||
|
||||
LLMResp result = new LLMResp();
|
||||
result.setQuery(llmReq.getQueryText());
|
||||
result.setModelName(modelName);
|
||||
result.setSqlWeight(sqlWeight);
|
||||
return result;
|
||||
}
|
||||
|
||||
@Override
|
||||
public FunctionResp requestFunction(FunctionReq functionReq) {
|
||||
|
||||
FunctionPromptGenerator promptGenerator = ContextUtils.getBean(FunctionPromptGenerator.class);
|
||||
|
||||
String functionCallPrompt = promptGenerator.generateFunctionCallPrompt(functionReq.getQueryText(),
|
||||
functionReq.getPluginConfigs());
|
||||
|
||||
ChatLanguageModel chatLanguageModel = ContextUtils.getBean(ChatLanguageModel.class);
|
||||
|
||||
String functionSelect = chatLanguageModel.generate(functionCallPrompt);
|
||||
|
||||
return OutputFormat.functionCallParse(functionSelect);
|
||||
}
|
||||
|
||||
}
|
||||
@@ -1,15 +1,19 @@
|
||||
package com.tencent.supersonic.chat.parser;
|
||||
|
||||
import com.tencent.supersonic.chat.api.pojo.QueryContext;
|
||||
import com.tencent.supersonic.chat.parser.plugin.function.FunctionReq;
|
||||
import com.tencent.supersonic.chat.parser.plugin.function.FunctionResp;
|
||||
import com.tencent.supersonic.chat.query.llm.s2sql.LLMReq;
|
||||
import com.tencent.supersonic.chat.query.llm.s2sql.LLMResp;
|
||||
|
||||
/**
|
||||
* Unified interpreter for invoking the llm layer.
|
||||
* LLMProxy encapsulates functions performed by LLMs so that multiple
|
||||
* orchestration frameworks (e.g. LangChain in python, LangChain4j in java)
|
||||
* could be used.
|
||||
*/
|
||||
public interface LLMInterpreter {
|
||||
public interface LLMProxy {
|
||||
|
||||
boolean isSkip(QueryContext queryContext);
|
||||
|
||||
LLMResp query2sql(LLMReq llmReq, String modelClusterKey);
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
package com.tencent.supersonic.chat.parser;
|
||||
|
||||
import com.alibaba.fastjson.JSON;
|
||||
import com.tencent.supersonic.chat.api.pojo.QueryContext;
|
||||
import com.tencent.supersonic.chat.config.LLMParserConfig;
|
||||
import com.tencent.supersonic.chat.parser.plugin.function.FunctionCallConfig;
|
||||
import com.tencent.supersonic.chat.parser.plugin.function.FunctionReq;
|
||||
@@ -12,16 +13,32 @@ import com.tencent.supersonic.common.util.JsonUtil;
|
||||
import java.net.URI;
|
||||
import java.net.URL;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.apache.commons.lang3.StringUtils;
|
||||
import org.springframework.http.HttpEntity;
|
||||
import org.springframework.http.HttpHeaders;
|
||||
import org.springframework.http.HttpMethod;
|
||||
import org.springframework.http.MediaType;
|
||||
import org.springframework.http.ResponseEntity;
|
||||
import org.springframework.stereotype.Component;
|
||||
import org.springframework.web.client.RestTemplate;
|
||||
import org.springframework.web.util.UriComponentsBuilder;
|
||||
|
||||
/**
|
||||
* PythonLLMProxy sends requests to LangChain-based python service.
|
||||
*/
|
||||
@Slf4j
|
||||
public class HttpLLMInterpreter implements LLMInterpreter {
|
||||
@Component
|
||||
public class PythonLLMProxy implements LLMProxy {
|
||||
|
||||
@Override
|
||||
public boolean isSkip(QueryContext queryContext) {
|
||||
LLMParserConfig llmParserConfig = ContextUtils.getBean(LLMParserConfig.class);
|
||||
if (StringUtils.isEmpty(llmParserConfig.getUrl())) {
|
||||
log.warn("llmParserUrl is empty, skip :{}", PythonLLMProxy.class.getName());
|
||||
return true;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
public LLMResp query2sql(LLMReq llmReq, String modelClusterKey) {
|
||||
|
||||
@@ -9,10 +9,10 @@ import com.tencent.supersonic.chat.api.pojo.SchemaElement;
|
||||
import com.tencent.supersonic.chat.api.pojo.SemanticParseInfo;
|
||||
import com.tencent.supersonic.chat.api.pojo.SemanticSchema;
|
||||
import com.tencent.supersonic.chat.api.pojo.response.SqlInfo;
|
||||
import com.tencent.supersonic.chat.query.llm.s2sql.S2SQLQuery;
|
||||
import com.tencent.supersonic.chat.query.llm.s2sql.LLMSqlQuery;
|
||||
import com.tencent.supersonic.chat.query.rule.RuleSemanticQuery;
|
||||
import com.tencent.supersonic.chat.service.SemanticService;
|
||||
import com.tencent.supersonic.common.pojo.QueryType;
|
||||
import com.tencent.supersonic.common.pojo.enums.QueryType;
|
||||
import com.tencent.supersonic.common.util.ContextUtils;
|
||||
import com.tencent.supersonic.common.util.jsqlparser.SqlParserSelectHelper;
|
||||
import com.tencent.supersonic.knowledge.service.SchemaService;
|
||||
@@ -26,8 +26,7 @@ import java.util.Set;
|
||||
import java.util.stream.Collectors;
|
||||
|
||||
/**
|
||||
* Query type parser, determine if the query is a metric query, an entity query,
|
||||
* or another type of query.
|
||||
* QueryTypeParser resolves query type as either METRIC or TAG, or ID.
|
||||
*/
|
||||
@Slf4j
|
||||
public class QueryTypeParser implements SemanticParser {
|
||||
@@ -51,11 +50,11 @@ public class QueryTypeParser implements SemanticParser {
|
||||
SemanticParseInfo parseInfo = semanticQuery.getParseInfo();
|
||||
SqlInfo sqlInfo = parseInfo.getSqlInfo();
|
||||
if (Objects.isNull(sqlInfo) || StringUtils.isBlank(sqlInfo.getS2SQL())) {
|
||||
return QueryType.OTHER;
|
||||
return QueryType.ID;
|
||||
}
|
||||
//1. entity queryType
|
||||
Set<Long> modelIds = parseInfo.getModel().getModelIds();
|
||||
if (semanticQuery instanceof RuleSemanticQuery || semanticQuery instanceof S2SQLQuery) {
|
||||
if (semanticQuery instanceof RuleSemanticQuery || semanticQuery instanceof LLMSqlQuery) {
|
||||
//If all the fields in the SELECT statement are of tag type.
|
||||
List<String> selectFields = SqlParserSelectHelper.getSelectFields(sqlInfo.getS2SQL());
|
||||
SemanticService semanticService = ContextUtils.getBean(SemanticService.class);
|
||||
@@ -79,7 +78,7 @@ public class QueryTypeParser implements SemanticParser {
|
||||
return QueryType.METRIC;
|
||||
}
|
||||
}
|
||||
return QueryType.OTHER;
|
||||
return QueryType.ID;
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
@@ -5,7 +5,7 @@ import com.tencent.supersonic.chat.api.component.SemanticQuery;
|
||||
import com.tencent.supersonic.chat.api.pojo.QueryContext;
|
||||
import com.tencent.supersonic.chat.api.pojo.SemanticParseInfo;
|
||||
import com.tencent.supersonic.chat.config.OptimizationConfig;
|
||||
import com.tencent.supersonic.chat.query.llm.s2sql.S2SQLQuery;
|
||||
import com.tencent.supersonic.chat.query.llm.s2sql.LLMSqlQuery;
|
||||
import com.tencent.supersonic.common.util.ContextUtils;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
|
||||
@@ -18,9 +18,9 @@ import lombok.extern.slf4j.Slf4j;
|
||||
public class SatisfactionChecker {
|
||||
|
||||
// check all the parse info in candidate
|
||||
public static boolean check(QueryContext queryContext) {
|
||||
public static boolean isSkip(QueryContext queryContext) {
|
||||
for (SemanticQuery query : queryContext.getCandidateQueries()) {
|
||||
if (query.getQueryMode().equals(S2SQLQuery.QUERY_MODE)) {
|
||||
if (query.getQueryMode().equals(LLMSqlQuery.QUERY_MODE)) {
|
||||
continue;
|
||||
}
|
||||
if (checkThreshold(queryContext.getRequest().getQueryText(), query.getParseInfo())) {
|
||||
|
||||
@@ -1,151 +0,0 @@
|
||||
package com.tencent.supersonic.chat.parser.llm.interpret;
|
||||
|
||||
import com.alibaba.fastjson.JSONObject;
|
||||
import com.google.common.collect.Sets;
|
||||
import com.tencent.supersonic.chat.agent.Agent;
|
||||
import com.tencent.supersonic.chat.agent.tool.AgentToolType;
|
||||
import com.tencent.supersonic.chat.agent.tool.MetricInterpretTool;
|
||||
import com.tencent.supersonic.chat.api.component.SemanticParser;
|
||||
import com.tencent.supersonic.chat.api.pojo.ChatContext;
|
||||
import com.tencent.supersonic.chat.api.pojo.QueryContext;
|
||||
import com.tencent.supersonic.chat.api.pojo.SchemaElement;
|
||||
import com.tencent.supersonic.chat.api.pojo.SchemaElementMatch;
|
||||
import com.tencent.supersonic.chat.api.pojo.SchemaElementType;
|
||||
import com.tencent.supersonic.chat.api.pojo.SemanticParseInfo;
|
||||
import com.tencent.supersonic.chat.api.pojo.request.QueryFilter;
|
||||
import com.tencent.supersonic.chat.api.pojo.request.QueryReq;
|
||||
import com.tencent.supersonic.chat.parser.SatisfactionChecker;
|
||||
import com.tencent.supersonic.chat.query.QueryManager;
|
||||
import com.tencent.supersonic.chat.query.llm.LLMSemanticQuery;
|
||||
import com.tencent.supersonic.chat.query.llm.interpret.MetricInterpretQuery;
|
||||
import com.tencent.supersonic.chat.service.AgentService;
|
||||
import com.tencent.supersonic.chat.service.SemanticService;
|
||||
import com.tencent.supersonic.common.pojo.DateConf;
|
||||
import com.tencent.supersonic.common.pojo.ModelCluster;
|
||||
import com.tencent.supersonic.common.pojo.enums.FilterOperatorEnum;
|
||||
import com.tencent.supersonic.common.pojo.enums.TimeDimensionEnum;
|
||||
import com.tencent.supersonic.common.util.ContextUtils;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.springframework.util.CollectionUtils;
|
||||
import java.util.HashMap;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
import java.util.Set;
|
||||
import java.util.stream.Collectors;
|
||||
|
||||
@Slf4j
|
||||
public class MetricInterpretParser implements SemanticParser {
|
||||
|
||||
@Override
|
||||
public void parse(QueryContext queryContext, ChatContext chatContext) {
|
||||
if (SatisfactionChecker.check(queryContext)) {
|
||||
log.info("skip MetricInterpretParser");
|
||||
return;
|
||||
}
|
||||
Map<Long, MetricInterpretTool> metricInterpretToolMap =
|
||||
getMetricInterpretTools(queryContext.getRequest().getAgentId());
|
||||
log.info("metric interpret tool : {}", metricInterpretToolMap);
|
||||
if (CollectionUtils.isEmpty(metricInterpretToolMap)) {
|
||||
return;
|
||||
}
|
||||
Map<Long, List<SchemaElementMatch>> elementMatches = queryContext.getMapInfo().getModelElementMatches();
|
||||
for (Long modelId : elementMatches.keySet()) {
|
||||
MetricInterpretTool metricInterpretTool = metricInterpretToolMap.get(modelId);
|
||||
if (metricInterpretTool == null) {
|
||||
continue;
|
||||
}
|
||||
if (CollectionUtils.isEmpty(elementMatches.get(modelId))) {
|
||||
continue;
|
||||
}
|
||||
List<MetricOption> metricOptions = metricInterpretTool.getMetricOptions();
|
||||
if (!CollectionUtils.isEmpty(metricOptions)) {
|
||||
List<Long> metricIds = metricOptions.stream()
|
||||
.map(MetricOption::getMetricId).collect(Collectors.toList());
|
||||
String name = metricInterpretTool.getName();
|
||||
buildQuery(modelId, queryContext, metricIds, elementMatches.get(modelId), name);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
private void buildQuery(Long modelId, QueryContext queryContext,
|
||||
List<Long> metricIds, List<SchemaElementMatch> schemaElementMatches, String toolName) {
|
||||
LLMSemanticQuery metricInterpretQuery = QueryManager.createLLMQuery(MetricInterpretQuery.QUERY_MODE);
|
||||
Set<SchemaElement> metrics = getMetrics(metricIds, modelId);
|
||||
SemanticParseInfo semanticParseInfo = buildSemanticParseInfo(modelId, queryContext.getRequest(),
|
||||
metrics, schemaElementMatches, toolName);
|
||||
semanticParseInfo.setQueryMode(metricInterpretQuery.getQueryMode());
|
||||
semanticParseInfo.getProperties().put("queryText", queryContext.getRequest().getQueryText());
|
||||
metricInterpretQuery.setParseInfo(semanticParseInfo);
|
||||
queryContext.getCandidateQueries().add(metricInterpretQuery);
|
||||
}
|
||||
|
||||
public Set<SchemaElement> getMetrics(List<Long> metricIds, Long modelId) {
|
||||
SemanticService semanticService = ContextUtils.getBean(SemanticService.class);
|
||||
List<SchemaElement> metrics = semanticService.getSemanticSchema().getMetrics();
|
||||
return metrics.stream().filter(schemaElement -> metricIds.contains(schemaElement.getId()))
|
||||
.collect(Collectors.toSet());
|
||||
}
|
||||
|
||||
private Map<Long, MetricInterpretTool> getMetricInterpretTools(Integer agentId) {
|
||||
AgentService agentService = ContextUtils.getBean(AgentService.class);
|
||||
Agent agent = agentService.getAgent(agentId);
|
||||
if (agent == null) {
|
||||
return new HashMap<>();
|
||||
}
|
||||
List<String> tools = agent.getTools(AgentToolType.INTERPRET);
|
||||
if (CollectionUtils.isEmpty(tools)) {
|
||||
return new HashMap<>();
|
||||
}
|
||||
List<MetricInterpretTool> metricInterpretTools = tools.stream().map(tool ->
|
||||
JSONObject.parseObject(tool, MetricInterpretTool.class))
|
||||
.filter(tool -> !CollectionUtils.isEmpty(tool.getMetricOptions()))
|
||||
.collect(Collectors.toList());
|
||||
Map<Long, MetricInterpretTool> metricInterpretToolMap = new HashMap<>();
|
||||
for (MetricInterpretTool metricInterpretTool : metricInterpretTools) {
|
||||
metricInterpretToolMap.putIfAbsent(metricInterpretTool.getModelId(),
|
||||
metricInterpretTool);
|
||||
}
|
||||
return metricInterpretToolMap;
|
||||
}
|
||||
|
||||
private SemanticParseInfo buildSemanticParseInfo(Long modelId, QueryReq queryReq, Set<SchemaElement> metrics,
|
||||
List<SchemaElementMatch> schemaElementMatches, String toolName) {
|
||||
SemanticParseInfo semanticParseInfo = new SemanticParseInfo();
|
||||
semanticParseInfo.setMetrics(metrics);
|
||||
SchemaElement dimension = new SchemaElement();
|
||||
dimension.setBizName(TimeDimensionEnum.DAY.getName());
|
||||
semanticParseInfo.setDimensions(Sets.newHashSet(dimension));
|
||||
semanticParseInfo.setElementMatches(schemaElementMatches);
|
||||
semanticParseInfo.setModel(ModelCluster.build(Sets.newHashSet(modelId)));
|
||||
semanticParseInfo.setScore(queryReq.getQueryText().length());
|
||||
DateConf dateConf = new DateConf();
|
||||
dateConf.setDateMode(DateConf.DateMode.RECENT);
|
||||
dateConf.setUnit(15);
|
||||
semanticParseInfo.setDateInfo(dateConf);
|
||||
Map<String, Object> properties = new HashMap<>();
|
||||
properties.put("type", "internal");
|
||||
properties.put("name", toolName);
|
||||
semanticParseInfo.setProperties(properties);
|
||||
fillSemanticParseInfo(semanticParseInfo);
|
||||
return semanticParseInfo;
|
||||
}
|
||||
|
||||
private void fillSemanticParseInfo(SemanticParseInfo semanticParseInfo) {
|
||||
List<SchemaElementMatch> schemaElementMatches = semanticParseInfo.getElementMatches();
|
||||
if (!CollectionUtils.isEmpty(schemaElementMatches)) {
|
||||
schemaElementMatches.stream().filter(schemaElementMatch ->
|
||||
SchemaElementType.VALUE.equals(schemaElementMatch.getElement().getType())
|
||||
|| SchemaElementType.ID.equals(schemaElementMatch.getElement().getType()))
|
||||
.forEach(schemaElementMatch -> {
|
||||
QueryFilter queryFilter = new QueryFilter();
|
||||
queryFilter.setValue(schemaElementMatch.getWord());
|
||||
queryFilter.setElementID(schemaElementMatch.getElement().getId());
|
||||
queryFilter.setName(schemaElementMatch.getElement().getName());
|
||||
queryFilter.setOperator(FilterOperatorEnum.EQUALS);
|
||||
queryFilter.setBizName(schemaElementMatch.getElement().getBizName());
|
||||
semanticParseInfo.getDimensionFilters().add(queryFilter);
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
@@ -1,14 +0,0 @@
|
||||
package com.tencent.supersonic.chat.parser.llm.interpret;
|
||||
|
||||
|
||||
import lombok.AllArgsConstructor;
|
||||
import lombok.Data;
|
||||
import lombok.NoArgsConstructor;
|
||||
|
||||
@Data
|
||||
@AllArgsConstructor
|
||||
@NoArgsConstructor
|
||||
public class MetricOption {
|
||||
|
||||
private Long metricId;
|
||||
}
|
||||
@@ -1,8 +1,7 @@
|
||||
package com.tencent.supersonic.chat.parser;
|
||||
package com.tencent.supersonic.chat.parser.plugin;
|
||||
|
||||
public enum ParseMode {
|
||||
|
||||
RULE,
|
||||
EMBEDDING_RECALL,
|
||||
FUNCTION_CALL;
|
||||
|
||||
@@ -27,6 +27,10 @@ import java.util.List;
|
||||
import java.util.Map;
|
||||
import java.util.Set;
|
||||
|
||||
|
||||
/**
|
||||
* PluginParser defines the basic process and common methods for recalling plugins.
|
||||
*/
|
||||
public abstract class PluginParser implements SemanticParser {
|
||||
|
||||
@Override
|
||||
@@ -99,7 +103,6 @@ public abstract class PluginParser implements SemanticParser {
|
||||
return semanticParseInfo;
|
||||
}
|
||||
|
||||
|
||||
private void fillSemanticParseInfo(SemanticParseInfo semanticParseInfo) {
|
||||
List<SchemaElementMatch> schemaElementMatches = semanticParseInfo.getElementMatches();
|
||||
if (CollectionUtils.isEmpty(schemaElementMatches)) {
|
||||
|
||||
@@ -2,9 +2,8 @@ package com.tencent.supersonic.chat.parser.plugin.embedding;
|
||||
|
||||
import com.google.common.collect.Lists;
|
||||
import com.tencent.supersonic.chat.api.pojo.QueryContext;
|
||||
import com.tencent.supersonic.chat.parser.HttpLLMInterpreter;
|
||||
import com.tencent.supersonic.chat.parser.LLMInterpreter;
|
||||
import com.tencent.supersonic.chat.parser.ParseMode;
|
||||
import com.tencent.supersonic.chat.parser.PythonLLMProxy;
|
||||
import com.tencent.supersonic.chat.parser.plugin.ParseMode;
|
||||
import com.tencent.supersonic.chat.parser.plugin.PluginParser;
|
||||
import com.tencent.supersonic.chat.plugin.Plugin;
|
||||
import com.tencent.supersonic.chat.plugin.PluginManager;
|
||||
@@ -12,25 +11,29 @@ import com.tencent.supersonic.chat.plugin.PluginRecallResult;
|
||||
import com.tencent.supersonic.chat.utils.ComponentFactory;
|
||||
import com.tencent.supersonic.common.config.EmbeddingConfig;
|
||||
import com.tencent.supersonic.common.util.ContextUtils;
|
||||
import java.util.Comparator;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
import java.util.Set;
|
||||
import java.util.stream.Collectors;
|
||||
import com.tencent.supersonic.common.util.embedding.Retrieval;
|
||||
import com.tencent.supersonic.common.util.embedding.RetrieveQueryResult;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.apache.commons.lang3.StringUtils;
|
||||
import org.apache.commons.lang3.tuple.Pair;
|
||||
import org.springframework.util.CollectionUtils;
|
||||
|
||||
@Slf4j
|
||||
public class EmbeddingBasedParser extends PluginParser {
|
||||
import java.util.Comparator;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
import java.util.Set;
|
||||
import java.util.stream.Collectors;
|
||||
|
||||
protected LLMInterpreter llmInterpreter = ComponentFactory.getLLMInterpreter();
|
||||
/**
|
||||
* EmbeddingRecallParser is an implementation of a recall plugin based on Embedding
|
||||
*/
|
||||
@Slf4j
|
||||
public class EmbeddingRecallParser extends PluginParser {
|
||||
|
||||
@Override
|
||||
public boolean checkPreCondition(QueryContext queryContext) {
|
||||
EmbeddingConfig embeddingConfig = ContextUtils.getBean(EmbeddingConfig.class);
|
||||
if (StringUtils.isBlank(embeddingConfig.getUrl()) && llmInterpreter instanceof HttpLLMInterpreter) {
|
||||
if (StringUtils.isBlank(embeddingConfig.getUrl()) && ComponentFactory.getLLMProxy() instanceof PythonLLMProxy) {
|
||||
return false;
|
||||
}
|
||||
List<Plugin> plugins = getPluginList(queryContext);
|
||||
@@ -40,13 +43,13 @@ public class EmbeddingBasedParser extends PluginParser {
|
||||
@Override
|
||||
public PluginRecallResult recallPlugin(QueryContext queryContext) {
|
||||
String text = queryContext.getRequest().getQueryText();
|
||||
List<RecallRetrieval> embeddingRetrievals = embeddingRecall(text);
|
||||
List<Retrieval> embeddingRetrievals = embeddingRecall(text);
|
||||
if (CollectionUtils.isEmpty(embeddingRetrievals)) {
|
||||
return null;
|
||||
}
|
||||
List<Plugin> plugins = getPluginList(queryContext);
|
||||
Map<Long, Plugin> pluginMap = plugins.stream().collect(Collectors.toMap(Plugin::getId, p -> p));
|
||||
for (RecallRetrieval embeddingRetrieval : embeddingRetrievals) {
|
||||
for (Retrieval embeddingRetrieval : embeddingRetrievals) {
|
||||
Plugin plugin = pluginMap.get(Long.parseLong(embeddingRetrieval.getId()));
|
||||
if (plugin == null) {
|
||||
continue;
|
||||
@@ -59,7 +62,7 @@ public class EmbeddingBasedParser extends PluginParser {
|
||||
continue;
|
||||
}
|
||||
plugin.setParseMode(ParseMode.EMBEDDING_RECALL);
|
||||
double distance = Double.parseDouble(embeddingRetrieval.getDistance());
|
||||
double distance = embeddingRetrieval.getDistance();
|
||||
double score = queryContext.getRequest().getQueryText().length() * (1 - distance);
|
||||
return PluginRecallResult.builder()
|
||||
.plugin(plugin).modelIds(modelList).score(score).distance(distance).build();
|
||||
@@ -68,14 +71,15 @@ public class EmbeddingBasedParser extends PluginParser {
|
||||
return null;
|
||||
}
|
||||
|
||||
public List<RecallRetrieval> embeddingRecall(String embeddingText) {
|
||||
public List<Retrieval> embeddingRecall(String embeddingText) {
|
||||
try {
|
||||
PluginManager pluginManager = ContextUtils.getBean(PluginManager.class);
|
||||
EmbeddingResp embeddingResp = pluginManager.recognize(embeddingText);
|
||||
List<RecallRetrieval> embeddingRetrievals = embeddingResp.getRetrieval();
|
||||
RetrieveQueryResult embeddingResp = pluginManager.recognize(embeddingText);
|
||||
|
||||
List<Retrieval> embeddingRetrievals = embeddingResp.getRetrieval();
|
||||
if (!CollectionUtils.isEmpty(embeddingRetrievals)) {
|
||||
embeddingRetrievals = embeddingRetrievals.stream().sorted(Comparator.comparingDouble(o ->
|
||||
Math.abs(Double.parseDouble(o.getDistance())))).collect(Collectors.toList());
|
||||
Math.abs(o.getDistance()))).collect(Collectors.toList());
|
||||
embeddingResp.setRetrieval(embeddingRetrievals);
|
||||
}
|
||||
return embeddingRetrievals;
|
||||
@@ -6,7 +6,7 @@ import lombok.Data;
|
||||
import java.util.List;
|
||||
|
||||
@Data
|
||||
public class EmbeddingResp {
|
||||
public class RecallRetrievalResp {
|
||||
|
||||
private String query;
|
||||
|
||||
@@ -1,39 +1,40 @@
|
||||
package com.tencent.supersonic.chat.parser.plugin.function;
|
||||
|
||||
import com.tencent.supersonic.chat.api.pojo.QueryContext;
|
||||
import com.tencent.supersonic.chat.parser.HttpLLMInterpreter;
|
||||
import com.tencent.supersonic.chat.parser.LLMInterpreter;
|
||||
import com.tencent.supersonic.chat.parser.ParseMode;
|
||||
import com.tencent.supersonic.chat.parser.PythonLLMProxy;
|
||||
import com.tencent.supersonic.chat.parser.plugin.ParseMode;
|
||||
import com.tencent.supersonic.chat.parser.plugin.PluginParser;
|
||||
import com.tencent.supersonic.chat.plugin.Plugin;
|
||||
import com.tencent.supersonic.chat.plugin.PluginManager;
|
||||
import com.tencent.supersonic.chat.plugin.PluginParseConfig;
|
||||
import com.tencent.supersonic.chat.plugin.PluginRecallResult;
|
||||
import com.tencent.supersonic.chat.query.llm.s2sql.S2SQLQuery;
|
||||
import com.tencent.supersonic.chat.query.llm.s2sql.LLMSqlQuery;
|
||||
import com.tencent.supersonic.chat.service.PluginService;
|
||||
import com.tencent.supersonic.chat.utils.ComponentFactory;
|
||||
import com.tencent.supersonic.common.util.ContextUtils;
|
||||
import com.tencent.supersonic.common.util.JsonUtil;
|
||||
import java.util.List;
|
||||
import java.util.Objects;
|
||||
import java.util.Optional;
|
||||
import java.util.Set;
|
||||
import java.util.stream.Collectors;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.apache.commons.lang3.StringUtils;
|
||||
import org.apache.commons.lang3.tuple.Pair;
|
||||
import org.springframework.util.CollectionUtils;
|
||||
|
||||
@Slf4j
|
||||
public class FunctionBasedParser extends PluginParser {
|
||||
import java.util.List;
|
||||
import java.util.Objects;
|
||||
import java.util.Optional;
|
||||
import java.util.Set;
|
||||
import java.util.stream.Collectors;
|
||||
|
||||
protected LLMInterpreter llmInterpreter = ComponentFactory.getLLMInterpreter();
|
||||
/**
|
||||
* FunctionCallParser is an implementation of a recall plugin based on FunctionCall
|
||||
*/
|
||||
@Slf4j
|
||||
public class FunctionCallParser extends PluginParser {
|
||||
|
||||
@Override
|
||||
public boolean checkPreCondition(QueryContext queryContext) {
|
||||
FunctionCallConfig functionCallConfig = ContextUtils.getBean(FunctionCallConfig.class);
|
||||
String functionUrl = functionCallConfig.getUrl();
|
||||
if (StringUtils.isBlank(functionUrl) && llmInterpreter instanceof HttpLLMInterpreter) {
|
||||
if (StringUtils.isBlank(functionUrl) && ComponentFactory.getLLMProxy() instanceof PythonLLMProxy) {
|
||||
log.info("functionUrl:{}, skip function parser, queryText:{}", functionUrl,
|
||||
queryContext.getRequest().getQueryText());
|
||||
return false;
|
||||
@@ -84,7 +85,7 @@ public class FunctionBasedParser extends PluginParser {
|
||||
FunctionReq functionReq = FunctionReq.builder()
|
||||
.queryText(queryContext.getRequest().getQueryText())
|
||||
.pluginConfigs(pluginToFunctionCall).build();
|
||||
functionResp = llmInterpreter.requestFunction(functionReq);
|
||||
functionResp = ComponentFactory.getLLMProxy().requestFunction(functionReq);
|
||||
}
|
||||
return functionResp;
|
||||
}
|
||||
@@ -97,7 +98,7 @@ public class FunctionBasedParser extends PluginParser {
|
||||
log.info("user decide Model:{}", modelId);
|
||||
List<Plugin> plugins = getPluginList(queryContext);
|
||||
List<PluginParseConfig> functionDOList = plugins.stream().filter(plugin -> {
|
||||
if (S2SQLQuery.QUERY_MODE.equalsIgnoreCase(plugin.getType())) {
|
||||
if (LLMSqlQuery.QUERY_MODE.equalsIgnoreCase(plugin.getType())) {
|
||||
return false;
|
||||
}
|
||||
if (plugin.getParseModeConfig() == null) {
|
||||
@@ -0,0 +1,44 @@
|
||||
package com.tencent.supersonic.chat.parser.plugin.function;
|
||||
|
||||
import com.tencent.supersonic.chat.parser.sql.llm.InputFormat;
|
||||
import com.tencent.supersonic.chat.plugin.PluginParseConfig;
|
||||
import java.util.List;
|
||||
import java.util.stream.Collectors;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.springframework.stereotype.Component;
|
||||
|
||||
@Component
|
||||
@Slf4j
|
||||
public class FunctionPromptGenerator {
|
||||
|
||||
public String generateFunctionCallPrompt(String queryText, List<PluginParseConfig> toolConfigList) {
|
||||
List<String> toolExplainList = toolConfigList.stream()
|
||||
.map(this::constructPluginPrompt)
|
||||
.collect(Collectors.toList());
|
||||
String functionList = String.join(InputFormat.SEPERATOR, toolExplainList);
|
||||
return constructTaskPrompt(queryText, functionList);
|
||||
}
|
||||
|
||||
public String constructPluginPrompt(PluginParseConfig parseConfig) {
|
||||
String toolName = parseConfig.getName();
|
||||
String toolDescription = parseConfig.getDescription();
|
||||
List<String> toolExamples = parseConfig.getExamples();
|
||||
|
||||
StringBuilder prompt = new StringBuilder();
|
||||
prompt.append("【工具名称】\n").append(toolName).append("\n");
|
||||
prompt.append("【工具描述】\n").append(toolDescription).append("\n");
|
||||
prompt.append("【工具适用问题示例】\n");
|
||||
for (String example : toolExamples) {
|
||||
prompt.append(example).append("\n");
|
||||
}
|
||||
return prompt.toString();
|
||||
}
|
||||
|
||||
public String constructTaskPrompt(String queryText, String functionList) {
|
||||
String instruction = String.format("问题为:%s\n请根据问题和工具的描述,选择对应的工具,完成任务。"
|
||||
+ "请注意,只能选择1个工具。请一步一步地分析选择工具的原因(每个工具的【工具适用问题示例】是选择的重要参考依据),"
|
||||
+ "并给出最终选择,输出格式为json,key为’分析过程‘, ’选择工具‘", queryText);
|
||||
|
||||
return String.format("工具选择如下:\n\n%s\n\n【任务说明】\n%s", functionList, instruction);
|
||||
}
|
||||
}
|
||||
@@ -1,30 +0,0 @@
|
||||
package com.tencent.supersonic.chat.parser.rule;
|
||||
|
||||
import com.tencent.supersonic.chat.api.component.SemanticParser;
|
||||
import com.tencent.supersonic.chat.api.pojo.ChatContext;
|
||||
import com.tencent.supersonic.chat.api.pojo.QueryContext;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
|
||||
import java.util.Arrays;
|
||||
import java.util.List;
|
||||
|
||||
/**
|
||||
* RuleBasedParser acts as a container that incorporates a group of
|
||||
* rule-based semantic parsers.
|
||||
*/
|
||||
@Slf4j
|
||||
public class RuleBasedParser implements SemanticParser {
|
||||
|
||||
private static List<SemanticParser> ruleParsers = Arrays.asList(
|
||||
new QueryModeParser(),
|
||||
new ContextInheritParser(),
|
||||
new AgentCheckParser(),
|
||||
new TimeRangeParser(),
|
||||
new AggregateTypeParser()
|
||||
);
|
||||
|
||||
@Override
|
||||
public void parse(QueryContext queryContext, ChatContext chatContext) {
|
||||
ruleParsers.stream().forEach(p -> p.parse(queryContext, chatContext));
|
||||
}
|
||||
}
|
||||
@@ -1,4 +1,4 @@
|
||||
package com.tencent.supersonic.chat.parser.llm.s2sql;
|
||||
package com.tencent.supersonic.chat.parser.sql.llm;
|
||||
|
||||
import com.tencent.supersonic.chat.api.component.SemanticQuery;
|
||||
import com.tencent.supersonic.chat.api.pojo.ChatContext;
|
||||
@@ -86,7 +86,6 @@ public class HeuristicModelResolver implements ModelResolver {
|
||||
return null;
|
||||
}
|
||||
|
||||
|
||||
public static Map<String, ModelMatchResult> getModelTypeMap(SchemaModelClusterMapInfo schemaMap) {
|
||||
Map<String, ModelMatchResult> modelCount = new HashMap<>();
|
||||
for (Map.Entry<String, List<SchemaElementMatch>> entry : schemaMap.getModelElementMatches().entrySet()) {
|
||||
@@ -114,7 +113,6 @@ public class HeuristicModelResolver implements ModelResolver {
|
||||
return modelCount;
|
||||
}
|
||||
|
||||
|
||||
public String resolve(QueryContext queryContext, ChatContext chatCtx, Set<Long> restrictiveModels) {
|
||||
SchemaModelClusterMapInfo mapInfo = queryContext.getModelClusterMapInfo();
|
||||
Set<String> matchedModelClusters = mapInfo.getElementMatchesByModelIds(restrictiveModels).keySet();
|
||||
@@ -0,0 +1,42 @@
|
||||
package com.tencent.supersonic.chat.parser.sql.llm;
|
||||
|
||||
import java.util.ArrayList;
|
||||
import java.util.HashMap;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
|
||||
@Slf4j
|
||||
public class InputFormat {
|
||||
|
||||
public static final String SEPERATOR = "\n\n";
|
||||
|
||||
public static String format(String template, List<String> templateKey,
|
||||
List<Map<String, String>> toFormatList) {
|
||||
List<String> result = new ArrayList<>();
|
||||
|
||||
for (Map<String, String> formatItem : toFormatList) {
|
||||
Map<String, String> retrievalMeta = subDict(formatItem, templateKey);
|
||||
result.add(format(template, retrievalMeta));
|
||||
}
|
||||
|
||||
return String.join(SEPERATOR, result);
|
||||
}
|
||||
|
||||
public static String format(String input, Map<String, String> replacements) {
|
||||
for (Map.Entry<String, String> entry : replacements.entrySet()) {
|
||||
input = input.replace(entry.getKey(), entry.getValue());
|
||||
}
|
||||
return input;
|
||||
}
|
||||
|
||||
private static Map<String, String> subDict(Map<String, String> dict, List<String> keys) {
|
||||
Map<String, String> subDict = new HashMap<>();
|
||||
for (String key : keys) {
|
||||
if (dict.containsKey(key)) {
|
||||
subDict.put(key, dict.get(key));
|
||||
}
|
||||
}
|
||||
return subDict;
|
||||
}
|
||||
}
|
||||
@@ -1,7 +1,7 @@
|
||||
package com.tencent.supersonic.chat.parser.llm.s2sql;
|
||||
package com.tencent.supersonic.chat.parser.sql.llm;
|
||||
|
||||
import com.tencent.supersonic.chat.agent.tool.AgentToolType;
|
||||
import com.tencent.supersonic.chat.agent.tool.CommonAgentTool;
|
||||
import com.tencent.supersonic.chat.agent.AgentToolType;
|
||||
import com.tencent.supersonic.chat.agent.NL2SQLTool;
|
||||
import com.tencent.supersonic.chat.api.component.SemanticInterpreter;
|
||||
import com.tencent.supersonic.chat.api.pojo.ChatContext;
|
||||
import com.tencent.supersonic.chat.api.pojo.QueryContext;
|
||||
@@ -12,7 +12,6 @@ import com.tencent.supersonic.chat.api.pojo.SemanticSchema;
|
||||
import com.tencent.supersonic.chat.api.pojo.request.QueryReq;
|
||||
import com.tencent.supersonic.chat.config.LLMParserConfig;
|
||||
import com.tencent.supersonic.chat.config.OptimizationConfig;
|
||||
import com.tencent.supersonic.chat.parser.LLMInterpreter;
|
||||
import com.tencent.supersonic.chat.parser.SatisfactionChecker;
|
||||
import com.tencent.supersonic.chat.query.llm.s2sql.LLMReq;
|
||||
import com.tencent.supersonic.chat.query.llm.s2sql.LLMReq.ElementValue;
|
||||
@@ -26,13 +25,6 @@ import com.tencent.supersonic.common.util.DateUtils;
|
||||
import com.tencent.supersonic.knowledge.service.SchemaService;
|
||||
import com.tencent.supersonic.semantic.api.model.pojo.SchemaItem;
|
||||
import com.tencent.supersonic.semantic.api.model.response.ModelSchemaResp;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.apache.commons.lang3.StringUtils;
|
||||
import org.apache.commons.lang3.tuple.Pair;
|
||||
import org.springframework.beans.factory.annotation.Autowired;
|
||||
import org.springframework.stereotype.Service;
|
||||
import org.springframework.util.CollectionUtils;
|
||||
|
||||
import java.util.ArrayList;
|
||||
import java.util.Comparator;
|
||||
import java.util.HashSet;
|
||||
@@ -42,13 +34,17 @@ import java.util.Objects;
|
||||
import java.util.Optional;
|
||||
import java.util.Set;
|
||||
import java.util.stream.Collectors;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.apache.commons.lang3.StringUtils;
|
||||
import org.apache.commons.lang3.tuple.Pair;
|
||||
import org.springframework.beans.factory.annotation.Autowired;
|
||||
import org.springframework.stereotype.Service;
|
||||
import org.springframework.util.CollectionUtils;
|
||||
|
||||
@Slf4j
|
||||
@Service
|
||||
public class LLMRequestService {
|
||||
|
||||
protected LLMInterpreter llmInterpreter = ComponentFactory.getLLMInterpreter();
|
||||
|
||||
protected SemanticInterpreter semanticInterpreter = ComponentFactory.getSemanticLayer();
|
||||
@Autowired
|
||||
private LLMParserConfig llmParserConfig;
|
||||
@@ -59,22 +55,19 @@ public class LLMRequestService {
|
||||
@Autowired
|
||||
private OptimizationConfig optimizationConfig;
|
||||
|
||||
|
||||
public boolean check(QueryContext queryCtx) {
|
||||
QueryReq request = queryCtx.getRequest();
|
||||
if (StringUtils.isEmpty(llmParserConfig.getUrl())) {
|
||||
log.info("llm parser url is empty, skip {} , llmParserConfig:{}", LLMS2SQLParser.class, llmParserConfig);
|
||||
public boolean isSkip(QueryContext queryCtx) {
|
||||
if (ComponentFactory.getLLMProxy().isSkip(queryCtx)) {
|
||||
return true;
|
||||
}
|
||||
if (SatisfactionChecker.check(queryCtx)) {
|
||||
log.info("skip {}, queryText:{}", LLMS2SQLParser.class, request.getQueryText());
|
||||
if (SatisfactionChecker.isSkip(queryCtx)) {
|
||||
log.info("skip {}, queryText:{}", LLMSqlParser.class, queryCtx.getRequest().getQueryText());
|
||||
return true;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
public ModelCluster getModelCluster(QueryContext queryCtx, ChatContext chatCtx, Integer agentId) {
|
||||
Set<Long> distinctModelIds = agentService.getModelIds(agentId, AgentToolType.LLM_S2SQL);
|
||||
Set<Long> distinctModelIds = agentService.getModelIds(agentId, AgentToolType.NL2SQL_LLM);
|
||||
if (agentService.containsAllModel(distinctModelIds)) {
|
||||
distinctModelIds = new HashSet<>();
|
||||
}
|
||||
@@ -84,10 +77,10 @@ public class LLMRequestService {
|
||||
return ModelCluster.build(modelCluster);
|
||||
}
|
||||
|
||||
public CommonAgentTool getParserTool(QueryReq request, Set<Long> modelIdSet) {
|
||||
List<CommonAgentTool> commonAgentTools = agentService.getParserTools(request.getAgentId(),
|
||||
AgentToolType.LLM_S2SQL);
|
||||
Optional<CommonAgentTool> llmParserTool = commonAgentTools.stream()
|
||||
public NL2SQLTool getParserTool(QueryReq request, Set<Long> modelIdSet) {
|
||||
List<NL2SQLTool> commonAgentTools = agentService.getParserTools(request.getAgentId(),
|
||||
AgentToolType.NL2SQL_LLM);
|
||||
Optional<NL2SQLTool> llmParserTool = commonAgentTools.stream()
|
||||
.filter(tool -> {
|
||||
List<Long> modelIds = tool.getModelIds();
|
||||
if (agentService.containsAllModel(new HashSet<>(modelIds))) {
|
||||
@@ -105,7 +98,7 @@ public class LLMRequestService {
|
||||
}
|
||||
|
||||
public LLMReq getLlmReq(QueryContext queryCtx, SemanticSchema semanticSchema,
|
||||
ModelCluster modelCluster, List<ElementValue> linkingValues) {
|
||||
ModelCluster modelCluster, List<ElementValue> linkingValues) {
|
||||
Map<Long, String> modelIdToName = semanticSchema.getModelIdToName();
|
||||
String queryText = queryCtx.getRequest().getQueryText();
|
||||
|
||||
@@ -134,20 +127,21 @@ public class LLMRequestService {
|
||||
}
|
||||
llmReq.setLinking(linking);
|
||||
|
||||
String currentDate = S2SQLDateHelper.getReferenceDate(firstModelId);
|
||||
String currentDate = S2SqlDateHelper.getReferenceDate(firstModelId);
|
||||
if (StringUtils.isEmpty(currentDate)) {
|
||||
currentDate = DateUtils.getBeforeDate(0);
|
||||
}
|
||||
llmReq.setCurrentDate(currentDate);
|
||||
llmReq.setSqlGenerationMode(optimizationConfig.getSqlGenerationMode().getName());
|
||||
return llmReq;
|
||||
}
|
||||
|
||||
public LLMResp requestLLM(LLMReq llmReq, String modelClusterKey) {
|
||||
return llmInterpreter.query2sql(llmReq, modelClusterKey);
|
||||
return ComponentFactory.getLLMProxy().query2sql(llmReq, modelClusterKey);
|
||||
}
|
||||
|
||||
protected List<String> getFieldNameList(QueryContext queryCtx, ModelCluster modelCluster,
|
||||
LLMParserConfig llmParserConfig) {
|
||||
LLMParserConfig llmParserConfig) {
|
||||
|
||||
Set<String> results = getTopNFieldNames(modelCluster, llmParserConfig);
|
||||
|
||||
@@ -192,7 +186,6 @@ public class LLMRequestService {
|
||||
return extraInfoSb.toString();
|
||||
}
|
||||
|
||||
|
||||
protected List<ElementValue> getValueList(QueryContext queryCtx, ModelCluster modelCluster) {
|
||||
Map<Long, String> itemIdToName = getItemIdToName(modelCluster);
|
||||
|
||||
@@ -223,7 +216,6 @@ public class LLMRequestService {
|
||||
.collect(Collectors.toMap(SchemaElement::getId, SchemaElement::getName, (value1, value2) -> value2));
|
||||
}
|
||||
|
||||
|
||||
private Set<String> getTopNFieldNames(ModelCluster modelCluster, LLMParserConfig llmParserConfig) {
|
||||
SemanticSchema semanticSchema = schemaService.getSemanticSchema();
|
||||
Set<String> results = semanticSchema.getDimensions(modelCluster.getModelIds()).stream()
|
||||
@@ -242,7 +234,6 @@ public class LLMRequestService {
|
||||
return results;
|
||||
}
|
||||
|
||||
|
||||
protected Set<String> getMatchedFieldNames(QueryContext queryCtx, ModelCluster modelCluster) {
|
||||
Map<Long, String> itemIdToName = getItemIdToName(modelCluster);
|
||||
List<SchemaElementMatch> matchedElements = queryCtx.getModelClusterMapInfo()
|
||||
@@ -1,12 +1,12 @@
|
||||
package com.tencent.supersonic.chat.parser.llm.s2sql;
|
||||
package com.tencent.supersonic.chat.parser.sql.llm;
|
||||
|
||||
import com.tencent.supersonic.chat.agent.tool.CommonAgentTool;
|
||||
import com.tencent.supersonic.chat.agent.NL2SQLTool;
|
||||
import com.tencent.supersonic.chat.api.pojo.QueryContext;
|
||||
import com.tencent.supersonic.chat.api.pojo.SemanticParseInfo;
|
||||
import com.tencent.supersonic.chat.query.QueryManager;
|
||||
import com.tencent.supersonic.chat.query.llm.LLMSemanticQuery;
|
||||
import com.tencent.supersonic.chat.query.llm.s2sql.LLMResp;
|
||||
import com.tencent.supersonic.chat.query.llm.s2sql.S2SQLQuery;
|
||||
import com.tencent.supersonic.chat.query.llm.s2sql.LLMSqlQuery;
|
||||
import com.tencent.supersonic.common.pojo.Constants;
|
||||
import com.tencent.supersonic.common.util.jsqlparser.SqlParserEqualHelper;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
@@ -25,10 +25,10 @@ public class LLMResponseService {
|
||||
if (Objects.isNull(weight)) {
|
||||
weight = 0D;
|
||||
}
|
||||
LLMSemanticQuery semanticQuery = QueryManager.createLLMQuery(S2SQLQuery.QUERY_MODE);
|
||||
LLMSemanticQuery semanticQuery = QueryManager.createLLMQuery(LLMSqlQuery.QUERY_MODE);
|
||||
SemanticParseInfo parseInfo = semanticQuery.getParseInfo();
|
||||
parseInfo.setModel(parseResult.getModelCluster());
|
||||
CommonAgentTool commonAgentTool = parseResult.getCommonAgentTool();
|
||||
NL2SQLTool commonAgentTool = parseResult.getCommonAgentTool();
|
||||
parseInfo.getElementMatches().addAll(queryCtx.getModelClusterMapInfo()
|
||||
.getMatchedElements(parseInfo.getModelClusterKey()));
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
package com.tencent.supersonic.chat.parser.llm.s2sql;
|
||||
package com.tencent.supersonic.chat.parser.sql.llm;
|
||||
|
||||
import com.tencent.supersonic.chat.agent.tool.CommonAgentTool;
|
||||
import com.tencent.supersonic.chat.agent.NL2SQLTool;
|
||||
import com.tencent.supersonic.chat.api.component.SemanticParser;
|
||||
import com.tencent.supersonic.chat.api.pojo.ChatContext;
|
||||
import com.tencent.supersonic.chat.api.pojo.QueryContext;
|
||||
@@ -21,7 +21,7 @@ import java.util.Map;
|
||||
import java.util.Objects;
|
||||
|
||||
@Slf4j
|
||||
public class LLMS2SQLParser implements SemanticParser {
|
||||
public class LLMSqlParser implements SemanticParser {
|
||||
|
||||
@Override
|
||||
public void parse(QueryContext queryCtx, ChatContext chatCtx) {
|
||||
@@ -29,7 +29,7 @@ public class LLMS2SQLParser implements SemanticParser {
|
||||
LLMRequestService requestService = ContextUtils.getBean(LLMRequestService.class);
|
||||
SemanticService semanticService = ContextUtils.getBean(SemanticService.class);
|
||||
//1.determine whether to skip this parser.
|
||||
if (requestService.check(queryCtx)) {
|
||||
if (requestService.isSkip(queryCtx)) {
|
||||
return;
|
||||
}
|
||||
try {
|
||||
@@ -39,9 +39,9 @@ public class LLMS2SQLParser implements SemanticParser {
|
||||
return;
|
||||
}
|
||||
//3.get agent tool and determine whether to skip this parser.
|
||||
CommonAgentTool commonAgentTool = requestService.getParserTool(request, modelCluster.getModelIds());
|
||||
NL2SQLTool commonAgentTool = requestService.getParserTool(request, modelCluster.getModelIds());
|
||||
if (Objects.isNull(commonAgentTool)) {
|
||||
log.info("no tool in this agent, skip {}", LLMS2SQLParser.class);
|
||||
log.info("no tool in this agent, skip {}", LLMSqlParser.class);
|
||||
return;
|
||||
}
|
||||
//4.construct a request, call the API for the large model, and retrieve the results.
|
||||
@@ -79,5 +79,4 @@ public class LLMS2SQLParser implements SemanticParser {
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
}
|
||||
@@ -1,4 +1,4 @@
|
||||
package com.tencent.supersonic.chat.parser.llm.s2sql;
|
||||
package com.tencent.supersonic.chat.parser.sql.llm;
|
||||
|
||||
import lombok.Data;
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
package com.tencent.supersonic.chat.parser.llm.s2sql;
|
||||
package com.tencent.supersonic.chat.parser.sql.llm;
|
||||
|
||||
|
||||
import com.tencent.supersonic.chat.api.pojo.ChatContext;
|
||||
@@ -0,0 +1,75 @@
|
||||
package com.tencent.supersonic.chat.parser.sql.llm;
|
||||
|
||||
|
||||
import com.tencent.supersonic.chat.config.OptimizationConfig;
|
||||
import com.tencent.supersonic.chat.query.llm.s2sql.LLMReq;
|
||||
import com.tencent.supersonic.chat.query.llm.s2sql.LLMReq.SqlGenerationMode;
|
||||
import com.tencent.supersonic.common.util.JsonUtil;
|
||||
import dev.langchain4j.data.message.AiMessage;
|
||||
import dev.langchain4j.model.chat.ChatLanguageModel;
|
||||
import dev.langchain4j.model.input.Prompt;
|
||||
import dev.langchain4j.model.input.PromptTemplate;
|
||||
import dev.langchain4j.model.output.Response;
|
||||
import java.util.HashMap;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
import java.util.concurrent.CopyOnWriteArrayList;
|
||||
import java.util.stream.Collectors;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.apache.commons.lang3.tuple.Pair;
|
||||
import org.springframework.beans.factory.InitializingBean;
|
||||
import org.springframework.beans.factory.annotation.Autowired;
|
||||
import org.springframework.stereotype.Service;
|
||||
|
||||
@Service
|
||||
@Slf4j
|
||||
public class OnePassSCSqlGeneration implements SqlGeneration, InitializingBean {
|
||||
|
||||
@Autowired
|
||||
private ChatLanguageModel chatLanguageModel;
|
||||
|
||||
@Autowired
|
||||
private SqlExampleLoader sqlExampleLoader;
|
||||
|
||||
@Autowired
|
||||
private OptimizationConfig optimizationConfig;
|
||||
|
||||
@Autowired
|
||||
private SqlPromptGenerator sqlPromptGenerator;
|
||||
|
||||
@Override
|
||||
public Map<String, Double> generation(LLMReq llmReq, String modelClusterKey) {
|
||||
//1.retriever sqlExamples and generate exampleListPool
|
||||
List<Map<String, String>> sqlExamples = sqlExampleLoader.retrieverSqlExamples(llmReq.getQueryText(),
|
||||
optimizationConfig.getText2sqlCollectionName(), optimizationConfig.getText2sqlExampleNum());
|
||||
|
||||
List<List<Map<String, String>>> exampleListPool = sqlPromptGenerator.getExampleCombos(sqlExamples,
|
||||
optimizationConfig.getText2sqlFewShotsNum(), optimizationConfig.getText2sqlSelfConsistencyNum());
|
||||
|
||||
//2.generator linking and sql prompt by sqlExamples,and parallel generate response.
|
||||
List<String> linkingSqlPromptPool = sqlPromptGenerator.generatePromptPool(llmReq, exampleListPool, true);
|
||||
List<String> llmResults = new CopyOnWriteArrayList<>();
|
||||
linkingSqlPromptPool.parallelStream().forEach(linkingSqlPrompt -> {
|
||||
Prompt prompt = PromptTemplate.from(JsonUtil.toString(linkingSqlPrompt))
|
||||
.apply(new HashMap<>());
|
||||
Response<AiMessage> response = chatLanguageModel.generate(prompt.toSystemMessage());
|
||||
llmResults.add(response.content().text());
|
||||
}
|
||||
);
|
||||
//3.format response.
|
||||
List<String> schemaLinkingResults = llmResults.stream()
|
||||
.map(llmResult -> OutputFormat.getSchemaLinks(llmResult)).collect(Collectors.toList());
|
||||
List<String> candidateSortedList = OutputFormat.formatList(schemaLinkingResults);
|
||||
Pair<String, Map<String, Double>> linkingMap = OutputFormat.selfConsistencyVote(candidateSortedList);
|
||||
List<String> sqlList = llmResults.stream()
|
||||
.map(llmResult -> OutputFormat.getSql(llmResult)).collect(Collectors.toList());
|
||||
Pair<String, Map<String, Double>> sqlMap = OutputFormat.selfConsistencyVote(sqlList);
|
||||
log.info("linkingMap result:{},sqlMap:{}", linkingMap, sqlMap);
|
||||
return sqlMap.getRight();
|
||||
}
|
||||
|
||||
@Override
|
||||
public void afterPropertiesSet() {
|
||||
SqlGenerationFactory.addSqlGenerationForFactory(SqlGenerationMode.ONE_PASS_AUTO_COT_SELF_CONSISTENCY, this);
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,63 @@
|
||||
package com.tencent.supersonic.chat.parser.sql.llm;
|
||||
|
||||
|
||||
import com.tencent.supersonic.chat.config.OptimizationConfig;
|
||||
import com.tencent.supersonic.chat.query.llm.s2sql.LLMReq;
|
||||
import com.tencent.supersonic.chat.query.llm.s2sql.LLMReq.SqlGenerationMode;
|
||||
import com.tencent.supersonic.common.util.JsonUtil;
|
||||
import dev.langchain4j.data.message.AiMessage;
|
||||
import dev.langchain4j.model.chat.ChatLanguageModel;
|
||||
import dev.langchain4j.model.input.Prompt;
|
||||
import dev.langchain4j.model.input.PromptTemplate;
|
||||
import dev.langchain4j.model.output.Response;
|
||||
import java.util.HashMap;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.springframework.beans.factory.InitializingBean;
|
||||
import org.springframework.beans.factory.annotation.Autowired;
|
||||
import org.springframework.stereotype.Service;
|
||||
|
||||
@Service
|
||||
@Slf4j
|
||||
public class OnePassSqlGeneration implements SqlGeneration, InitializingBean {
|
||||
|
||||
@Autowired
|
||||
private ChatLanguageModel chatLanguageModel;
|
||||
|
||||
@Autowired
|
||||
private SqlExampleLoader sqlExampleLoader;
|
||||
|
||||
@Autowired
|
||||
private OptimizationConfig optimizationConfig;
|
||||
|
||||
@Autowired
|
||||
private SqlPromptGenerator sqlPromptGenerator;
|
||||
|
||||
@Override
|
||||
public Map<String, Double> generation(LLMReq llmReq, String modelClusterKey) {
|
||||
//1.retriever sqlExamples
|
||||
List<Map<String, String>> sqlExamples = sqlExampleLoader.retrieverSqlExamples(llmReq.getQueryText(),
|
||||
optimizationConfig.getText2sqlCollectionName(), optimizationConfig.getText2sqlExampleNum());
|
||||
|
||||
//2.generator linking and sql prompt by sqlExamples,and generate response.
|
||||
String promptStr = sqlPromptGenerator.generatorLinkingAndSqlPrompt(llmReq, sqlExamples);
|
||||
|
||||
Prompt prompt = PromptTemplate.from(JsonUtil.toString(promptStr)).apply(new HashMap<>());
|
||||
Response<AiMessage> response = chatLanguageModel.generate(prompt.toSystemMessage());
|
||||
|
||||
//3.format response.
|
||||
String llmResult = response.content().text();
|
||||
String schemaLinkStr = OutputFormat.getSchemaLinks(response.content().text());
|
||||
String sql = OutputFormat.getSql(response.content().text());
|
||||
Map<String, Double> sqlMap = new HashMap<>();
|
||||
sqlMap.put(sql, 1D);
|
||||
log.info("llmResult:{},schemaLinkStr:{},sql:{}", llmResult, schemaLinkStr, sql);
|
||||
return sqlMap;
|
||||
}
|
||||
|
||||
@Override
|
||||
public void afterPropertiesSet() {
|
||||
SqlGenerationFactory.addSqlGenerationForFactory(SqlGenerationMode.ONE_PASS_AUTO_COT, this);
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,129 @@
|
||||
package com.tencent.supersonic.chat.parser.sql.llm;
|
||||
|
||||
import com.fasterxml.jackson.databind.JsonNode;
|
||||
import com.fasterxml.jackson.databind.ObjectMapper;
|
||||
import com.tencent.supersonic.chat.parser.plugin.function.FunctionResp;
|
||||
import java.util.ArrayList;
|
||||
import java.util.Collections;
|
||||
import java.util.HashMap;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
import java.util.regex.Matcher;
|
||||
import java.util.regex.Pattern;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.apache.commons.lang3.tuple.Pair;
|
||||
|
||||
/***
|
||||
* output format
|
||||
*/
|
||||
@Slf4j
|
||||
public class OutputFormat {
|
||||
|
||||
public static final String PATTERN = "\\{[^{}]+\\}";
|
||||
|
||||
public static String getSchemaLink(String schemaLink) {
|
||||
String reult = "";
|
||||
try {
|
||||
reult = schemaLink.trim();
|
||||
String pattern = "Schema_links:(.*)";
|
||||
Pattern regexPattern = Pattern.compile(pattern, Pattern.DOTALL);
|
||||
Matcher matcher = regexPattern.matcher(reult);
|
||||
if (matcher.find()) {
|
||||
return matcher.group(1).trim();
|
||||
}
|
||||
} catch (Exception e) {
|
||||
log.error("", e);
|
||||
}
|
||||
return reult;
|
||||
}
|
||||
|
||||
public static String getSql(String sqlOutput) {
|
||||
String sql = "";
|
||||
try {
|
||||
sqlOutput = sqlOutput.trim();
|
||||
String pattern = "SQL:(.*)";
|
||||
Pattern regexPattern = Pattern.compile(pattern);
|
||||
Matcher matcher = regexPattern.matcher(sqlOutput);
|
||||
if (matcher.find()) {
|
||||
return matcher.group(1);
|
||||
}
|
||||
} catch (Exception e) {
|
||||
log.error("", e);
|
||||
}
|
||||
return sql;
|
||||
}
|
||||
|
||||
public static String getSchemaLinks(String text) {
|
||||
String schemaLinks = "";
|
||||
try {
|
||||
text = text.trim();
|
||||
String pattern = "Schema_links:(\\[.*?\\])|Schema_links: (\\[.*?\\])";
|
||||
Pattern regexPattern = Pattern.compile(pattern);
|
||||
Matcher matcher = regexPattern.matcher(text);
|
||||
|
||||
if (matcher.find()) {
|
||||
if (matcher.group(1) != null) {
|
||||
schemaLinks = matcher.group(1);
|
||||
} else if (matcher.group(2) != null) {
|
||||
schemaLinks = matcher.group(2);
|
||||
}
|
||||
}
|
||||
} catch (Exception e) {
|
||||
log.error("", e);
|
||||
}
|
||||
|
||||
return schemaLinks;
|
||||
}
|
||||
|
||||
public static Pair<String, Map<String, Double>> selfConsistencyVote(List<String> inputList) {
|
||||
Map<String, Integer> inputCounts = new HashMap<>();
|
||||
for (String input : inputList) {
|
||||
inputCounts.put(input, inputCounts.getOrDefault(input, 0) + 1);
|
||||
}
|
||||
|
||||
String inputMax = null;
|
||||
int maxCount = 0;
|
||||
int inputSize = inputList.size();
|
||||
Map<String, Double> votePercentage = new HashMap<>();
|
||||
for (Map.Entry<String, Integer> entry : inputCounts.entrySet()) {
|
||||
String input = entry.getKey();
|
||||
int count = entry.getValue();
|
||||
if (count > maxCount) {
|
||||
inputMax = input;
|
||||
maxCount = count;
|
||||
}
|
||||
double percentage = (double) count / inputSize;
|
||||
votePercentage.put(input, percentage);
|
||||
}
|
||||
return Pair.of(inputMax, votePercentage);
|
||||
}
|
||||
|
||||
public static List<String> formatList(List<String> toFormatList) {
|
||||
List<String> results = new ArrayList<>();
|
||||
for (String toFormat : toFormatList) {
|
||||
List<String> items = new ArrayList<>();
|
||||
String[] split = toFormat.replace("[", "").replace("]", "").split(",");
|
||||
for (String item : split) {
|
||||
items.add(item.trim());
|
||||
}
|
||||
Collections.sort(items);
|
||||
String result = "[" + String.join(",", items) + "]";
|
||||
results.add(result);
|
||||
}
|
||||
return results;
|
||||
}
|
||||
|
||||
public static FunctionResp functionCallParse(String llmOutput) {
|
||||
try {
|
||||
ObjectMapper objectMapper = new ObjectMapper();
|
||||
JsonNode jsonNode = objectMapper.readTree(llmOutput);
|
||||
String selectedTool = jsonNode.get("选择工具").asText();
|
||||
FunctionResp resp = new FunctionResp();
|
||||
resp.setToolSelection(selectedTool);
|
||||
return resp;
|
||||
} catch (Exception e) {
|
||||
log.error("", e);
|
||||
}
|
||||
return null;
|
||||
}
|
||||
}
|
||||
@@ -1,6 +1,6 @@
|
||||
package com.tencent.supersonic.chat.parser.llm.s2sql;
|
||||
package com.tencent.supersonic.chat.parser.sql.llm;
|
||||
|
||||
import com.tencent.supersonic.chat.agent.tool.CommonAgentTool;
|
||||
import com.tencent.supersonic.chat.agent.NL2SQLTool;
|
||||
import com.tencent.supersonic.chat.api.pojo.request.QueryReq;
|
||||
import com.tencent.supersonic.chat.query.llm.s2sql.LLMReq;
|
||||
import com.tencent.supersonic.chat.query.llm.s2sql.LLMReq.ElementValue;
|
||||
@@ -27,7 +27,7 @@ public class ParseResult {
|
||||
|
||||
private QueryReq request;
|
||||
|
||||
private CommonAgentTool commonAgentTool;
|
||||
private NL2SQLTool commonAgentTool;
|
||||
|
||||
private List<ElementValue> linkingValues;
|
||||
}
|
||||
@@ -1,4 +1,4 @@
|
||||
package com.tencent.supersonic.chat.parser.llm.s2sql;
|
||||
package com.tencent.supersonic.chat.parser.sql.llm;
|
||||
|
||||
import com.tencent.supersonic.chat.api.pojo.request.ChatConfigFilter;
|
||||
import com.tencent.supersonic.chat.api.pojo.request.ChatDefaultConfigReq;
|
||||
@@ -11,7 +11,7 @@ import java.util.List;
|
||||
import java.util.Objects;
|
||||
import org.apache.commons.collections.CollectionUtils;
|
||||
|
||||
public class S2SQLDateHelper {
|
||||
public class S2SqlDateHelper {
|
||||
|
||||
public static String getReferenceDate(Long modelId) {
|
||||
String defaultDate = DateUtils.getBeforeDate(0);
|
||||
@@ -0,0 +1,19 @@
|
||||
package com.tencent.supersonic.chat.parser.sql.llm;
|
||||
|
||||
import lombok.Data;
|
||||
|
||||
@Data
|
||||
public class SqlExample {
|
||||
|
||||
private String question;
|
||||
|
||||
private String questionAugmented;
|
||||
|
||||
private String dbSchema;
|
||||
|
||||
private String sql;
|
||||
|
||||
private String generatedSchemaLinkingCoT;
|
||||
|
||||
private String generatedSchemaLinkings;
|
||||
}
|
||||
@@ -0,0 +1,76 @@
|
||||
package com.tencent.supersonic.chat.parser.sql.llm;
|
||||
|
||||
|
||||
import com.fasterxml.jackson.core.type.TypeReference;
|
||||
import com.tencent.supersonic.common.util.ComponentFactory;
|
||||
import com.tencent.supersonic.common.util.JsonUtil;
|
||||
import com.tencent.supersonic.common.util.embedding.EmbeddingQuery;
|
||||
import com.tencent.supersonic.common.util.embedding.Retrieval;
|
||||
import com.tencent.supersonic.common.util.embedding.RetrieveQuery;
|
||||
import com.tencent.supersonic.common.util.embedding.RetrieveQueryResult;
|
||||
import com.tencent.supersonic.common.util.embedding.S2EmbeddingStore;
|
||||
import java.io.IOException;
|
||||
import java.io.InputStream;
|
||||
import java.util.ArrayList;
|
||||
import java.util.Collections;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
import java.util.Objects;
|
||||
import java.util.stream.Collectors;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.apache.commons.collections4.CollectionUtils;
|
||||
import org.springframework.core.io.ClassPathResource;
|
||||
import org.springframework.stereotype.Component;
|
||||
|
||||
@Slf4j
|
||||
@Component
|
||||
public class SqlExampleLoader {
|
||||
|
||||
private static final String EXAMPLE_JSON_FILE = "s2ql_examplar.json";
|
||||
|
||||
private S2EmbeddingStore s2EmbeddingStore = ComponentFactory.getS2EmbeddingStore();
|
||||
private TypeReference<List<SqlExample>> valueTypeRef = new TypeReference<List<SqlExample>>() {
|
||||
};
|
||||
|
||||
public List<SqlExample> getSqlExamples() throws IOException {
|
||||
ClassPathResource resource = new ClassPathResource(EXAMPLE_JSON_FILE);
|
||||
InputStream inputStream = resource.getInputStream();
|
||||
return JsonUtil.INSTANCE.getObjectMapper().readValue(inputStream, valueTypeRef);
|
||||
}
|
||||
|
||||
public void addEmbeddingStore(List<SqlExample> sqlExamples, String collectionName) {
|
||||
List<EmbeddingQuery> queries = new ArrayList<>();
|
||||
for (int i = 0; i < sqlExamples.size(); i++) {
|
||||
SqlExample sqlExample = sqlExamples.get(i);
|
||||
String question = sqlExample.getQuestion();
|
||||
Map<String, Object> metaDataMap = JsonUtil.toMap(JsonUtil.toString(sqlExample), String.class, Object.class);
|
||||
EmbeddingQuery embeddingQuery = new EmbeddingQuery();
|
||||
embeddingQuery.setQueryId(String.valueOf(i));
|
||||
embeddingQuery.setQuery(question);
|
||||
embeddingQuery.setMetadata(metaDataMap);
|
||||
queries.add(embeddingQuery);
|
||||
}
|
||||
s2EmbeddingStore.addQuery(collectionName, queries);
|
||||
}
|
||||
|
||||
public List<Map<String, String>> retrieverSqlExamples(String queryText, String collectionName, int maxResults) {
|
||||
|
||||
RetrieveQuery retrieveQuery = RetrieveQuery.builder().queryTextsList(Collections.singletonList(queryText))
|
||||
.queryEmbeddings(null).build();
|
||||
|
||||
List<RetrieveQueryResult> resultList = s2EmbeddingStore.retrieveQuery(collectionName, retrieveQuery,
|
||||
maxResults);
|
||||
List<Map<String, String>> result = new ArrayList<>();
|
||||
if (CollectionUtils.isEmpty(resultList)) {
|
||||
return result;
|
||||
}
|
||||
for (Retrieval retrieval : resultList.get(0).getRetrieval()) {
|
||||
if (Objects.nonNull(retrieval.getMetadata()) && !retrieval.getMetadata().isEmpty()) {
|
||||
Map<String, String> convertedMap = retrieval.getMetadata().entrySet().stream()
|
||||
.collect(Collectors.toMap(Map.Entry::getKey, entry -> String.valueOf(entry.getValue())));
|
||||
result.add(convertedMap);
|
||||
}
|
||||
}
|
||||
return result;
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,20 @@
|
||||
package com.tencent.supersonic.chat.parser.sql.llm;
|
||||
|
||||
|
||||
import com.tencent.supersonic.chat.query.llm.s2sql.LLMReq;
|
||||
import java.util.Map;
|
||||
|
||||
/**
|
||||
* Sql Generation interface, generating SQL using a large model.
|
||||
*/
|
||||
public interface SqlGeneration {
|
||||
|
||||
/***
|
||||
* generate SQL through LLMReq.
|
||||
* @param llmReq
|
||||
* @param modelClusterKey
|
||||
* @return
|
||||
*/
|
||||
Map<String, Double> generation(LLMReq llmReq, String modelClusterKey);
|
||||
|
||||
}
|
||||
@@ -0,0 +1,19 @@
|
||||
package com.tencent.supersonic.chat.parser.sql.llm;
|
||||
|
||||
|
||||
import com.tencent.supersonic.chat.query.llm.s2sql.LLMReq.SqlGenerationMode;
|
||||
import java.util.Map;
|
||||
import java.util.concurrent.ConcurrentHashMap;
|
||||
|
||||
public class SqlGenerationFactory {
|
||||
|
||||
private static Map<SqlGenerationMode, SqlGeneration> sqlGenerationMap = new ConcurrentHashMap<>();
|
||||
|
||||
public static SqlGeneration get(SqlGenerationMode strategyType) {
|
||||
return sqlGenerationMap.get(strategyType);
|
||||
}
|
||||
|
||||
public static void addSqlGenerationForFactory(SqlGenerationMode strategy, SqlGeneration sqlGeneration) {
|
||||
sqlGenerationMap.put(strategy, sqlGeneration);
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,131 @@
|
||||
package com.tencent.supersonic.chat.parser.sql.llm;
|
||||
|
||||
import com.tencent.supersonic.chat.query.llm.s2sql.LLMReq;
|
||||
import com.tencent.supersonic.chat.query.llm.s2sql.LLMReq.ElementValue;
|
||||
import java.util.ArrayList;
|
||||
import java.util.Arrays;
|
||||
import java.util.Collections;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.apache.commons.lang3.tuple.Pair;
|
||||
import org.springframework.stereotype.Component;
|
||||
|
||||
@Component
|
||||
@Slf4j
|
||||
public class SqlPromptGenerator {
|
||||
|
||||
public String generatorLinkingAndSqlPrompt(LLMReq llmReq, List<Map<String, String>> exampleList) {
|
||||
String instruction =
|
||||
"# Find the schema_links for generating SQL queries for each question based on the database schema "
|
||||
+ "and Foreign keys. Then use the the schema links to generate the "
|
||||
+ "SQL queries for each of the questions.";
|
||||
|
||||
List<String> exampleKeys = Arrays.asList("questionAugmented", "dbSchema", "generatedSchemaLinkingCoT", "sql");
|
||||
String exampleTemplate = "dbSchema\nQ: questionAugmented\nA: generatedSchemaLinkingCoT\nSQL: sql";
|
||||
|
||||
String exampleFormat = InputFormat.format(exampleTemplate, exampleKeys, exampleList);
|
||||
|
||||
Pair<String, String> questionPrompt = transformQuestionPrompt(llmReq);
|
||||
String dbSchema = questionPrompt.getLeft();
|
||||
String questionAugmented = questionPrompt.getRight();
|
||||
|
||||
String newCaseTemplate = "%s\nQ: %s\nA: Let’s think step by step. In the question \"%s\", we are asked:";
|
||||
String newCasePrompt = String.format(newCaseTemplate, dbSchema, questionAugmented, questionAugmented);
|
||||
|
||||
return instruction + InputFormat.SEPERATOR + exampleFormat + InputFormat.SEPERATOR + newCasePrompt;
|
||||
}
|
||||
|
||||
public String generateLinkingPrompt(LLMReq llmReq, List<Map<String, String>> exampleList) {
|
||||
String instruction = "# Find the schema_links for generating SQL queries for each question "
|
||||
+ "based on the database schema and Foreign keys.";
|
||||
|
||||
List<String> exampleKeys = Arrays.asList("questionAugmented", "dbSchema", "generatedSchemaLinkingCoT");
|
||||
String exampleTemplate = "dbSchema\nQ: questionAugmented\nA: generatedSchemaLinkingCoT";
|
||||
String exampleFormat = InputFormat.format(exampleTemplate, exampleKeys, exampleList);
|
||||
|
||||
Pair<String, String> questionPrompt = transformQuestionPrompt(llmReq);
|
||||
String dbSchema = questionPrompt.getLeft();
|
||||
String questionAugmented = questionPrompt.getRight();
|
||||
String newCaseTemplate = "%s\nQ: %s\nA: Let’s think step by step. In the question \"%s\", we are asked:";
|
||||
String newCasePrompt = String.format(newCaseTemplate, dbSchema, questionAugmented, questionAugmented);
|
||||
|
||||
return instruction + InputFormat.SEPERATOR + exampleFormat + InputFormat.SEPERATOR + newCasePrompt;
|
||||
}
|
||||
|
||||
public String generateSqlPrompt(LLMReq llmReq, String schemaLinkStr, List<Map<String, String>> fewshotExampleList) {
|
||||
String instruction = "# Use the the schema links to generate the SQL queries for each of the questions.";
|
||||
List<String> exampleKeys = Arrays.asList("questionAugmented", "dbSchema", "generatedSchemaLinkings", "sql");
|
||||
String exampleTemplate = "dbSchema\nQ: questionAugmented\n" + "Schema_links: generatedSchemaLinkings\n"
|
||||
+ "SQL: sql";
|
||||
|
||||
String schemaLinkingPrompt = InputFormat.format(exampleTemplate, exampleKeys, fewshotExampleList);
|
||||
Pair<String, String> questionPrompt = transformQuestionPrompt(llmReq);
|
||||
String dbSchema = questionPrompt.getLeft();
|
||||
String questionAugmented = questionPrompt.getRight();
|
||||
String newCaseTemplate = "%s\nQ: %s\nSchema_links: %s\nSQL: ";
|
||||
String newCasePrompt = String.format(newCaseTemplate, dbSchema, questionAugmented, schemaLinkStr);
|
||||
return instruction + InputFormat.SEPERATOR + schemaLinkingPrompt + InputFormat.SEPERATOR + newCasePrompt;
|
||||
}
|
||||
|
||||
public List<String> generatePromptPool(LLMReq llmReq, List<List<Map<String, String>>> exampleListPool,
|
||||
boolean isSqlPrompt) {
|
||||
List<String> promptPool = new ArrayList<>();
|
||||
for (List<Map<String, String>> exampleList : exampleListPool) {
|
||||
String prompt;
|
||||
if (isSqlPrompt) {
|
||||
prompt = generatorLinkingAndSqlPrompt(llmReq, exampleList);
|
||||
} else {
|
||||
prompt = generateLinkingPrompt(llmReq, exampleList);
|
||||
}
|
||||
promptPool.add(prompt);
|
||||
}
|
||||
return promptPool;
|
||||
}
|
||||
|
||||
public List<List<Map<String, String>>> getExampleCombos(List<Map<String, String>> exampleList, int numFewShots,
|
||||
int numSelfConsistency) {
|
||||
List<List<Map<String, String>>> results = new ArrayList<>();
|
||||
for (int i = 0; i < numSelfConsistency; i++) {
|
||||
List<Map<String, String>> shuffledList = new ArrayList<>(exampleList);
|
||||
Collections.shuffle(shuffledList);
|
||||
results.add(shuffledList.subList(0, numFewShots));
|
||||
}
|
||||
return results;
|
||||
}
|
||||
|
||||
public Pair<String, String> transformQuestionPrompt(LLMReq llmReq) {
|
||||
String modelName = llmReq.getSchema().getModelName();
|
||||
List<String> fieldNameList = llmReq.getSchema().getFieldNameList();
|
||||
List<ElementValue> linking = llmReq.getLinking();
|
||||
String currentDate = llmReq.getCurrentDate();
|
||||
String priorExts = llmReq.getPriorExts();
|
||||
|
||||
String dbSchema = "Table: " + modelName + ", Columns = " + fieldNameList + "\nForeign_keys: []";
|
||||
|
||||
List<String> priorLinkingList = new ArrayList<>();
|
||||
for (ElementValue priorLinking : linking) {
|
||||
String fieldName = priorLinking.getFieldName();
|
||||
String fieldValue = priorLinking.getFieldValue();
|
||||
priorLinkingList.add("'" + fieldValue + "'是一个'" + fieldName + "'");
|
||||
}
|
||||
String currentDataStr = "当前的日期是" + currentDate;
|
||||
String linkingListStr = String.join(",", priorLinkingList);
|
||||
String questionAugmented = String.format("%s (补充信息:%s 。 %s) (备注: %s)", llmReq.getQueryText(), linkingListStr,
|
||||
currentDataStr, priorExts);
|
||||
return Pair.of(dbSchema, questionAugmented);
|
||||
}
|
||||
|
||||
public List<String> generateSqlPromptPool(LLMReq llmReq, List<String> schemaLinkStrPool,
|
||||
List<List<Map<String, String>>> fewshotExampleListPool) {
|
||||
List<String> sqlPromptPool = new ArrayList<>();
|
||||
for (int i = 0; i < schemaLinkStrPool.size(); i++) {
|
||||
String schemaLinkStr = schemaLinkStrPool.get(i);
|
||||
List<Map<String, String>> fewshotExampleList = fewshotExampleListPool.get(i);
|
||||
String sqlPrompt = generateSqlPrompt(llmReq, schemaLinkStr, fewshotExampleList);
|
||||
sqlPromptPool.add(sqlPrompt);
|
||||
}
|
||||
return sqlPromptPool;
|
||||
}
|
||||
|
||||
}
|
||||
@@ -0,0 +1,80 @@
|
||||
package com.tencent.supersonic.chat.parser.sql.llm;
|
||||
|
||||
|
||||
import com.tencent.supersonic.chat.config.OptimizationConfig;
|
||||
import com.tencent.supersonic.chat.query.llm.s2sql.LLMReq;
|
||||
import com.tencent.supersonic.chat.query.llm.s2sql.LLMReq.SqlGenerationMode;
|
||||
import com.tencent.supersonic.common.util.JsonUtil;
|
||||
import dev.langchain4j.data.message.AiMessage;
|
||||
import dev.langchain4j.model.chat.ChatLanguageModel;
|
||||
import dev.langchain4j.model.input.Prompt;
|
||||
import dev.langchain4j.model.input.PromptTemplate;
|
||||
import dev.langchain4j.model.output.Response;
|
||||
import java.util.HashMap;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
import java.util.concurrent.CopyOnWriteArrayList;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.apache.commons.lang3.tuple.Pair;
|
||||
import org.springframework.beans.factory.InitializingBean;
|
||||
import org.springframework.beans.factory.annotation.Autowired;
|
||||
import org.springframework.stereotype.Service;
|
||||
|
||||
@Service
|
||||
@Slf4j
|
||||
public class TwoPassSCSqlGeneration implements SqlGeneration, InitializingBean {
|
||||
|
||||
@Autowired
|
||||
private ChatLanguageModel chatLanguageModel;
|
||||
|
||||
@Autowired
|
||||
private SqlExampleLoader sqlExampleLoader;
|
||||
|
||||
@Autowired
|
||||
private OptimizationConfig optimizationConfig;
|
||||
|
||||
@Autowired
|
||||
private SqlPromptGenerator sqlPromptGenerator;
|
||||
|
||||
@Override
|
||||
public Map<String, Double> generation(LLMReq llmReq, String modelClusterKey) {
|
||||
//1.retriever sqlExamples and generate exampleListPool
|
||||
List<Map<String, String>> sqlExamples = sqlExampleLoader.retrieverSqlExamples(llmReq.getQueryText(),
|
||||
optimizationConfig.getText2sqlCollectionName(), optimizationConfig.getText2sqlExampleNum());
|
||||
|
||||
List<List<Map<String, String>>> exampleListPool = sqlPromptGenerator.getExampleCombos(sqlExamples,
|
||||
optimizationConfig.getText2sqlFewShotsNum(), optimizationConfig.getText2sqlSelfConsistencyNum());
|
||||
|
||||
//2.generator linking prompt,and parallel generate response.
|
||||
List<String> linkingPromptPool = sqlPromptGenerator.generatePromptPool(llmReq, exampleListPool, false);
|
||||
List<String> linkingResults = new CopyOnWriteArrayList<>();
|
||||
linkingPromptPool.parallelStream().forEach(
|
||||
linkingPrompt -> {
|
||||
Prompt prompt = PromptTemplate.from(JsonUtil.toString(linkingPrompt)).apply(new HashMap<>());
|
||||
Response<AiMessage> linkingResult = chatLanguageModel.generate(prompt.toSystemMessage());
|
||||
String result = linkingResult.content().text();
|
||||
linkingResults.add(OutputFormat.getSchemaLink(result));
|
||||
}
|
||||
);
|
||||
List<String> sortedList = OutputFormat.formatList(linkingResults);
|
||||
Pair<String, Map<String, Double>> linkingMap = OutputFormat.selfConsistencyVote(sortedList);
|
||||
//3.generator sql prompt,and parallel generate response.
|
||||
List<String> sqlPromptPool = sqlPromptGenerator.generateSqlPromptPool(llmReq, sortedList, exampleListPool);
|
||||
List<String> sqlTaskPool = new CopyOnWriteArrayList<>();
|
||||
sqlPromptPool.parallelStream().forEach(sqlPrompt -> {
|
||||
Prompt linkingPrompt = PromptTemplate.from(JsonUtil.toString(sqlPrompt)).apply(new HashMap<>());
|
||||
Response<AiMessage> sqlResult = chatLanguageModel.generate(linkingPrompt.toSystemMessage());
|
||||
String result = sqlResult.content().text();
|
||||
sqlTaskPool.add(result);
|
||||
});
|
||||
//4.format response.
|
||||
Pair<String, Map<String, Double>> sqlMap = OutputFormat.selfConsistencyVote(sqlTaskPool);
|
||||
log.info("linkingMap result:{},sqlMap:{}", linkingMap, sqlMap);
|
||||
return sqlMap.getRight();
|
||||
}
|
||||
|
||||
@Override
|
||||
public void afterPropertiesSet() {
|
||||
SqlGenerationFactory.addSqlGenerationForFactory(SqlGenerationMode.TWO_PASS_AUTO_COT_SELF_CONSISTENCY, this);
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,63 @@
|
||||
package com.tencent.supersonic.chat.parser.sql.llm;
|
||||
|
||||
|
||||
import com.tencent.supersonic.chat.config.OptimizationConfig;
|
||||
import com.tencent.supersonic.chat.query.llm.s2sql.LLMReq;
|
||||
import com.tencent.supersonic.chat.query.llm.s2sql.LLMReq.SqlGenerationMode;
|
||||
import com.tencent.supersonic.common.util.JsonUtil;
|
||||
import dev.langchain4j.data.message.AiMessage;
|
||||
import dev.langchain4j.model.chat.ChatLanguageModel;
|
||||
import dev.langchain4j.model.input.Prompt;
|
||||
import dev.langchain4j.model.input.PromptTemplate;
|
||||
import dev.langchain4j.model.output.Response;
|
||||
import java.util.HashMap;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.springframework.beans.factory.InitializingBean;
|
||||
import org.springframework.beans.factory.annotation.Autowired;
|
||||
import org.springframework.stereotype.Service;
|
||||
|
||||
@Service
|
||||
@Slf4j
|
||||
public class TwoPassSqlGeneration implements SqlGeneration, InitializingBean {
|
||||
|
||||
@Autowired
|
||||
private ChatLanguageModel chatLanguageModel;
|
||||
|
||||
@Autowired
|
||||
private SqlExampleLoader sqlExampleLoader;
|
||||
|
||||
@Autowired
|
||||
private OptimizationConfig optimizationConfig;
|
||||
|
||||
@Autowired
|
||||
private SqlPromptGenerator sqlPromptGenerator;
|
||||
|
||||
@Override
|
||||
public Map<String, Double> generation(LLMReq llmReq, String modelClusterKey) {
|
||||
|
||||
List<Map<String, String>> sqlExamples = sqlExampleLoader.retrieverSqlExamples(llmReq.getQueryText(),
|
||||
optimizationConfig.getText2sqlCollectionName(), optimizationConfig.getText2sqlExampleNum());
|
||||
|
||||
String linkingPromptStr = sqlPromptGenerator.generateLinkingPrompt(llmReq, sqlExamples);
|
||||
|
||||
Prompt prompt = PromptTemplate.from(JsonUtil.toString(linkingPromptStr)).apply(new HashMap<>());
|
||||
Response<AiMessage> response = chatLanguageModel.generate(prompt.toSystemMessage());
|
||||
|
||||
String schemaLinkStr = OutputFormat.getSchemaLink(response.content().text());
|
||||
|
||||
String generateSqlPrompt = sqlPromptGenerator.generateSqlPrompt(llmReq, schemaLinkStr, sqlExamples);
|
||||
|
||||
Prompt sqlPrompt = PromptTemplate.from(JsonUtil.toString(generateSqlPrompt)).apply(new HashMap<>());
|
||||
Response<AiMessage> sqlResult = chatLanguageModel.generate(sqlPrompt.toSystemMessage());
|
||||
Map<String, Double> sqlMap = new HashMap<>();
|
||||
sqlMap.put(sqlResult.content().text(), 1D);
|
||||
return sqlMap;
|
||||
}
|
||||
|
||||
@Override
|
||||
public void afterPropertiesSet() {
|
||||
SqlGenerationFactory.addSqlGenerationForFactory(SqlGenerationMode.TWO_PASS_AUTO_COT, this);
|
||||
}
|
||||
}
|
||||
@@ -1,17 +1,17 @@
|
||||
package com.tencent.supersonic.chat.parser.rule;
|
||||
package com.tencent.supersonic.chat.parser.sql.rule;
|
||||
|
||||
import com.alibaba.fastjson.JSONObject;
|
||||
import com.google.common.collect.Lists;
|
||||
import com.tencent.supersonic.chat.agent.Agent;
|
||||
import com.tencent.supersonic.chat.agent.tool.AgentToolType;
|
||||
import com.tencent.supersonic.chat.agent.tool.RuleQueryTool;
|
||||
import com.tencent.supersonic.chat.agent.AgentToolType;
|
||||
import com.tencent.supersonic.chat.agent.RuleParserTool;
|
||||
import com.tencent.supersonic.chat.api.component.SemanticParser;
|
||||
import com.tencent.supersonic.chat.api.component.SemanticQuery;
|
||||
import com.tencent.supersonic.chat.api.pojo.ChatContext;
|
||||
import com.tencent.supersonic.chat.api.pojo.QueryContext;
|
||||
import com.tencent.supersonic.chat.query.QueryManager;
|
||||
import com.tencent.supersonic.chat.service.AgentService;
|
||||
import com.tencent.supersonic.common.pojo.QueryType;
|
||||
import com.tencent.supersonic.common.pojo.enums.QueryType;
|
||||
import com.tencent.supersonic.common.util.ContextUtils;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.apache.commons.collections.CollectionUtils;
|
||||
@@ -35,7 +35,7 @@ public class AgentCheckParser implements SemanticParser {
|
||||
if (agent == null) {
|
||||
return;
|
||||
}
|
||||
List<RuleQueryTool> queryTools = getRuleTools(agentId);
|
||||
List<RuleParserTool> queryTools = getRuleTools(agentId);
|
||||
if (CollectionUtils.isEmpty(queryTools)) {
|
||||
queries.clear();
|
||||
return;
|
||||
@@ -43,7 +43,7 @@ public class AgentCheckParser implements SemanticParser {
|
||||
log.info("queries resolved:{} {}", agent.getName(),
|
||||
queries.stream().map(SemanticQuery::getQueryMode).collect(Collectors.toList()));
|
||||
queries.removeIf(query -> {
|
||||
for (RuleQueryTool tool : queryTools) {
|
||||
for (RuleParserTool tool : queryTools) {
|
||||
if (CollectionUtils.isNotEmpty(tool.getQueryModes())
|
||||
&& !tool.getQueryModes().contains(query.getQueryMode())) {
|
||||
return true;
|
||||
@@ -73,17 +73,17 @@ public class AgentCheckParser implements SemanticParser {
|
||||
queries.stream().map(SemanticQuery::getQueryMode).collect(Collectors.toList()));
|
||||
}
|
||||
|
||||
private static List<RuleQueryTool> getRuleTools(Integer agentId) {
|
||||
private static List<RuleParserTool> getRuleTools(Integer agentId) {
|
||||
AgentService agentService = ContextUtils.getBean(AgentService.class);
|
||||
Agent agent = agentService.getAgent(agentId);
|
||||
if (agent == null) {
|
||||
return Lists.newArrayList();
|
||||
}
|
||||
List<String> tools = agent.getTools(AgentToolType.RULE);
|
||||
List<String> tools = agent.getTools(AgentToolType.NL2SQL_RULE);
|
||||
if (CollectionUtils.isEmpty(tools)) {
|
||||
return Lists.newArrayList();
|
||||
}
|
||||
return tools.stream().map(tool -> JSONObject.parseObject(tool, RuleQueryTool.class))
|
||||
return tools.stream().map(tool -> JSONObject.parseObject(tool, RuleParserTool.class))
|
||||
.collect(Collectors.toList());
|
||||
}
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
package com.tencent.supersonic.chat.parser.rule;
|
||||
package com.tencent.supersonic.chat.parser.sql.rule;
|
||||
|
||||
import static com.tencent.supersonic.common.pojo.enums.AggregateTypeEnum.AVG;
|
||||
import static com.tencent.supersonic.common.pojo.enums.AggregateTypeEnum.COUNT;
|
||||
@@ -1,4 +1,4 @@
|
||||
package com.tencent.supersonic.chat.parser.rule;
|
||||
package com.tencent.supersonic.chat.parser.sql.rule;
|
||||
|
||||
import com.tencent.supersonic.chat.api.component.SemanticParser;
|
||||
import com.tencent.supersonic.chat.api.component.SemanticQuery;
|
||||
@@ -1,4 +1,4 @@
|
||||
package com.tencent.supersonic.chat.parser.rule;
|
||||
package com.tencent.supersonic.chat.parser.sql.rule;
|
||||
|
||||
import com.tencent.supersonic.chat.api.component.SemanticParser;
|
||||
import com.tencent.supersonic.chat.api.pojo.ChatContext;
|
||||
@@ -8,14 +8,22 @@ import com.tencent.supersonic.chat.api.pojo.SchemaModelClusterMapInfo;
|
||||
import com.tencent.supersonic.chat.query.rule.RuleSemanticQuery;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
|
||||
import java.util.Arrays;
|
||||
import java.util.List;
|
||||
|
||||
/**
|
||||
* QueryModeParser resolves a specific query mode according to co-appearance
|
||||
* RuleSqlParser resolves a specific SemanticQuery according to co-appearance
|
||||
* of certain schema element types.
|
||||
*/
|
||||
@Slf4j
|
||||
public class QueryModeParser implements SemanticParser {
|
||||
public class RuleSqlParser implements SemanticParser {
|
||||
|
||||
private static List<SemanticParser> auxiliaryParsers = Arrays.asList(
|
||||
new ContextInheritParser(),
|
||||
new AgentCheckParser(),
|
||||
new TimeRangeParser(),
|
||||
new AggregateTypeParser()
|
||||
);
|
||||
|
||||
@Override
|
||||
public void parse(QueryContext queryContext, ChatContext chatContext) {
|
||||
@@ -29,6 +37,7 @@ public class QueryModeParser implements SemanticParser {
|
||||
queryContext.getCandidateQueries().add(query);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
auxiliaryParsers.stream().forEach(p -> p.parse(queryContext, chatContext));
|
||||
}
|
||||
}
|
||||
@@ -1,4 +1,4 @@
|
||||
package com.tencent.supersonic.chat.parser.rule;
|
||||
package com.tencent.supersonic.chat.parser.sql.rule;
|
||||
|
||||
import com.tencent.supersonic.chat.api.component.SemanticParser;
|
||||
import com.tencent.supersonic.chat.api.component.SemanticQuery;
|
||||
@@ -1,7 +1,9 @@
|
||||
package com.tencent.supersonic.chat.persistence.dataobject;
|
||||
|
||||
import lombok.Data;
|
||||
import java.util.Date;
|
||||
|
||||
@Data
|
||||
public class ChatQueryDO {
|
||||
/**
|
||||
*/
|
||||
@@ -43,155 +45,6 @@ public class ChatQueryDO {
|
||||
*/
|
||||
private String queryResult;
|
||||
|
||||
/**
|
||||
* @return question_id
|
||||
*/
|
||||
public Long getQuestionId() {
|
||||
return questionId;
|
||||
}
|
||||
private String similarQueries;
|
||||
|
||||
/**
|
||||
* @param questionId
|
||||
*/
|
||||
public void setQuestionId(Long questionId) {
|
||||
this.questionId = questionId;
|
||||
}
|
||||
|
||||
/**
|
||||
* @return agent_id
|
||||
*/
|
||||
public Integer getAgentId() {
|
||||
return agentId;
|
||||
}
|
||||
|
||||
/**
|
||||
* @param agentId
|
||||
*/
|
||||
public void setAgentId(Integer agentId) {
|
||||
this.agentId = agentId;
|
||||
}
|
||||
|
||||
/**
|
||||
* @return create_time
|
||||
*/
|
||||
public Date getCreateTime() {
|
||||
return createTime;
|
||||
}
|
||||
|
||||
/**
|
||||
* @param createTime
|
||||
*/
|
||||
public void setCreateTime(Date createTime) {
|
||||
this.createTime = createTime;
|
||||
}
|
||||
|
||||
/**
|
||||
* @return user_name
|
||||
*/
|
||||
public String getUserName() {
|
||||
return userName;
|
||||
}
|
||||
|
||||
/**
|
||||
* @param userName
|
||||
*/
|
||||
public void setUserName(String userName) {
|
||||
this.userName = userName == null ? null : userName.trim();
|
||||
}
|
||||
|
||||
/**
|
||||
*
|
||||
* @return query_state
|
||||
*/
|
||||
public Integer getQueryState() {
|
||||
return queryState;
|
||||
}
|
||||
|
||||
/**
|
||||
*
|
||||
* @param queryState
|
||||
*/
|
||||
public void setQueryState(Integer queryState) {
|
||||
this.queryState = queryState;
|
||||
}
|
||||
|
||||
/**
|
||||
*
|
||||
* @return chat_id
|
||||
*/
|
||||
public Long getChatId() {
|
||||
return chatId;
|
||||
}
|
||||
|
||||
/**
|
||||
*
|
||||
* @param chatId
|
||||
*/
|
||||
public void setChatId(Long chatId) {
|
||||
this.chatId = chatId;
|
||||
}
|
||||
|
||||
/**
|
||||
*
|
||||
* @return score
|
||||
*/
|
||||
public Integer getScore() {
|
||||
return score;
|
||||
}
|
||||
|
||||
/**
|
||||
*
|
||||
* @param score
|
||||
*/
|
||||
public void setScore(Integer score) {
|
||||
this.score = score;
|
||||
}
|
||||
|
||||
/**
|
||||
*
|
||||
* @return feedback
|
||||
*/
|
||||
public String getFeedback() {
|
||||
return feedback;
|
||||
}
|
||||
|
||||
/**
|
||||
*
|
||||
* @param feedback
|
||||
*/
|
||||
public void setFeedback(String feedback) {
|
||||
this.feedback = feedback == null ? null : feedback.trim();
|
||||
}
|
||||
|
||||
/**
|
||||
*
|
||||
* @return query_text
|
||||
*/
|
||||
public String getQueryText() {
|
||||
return queryText;
|
||||
}
|
||||
|
||||
/**
|
||||
*
|
||||
* @param queryText
|
||||
*/
|
||||
public void setQueryText(String queryText) {
|
||||
this.queryText = queryText == null ? null : queryText.trim();
|
||||
}
|
||||
|
||||
/**
|
||||
*
|
||||
* @return query_result
|
||||
*/
|
||||
public String getQueryResult() {
|
||||
return queryResult;
|
||||
}
|
||||
|
||||
/**
|
||||
*
|
||||
* @param queryResult
|
||||
*/
|
||||
public void setQueryResult(String queryResult) {
|
||||
this.queryResult = queryResult == null ? null : queryResult.trim();
|
||||
}
|
||||
}
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user