4 Commits

Author SHA1 Message Date
jipeli
40ea6a9396 (feature)(headless) Add tag query api (#790) 2024-03-06 11:41:47 +08:00
daikon
78d724ea83 [improvement](headless) add queryTag for tagMarket (#772) 2024-02-27 20:34:00 +08:00
lexluo09
eadbdc4e30 Merge pull request #759 from lexluo09/master
(improvement)(project) merge master to dev-0.9
2024-02-26 14:41:23 +08:00
jipeli
b8831317e9 (feature)(headless) Add tag rest api (#733) 2024-02-21 17:45:28 +08:00
1037 changed files with 35677 additions and 43073 deletions

View File

@@ -1,4 +1,4 @@
name: supersonic ubuntu CI
name: supersonic CI
on:
push:

View File

@@ -1,35 +0,0 @@
name: supersonic mac CI
on:
push:
branches:
- master
pull_request:
branches:
- master
jobs:
build:
runs-on: macos-latest # Specify a macOS runner
steps:
- uses: actions/checkout@v2
- name: Set up JDK 8
uses: actions/setup-java@v2
with:
java-version: '8'
distribution: 'adopt'
- name: Cache Maven packages
uses: actions/cache@v2
with:
path: ~/Library/Caches/Maven # macOS Maven cache path
key: ${{ runner.os }}-m2-${{ hashFiles('**/pom.xml') }}
restore-keys: ${{ runner.os }}-m2
- name: Build with Maven
run: mvn -B package --file pom.xml
- name: Test with Maven
run: mvn test

View File

@@ -1,35 +0,0 @@
name: supersonic windows CI
on:
push:
branches:
- master
pull_request:
branches:
- master
jobs:
build:
runs-on: windows-latest # Specify a Windows runner
steps:
- uses: actions/checkout@v2
- name: Set up JDK 8
uses: actions/setup-java@v2
with:
java-version: '8'
distribution: 'adopt'
- name: Cache Maven packages
uses: actions/cache@v2
with:
path: ~\.m2 # Windows uses a backslash for paths
key: ${{ runner.os }}-m2-${{ hashFiles('**/pom.xml') }}
restore-keys: ${{ runner.os }}-m2
- name: Build with Maven
run: mvn -B package --file pom.xml
- name: Test with Maven
run: mvn test

2
.gitignore vendored
View File

@@ -8,7 +8,6 @@ log/
*.bin
*.log
*.tar.gz
*.zip
*.lib
assembly/runtime/*
**/dist/
@@ -19,4 +18,3 @@ assembly/runtime/*
chm_db/
__pycache__/
/dict
assembly/build/*-SNAPSHOT

View File

@@ -4,29 +4,6 @@
- "Breaking Changes" describes any changes that may break existing functionality or cause
compatibility issues with previous versions.
## SuperSonic [0.9.2] - 2024-06-01
### Added
- support multiple rounds of dialogue
- add term configuration and identification to help LLM learn private domain knowledge
- support configuring LLM parameters in the agent
- metric market supports searching in natural language
### Updated
- introducing WorkFlow, Mapper, Parser, and Corrector support jump execution
- Introducing the concept of Model-Set to simplify Domain management
- overall optimization and upgrade of system pages
- optimize startup script
## SuperSonic [0.9.0] - 2024-04-03
### Added
- add tag abstraction and enhance tag marketplace management.
- headless-server provides Chat API interface.
### Updated
- migrate chat-core core component to headless-core.
## SuperSonic [0.8.6] - 2024-02-23
### Added

View File

@@ -2,39 +2,31 @@
![Java CI](https://github.com/tencentmusic/supersonic/workflows/supersonic%20CI/badge.svg)
# SuperSonic
# SuperSonic (超音数)
SuperSonic is the next-generation BI platform that integrates **Chat BI** (powered by LLM) and **Headless BI** (powered by semantic layer) paradigms. This integration ensures that Chat BI has access to the same curated and governed semantic data models as traditional BI. Furthermore, the implementation of both paradigms benefits from the integration:
- Chat BI's Text2SQL gets augmented with context-retrieval from semantic models.
- Headless BI's query interface gets extended with natural language API.
<img src="./docs/images/supersonic_ideas.png" height="75%" width="75%" align="center"/>
SuperSonic provides a **Chat BI interface** that empowers users to query data using natural language and visualize the results with suitable charts. To enable such experience, the only thing necessary is to build logical semantic models (definition of metric/dimension/tag, along with their meaning and relationships) through a **Headless BI interface**. Meanwhile, SuperSonic is designed to be extensible and composable, allowing custom implementations to be added and configured with Java SPI.
**SuperSonic is the next-generation LLM-powered data analytics platform that integrates ChatBI and HeadlessBI**. SuperSonic provides a chat interface that empowers users to query data using natural language and visualize the results with suitable charts. To enable such experience, the only thing necessary is to build logical semantic models (definition of entities/metrics/dimensions/tags, along with their meaning, context and relationships) on top of physical data models, and **no data modification or copying** is required. Meanwhile, SuperSonic is designed to be **highly extensible**, allowing custom functionalities to be added and configured with Java SPI.
<img src="./docs/images/supersonic_demo.gif" height="100%" width="100%" align="center"/>
## Motivation
The emergence of Large Language Model (LLM) like ChatGPT is reshaping the way information is retrieved, leading to a new paradigm in the field of data analytics known as Chat BI. To implement Chat BI, both academia and industry are primarily focused on harnessing the power of LLMs to convert natural language into SQL, commonly referred to as Text2SQL or NL2SQL. While some approaches show promising results, their **reliability** falls short for large-scale real-world applications.
The emergence of Large Language Model (LLM) like ChatGPT is reshaping the way information is retrieved. In the field of data analytics, both academia and industry are primarily focused on leveraging LLM to convert natural language into SQL (so called Text2SQL or NL2SQL). While some approaches exhibit promising results, their **reliability** and **efficiency** are insufficient for real-world applications.
Meanwhile, another emerging paradigm called Headless BI, which focuses on constructing unified semantic data models, has garnered significant attention. Headless BI is implemented through a universal semantic layer that exposes consistent data semantics via an open API.
From our perspective, the key to filling the real-world gap lies in three aspects:
1. Integrate ChatBI with HeadlessBI encapsulating underlying data context (joins, keys, formulas, etc) to **reduce complexity**.
2. Augment the LLM with schema mappers(as a kind of preprocessor) and semantic correctors(as a kind of postprocessor) to **mitigate hallucination**.
3. Utilize rule-based schema parsers when necessary to **improve efficiency**(in terms of latency and cost).
From our perspective, the integration of Chat BI and Headless BI has the potential to enhance the Text2SQL generation in two dimensions:
1. Incorporate data semantics (such as business terms, column values, etc.) into the prompt, enabling LLM to better understand the semantics and **reduce hallucination**.
2. Offload the generation of advanced SQL syntax (such as join, formula, etc.) from LLM to the semantic layer to **reduce complexity**.
With these ideas in mind, we develop SuperSonic as a practical reference implementation and use it to power our real-world products. Additionally, to facilitate further development we decide to open source SuperSonic as an extensible framework.
With these ideas in mind, we develop SuperSonic as a practical reference implementation and use it to power our real-world products. Additionally, to facilitate further development of ChatBI, we decide to open source SuperSonic as an extensible framework.
## Out-of-the-box Features
- Built-in ChatBI interface for *business users* to enter natural language queries
- Built-in Headless BI interface for *analytics engineers* to build semantic data models
- Built-in rule-based semantic parser to improve efficiency in certain scenarios (e.g. demonstration, integration testing)
- Built-in support for input auto-completion, multi-turn conversation as well as post-query recommendation
- Built-in support for three-level data access control: dataset-level, column-level and row-level
- Built-in HeadlessBI interface for *analytics engineers* to build semantic models
- Built-in GUI for *system administrators* to manage chat agents and third-party plugins
- Support input auto-completion as well as query recommendation
- Support multi-turn conversation and history context management
- Support four-level permission control: domain-level, model-level, column-level and row-level
## Extensible Components
@@ -46,11 +38,11 @@ The high-level architecture and main process flow is as follows:
- **Schema Mapper:** identifies references to schema elements(metrics/dimensions/entities/values) in user queries. It matches the query text against the knowledge base.
- **Semantic Parser:** understands user queries and generates semantic query statement. It consists of a combination of rule-based and model-based parsers, each of which deals with specific scenarios.
- **Semantic Parser:** understands user queries and extracts semantic information. It consists of a combination of rule-based and model-based parsers, each of which deals with specific scenarios.
- **Semantic Corrector:** checks validity of semantic query statement and performs correction and optimization if needed.
- **Semantic Corrector:** checks validity of extracted semantic information and performs correction and optimization if needed.
- **Semantic Translator:** converts semantic query statement into SQL statement that can be executed against physical data models.
- **Semantic Interpreter:** performs execution according to extracted semantic information. It generates SQL statements and executes them against physical data models.
- **Chat Plugin:** extends functionality with third-party tools. The LLM is going to select the most suitable one, given all configured plugins with function description and sample questions.

View File

@@ -1,26 +1,17 @@
# SuperSonic
# SuperSonic (超音数)
**SuperSonic融合Chat BIpowered by LLM和Headless BIpowered by 语义层)打造新一代的BI平台**这种融合确保了Chat BI能够与传统BI一样访问统一化治理的语义数据模型。此外两种BI新范式都从中获得收益
- Chat BI的Text2SQL生成通过检索语义数据模型得到增强。
- Headless BI的查询接口通过支持自然语言API得到拓展。
<img src="./docs/images/supersonic_ideas.png" height="75%" width="75%" align="center"/>
通过SuperSonic的问答对话界面用户能够使用自然语言查询数据系统会选择合适的可视化图表呈现结果。SuperSonic不需要修改或复制数据只需要在物理数据模型之上构建逻辑语义模型定义指标/维度/实体/标签以及它们的业务含义、相互关系等即可开启数据问答体验。与此同时SuperSonic被设计为可插拔的框架采用Java SPI机制来扩展定制功能。
**SuperSonic融合ChatBI和HeadlessBI打造新一代的数据分析平台**通过SuperSonic的问答对话界面用户能够使用自然语言查询数据系统会选择合适的可视化图表呈现结果。SuperSonic不需要修改或复制数据只需要在物理数据模型之上构建逻辑语义模型指标/维度/实体的定义以及他们的业务含义、相互间关系等即可开启数据问答体验。与此同时SuperSonic被设计为可插拔的框架采用Java SPI机制来扩展定制功能。
<img src="./docs/images/supersonic_demo.gif" height="100%" width="100%" align="center"/>
## 项目动机
大型语言模型LLM如ChatGPT的出现正在重塑信息检索的方式,引领数据分析领域的一种新范式被称为Chat BI。为了实现Chat BI,学术界和工业界主要关注利用LLM的能力将自然语言转换为SQL通常称为Text2SQL或NL2SQL。尽管一些方法显示出有希望的结果,但它们在大规模实际应用中的可靠性还不足
大型语言模型LLMs如ChatGPT的出现正在重塑信息检索的方式。在数据分析领域,学术界和工业界主要关注利用深度学习模型将自然语言查询转换为SQL查询。虽然一些工作显示出有前景的结果,但它们的可靠性还达不到生产可用的要求
与此同时另一种新兴范式被称为Headless BI它专注于构建统一的语义数据模型并引起了广泛的关注。Headless BI通过一个通用的语义层来实现通过开放的API公开一致的数据语义。
从我们的角度来看Chat BI和Headless BI的融合有潜力在两个方面增强Text2SQL的能力
1. 将数据语义如业务术语、列值等纳入提示词中使LLM能够更好地理解语义以**减少幻觉**。
2. 将高级SQL语法如连接、公式等的生成从LLM卸载到语义层以**减少复杂度**。
在我们看来,为了在实际场景发挥价值,有三个关键点:
1. 融合HeadlessBI通过统一语义层封装底层数据细节关联、键值、公式等降低SQL生成的**复杂度**。
2. 通过一前一后的模式映射器和语义修正器来缓解LLM常见的**幻觉**现象。
3. 设计启发式的规则,在一些特定场景提升语义解析的**效率**。
为了验证上述想法我们开发了SuperSonic项目并将其应用在实际的内部产品中。与此同时我们将SuperSonic作为一个可扩展的框架开源希望能够促进数据问答对话领域的进一步发展。
@@ -28,9 +19,10 @@
- 内置ChatBI界面以便*业务用户*输入数据查询。
- 内置HeadlessBI界面以便*分析工程师*构建语义模型。
- 内置基于规则的语义解析器在特定场景比如DEMO演示、集成测试可以提升推理效率
- 支持文本输入联想、多轮对话、查询问题推荐等高级特征
- 支持三级权限控制:数据集级、列级、行级
- 内置图形用户界面以便*系统管理员*管理第三方插件和对话助理
- 支持文本输入联想查询问题推荐。
- 支持多轮对话,根据语境自动切换上下文
- 支持四级权限控制:主题域级、模型级、列级、行级。
## 易于扩展的组件
@@ -42,11 +34,11 @@ SuperSonic的整体架构和主流程如下图所示
- **模式映射器(Schema Mapper)** 将自然语言文本在知识库中进行匹配,为后续的语义解析提供相关信息。
- **语义解析器(Semantic Parser)** 理解用户查询并抽取语义信息,生成语义查询语句S2SQL
- **语义解析器(Semantic Parser)** 理解用户查询并抽取语义信息,其由一组基于规则和基于模型的解析器组成,每个解析器可应对不同的特定场景
- **语义修正器(Semantic Corrector)** 检查语义查询语句的合法性,对不合法的信息做修正和优化处理。
- **语义修正器(Semantic Corrector)** 检查语义信息的合法性,对不合法的信息做修正和优化处理。
- **语义翻译器(Semantic Translator)** 将语义查询语句翻译成可在物理数据模型上执行的SQL语句
- **语义解释器(Semantic Interpreter)** 根据语义信息生成物理SQL执行查询
- **问答插件(Chat Plugin)** 通过第三方工具扩展功能。给定所有配置的插件及其功能描述和示例问题,大语言模型将选择最合适的插件。

View File

@@ -1,98 +1,72 @@
@echo off
setlocal enabledelayedexpansion
setlocal
chcp 65001
call supersonic-common.bat %*
set "sbinDir=%~dp0"
set "baseDir=%~dp0.."
set "buildDir=%baseDir%\build"
set "runtimeDir=%baseDir%\..\runtime"
set "pip_path=pip3"
set "service=%~1"
cd %projectDir%
if "%service%"=="" (
set service=%standalone_service%
)
call mvn help:evaluate -Dexpression=project.version > temp.txt
for /f "delims=" %%i in (temp.txt) do (
set line=%%i
if not "!line:~0,1!"=="[" (
set MVN_VERSION=!line!
)
)
del temp.txt
cd %baseDir%
rem 1. build backend java modules
del /q "%buildDir%\*.tar.gz" 2>NUL
call mvn -f "%baseDir%\..\pom.xml" clean package -DskipTests
if "%service%"=="%pyllm_service%" (
echo start installing python modules required by supersonic-pyllm: %pip_path%
%pip_path% install -r %projectDir%\headless\python\requirements.txt"
echo install python modules success
goto :EOF
) else if "%service%"=="webapp" (
call :buildWebapp
tar xvf supersonic-webapp.tar.gz
move /y supersonic-webapp webapp
move /y webapp %projectDir%\launchers\%STANDALONE_SERVICE%\target\classes
goto :EOF
) else (
call :buildJavaService
call :buildWebapp
call :packageRelease
goto :EOF
)
:buildJavaService
set "model_name=%service%"
echo "starting building supersonic-%model_name% service"
call mvn -f %projectDir%\launchers\%model_name% clean package -DskipTests
IF ERRORLEVEL 1 (
ECHO Failed to build backend Java modules.
EXIT /B 1
)
copy /y %projectDir%\launchers\%model_name%\target\*.tar.gz %buildDir%\
echo "finished building supersonic-%model_name% service"
goto :EOF
rem 2. move package to build
echo f|xcopy "%baseDir%\..\launchers\standalone\target\*.tar.gz" "%buildDir%\supersonic-standalone.tar.gz"
:buildWebapp
echo "starting building supersonic webapp"
cd %projectDir%\webapp
rem 3. build frontend webapp
cd "%baseDir%\..\webapp"
call start-fe-prod.bat
copy /y supersonic-webapp.tar.gz %buildDir%\
rem check build result
copy /y "%baseDir%\..\webapp\supersonic-webapp.tar.gz" "%buildDir%\"
IF ERRORLEVEL 1 (
ECHO Failed to build frontend webapp.
EXIT /B 1
)
echo "finished building supersonic webapp"
goto :EOF
rem 4. copy webapp to java classpath
cd "%buildDir%"
tar -zxvf supersonic-webapp.tar.gz
move supersonic-webapp webapp
move webapp ..\..\launchers\standalone\target\classes
:packageRelease
set "model_name=%service%"
set "release_dir=supersonic-%model_name%-%MVN_VERSION%"
set "service_name=launchers-%model_name%-%MVN_VERSION%"
echo "starting packaging supersonic release"
cd %buildDir%
if exist %release_dir% rmdir /s /q %release_dir%
if exist %release_dir%.zip del %release_dir%.zip
mkdir %release_dir%
rem package webapp
tar xvf supersonic-webapp.tar.gz
move /y supersonic-webapp webapp
echo {"env": ""} > webapp\supersonic.config.json
move /y webapp %release_dir%
rem package java service
tar xvf %service_name%-bin.tar.gz
for /d %%D in ("%service_name%\*") do (
move "%%D" "%release_dir%"
rem 5. build backend python modules
if "%service%"=="pyllm" (
echo "start installing python modules with pip: ${pip_path}"
set requirementPath="%baseDir%/../chat/python/requirements.txt"
%pip_path% install -r %requirementPath%
echo "install python modules success"
)
rem generate zip file
powershell Compress-Archive -Path %release_dir% -DestinationPath %release_dir%.zip
del %service_name%-bin.tar.gz
del supersonic-webapp.tar.gz
rmdir /s /q %service_name%
echo "finished packaging supersonic release"
goto :EOF
call :BUILD_RUNTIME
:BUILD_RUNTIME
rem 6. reset runtime
IF EXIST "%runtimeDir%" (
echo begin to delete dir : %runtimeDir%
rd /s /q "%runtimeDir%"
) ELSE (
echo %runtimeDir% does not exist, create directly
)
mkdir "%runtimeDir%"
tar -zxvf "%buildDir%\supersonic-standalone.tar.gz" -C "%runtimeDir%"
for /d %%f in ("%runtimeDir%\launchers-standalone-*") do (
move "%%f" "%runtimeDir%\supersonic-standalone"
)
rem 7. copy webapp to runtime
tar -zxvf "%buildDir%\supersonic-webapp.tar.gz" -C "%buildDir%"
if not exist "%runtimeDir%\supersonic-standalone\webapp" mkdir "%runtimeDir%\supersonic-standalone\webapp"
xcopy /s /e /h /y "%buildDir%\supersonic-webapp\*" "%runtimeDir%\supersonic-standalone\webapp"
if not exist "%runtimeDir%\supersonic-standalone\conf\webapp" mkdir "%runtimeDir%\supersonic-standalone\conf\webapp"
xcopy /s /e /h /y "%runtimeDir%\supersonic-standalone\webapp\*" "%runtimeDir%\supersonic-standalone\conf\webapp"
rd /s /q "%buildDir%\supersonic-webapp"
endlocal

View File

@@ -1,80 +1,58 @@
#!/usr/bin/env bash
set -x
sbinDir=$(cd "$(dirname "$0")"; pwd)
chmod +x $sbinDir/supersonic-common.sh
source $sbinDir/supersonic-common.sh
cd $projectDir
MVN_VERSION=$(mvn help:evaluate -Dexpression=project.version | grep -e '^[^\[]')
cd $baseDir
service=$1
if [ -z "$service" ]; then
service=${STANDALONE_SERVICE}
fi
function buildJavaService {
model_name=$1
echo "starting building supersonic-${model_name} service"
mvn -f $projectDir clean package -DskipTests
service=$1
#1. build backend java modules
rm -fr ${buildDir}/*.tar.gz
rm -fr dist
set +x
mvn -f $baseDir/../ clean package -DskipTests
# check build result
if [ $? -ne 0 ]; then
echo "Failed to build backend Java modules."
exit 1
fi
cp $projectDir/launchers/${model_name}/target/*.tar.gz ${buildDir}/
echo "finished building supersonic-${model_name} service"
}
function buildWebapp {
echo "starting building supersonic webapp"
chmod +x $projectDir/webapp/start-fe-prod.sh
cd $projectDir/webapp
#2. move package to build
cp $baseDir/../launchers/headless/target/*.tar.gz ${buildDir}/supersonic-headless.tar.gz
cp $baseDir/../launchers/chat/target/*.tar.gz ${buildDir}/supersonic-chat.tar.gz
cp $baseDir/../launchers/standalone/target/*.tar.gz ${buildDir}/supersonic-standalone.tar.gz
#3. build frontend webapp
chmod +x $baseDir/../webapp/start-fe-prod.sh
cd ../webapp
sh ./start-fe-prod.sh
cp -fr ./supersonic-webapp.tar.gz ${buildDir}/
# check build result
if [ $? -ne 0 ]; then
echo "Failed to build frontend webapp."
exit 1
fi
echo "finished building supersonic webapp"
}
function packageRelease {
model_name=$1
release_dir=supersonic-${model_name}-${MVN_VERSION}
service_name=launchers-${model_name}-${MVN_VERSION}
echo "starting packaging supersonic release"
#4. copy webapp to java classpath
cd $buildDir
mkdir $release_dir
# package webapp
tar xvf supersonic-webapp.tar.gz
mv supersonic-webapp webapp
json='{"env": "''"}'
echo $json > webapp/supersonic.config.json
mv webapp $release_dir/
# package java service
tar xvf $service_name-bin.tar.gz
mv $service_name/* $release_dir/
# generate zip file
zip -r $release_dir.zip $release_dir
# delete intermediate files
rm supersonic-webapp.tar.gz $service_name-bin.tar.gz
rm -rf webapp $service_name $release_dir
echo "finished packaging supersonic release"
}
cp -fr webapp ../../launchers/headless/target/classes
cp -fr webapp ../../launchers/chat/target/classes
cp -fr webapp ../../launchers/standalone/target/classes
rm -fr ${buildDir}/webapp
#1. build backend services
if [ "$service" == $PYLLM_SERVICE ]; then
echo "start installing python modules required by supersonic-pyllm: ${pip_path}"
requirementPath=$projectDir/headless/python/requirements.txt
#5. build backend python modules
if [ "$service" == "pyllm" ]; then
echo "start installing python modules with pip: ${pip_path}"
requirementPath=$baseDir/../chat/python/requirements.txt
${pip_path} install -r ${requirementPath}
echo "install python modules success"
elif [ "$service" == "webapp" ]; then
buildWebapp
target_path=$projectDir/launchers/$STANDALONE_SERVICE/target/classes
tar xvf $projectDir/webapp/supersonic-webapp.tar.gz -C $target_path
mv $target_path/supersonic-webapp $target_path/webapp
else
buildJavaService $service
buildWebapp
packageRelease $service
fi
#6. reset runtime
rm -fr $runtimeDir/supersonic*
moveAllToRuntime
setEnvToWeb chat
setEnvToWeb headless

View File

@@ -1,9 +0,0 @@
set "sbinDir=%~dp0"
set "baseDir=%~dp0.."
set "buildDir=%baseDir%\build"
set "main_class=com.tencent.supersonic.StandaloneLauncher"
set "python_path=python"
set "pip_path=pip3"
set "standalone_service=standalone"
set "pyllm_service=pyllm"
set "projectDir=%baseDir%\.."

View File

@@ -6,19 +6,105 @@ pip_path=${PIP_PATH:-"pip3"}
sbinDir=$(cd "$(dirname "$0")"; pwd)
baseDir=$(cd "$sbinDir/.." && pwd -P)
runtimeDir=$baseDir/runtime
runtimeDir=$baseDir/../runtime
buildDir=$baseDir/build
projectDir=$baseDir/..
readonly CHAT_APP_NAME="supersonic_chat"
readonly HEADLESS_APP_NAME="supersonic_headless"
readonly PYLLM_APP_NAME="supersonic_pyllm"
readonly STANDALONE_APP_NAME="supersonic_standalone"
readonly CHAT_SERVICE="chat"
readonly HEADLESS_SERVICE="headless"
readonly PYLLM_SERVICE="pyllm"
readonly STANDALONE_SERVICE="standalone"
readonly PYLLM_HOST="127.0.0.1"
readonly PYLLM_PORT="9092"
function setEnvToWeb {
model_name=$1
json='{"env": "'$model_name'"}'
echo $json > ${runtimeDir}/supersonic-${model_name}/webapp/supersonic.config.json
echo $json > $baseDir/../launchers/${model_name}/target/classes/webapp/supersonic.config.json
}
function moveToRuntime {
model_name=$1
file="${buildDir}/supersonic-${model_name}.tar.gz"
if [ -f "$file" ]; then
tar -zxvf "$file" -C ${runtimeDir}
mv ${runtimeDir}/launchers-${model_name}-* ${runtimeDir}/supersonic-${model_name}
mkdir -p ${runtimeDir}/supersonic-${model_name}/webapp
cp -fr ${buildDir}/webapp/* ${runtimeDir}/supersonic-${model_name}/webapp
else
echo "File $file does not exist. Skipping the move to runtime."
fi
}
function moveAllToRuntime {
mkdir -p ${runtimeDir}
tar xvf ${buildDir}/supersonic-webapp.tar.gz -C ${buildDir}
mv ${buildDir}/supersonic-webapp ${buildDir}/webapp
moveToRuntime chat
moveToRuntime headless
moveToRuntime standalone
rm -fr ${buildDir}/webapp
}
# run java service
function runJavaService {
javaRunDir=${runtimeDir}/supersonic-${model_name}
local_app_name=$1
libDir=$javaRunDir/lib
confDir=$javaRunDir/conf
CLASSPATH=""
CLASSPATH=$CLASSPATH:$confDir
for jarPath in $libDir/*.jar; do
CLASSPATH=$CLASSPATH:$jarPath
done
export CLASSPATH
export LANG="zh_CN.UTF-8"
cd $javaRunDir
if [[ "$JAVA_HOME" == "" ]]; then
JAVA_HOME=$(ls /usr/jdk64/jdk* -d 2>/dev/null | xargs | awk '{print "'$local_app_name'"}')
fi
export PATH=$JAVA_HOME/bin:$PATH
command="-Dfile.encoding="UTF-8" -Duser.language="Zh" -Duser.region="CN" -Duser.timezone="GMT+08" -Dapp_name=${local_app_name} -Xms1024m -Xmx2048m "$main_class
mkdir -p $javaRunDir/logs
if [[ "$is_test" == "true" ]]; then
java -Dspring.profiles.active="dev" $command >/dev/null 2>$javaRunDir/logs/error.log &
else
java $command $javaRunDir >/dev/null 2>$javaRunDir/logs/error.log &
fi
}
# run python service
function runPythonService {
pythonRunDir=${runtimeDir}/supersonic-${model_name}/pyllm
cd $pythonRunDir
nohup ${python_path} supersonic_pyllm.py > $pythonRunDir/pyllm.log 2>&1 &
# add health check
for i in {1..10}
do
echo "pyllm health check attempt $i..."
response=$(curl -s http://${PYLLM_HOST}:${PYLLM_PORT}/health)
echo "pyllm health check response: $response"
status_ok="Healthy"
if [[ $response == *$status_ok* ]] ; then
echo "pyllm Health check passed."
break
else
if [ "$i" -eq 10 ]; then
echo "pyllm Health check failed after 10 attempts."
echo "May still downloading model files. Please check pyllm.log in runtime directory."
fi
echo "Retrying after 5 seconds..."
sleep 5
fi
done
}

View File

@@ -1,102 +1,118 @@
@echo off
setlocal
chcp 65001
set "sbinDir=%~dp0"
set "baseDir=%~dp0.."
set "runtimeDir=%baseDir%\..\runtime"
set "buildDir=%baseDir%\build"
set "main_class=com.tencent.supersonic.StandaloneLauncher"
set "python_path=python"
set "pip_path=pip3"
set "standalone_service=standalone"
set "pyllm_service=pyllm"
call supersonic-common.bat %*
call %sbinDir%/../conf/supersonic-env.bat %*
set "javaRunDir=%runtimeDir%\supersonic-standalone"
set "pythonRunDir=%runtimeDir%\supersonic-standalone\pyllm"
set "command=%~1"
set "service=%~2"
if "%service%"=="" (
set "service=%standalone_service%"
)
set "model_name=%service%"
IF "%service%"=="pyllm" (
set "llmProxy=PythonLLMProxy"
set "model_name=%standalone_service%"
SET "llmProxy=PythonLLMProxy"
)
cd %baseDir%
call :BUILD_RUNTIME
if "%command%"=="restart" (
call :stop
call :start
call :STOP
call :START
goto :EOF
) else if "%command%"=="start" (
call :start
call :START
goto :EOF
) else if "%command%"=="stop" (
call :stop
call :STOP
goto :EOF
) else if "%command%"=="reload" (
call :reloadExamples
call :RELOAD_EXAMPLE
goto :EOF
) else (
echo "Use command {start|stop|restart} to run."
goto :EOF
)
: start
:START
if "%service%"=="%pyllm_service%" (
call :runPythonService
call :runJavaService
call :START_PYTHON
call :START_JAVA
goto :EOF
)
call :runJavaService
call :START_JAVA
goto :EOF
: stop
call :stopPythonService
call :stopJavaService
:STOP
call :STOP_PYTHON
call :STOP_JAVA
goto :EOF
: reloadExamples
set "pythonRunDir=%baseDir%\pyllm"
cd "%pythonRunDir%\sql"
start %python_path% examples_reload_run.py
:START_PYTHON
echo 'python service starting, see logs in pyllm/pyllm.log'
cd "%pythonRunDir%"
start /B %python_path% supersonic_pyllm.py > %pythonRunDir%\pyllm.log 2>&1
timeout /t 10 >nul
echo 'python service started'
goto :EOF
: runJavaService
:START_JAVA
echo 'java service starting, see logs in logs/'
set "libDir=%baseDir%\lib"
set "confDir=%baseDir%\conf"
set "webDir=%baseDir%\webapp"
set "logDir=%baseDir%\logs"
set "classpath=%baseDir%;%webDir%;%libDir%\*;%confDir%"
cd "%javaRunDir%"
if not exist "%runtimeDir%\supersonic-standalone\logs" mkdir "%runtimeDir%\supersonic-standalone\logs"
set "libDir=%runtimeDir%\supersonic-standalone\lib"
set "confDir=%runtimeDir%\supersonic-standalone\conf"
set "webDir=%runtimeDir%\supersonic-standalone\webapp"
set "classpath=%confDir%;%webDir%;%libDir%\*"
set "java-command=-Dfile.encoding=UTF-8 -Duser.language=Zh -Duser.region=CN -Duser.timezone=GMT+08 -Xms1024m -Xmx2048m -cp %CLASSPATH% %MAIN_CLASS%"
if not exist %logDir% mkdir %logDir%
start /B java %java-command% >nul 2>&1
timeout /t 10 >nul
echo 'java service started'
goto :EOF
: runPythonService
echo 'python service starting, see logs in pyllm\pyllm.log'
set "pythonRunDir=%baseDir%\pyllm"
start /B %python_path% %pythonRunDir%\supersonic_pyllm.py > %pythonRunDir%\pyllm.log 2>&1
timeout /t 10 >nul
echo 'python service started'
goto :EOF
: stopPythonService
:STOP_PYTHON
for /f "tokens=2" %%i in ('tasklist ^| findstr /i "python"') do (
taskkill /PID %%i /F
echo "python service (PID = %%i) is killed."
)
goto :EOF
: stopJavaService
:STOP_JAVA
for /f "tokens=2" %%i in ('tasklist ^| findstr /i "java"') do (
taskkill /PID %%i /F
echo "java service (PID = %%i) is killed."
)
goto :EOF
endlocal
:RELOAD_EXAMPLE
cd "%runtimeDir%\supersonic-standalone\pyllm\sql"
start %python_path% examples_reload_run.py
goto :EOF
:BUILD_RUNTIME
rem 6. reset runtime
if exist "%runtimeDir%" goto :EOF
mkdir "%runtimeDir%"
tar -zxvf "%buildDir%\supersonic-standalone.tar.gz" -C "%runtimeDir%"
for /d %%f in ("%runtimeDir%\launchers-standalone-*") do (
move "%%f" "%runtimeDir%\supersonic-standalone"
)
rem 7. copy webapp to runtime
tar -zxvf "%buildDir%\supersonic-webapp.tar.gz" -C "%buildDir%"
if not exist "%runtimeDir%\supersonic-standalone\webapp" mkdir "%runtimeDir%\supersonic-standalone\webapp"
xcopy /s /e /h /y "%buildDir%\supersonic-webapp\*" "%runtimeDir%\supersonic-standalone\webapp"
if not exist "%runtimeDir%\supersonic-standalone\conf\webapp" mkdir "%runtimeDir%\supersonic-standalone\conf\webapp"
xcopy /s /e /h /y "%runtimeDir%\supersonic-standalone\webapp\*" "%runtimeDir%\supersonic-standalone\conf\webapp"
rd /s /q "%buildDir%\supersonic-webapp"

View File

@@ -1,11 +1,16 @@
#!/usr/bin/env bash
set -x
sbinDir=$(cd "$(dirname "$0")"; pwd)
chmod +x $sbinDir/supersonic-common.sh
source $sbinDir/supersonic-common.sh
set -a
source $sbinDir/../conf/supersonic-env.sh
set +a
# 1.init environment parameters
if [ ! -d "$runtimeDir" ]; then
echo "the runtime dir does not exist move all to runtime"
moveAllToRuntime
fi
set +x
command=$1
service=$2
@@ -13,93 +18,44 @@ if [ -z "$service" ]; then
service=${STANDALONE_SERVICE}
fi
app_name=$STANDALONE_APP_NAME
main_class="com.tencent.supersonic.StandaloneLauncher"
model_name=$service
if [ "$service" == "pyllm" ]; then
model_name=${STANDALONE_SERVICE}
export llmProxy=PythonLLMProxy
fi
cd $baseDir
# 2.set main class
function setMainClass {
if [ "$service" == $CHAT_SERVICE ]; then
main_class="com.tencent.supersonic.ChatLauncher"
elif [ "$service" == $HEADLESS_SERVICE ]; then
main_class="com.tencent.supersonic.HeadlessLauncher"
else
main_class="com.tencent.supersonic.StandaloneLauncher"
fi
}
setMainClass
# 3.set app name
function setAppName {
if [ "$service" == $CHAT_SERVICE ]; then
app_name=$CHAT_APP_NAME
elif [ "$service" == $HEADLESS_SERVICE ]; then
app_name=$HEADLESS_APP_NAME
else
app_name=$STANDALONE_APP_NAME
elif [ "$service" == $PYLLM_SERVICE ]; then
app_name=$PYLLM_APP_NAME
fi
}
setAppName
function reloadExamples {
cd $baseDir/pyllm/sql
pythonRunDir=${runtimeDir}/supersonic-${model_name}/pyllm
cd $pythonRunDir/sql
${python_path} examples_reload_run.py
}
function runJavaService {
javaRunDir=$baseDir
local_app_name=$1
libDir=$baseDir/lib
confDir=$baseDir/conf
CLASSPATH=""
CLASSPATH=$CLASSPATH:$confDir
for jarPath in $libDir/*.jar; do
CLASSPATH=$CLASSPATH:$jarPath
done
export CLASSPATH
export LANG="zh_CN.UTF-8"
cd $javaRunDir
if [[ "$JAVA_HOME" == "" ]]; then
JAVA_HOME=$(ls /usr/jdk64/jdk* -d 2>/dev/null | xargs | awk '{print "'$local_app_name'"}')
fi
export PATH=$JAVA_HOME/bin:$PATH
command="-Dfile.encoding="UTF-8" -Duser.language="Zh" -Duser.region="CN" -Duser.timezone="GMT+08" -Dapp_name=${local_app_name} -Xms1024m -Xmx2048m "$main_class
mkdir -p $javaRunDir/logs
if [[ "$is_test" == "true" ]]; then
java -Dspring.profiles.active="dev" $command >/dev/null 2>$javaRunDir/logs/error.log &
else
java $command $javaRunDir >/dev/null 2>$javaRunDir/logs/error.log &
fi
}
function runPythonService {
pythonRunDir=$baseDir/pyllm
cd $pythonRunDir
nohup ${python_path} supersonic_pyllm.py > $pythonRunDir/pyllm.log 2>&1 &
# add health check
for i in {1..10}
do
echo "pyllm health check attempt $i..."
response=$(curl -s http://${PYLLM_HOST}:${PYLLM_PORT}/health)
echo "pyllm health check response: $response"
status_ok="Healthy"
if [[ $response == *$status_ok* ]] ; then
echo "pyllm Health check passed."
break
else
if [ "$i" -eq 10 ]; then
echo "pyllm Health check failed after 10 attempts."
echo "May still downloading model files. Please check pyllm.log in runtime directory."
fi
echo "Retrying after 5 seconds..."
sleep 5
fi
done
}
function start()
{
@@ -137,16 +93,18 @@ function reload()
fi
}
setMainClass
setAppName
# 4. execute command operation
case "$command" in
start)
if [ "$service" == $PYLLM_SERVICE ]; then
echo "Starting $PYLLM_APP_NAME"
start $PYLLM_APP_NAME
echo "Starting $app_name"
start $app_name
echo "Starting $STANDALONE_APP_NAME"
start $STANDALONE_APP_NAME
else
echo "Starting $app_name"
start $app_name
fi
echo "Starting ${app_name}"
start ${app_name}
echo "Start success"
;;
stop)
@@ -163,15 +121,20 @@ case "$command" in
;;
restart)
if [ "$service" == $PYLLM_SERVICE ]; then
echo "Stopping $PYLLM_APP_NAME"
stop $PYLLM_APP_NAME
echo "Starting $PYLLM_APP_NAME"
start $PYLLM_APP_NAME
fi
echo "Stopping ${app_name}"
stop ${app_name}
echo "Stopping ${STANDALONE_APP_NAME}"
stop $STANDALONE_APP_NAME
echo "Starting ${app_name}"
start ${app_name}
echo "Starting ${STANDALONE_APP_NAME}"
start $STANDALONE_APP_NAME
else
echo "Stopping ${app_name}"
stop ${app_name}
echo "Starting ${app_name}"
start ${app_name}
fi
echo "Restart success"
;;
*)

View File

@@ -21,17 +21,11 @@
</includes>
</fileSet>
<fileSet>
<directory>${project.basedir}/../../headless/python</directory>
<directory>${project.basedir}/../../chat/python</directory>
<outputDirectory>pyllm</outputDirectory>
<fileMode>0777</fileMode>
<directoryMode>0755</directoryMode>
</fileSet>
<fileSet>
<directory>${project.basedir}/../../assembly/bin</directory>
<outputDirectory>bin</outputDirectory>
<fileMode>0777</fileMode>
<directoryMode>0755</directoryMode>
</fileSet>
</fileSets>
<dependencySets>

View File

@@ -2,8 +2,8 @@ package com.tencent.supersonic.auth.api.authentication.utils;
import com.tencent.supersonic.auth.api.authentication.pojo.User;
import com.tencent.supersonic.auth.api.authentication.service.UserStrategy;
import com.tencent.supersonic.common.pojo.SystemConfig;
import com.tencent.supersonic.common.service.SystemConfigService;
import com.tencent.supersonic.common.pojo.SysParameter;
import com.tencent.supersonic.common.service.SysParameterService;
import com.tencent.supersonic.common.util.ContextUtils;
import org.springframework.util.CollectionUtils;
@@ -20,10 +20,10 @@ public final class UserHolder {
public static User findUser(HttpServletRequest request, HttpServletResponse response) {
User user = REPO.findUser(request, response);
SystemConfigService sysParameterService = ContextUtils.getBean(SystemConfigService.class);
SystemConfig systemConfig = sysParameterService.getSystemConfig();
if (!CollectionUtils.isEmpty(systemConfig.getAdmins())
&& systemConfig.getAdmins().contains(user.getName())) {
SysParameterService sysParameterService = ContextUtils.getBean(SysParameterService.class);
SysParameter sysParameter = sysParameterService.getSysParameter();
if (!CollectionUtils.isEmpty(sysParameter.getAdmins())
&& sysParameter.getAdmins().contains(user.getName())) {
user.setIsAdmin(1);
}
return user;

View File

@@ -6,8 +6,8 @@ import com.tencent.supersonic.auth.api.authentication.request.UserReq;
import com.tencent.supersonic.auth.api.authentication.service.UserService;
import com.tencent.supersonic.auth.api.authentication.utils.UserHolder;
import com.tencent.supersonic.auth.authentication.utils.ComponentFactory;
import com.tencent.supersonic.common.pojo.SystemConfig;
import com.tencent.supersonic.common.service.SystemConfigService;
import com.tencent.supersonic.common.pojo.SysParameter;
import com.tencent.supersonic.common.service.SysParameterService;
import org.springframework.stereotype.Service;
import org.springframework.util.CollectionUtils;
import javax.servlet.http.HttpServletRequest;
@@ -18,9 +18,9 @@ import java.util.Set;
@Service
public class UserServiceImpl implements UserService {
private SystemConfigService sysParameterService;
private SysParameterService sysParameterService;
public UserServiceImpl(SystemConfigService sysParameterService) {
public UserServiceImpl(SysParameterService sysParameterService) {
this.sysParameterService = sysParameterService;
}
@@ -28,9 +28,9 @@ public class UserServiceImpl implements UserService {
public User getCurrentUser(HttpServletRequest httpServletRequest, HttpServletResponse httpServletResponse) {
User user = UserHolder.findUser(httpServletRequest, httpServletResponse);
if (user != null) {
SystemConfig systemConfig = sysParameterService.getSystemConfig();
if (!CollectionUtils.isEmpty(systemConfig.getAdmins())
&& systemConfig.getAdmins().contains(user.getName())) {
SysParameter sysParameter = sysParameterService.getSysParameter();
if (!CollectionUtils.isEmpty(sysParameter.getAdmins())
&& sysParameter.getAdmins().contains(user.getName())) {
user.setIsAdmin(1);
}
}

View File

@@ -1,5 +1,6 @@
package com.tencent.supersonic.headless.api.pojo;
package com.tencent.supersonic.chat.api.pojo;
import com.tencent.supersonic.headless.api.pojo.SchemaElement;
import lombok.AllArgsConstructor;
import lombok.Builder;
import lombok.Data;

View File

@@ -0,0 +1,33 @@
package com.tencent.supersonic.chat.api.pojo;
import com.google.common.collect.Lists;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Set;
public class SchemaMapInfo {
private Map<Long, List<SchemaElementMatch>> viewElementMatches = new HashMap<>();
public Set<Long> getMatchedViewInfos() {
return viewElementMatches.keySet();
}
public List<SchemaElementMatch> getMatchedElements(Long view) {
return viewElementMatches.getOrDefault(view, Lists.newArrayList());
}
public Map<Long, List<SchemaElementMatch>> getViewElementMatches() {
return viewElementMatches;
}
public void setViewElementMatches(Map<Long, List<SchemaElementMatch>> viewElementMatches) {
this.viewElementMatches = viewElementMatches;
}
public void setMatchedElements(Long view, List<SchemaElementMatch> elementMatches) {
viewElementMatches.put(view, elementMatches);
}
}

View File

@@ -1,6 +1,6 @@
package com.tencent.supersonic.headless.api.pojo;
package com.tencent.supersonic.chat.api.pojo;
import com.tencent.supersonic.headless.api.pojo.request.QueryFilters;
import com.tencent.supersonic.chat.api.pojo.request.QueryFilters;
import lombok.AllArgsConstructor;
import lombok.Builder;
import lombok.Data;

View File

@@ -1,11 +1,15 @@
package com.tencent.supersonic.headless.api.pojo;
package com.tencent.supersonic.chat.api.pojo;
import com.tencent.supersonic.chat.api.pojo.response.SqlInfo;
import com.tencent.supersonic.chat.api.pojo.request.QueryFilter;
import com.tencent.supersonic.chat.api.pojo.response.EntityInfo;
import com.tencent.supersonic.common.pojo.DateConf;
import com.tencent.supersonic.common.pojo.Order;
import com.tencent.supersonic.common.pojo.enums.QueryType;
import com.tencent.supersonic.common.pojo.enums.AggregateTypeEnum;
import com.tencent.supersonic.common.pojo.enums.FilterType;
import com.tencent.supersonic.common.pojo.enums.QueryType;
import com.tencent.supersonic.headless.api.pojo.request.QueryFilter;
import com.tencent.supersonic.headless.api.pojo.SchemaElement;
import lombok.Data;
import java.util.ArrayList;
@@ -22,12 +26,12 @@ public class SemanticParseInfo {
private Integer id;
private String queryMode;
private SchemaElement dataSet;
private SchemaElement view;
private Set<SchemaElement> metrics = new TreeSet<>(new SchemaNameLengthComparator());
private Set<SchemaElement> dimensions = new LinkedHashSet();
private SchemaElement entity;
private AggregateTypeEnum aggType = AggregateTypeEnum.NONE;
private FilterType filterType = FilterType.AND;
private FilterType filterType = FilterType.UNION;
private Set<QueryFilter> dimensionFilters = new LinkedHashSet();
private Set<QueryFilter> metricFilters = new LinkedHashSet();
private Set<Order> orders = new LinkedHashSet();
@@ -36,10 +40,10 @@ public class SemanticParseInfo {
private double score;
private List<SchemaElementMatch> elementMatches = new ArrayList<>();
private Map<String, Object> properties = new HashMap<>();
private EntityInfo entityInfo;
private SqlInfo sqlInfo = new SqlInfo();
private QueryType queryType = QueryType.ID;
private EntityInfo entityInfo;
private String textInfo;
private static class SchemaNameLengthComparator implements Comparator<SchemaElement> {
@Override
@@ -68,11 +72,15 @@ public class SemanticParseInfo {
return metrics;
}
public Long getDataSetId() {
if (dataSet == null) {
public Long getViewId() {
if (view == null) {
return null;
}
return dataSet.getDataSet();
return view.getView();
}
public SchemaElement getModel() {
return view;
}
}

View File

@@ -0,0 +1,157 @@
package com.tencent.supersonic.chat.api.pojo;
import com.tencent.supersonic.headless.api.pojo.SchemaElement;
import com.tencent.supersonic.headless.api.pojo.SchemaElementType;
import java.io.Serializable;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.stream.Collectors;
import org.springframework.util.CollectionUtils;
public class SemanticSchema implements Serializable {
private List<ViewSchema> viewSchemaList;
public SemanticSchema(List<ViewSchema> viewSchemaList) {
this.viewSchemaList = viewSchemaList;
}
public void add(ViewSchema schema) {
viewSchemaList.add(schema);
}
public SchemaElement getElement(SchemaElementType elementType, long elementID) {
Optional<SchemaElement> element = Optional.empty();
switch (elementType) {
case ENTITY:
element = getElementsById(elementID, getEntities());
break;
case VIEW:
element = getElementsById(elementID, getViews());
break;
case METRIC:
element = getElementsById(elementID, getMetrics());
break;
case DIMENSION:
element = getElementsById(elementID, getDimensions());
break;
case VALUE:
element = getElementsById(elementID, getDimensionValues());
break;
default:
}
if (element.isPresent()) {
return element.get();
} else {
return null;
}
}
public Map<Long, String> getViewIdToName() {
return viewSchemaList.stream()
.collect(Collectors.toMap(a -> a.getView().getId(), a -> a.getView().getName(), (k1, k2) -> k1));
}
public List<SchemaElement> getDimensionValues() {
List<SchemaElement> dimensionValues = new ArrayList<>();
viewSchemaList.stream().forEach(d -> dimensionValues.addAll(d.getDimensionValues()));
return dimensionValues;
}
public List<SchemaElement> getDimensions() {
List<SchemaElement> dimensions = new ArrayList<>();
viewSchemaList.stream().forEach(d -> dimensions.addAll(d.getDimensions()));
return dimensions;
}
public List<SchemaElement> getDimensions(Long viewId) {
List<SchemaElement> dimensions = getDimensions();
return getElementsByViewId(viewId, dimensions);
}
public SchemaElement getDimension(Long id) {
List<SchemaElement> dimensions = getDimensions();
Optional<SchemaElement> dimension = getElementsById(id, dimensions);
return dimension.orElse(null);
}
public List<SchemaElement> getTags() {
List<SchemaElement> tags = new ArrayList<>();
viewSchemaList.stream().forEach(d -> tags.addAll(d.getTags()));
return tags;
}
public List<SchemaElement> getTags(Long viewId) {
List<SchemaElement> tags = new ArrayList<>();
viewSchemaList.stream().filter(schemaElement ->
viewId.equals(schemaElement.getView().getView()))
.forEach(d -> tags.addAll(d.getTags()));
return tags;
}
public List<SchemaElement> getMetrics() {
List<SchemaElement> metrics = new ArrayList<>();
viewSchemaList.stream().forEach(d -> metrics.addAll(d.getMetrics()));
return metrics;
}
public List<SchemaElement> getMetrics(Long viewId) {
List<SchemaElement> metrics = getMetrics();
return getElementsByViewId(viewId, metrics);
}
public List<SchemaElement> getEntities() {
List<SchemaElement> entities = new ArrayList<>();
viewSchemaList.stream().forEach(d -> entities.add(d.getEntity()));
return entities;
}
public List<SchemaElement> getEntities(Long viewId) {
List<SchemaElement> entities = getEntities();
return getElementsByViewId(viewId, entities);
}
private List<SchemaElement> getElementsByViewId(Long viewId, List<SchemaElement> elements) {
return elements.stream()
.filter(schemaElement -> viewId.equals(schemaElement.getView()))
.collect(Collectors.toList());
}
private Optional<SchemaElement> getElementsById(Long id, List<SchemaElement> elements) {
return elements.stream()
.filter(schemaElement -> id.equals(schemaElement.getId()))
.findFirst();
}
public SchemaElement getView(Long viewId) {
List<SchemaElement> views = getViews();
return getElementsById(viewId, views).orElse(null);
}
public List<SchemaElement> getViews() {
List<SchemaElement> views = new ArrayList<>();
viewSchemaList.stream().forEach(d -> views.add(d.getView()));
return views;
}
public Map<String, String> getBizNameToName(Long viewId) {
List<SchemaElement> allElements = new ArrayList<>();
allElements.addAll(getDimensions(viewId));
allElements.addAll(getMetrics(viewId));
return allElements.stream()
.collect(Collectors.toMap(SchemaElement::getBizName, SchemaElement::getName, (k1, k2) -> k1));
}
public Map<Long, ViewSchema> getViewSchemaMap() {
if (CollectionUtils.isEmpty(viewSchemaList)) {
return new HashMap<>();
}
return viewSchemaList.stream().collect(Collectors.toMap(viewSchema
-> viewSchema.getView().getView(), viewSchema -> viewSchema));
}
}

View File

@@ -1,18 +1,24 @@
package com.tencent.supersonic.headless.api.pojo;
package com.tencent.supersonic.chat.api.pojo;
import com.tencent.supersonic.headless.api.pojo.QueryConfig;
import com.tencent.supersonic.headless.api.pojo.SchemaElement;
import com.tencent.supersonic.headless.api.pojo.SchemaElementType;
import com.tencent.supersonic.headless.api.pojo.TagTypeDefaultConfig;
import com.tencent.supersonic.headless.api.pojo.TimeDefaultConfig;
import lombok.Data;
import java.util.HashSet;
import java.util.Optional;
import java.util.Set;
@Data
public class DataSetSchema {
private SchemaElement dataSet;
public class ViewSchema {
private SchemaElement view;
private Set<SchemaElement> metrics = new HashSet<>();
private Set<SchemaElement> dimensions = new HashSet<>();
private Set<SchemaElement> tags = new HashSet<>();
private Set<SchemaElement> dimensionValues = new HashSet<>();
private Set<SchemaElement> terms = new HashSet<>();
private Set<SchemaElement> tags = new HashSet<>();
private SchemaElement entity = new SchemaElement();
private QueryConfig queryConfig;
@@ -23,8 +29,8 @@ public class DataSetSchema {
case ENTITY:
element = Optional.ofNullable(entity);
break;
case DATASET:
element = Optional.of(dataSet);
case VIEW:
element = Optional.of(view);
break;
case METRIC:
element = metrics.stream().filter(e -> e.getId() == elementID).findFirst();
@@ -38,8 +44,34 @@ public class DataSetSchema {
case TAG:
element = tags.stream().filter(e -> e.getId() == elementID).findFirst();
break;
case TERM:
element = terms.stream().filter(e -> e.getId() == elementID).findFirst();
default:
}
if (element.isPresent()) {
return element.get();
} else {
return null;
}
}
public SchemaElement getElement(SchemaElementType elementType, String name) {
Optional<SchemaElement> element = Optional.empty();
switch (elementType) {
case ENTITY:
element = Optional.ofNullable(entity);
break;
case VIEW:
element = Optional.of(view);
break;
case METRIC:
element = metrics.stream().filter(e -> name.equals(e.getName())).findFirst();
break;
case DIMENSION:
element = dimensions.stream().filter(e -> name.equals(e.getName())).findFirst();
break;
case VALUE:
element = dimensionValues.stream().filter(e -> name.equals(e.getName())).findFirst();
break;
default:
}

View File

@@ -16,6 +16,16 @@ public class ChatConfigBaseReq {
private Long modelId;
/**
* the chatDetailConfig about the model
*/
private ChatDetailConfigReq chatDetailConfig;
/**
* the chatAggConfig about the model
*/
private ChatAggConfigReq chatAggConfig;
/**
* the recommended questions about the model

View File

@@ -1,21 +0,0 @@
package com.tencent.supersonic.chat.api.pojo.request;
import com.tencent.supersonic.auth.api.authentication.pojo.User;
import lombok.AllArgsConstructor;
import lombok.Builder;
import lombok.Data;
import lombok.NoArgsConstructor;
@Data
@Builder
@NoArgsConstructor
@AllArgsConstructor
public class ChatExecuteReq {
private User user;
private Long queryId;
private Integer chatId;
private int parseId;
private String queryText;
private boolean saveAnswer;
}

View File

@@ -1,22 +0,0 @@
package com.tencent.supersonic.chat.api.pojo.request;
import com.tencent.supersonic.auth.api.authentication.pojo.User;
import com.tencent.supersonic.common.pojo.DateConf;
import com.tencent.supersonic.headless.api.pojo.SchemaElement;
import com.tencent.supersonic.headless.api.pojo.request.QueryFilter;
import lombok.Data;
import java.util.HashSet;
import java.util.Set;
@Data
public class ChatQueryDataReq {
private User user;
private Set<SchemaElement> metrics = new HashSet<>();
private Set<SchemaElement> dimensions = new HashSet<>();
private Set<QueryFilter> dimensionFilters = new HashSet<>();
private Set<QueryFilter> metricFilters = new HashSet<>();
private DateConf dateInfo;
private Long queryId;
private Integer parseId;
}

View File

@@ -1,8 +1,8 @@
package com.tencent.supersonic.headless.api.pojo.request;
package com.tencent.supersonic.chat.api.pojo.request;
import com.tencent.supersonic.auth.api.authentication.pojo.User;
import com.tencent.supersonic.headless.api.pojo.SemanticParseInfo;
import com.tencent.supersonic.chat.api.pojo.SemanticParseInfo;
import lombok.Builder;
import lombok.Data;
@@ -10,9 +10,11 @@ import lombok.Data;
@Data
public class ExecuteQueryReq {
private User user;
private Long queryId;
private Integer agentId;
private Integer chatId;
private String queryText;
private Long queryId;
private Integer parseId;
private SemanticParseInfo parseInfo;
private boolean saveAnswer;
}

View File

@@ -13,7 +13,7 @@ public class PluginQueryReq {
private String type;
private String dataSet;
private String view;
private String pattern;

View File

@@ -1,14 +1,12 @@
package com.tencent.supersonic.headless.api.pojo.request;
package com.tencent.supersonic.chat.api.pojo.request;
import com.tencent.supersonic.auth.api.authentication.pojo.User;
import com.tencent.supersonic.common.pojo.DateConf;
import com.tencent.supersonic.headless.api.pojo.SchemaElement;
import com.tencent.supersonic.headless.api.pojo.SemanticParseInfo;
import lombok.Data;
import com.tencent.supersonic.common.pojo.DateConf;
import java.util.HashSet;
import java.util.Set;
import lombok.Data;
@Data
public class QueryDataReq {
@@ -19,5 +17,5 @@ public class QueryDataReq {
private Set<QueryFilter> metricFilters = new HashSet<>();
private DateConf dateInfo;
private Long queryId;
private SemanticParseInfo parseInfo;
private Integer parseId;
}

View File

@@ -1,4 +1,4 @@
package com.tencent.supersonic.headless.api.pojo.request;
package com.tencent.supersonic.chat.api.pojo.request;
import com.google.common.base.Objects;
import com.tencent.supersonic.common.pojo.enums.FilterOperatorEnum;

View File

@@ -1,4 +1,4 @@
package com.tencent.supersonic.headless.api.pojo.request;
package com.tencent.supersonic.chat.api.pojo.request;
import lombok.Data;
import java.util.ArrayList;

View File

@@ -1,19 +1,15 @@
package com.tencent.supersonic.chat.api.pojo.request;
import com.tencent.supersonic.auth.api.authentication.pojo.User;
import com.tencent.supersonic.headless.api.pojo.SchemaMapInfo;
import com.tencent.supersonic.headless.api.pojo.request.QueryFilters;
import lombok.Data;
@Data
public class ChatParseReq {
public class QueryReq {
private String queryText;
private Integer chatId;
private Integer agentId;
private Integer topN = 10;
private Long modelId;
private User user;
private QueryFilters queryFilters;
private boolean saveAnswer = true;
private SchemaMapInfo mapInfo = new SchemaMapInfo();
private Integer agentId;
}

View File

@@ -18,7 +18,7 @@ public class SimilarQueryReq {
private String queryText;
private Long dataSetId;
private Long viewId;
private Integer agentId;

View File

@@ -1,9 +1,8 @@
package com.tencent.supersonic.headless.api.pojo;
import lombok.Data;
package com.tencent.supersonic.chat.api.pojo.response;
import java.util.ArrayList;
import java.util.List;
import lombok.Data;
@Data
public class AggregateInfo {

View File

@@ -1,4 +1,4 @@
package com.tencent.supersonic.headless.api.pojo;
package com.tencent.supersonic.chat.api.pojo.response;
import lombok.AllArgsConstructor;
import lombok.Data;

View File

@@ -1,4 +1,4 @@
package com.tencent.supersonic.headless.api.pojo;
package com.tencent.supersonic.chat.api.pojo.response;
import lombok.Data;
@@ -8,7 +8,7 @@ import java.util.List;
@Data
public class EntityInfo {
private DataSetInfo dataSetInfo = new DataSetInfo();
private ViewInfo viewInfo = new ViewInfo();
private List<DataInfo> dimensions = new ArrayList<>();
private List<DataInfo> metrics = new ArrayList<>();
private String entityId;

View File

@@ -1,8 +1,7 @@
package com.tencent.supersonic.headless.api.pojo;
import lombok.Data;
package com.tencent.supersonic.chat.api.pojo.response;
import java.util.Map;
import lombok.Data;
@Data
public class MetricInfo {

View File

@@ -0,0 +1,23 @@
package com.tencent.supersonic.chat.api.pojo.response;
import com.google.common.collect.Lists;
import com.tencent.supersonic.chat.api.pojo.SemanticParseInfo;
import lombok.Data;
import java.util.List;
@Data
public class ParseResp {
private Integer chatId;
private String queryText;
private Long queryId;
private ParseState state;
private List<SemanticParseInfo> selectedParses = Lists.newArrayList();
private ParseTimeCostDO parseTimeCost = new ParseTimeCostDO();
public enum ParseState {
COMPLETED,
PENDING,
FAILED
}
}

View File

@@ -1,15 +1,15 @@
package com.tencent.supersonic.headless.api.pojo.response;
package com.tencent.supersonic.chat.api.pojo.response;
import lombok.Data;
@Data
public class ParseTimeCostResp {
public class ParseTimeCostDO {
private long parseStartTime;
private long parseTime;
private long sqlTime;
public ParseTimeCostResp() {
public ParseTimeCostDO() {
this.parseStartTime = System.currentTimeMillis();
}
}

View File

@@ -1,12 +1,10 @@
package com.tencent.supersonic.chat.api.pojo.response;
import com.tencent.supersonic.headless.api.pojo.SemanticParseInfo;
import com.tencent.supersonic.headless.api.pojo.response.QueryResult;
import com.tencent.supersonic.chat.api.pojo.SemanticParseInfo;
import lombok.Data;
import java.util.Date;
import java.util.List;
@Data
public class QueryResp {
@@ -20,4 +18,5 @@ public class QueryResp {
private List<SemanticParseInfo> parseInfos;
private List<SimilarQueryRecallResp> similarQueries;
}

View File

@@ -1,18 +1,18 @@
package com.tencent.supersonic.headless.api.pojo.response;
package com.tencent.supersonic.chat.api.pojo.response;
import com.tencent.supersonic.headless.api.pojo.SchemaElement;
import com.tencent.supersonic.chat.api.pojo.SemanticParseInfo;
import com.tencent.supersonic.common.pojo.QueryAuthorization;
import com.tencent.supersonic.common.pojo.QueryColumn;
import com.tencent.supersonic.headless.api.pojo.AggregateInfo;
import com.tencent.supersonic.headless.api.pojo.EntityInfo;
import com.tencent.supersonic.headless.api.pojo.SchemaElement;
import com.tencent.supersonic.headless.api.pojo.SemanticParseInfo;
import lombok.Data;
import java.util.List;
import java.util.Map;
@Data
public class QueryResult {
public EntityInfo entityInfo;
public AggregateInfo aggregateInfo;
private Long queryId;
private String queryMode;
private String querySql;
@@ -22,9 +22,6 @@ public class QueryResult {
private SemanticParseInfo chatContext;
private Object response;
private List<Map<String, Object>> queryResults;
private String textResult;
private Long queryTimeCost;
private EntityInfo entityInfo;
private List<SchemaElement> recommendedDimensions;
private AggregateInfo aggregateInfo;
}

View File

@@ -1,4 +1,4 @@
package com.tencent.supersonic.headless.api.pojo.response;
package com.tencent.supersonic.chat.api.pojo.response;
public enum QueryState {
SUCCESS,

View File

@@ -1,4 +1,4 @@
package com.tencent.supersonic.headless.api.pojo.response;
package com.tencent.supersonic.chat.api.pojo.response;
import java.util.List;
import lombok.Data;

View File

@@ -1,13 +1,12 @@
package com.tencent.supersonic.headless.api.pojo.response;
package com.tencent.supersonic.chat.api.pojo.response;
import com.tencent.supersonic.headless.api.pojo.SchemaElementType;
import java.util.Objects;
import lombok.Builder;
import lombok.Data;
import lombok.Getter;
import lombok.Setter;
import java.util.Objects;
@Data
@Setter
@Getter

View File

@@ -10,6 +10,8 @@ public class SimilarQueryRecallResp {
private Long queryId;
private Integer parseId;
private String queryText;
}

View File

@@ -1,4 +1,4 @@
package com.tencent.supersonic.headless.api.pojo;
package com.tencent.supersonic.chat.api.pojo.response;
import lombok.Data;

View File

@@ -1,4 +1,4 @@
package com.tencent.supersonic.headless.api.pojo;
package com.tencent.supersonic.chat.api.pojo.response;
import lombok.Data;
@@ -6,7 +6,7 @@ import java.io.Serializable;
import java.util.List;
@Data
public class DataSetInfo extends DataInfo implements Serializable {
public class ViewInfo extends DataInfo implements Serializable {
private List<String> words;
private String primaryKey;

110
chat/core/pom.xml Normal file
View File

@@ -0,0 +1,110 @@
<?xml version="1.0" encoding="UTF-8"?>
<project xmlns="http://maven.apache.org/POM/4.0.0"
xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance"
xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/xsd/maven-4.0.0.xsd">
<parent>
<artifactId>chat</artifactId>
<groupId>com.tencent.supersonic</groupId>
<version>${revision}</version>
</parent>
<modelVersion>4.0.0</modelVersion>
<artifactId>chat-core</artifactId>
<properties>
<maven.compiler.source>8</maven.compiler.source>
<maven.compiler.target>8</maven.compiler.target>
</properties>
<dependencies>
<dependency>
<groupId>org.springframework</groupId>
<artifactId>spring-context</artifactId>
</dependency>
<dependency>
<groupId>org.testng</groupId>
<artifactId>testng</artifactId>
<version>${org.testng.version}</version>
<scope>test</scope>
</dependency>
<dependency>
<groupId>org.junit.jupiter</groupId>
<artifactId>junit-jupiter</artifactId>
<scope>test</scope>
</dependency>
<dependency>
<groupId>org.apache.commons</groupId>
<artifactId>commons-compress</artifactId>
<version>${commons.compress.version}</version>
<scope>compile</scope>
</dependency>
<dependency>
<groupId>org.springframework.boot</groupId>
<artifactId>spring-boot-starter-test</artifactId>
</dependency>
<dependency>
<groupId>junit</groupId>
<artifactId>junit</artifactId>
<scope>test</scope>
</dependency>
<dependency>
<groupId>org.springframework.boot</groupId>
<artifactId>spring-boot-starter-web</artifactId>
</dependency>
<dependency>
<groupId>com.alibaba</groupId>
<artifactId>druid</artifactId>
<version>${alibaba.druid.version}</version>
</dependency>
<dependency>
<groupId>mysql</groupId>
<artifactId>mysql-connector-java</artifactId>
</dependency>
<dependency>
<groupId>com.h2database</groupId>
<artifactId>h2</artifactId>
<version>${h2.version}</version>
</dependency>
<dependency>
<groupId>com.tencent.supersonic</groupId>
<artifactId>headless-api</artifactId>
<version>${project.version}</version>
</dependency>
<dependency>
<groupId>com.tencent.supersonic</groupId>
<artifactId>headless-core</artifactId>
<version>${project.version}</version>
</dependency>
<dependency>
<groupId>com.tencent.supersonic</groupId>
<artifactId>chat-api</artifactId>
<version>${project.version}</version>
<scope>compile</scope>
</dependency>
<dependency>
<groupId>com.github.xkzhangsan</groupId>
<artifactId>xk-time</artifactId>
<version>${xk.time.version}</version>
</dependency>
<dependency>
<groupId>org.mockito</groupId>
<artifactId>mockito-inline</artifactId>
<version>${mockito-inline.version}</version>
<scope>test</scope>
</dependency>
<dependency>
<groupId>com.tencent.supersonic</groupId>
<artifactId>headless-server</artifactId>
<version>${project.version}</version>
<scope>compile</scope>
</dependency>
</dependencies>
</project>

View File

@@ -1,10 +1,8 @@
package com.tencent.supersonic.chat.server.agent;
package com.tencent.supersonic.chat.core.agent;
import com.alibaba.fastjson.JSONObject;
import com.google.common.collect.Lists;
import com.google.common.collect.Sets;
import com.tencent.supersonic.headless.api.pojo.LLMConfig;
import com.tencent.supersonic.common.pojo.RecordInfo;
import lombok.Data;
import org.springframework.util.CollectionUtils;
@@ -31,8 +29,6 @@ public class Agent extends RecordInfo {
private Integer status;
private List<String> examples;
private String agentConfig;
private LLMConfig llmConfig;
private MultiTurnConfig multiTurnConfig;
public List<String> getTools(AgentToolType type) {
Map map = JSONObject.parseObject(agentConfig, Map.class);
@@ -69,33 +65,12 @@ public class Agent extends RecordInfo {
.collect(Collectors.toList());
}
public boolean containsLLMParserTool() {
return !CollectionUtils.isEmpty(getParserTools(AgentToolType.NL2SQL_LLM));
}
public boolean containsRuleTool() {
return !CollectionUtils.isEmpty(getParserTools(AgentToolType.NL2SQL_RULE));
}
public boolean containsNL2SQLTool() {
return !CollectionUtils.isEmpty(getParserTools(AgentToolType.NL2SQL_LLM))
|| !CollectionUtils.isEmpty(getParserTools(AgentToolType.NL2SQL_RULE));
}
public Set<Long> getDataSetIds() {
Set<Long> dataSetIds = getDataSetIds(null);
if (containsAllModel(dataSetIds)) {
return Sets.newHashSet();
}
return dataSetIds;
}
public Set<Long> getDataSetIds(AgentToolType agentToolType) {
public Set<Long> getViewIds(AgentToolType agentToolType) {
List<NL2SQLTool> commonAgentTools = getParserTools(agentToolType);
if (CollectionUtils.isEmpty(commonAgentTools)) {
return new HashSet<>();
}
return commonAgentTools.stream().map(NL2SQLTool::getDataSetIds)
return commonAgentTools.stream().map(NL2SQLTool::getViewIds)
.filter(modelIds -> !CollectionUtils.isEmpty(modelIds))
.flatMap(Collection::stream)
.collect(Collectors.toSet());

View File

@@ -1,4 +1,4 @@
package com.tencent.supersonic.chat.server.agent;
package com.tencent.supersonic.chat.core.agent;
import com.google.common.collect.Lists;
import lombok.AllArgsConstructor;

View File

@@ -1,4 +1,4 @@
package com.tencent.supersonic.chat.server.agent;
package com.tencent.supersonic.chat.core.agent;
import lombok.AllArgsConstructor;
import lombok.Data;

View File

@@ -1,4 +1,4 @@
package com.tencent.supersonic.chat.server.agent;
package com.tencent.supersonic.chat.core.agent;
import java.util.HashMap;
import java.util.Map;

View File

@@ -1,4 +1,4 @@
package com.tencent.supersonic.chat.server.agent;
package com.tencent.supersonic.chat.core.agent;
import lombok.Data;

View File

@@ -1,4 +1,4 @@
package com.tencent.supersonic.chat.server.agent;
package com.tencent.supersonic.chat.core.agent;
import lombok.AllArgsConstructor;
@@ -12,6 +12,6 @@ import java.util.List;
@AllArgsConstructor
public class NL2SQLTool extends AgentTool {
protected List<Long> dataSetIds;
protected List<Long> viewIds;
}

View File

@@ -1,4 +1,4 @@
package com.tencent.supersonic.chat.server.agent;
package com.tencent.supersonic.chat.core.agent;
import lombok.Data;

View File

@@ -1,4 +1,4 @@
package com.tencent.supersonic.chat.server.agent;
package com.tencent.supersonic.chat.core.agent;
import lombok.Data;
@@ -15,7 +15,7 @@ public class RuleParserTool extends NL2SQLTool {
private List<String> queryTypes;
public boolean isContainsAllModel() {
return CollectionUtils.isNotEmpty(dataSetIds) && dataSetIds.contains(-1L);
return CollectionUtils.isNotEmpty(viewIds) && viewIds.contains(-1L);
}
}

View File

@@ -1,4 +1,4 @@
package com.tencent.supersonic.headless.core.config;
package com.tencent.supersonic.chat.core.config;
import lombok.Data;
import org.springframework.beans.factory.annotation.Value;

View File

@@ -1,13 +1,14 @@
package com.tencent.supersonic.headless.core.config;
package com.tencent.supersonic.chat.core.config;
import com.tencent.supersonic.headless.core.knowledge.helper.HanlpHelper;
import java.io.FileNotFoundException;
import com.tencent.supersonic.headless.core.chat.knowledge.helper.HanlpHelper;
import lombok.Data;
import lombok.extern.slf4j.Slf4j;
import org.springframework.beans.factory.annotation.Value;
import org.springframework.context.annotation.Configuration;
import java.io.FileNotFoundException;
@Data
@Configuration

View File

@@ -1,4 +1,4 @@
package com.tencent.supersonic.headless.core.config;
package com.tencent.supersonic.chat.core.config;
import lombok.AllArgsConstructor;
import lombok.Data;

View File

@@ -1,4 +1,4 @@
package com.tencent.supersonic.headless.core.config;
package com.tencent.supersonic.chat.core.config;
import com.tencent.supersonic.common.pojo.Constants;
import lombok.Data;

View File

@@ -1,4 +1,4 @@
package com.tencent.supersonic.headless.core.config;
package com.tencent.supersonic.chat.core.config;
import lombok.Data;
import org.springframework.beans.factory.annotation.Value;

View File

@@ -1,12 +1,11 @@
package com.tencent.supersonic.headless.core.config;
package com.tencent.supersonic.chat.core.config;
import java.util.List;
import lombok.AllArgsConstructor;
import lombok.Data;
import lombok.NoArgsConstructor;
import lombok.ToString;
import java.util.List;
@Data
@AllArgsConstructor
@ToString

View File

@@ -1,8 +1,7 @@
package com.tencent.supersonic.headless.core.config;
import lombok.Data;
package com.tencent.supersonic.chat.core.config;
import java.util.List;
import lombok.Data;
/**
* when query an entity, return related dimension/metric info

View File

@@ -1,11 +1,10 @@
package com.tencent.supersonic.headless.core.config;
package com.tencent.supersonic.chat.core.config;
import com.tencent.supersonic.headless.api.pojo.response.DimSchemaResp;
import com.tencent.supersonic.headless.api.pojo.response.MetricSchemaResp;
import lombok.Data;
import java.util.List;
import lombok.Data;
@Data
public class EntityInternalDetail {

View File

@@ -1,4 +1,4 @@
package com.tencent.supersonic.headless.core.config;
package com.tencent.supersonic.chat.core.config;
import lombok.Data;
@@ -9,21 +9,19 @@ import org.springframework.context.annotation.Configuration;
@Data
public class LLMParserConfig {
@Value("${s2.parser.url:}")
@Value("${llm.parser.url:}")
private String url;
@Value("${s2.query2sql.path:/query2sql}")
@Value("${query2sql.path:/query2sql}")
private String queryToSqlPath;
@Value("${s2.dimension.topn:10}")
@Value("${dimension.topn:10}")
private Integer dimensionTopN;
@Value("${s2.metric.topn:10}")
@Value("${metric.topn:10}")
private Integer metricTopN;
@Value("${s2.tag.topn:20}")
private Integer tagTopN;
@Value("${s2.all.model:false}")
@Value("${all.model:false}")
private Boolean allModel;
}

View File

@@ -0,0 +1,175 @@
package com.tencent.supersonic.chat.core.config;
import com.tencent.supersonic.chat.core.query.llm.s2sql.LLMReq.SqlGenerationMode;
import com.tencent.supersonic.common.service.SysParameterService;
import lombok.Data;
import lombok.extern.slf4j.Slf4j;
import org.apache.commons.lang3.StringUtils;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.beans.factory.annotation.Value;
import org.springframework.context.annotation.Configuration;
@Configuration
@Data
@Slf4j
public class OptimizationConfig {
@Value("${one.detection.size:8}")
private Integer oneDetectionSize;
@Value("${one.detection.max.size:20}")
private Integer oneDetectionMaxSize;
@Value("${metric.dimension.min.threshold:0.3}")
private Double metricDimensionMinThresholdConfig;
@Value("${metric.dimension.threshold:0.3}")
private Double metricDimensionThresholdConfig;
@Value("${dimension.value.threshold:0.5}")
private Double dimensionValueThresholdConfig;
@Value("${long.text.threshold:0.8}")
private Double longTextThreshold;
@Value("${short.text.threshold:0.5}")
private Double shortTextThreshold;
@Value("${query.text.length.threshold:10}")
private Integer queryTextLengthThreshold;
@Value("${embedding.mapper.word.min:4}")
private int embeddingMapperWordMin;
@Value("${embedding.mapper.word.max:5}")
private int embeddingMapperWordMax;
@Value("${embedding.mapper.batch:50}")
private int embeddingMapperBatch;
@Value("${embedding.mapper.number:5}")
private int embeddingMapperNumber;
@Value("${embedding.mapper.round.number:10}")
private int embeddingMapperRoundNumber;
@Value("${embedding.mapper.distance.threshold:0.01}")
private Double embeddingMapperDistanceThreshold;
@Value("${s2SQL.linking.value.switch:true}")
private boolean useLinkingValueSwitch;
@Value("${s2SQL.generation:TWO_PASS_AUTO_COT}")
private SqlGenerationMode sqlGenerationMode;
@Value("${s2SQL.use.switch:true}")
private boolean useS2SqlSwitch;
@Value("${text2sql.example.num:15}")
private int text2sqlExampleNum;
@Value("${text2sql.fewShots.num:10}")
private int text2sqlFewShotsNum;
@Value("${text2sql.self.consistency.num:5}")
private int text2sqlSelfConsistencyNum;
@Value("${parse.show.count:3}")
private Integer parseShowCount;
@Autowired
private SysParameterService sysParameterService;
public Integer getOneDetectionSize() {
return convertValue("one.detection.size", Integer.class, oneDetectionSize);
}
public Integer getOneDetectionMaxSize() {
return convertValue("one.detection.max.size", Integer.class, oneDetectionMaxSize);
}
public Double getMetricDimensionMinThresholdConfig() {
return convertValue("metric.dimension.min.threshold", Double.class, metricDimensionMinThresholdConfig);
}
public Double getMetricDimensionThresholdConfig() {
return convertValue("metric.dimension.threshold", Double.class, metricDimensionThresholdConfig);
}
public Double getDimensionValueThresholdConfig() {
return convertValue("dimension.value.threshold", Double.class, dimensionValueThresholdConfig);
}
public Double getLongTextThreshold() {
return convertValue("long.text.threshold", Double.class, longTextThreshold);
}
public Double getShortTextThreshold() {
return convertValue("short.text.threshold", Double.class, shortTextThreshold);
}
public Integer getQueryTextLengthThreshold() {
return convertValue("query.text.length.threshold", Integer.class, queryTextLengthThreshold);
}
public boolean isUseS2SqlSwitch() {
return convertValue("use.s2SQL.switch", Boolean.class, useS2SqlSwitch);
}
public Integer getEmbeddingMapperWordMin() {
return convertValue("embedding.mapper.word.min", Integer.class, embeddingMapperWordMin);
}
public Integer getEmbeddingMapperWordMax() {
return convertValue("embedding.mapper.word.max", Integer.class, embeddingMapperWordMax);
}
public Integer getEmbeddingMapperBatch() {
return convertValue("embedding.mapper.batch", Integer.class, embeddingMapperBatch);
}
public Integer getEmbeddingMapperNumber() {
return convertValue("embedding.mapper.number", Integer.class, embeddingMapperNumber);
}
public Integer getEmbeddingMapperRoundNumber() {
return convertValue("embedding.mapper.round.number", Integer.class, embeddingMapperRoundNumber);
}
public Double getEmbeddingMapperDistanceThreshold() {
return convertValue("embedding.mapper.distance.threshold", Double.class, embeddingMapperDistanceThreshold);
}
public boolean isUseLinkingValueSwitch() {
return convertValue("s2SQL.linking.value.switch", Boolean.class, useLinkingValueSwitch);
}
public SqlGenerationMode getSqlGenerationMode() {
return convertValue("s2SQL.generation", SqlGenerationMode.class, sqlGenerationMode);
}
public Integer getParseShowCount() {
return convertValue("parse.show.count", Integer.class, parseShowCount);
}
public <T> T convertValue(String paramName, Class<T> targetType, T defaultValue) {
try {
String value = sysParameterService.getSysParameter().getParameterByName(paramName);
if (StringUtils.isBlank(value)) {
return defaultValue;
}
if (targetType == Double.class) {
return targetType.cast(Double.parseDouble(value));
} else if (targetType == Integer.class) {
return targetType.cast(Integer.parseInt(value));
} else if (targetType == Boolean.class) {
return targetType.cast(Boolean.parseBoolean(value));
} else if (targetType == SqlGenerationMode.class) {
return targetType.cast(SqlGenerationMode.valueOf(value));
}
} catch (Exception e) {
log.error("convertValue", e);
}
return defaultValue;
}
}

View File

@@ -1,15 +1,15 @@
package com.tencent.supersonic.headless.core.chat.corrector;
package com.tencent.supersonic.chat.core.corrector;
import com.tencent.supersonic.common.util.ContextUtils;
import com.tencent.supersonic.headless.api.pojo.SchemaElement;
import com.tencent.supersonic.chat.api.pojo.SemanticParseInfo;
import com.tencent.supersonic.chat.api.pojo.SemanticSchema;
import com.tencent.supersonic.chat.core.pojo.QueryContext;
import com.tencent.supersonic.common.pojo.enums.AggregateTypeEnum;
import com.tencent.supersonic.common.pojo.enums.TimeDimensionEnum;
import com.tencent.supersonic.common.util.ContextUtils;
import com.tencent.supersonic.common.util.jsqlparser.SqlAddHelper;
import com.tencent.supersonic.common.util.jsqlparser.SqlSelectFunctionHelper;
import com.tencent.supersonic.common.util.jsqlparser.SqlSelectHelper;
import com.tencent.supersonic.headless.api.pojo.SchemaElement;
import com.tencent.supersonic.headless.api.pojo.SemanticParseInfo;
import com.tencent.supersonic.headless.api.pojo.SemanticSchema;
import com.tencent.supersonic.headless.core.pojo.QueryContext;
import lombok.extern.slf4j.Slf4j;
import org.apache.commons.lang3.StringUtils;
import org.apache.commons.lang3.tuple.Pair;
@@ -37,7 +37,7 @@ public abstract class BaseSemanticCorrector implements SemanticCorrector {
return;
}
doCorrect(queryContext, semanticParseInfo);
log.debug("sqlCorrection:{} sql:{}", this.getClass().getSimpleName(), semanticParseInfo.getSqlInfo());
log.info("sqlCorrection:{} sql:{}", this.getClass().getSimpleName(), semanticParseInfo.getSqlInfo());
} catch (Exception e) {
log.error(String.format("correct error,sqlInfo:%s", semanticParseInfo.getSqlInfo()), e);
}
@@ -45,7 +45,7 @@ public abstract class BaseSemanticCorrector implements SemanticCorrector {
public abstract void doCorrect(QueryContext queryContext, SemanticParseInfo semanticParseInfo);
protected Map<String, String> getFieldNameMap(QueryContext queryContext, Long dataSetId) {
protected Map<String, String> getFieldNameMap(QueryContext queryContext, Long viewId) {
SemanticSchema semanticSchema = queryContext.getSemanticSchema();
@@ -55,7 +55,7 @@ public abstract class BaseSemanticCorrector implements SemanticCorrector {
// support fieldName and field alias
Map<String, String> result = dbAllFields.stream()
.filter(entry -> dataSetId.equals(entry.getDataSet()))
.filter(entry -> viewId.equals(entry.getView()))
.flatMap(schemaElement -> {
Set<String> elements = new HashSet<>();
elements.add(schemaElement.getName());
@@ -82,7 +82,7 @@ public abstract class BaseSemanticCorrector implements SemanticCorrector {
//decide whether add order by expression field to select
Environment environment = ContextUtils.getBean(Environment.class);
String correctorAdditionalInfo = environment.getProperty("s2.corrector.additional.information");
String correctorAdditionalInfo = environment.getProperty("corrector.additional.information");
if (StringUtils.isNotBlank(correctorAdditionalInfo) && Boolean.parseBoolean(correctorAdditionalInfo)) {
needAddFields.addAll(SqlSelectHelper.getOrderByFields(correctS2SQL));
}
@@ -109,8 +109,8 @@ public abstract class BaseSemanticCorrector implements SemanticCorrector {
protected void addAggregateToMetric(QueryContext queryContext, SemanticParseInfo semanticParseInfo) {
//add aggregate to all metric
String correctS2SQL = semanticParseInfo.getSqlInfo().getCorrectS2SQL();
Long dataSetId = semanticParseInfo.getDataSet().getDataSet();
List<SchemaElement> metrics = getMetricElements(queryContext, dataSetId);
Long viewId = semanticParseInfo.getView().getView();
List<SchemaElement> metrics = getMetricElements(queryContext, viewId);
Map<String, String> metricToAggregate = metrics.stream()
.map(schemaElement -> {
@@ -135,24 +135,9 @@ public abstract class BaseSemanticCorrector implements SemanticCorrector {
semanticParseInfo.getSqlInfo().setCorrectS2SQL(aggregateSql);
}
protected List<SchemaElement> getMetricElements(QueryContext queryContext, Long dataSetId) {
protected List<SchemaElement> getMetricElements(QueryContext queryContext, Long viewId) {
SemanticSchema semanticSchema = queryContext.getSemanticSchema();
return semanticSchema.getMetrics(dataSetId);
return semanticSchema.getMetrics(viewId);
}
protected Set<String> getDimensions(Long dataSetId, SemanticSchema semanticSchema) {
Set<String> dimensions = semanticSchema.getDimensions(dataSetId).stream()
.flatMap(
schemaElement -> {
Set<String> elements = new HashSet<>();
elements.add(schemaElement.getName());
if (!CollectionUtils.isEmpty(schemaElement.getAlias())) {
elements.addAll(schemaElement.getAlias());
}
return elements.stream();
}
).collect(Collectors.toSet());
dimensions.add(TimeDimensionEnum.DAY.getChName());
return dimensions;
}
}

View File

@@ -1,18 +1,25 @@
package com.tencent.supersonic.headless.core.chat.corrector;
package com.tencent.supersonic.chat.core.corrector;
import com.tencent.supersonic.chat.api.pojo.SemanticParseInfo;
import com.tencent.supersonic.chat.api.pojo.SemanticSchema;
import com.tencent.supersonic.chat.api.pojo.response.SqlInfo;
import com.tencent.supersonic.chat.core.pojo.QueryContext;
import com.tencent.supersonic.common.pojo.enums.TimeDimensionEnum;
import com.tencent.supersonic.common.util.ContextUtils;
import com.tencent.supersonic.common.util.jsqlparser.SqlAddHelper;
import com.tencent.supersonic.common.util.jsqlparser.SqlSelectHelper;
import com.tencent.supersonic.headless.api.pojo.SemanticParseInfo;
import com.tencent.supersonic.headless.api.pojo.SemanticSchema;
import com.tencent.supersonic.headless.api.pojo.SqlInfo;
import com.tencent.supersonic.headless.core.pojo.QueryContext;
import com.tencent.supersonic.headless.api.pojo.Dim;
import com.tencent.supersonic.headless.api.pojo.response.ModelResp;
import com.tencent.supersonic.headless.api.pojo.response.ViewResp;
import com.tencent.supersonic.headless.server.pojo.MetaFilter;
import com.tencent.supersonic.headless.server.service.ModelService;
import com.tencent.supersonic.headless.server.service.ViewService;
import lombok.extern.slf4j.Slf4j;
import org.apache.commons.lang3.StringUtils;
import org.springframework.core.env.Environment;
import org.springframework.util.CollectionUtils;
import java.util.HashSet;
import java.util.List;
import java.util.Objects;
import java.util.Set;
import java.util.stream.Collectors;
@@ -32,7 +39,23 @@ public class GroupByCorrector extends BaseSemanticCorrector {
}
private Boolean needAddGroupBy(QueryContext queryContext, SemanticParseInfo semanticParseInfo) {
Long dataSetId = semanticParseInfo.getDataSetId();
Long viewId = semanticParseInfo.getViewId();
ViewService viewService = ContextUtils.getBean(ViewService.class);
ModelService modelService = ContextUtils.getBean(ModelService.class);
ViewResp viewResp = viewService.getView(viewId);
List<Long> modelIds = viewResp.getViewDetail().getViewModelConfigs().stream().map(config -> config.getId())
.collect(Collectors.toList());
MetaFilter metaFilter = new MetaFilter();
metaFilter.setIds(modelIds);
List<ModelResp> modelRespList = modelService.getModelList(metaFilter);
for (ModelResp modelResp : modelRespList) {
List<Dim> dimList = modelResp.getModelDetail().getDimensions();
for (Dim dim : dimList) {
if (Objects.nonNull(dim.getTypeParams()) && dim.getTypeParams().getTimeGranularity().equals("none")) {
return false;
}
}
}
//add dimension group by
SqlInfo sqlInfo = semanticParseInfo.getSqlInfo();
String correctS2SQL = sqlInfo.getCorrectS2SQL();
@@ -43,7 +66,7 @@ public class GroupByCorrector extends BaseSemanticCorrector {
return false;
}
//add alias field name
Set<String> dimensions = getDimensions(dataSetId, semanticSchema);
Set<String> dimensions = getDimensions(viewId, semanticSchema);
List<String> selectFields = SqlSelectHelper.getSelectFields(correctS2SQL);
if (CollectionUtils.isEmpty(selectFields) || CollectionUtils.isEmpty(dimensions)) {
return false;
@@ -56,22 +79,33 @@ public class GroupByCorrector extends BaseSemanticCorrector {
log.info("not add group by ,exist group by in correctS2SQL:{}", correctS2SQL);
return false;
}
Environment environment = ContextUtils.getBean(Environment.class);
String correctorAdditionalInfo = environment.getProperty("s2.corrector.additional.information");
if (StringUtils.isNotBlank(correctorAdditionalInfo) && !Boolean.parseBoolean(correctorAdditionalInfo)) {
return false;
}
return true;
}
private Set<String> getDimensions(Long viewId, SemanticSchema semanticSchema) {
Set<String> dimensions = semanticSchema.getDimensions(viewId).stream()
.flatMap(
schemaElement -> {
Set<String> elements = new HashSet<>();
elements.add(schemaElement.getName());
if (!CollectionUtils.isEmpty(schemaElement.getAlias())) {
elements.addAll(schemaElement.getAlias());
}
return elements.stream();
}
).collect(Collectors.toSet());
dimensions.add(TimeDimensionEnum.DAY.getChName());
return dimensions;
}
private void addGroupByFields(QueryContext queryContext, SemanticParseInfo semanticParseInfo) {
Long dataSetId = semanticParseInfo.getDataSetId();
Long viewId = semanticParseInfo.getViewId();
//add dimension group by
SqlInfo sqlInfo = semanticParseInfo.getSqlInfo();
String correctS2SQL = sqlInfo.getCorrectS2SQL();
SemanticSchema semanticSchema = queryContext.getSemanticSchema();
//add alias field name
Set<String> dimensions = getDimensions(dataSetId, semanticSchema);
Set<String> dimensions = getDimensions(viewId, semanticSchema);
List<String> selectFields = SqlSelectHelper.getSelectFields(correctS2SQL);
List<String> aggregateFields = SqlSelectHelper.getAggregateFields(correctS2SQL);
Set<String> groupByFields = selectFields.stream()

View File

@@ -1,12 +1,12 @@
package com.tencent.supersonic.headless.core.chat.corrector;
package com.tencent.supersonic.chat.core.corrector;
import com.tencent.supersonic.chat.api.pojo.SemanticParseInfo;
import com.tencent.supersonic.chat.api.pojo.SemanticSchema;
import com.tencent.supersonic.chat.core.pojo.QueryContext;
import com.tencent.supersonic.common.util.ContextUtils;
import com.tencent.supersonic.common.util.jsqlparser.SqlAddHelper;
import com.tencent.supersonic.common.util.jsqlparser.SqlSelectFunctionHelper;
import com.tencent.supersonic.common.util.jsqlparser.SqlSelectHelper;
import com.tencent.supersonic.headless.api.pojo.SemanticParseInfo;
import com.tencent.supersonic.headless.api.pojo.SemanticSchema;
import com.tencent.supersonic.headless.core.pojo.QueryContext;
import lombok.extern.slf4j.Slf4j;
import net.sf.jsqlparser.expression.Expression;
import org.apache.commons.lang3.StringUtils;
@@ -31,7 +31,7 @@ public class HavingCorrector extends BaseSemanticCorrector {
//decide whether add having expression field to select
Environment environment = ContextUtils.getBean(Environment.class);
String correctorAdditionalInfo = environment.getProperty("s2.corrector.additional.information");
String correctorAdditionalInfo = environment.getProperty("corrector.additional.information");
if (StringUtils.isNotBlank(correctorAdditionalInfo) && Boolean.parseBoolean(correctorAdditionalInfo)) {
addHavingToSelect(semanticParseInfo);
}
@@ -39,11 +39,11 @@ public class HavingCorrector extends BaseSemanticCorrector {
}
private void addHaving(QueryContext queryContext, SemanticParseInfo semanticParseInfo) {
Long dataSet = semanticParseInfo.getDataSet().getDataSet();
Long viewId = semanticParseInfo.getView().getView();
SemanticSchema semanticSchema = queryContext.getSemanticSchema();
Set<String> metrics = semanticSchema.getMetrics(dataSet).stream()
Set<String> metrics = semanticSchema.getMetrics(viewId).stream()
.map(schemaElement -> schemaElement.getName()).collect(Collectors.toSet());
if (CollectionUtils.isEmpty(metrics)) {

View File

@@ -1,26 +1,17 @@
package com.tencent.supersonic.headless.core.chat.corrector;
package com.tencent.supersonic.chat.core.corrector;
import com.tencent.supersonic.chat.api.pojo.SemanticParseInfo;
import com.tencent.supersonic.chat.api.pojo.response.SqlInfo;
import com.tencent.supersonic.chat.core.parser.sql.llm.ParseResult;
import com.tencent.supersonic.chat.core.pojo.QueryContext;
import com.tencent.supersonic.chat.core.query.llm.s2sql.LLMReq.ElementValue;
import com.tencent.supersonic.common.pojo.Constants;
import com.tencent.supersonic.common.pojo.enums.FilterOperatorEnum;
import com.tencent.supersonic.common.pojo.enums.TimeDimensionEnum;
import com.tencent.supersonic.common.util.DateUtils;
import com.tencent.supersonic.common.util.JsonUtil;
import com.tencent.supersonic.common.util.jsqlparser.AggregateEnum;
import com.tencent.supersonic.common.util.jsqlparser.FieldExpression;
import com.tencent.supersonic.common.util.jsqlparser.SqlRemoveHelper;
import com.tencent.supersonic.common.util.jsqlparser.SqlReplaceHelper;
import com.tencent.supersonic.common.util.jsqlparser.SqlSelectHelper;
import com.tencent.supersonic.headless.api.pojo.SemanticParseInfo;
import com.tencent.supersonic.headless.api.pojo.SemanticSchema;
import com.tencent.supersonic.headless.api.pojo.SqlInfo;
import com.tencent.supersonic.headless.core.chat.parser.llm.ParseResult;
import com.tencent.supersonic.headless.core.pojo.QueryContext;
import com.tencent.supersonic.headless.core.chat.query.llm.s2sql.LLMReq.ElementValue;
import lombok.extern.slf4j.Slf4j;
import org.apache.commons.lang3.StringUtils;
import org.springframework.util.CollectionUtils;
import java.util.ArrayList;
import java.util.List;
import java.util.Map;
import java.util.Objects;
@@ -61,7 +52,7 @@ public class SchemaCorrector extends BaseSemanticCorrector {
}
private void correctFieldName(QueryContext queryContext, SemanticParseInfo semanticParseInfo) {
Map<String, String> fieldNameMap = getFieldNameMap(queryContext, semanticParseInfo.getDataSetId());
Map<String, String> fieldNameMap = getFieldNameMap(queryContext, semanticParseInfo.getViewId());
SqlInfo sqlInfo = semanticParseInfo.getSqlInfo();
String sql = SqlReplaceHelper.replaceFields(sqlInfo.getCorrectS2SQL(), fieldNameMap);
sqlInfo.setCorrectS2SQL(sql);
@@ -114,35 +105,4 @@ public class SchemaCorrector extends BaseSemanticCorrector {
String sql = SqlReplaceHelper.replaceValue(sqlInfo.getCorrectS2SQL(), filedNameToValueMap, false);
sqlInfo.setCorrectS2SQL(sql);
}
public void removeFilterIfNotInLinkingValue(QueryContext queryContext, SemanticParseInfo semanticParseInfo) {
SqlInfo sqlInfo = semanticParseInfo.getSqlInfo();
String correctS2SQL = sqlInfo.getCorrectS2SQL();
List<FieldExpression> whereExpressionList = SqlSelectHelper.getWhereExpressions(correctS2SQL);
if (CollectionUtils.isEmpty(whereExpressionList)) {
return;
}
List<ElementValue> linkingValues = getLinkingValues(semanticParseInfo);
SemanticSchema semanticSchema = queryContext.getSemanticSchema();
Set<String> dimensions = getDimensions(semanticParseInfo.getDataSetId(), semanticSchema);
if (CollectionUtils.isEmpty(linkingValues)) {
linkingValues = new ArrayList<>();
}
Set<String> linkingFieldNames = linkingValues.stream().map(linking -> linking.getFieldName())
.collect(Collectors.toSet());
Set<String> removeFieldNames = whereExpressionList.stream()
.filter(fieldExpression -> StringUtils.isBlank(fieldExpression.getFunction()))
.filter(fieldExpression -> !TimeDimensionEnum.containsTimeDimension(fieldExpression.getFieldName()))
.filter(fieldExpression -> FilterOperatorEnum.EQUALS.getValue().equals(fieldExpression.getOperator()))
.filter(fieldExpression -> dimensions.contains(fieldExpression.getFieldName()))
.filter(fieldExpression -> !DateUtils.isAnyDateString(fieldExpression.getFieldValue().toString()))
.filter(fieldExpression -> !linkingFieldNames.contains(fieldExpression.getFieldName()))
.map(fieldExpression -> fieldExpression.getFieldName()).collect(Collectors.toSet());
String sql = SqlRemoveHelper.removeWhereCondition(correctS2SQL, removeFieldNames);
sqlInfo.setCorrectS2SQL(sql);
}
}

View File

@@ -1,15 +1,12 @@
package com.tencent.supersonic.headless.core.chat.corrector;
package com.tencent.supersonic.chat.core.corrector;
import com.tencent.supersonic.common.util.jsqlparser.SqlReplaceHelper;
import com.tencent.supersonic.chat.core.pojo.QueryContext;
import com.tencent.supersonic.chat.api.pojo.SemanticParseInfo;
import com.tencent.supersonic.common.util.jsqlparser.SqlSelectHelper;
import com.tencent.supersonic.headless.api.pojo.SemanticParseInfo;
import com.tencent.supersonic.headless.core.pojo.QueryContext;
import java.util.List;
import lombok.extern.slf4j.Slf4j;
import org.springframework.util.CollectionUtils;
import java.util.List;
/**
* Perform SQL corrections on the "Select" section in S2SQL.
*/
@@ -28,7 +25,5 @@ public class SelectCorrector extends BaseSemanticCorrector {
return;
}
addFieldsToSelect(semanticParseInfo, correctS2SQL);
String querySql = SqlReplaceHelper.dealAliasToOrderBy(correctS2SQL);
semanticParseInfo.getSqlInfo().setCorrectS2SQL(querySql);
}
}

View File

@@ -1,8 +1,7 @@
package com.tencent.supersonic.headless.core.chat.corrector;
package com.tencent.supersonic.chat.core.corrector;
import com.tencent.supersonic.headless.api.pojo.SemanticParseInfo;
import com.tencent.supersonic.headless.core.pojo.QueryContext;
import com.tencent.supersonic.chat.core.pojo.QueryContext;
import com.tencent.supersonic.chat.api.pojo.SemanticParseInfo;
/**
* A semantic corrector checks validity of extracted semantic information and

View File

@@ -1,25 +1,16 @@
package com.tencent.supersonic.headless.core.chat.corrector;
package com.tencent.supersonic.chat.core.corrector;
import com.tencent.supersonic.common.pojo.enums.TimeDimensionEnum;
import com.tencent.supersonic.chat.api.pojo.SemanticParseInfo;
import com.tencent.supersonic.chat.core.pojo.QueryContext;
import com.tencent.supersonic.common.util.jsqlparser.DateVisitor.DateBoundInfo;
import com.tencent.supersonic.common.util.jsqlparser.SqlAddHelper;
import com.tencent.supersonic.common.util.jsqlparser.SqlDateSelectHelper;
import com.tencent.supersonic.common.util.jsqlparser.SqlReplaceHelper;
import com.tencent.supersonic.common.util.jsqlparser.SqlSelectHelper;
import com.tencent.supersonic.headless.api.pojo.SemanticParseInfo;
import com.tencent.supersonic.headless.core.pojo.QueryContext;
import com.tencent.supersonic.headless.core.utils.S2SqlDateHelper;
import java.util.Objects;
import lombok.extern.slf4j.Slf4j;
import net.sf.jsqlparser.JSQLParserException;
import net.sf.jsqlparser.expression.Expression;
import net.sf.jsqlparser.parser.CCJSqlParserUtil;
import org.apache.commons.lang3.StringUtils;
import org.apache.commons.lang3.tuple.Pair;
import org.springframework.util.CollectionUtils;
import java.util.List;
import java.util.Objects;
/**
* Perform SQL corrections on the time in S2SQL.
@@ -30,39 +21,12 @@ public class TimeCorrector extends BaseSemanticCorrector {
@Override
public void doCorrect(QueryContext queryContext, SemanticParseInfo semanticParseInfo) {
addDateIfNotExist(queryContext, semanticParseInfo);
parserDateDiffFunction(semanticParseInfo);
addLowerBoundDate(semanticParseInfo);
}
private void addDateIfNotExist(QueryContext queryContext, SemanticParseInfo semanticParseInfo) {
String correctS2SQL = semanticParseInfo.getSqlInfo().getCorrectS2SQL();
List<String> whereFields = SqlSelectHelper.getWhereFields(correctS2SQL);
if (CollectionUtils.isEmpty(whereFields) || !TimeDimensionEnum.containsZhTimeDimension(whereFields)) {
Pair<String, String> startEndDate = S2SqlDateHelper.getStartEndDate(queryContext,
semanticParseInfo.getDataSetId(), semanticParseInfo.getQueryType());
if (StringUtils.isNotBlank(startEndDate.getLeft())
&& StringUtils.isNotBlank(startEndDate.getRight())) {
correctS2SQL = SqlAddHelper.addParenthesisToWhere(correctS2SQL);
String dateChName = TimeDimensionEnum.DAY.getChName();
String condExpr = String.format(" ( %s >= '%s' and %s <= '%s' )", dateChName,
startEndDate.getLeft(), dateChName, startEndDate.getRight());
try {
Expression expression = CCJSqlParserUtil.parseCondExpression(condExpr);
correctS2SQL = SqlAddHelper.addWhere(correctS2SQL, expression);
} catch (JSQLParserException e) {
log.error("parseCondExpression:{}", e);
}
}
}
semanticParseInfo.getSqlInfo().setCorrectS2SQL(correctS2SQL);
}
private void addLowerBoundDate(SemanticParseInfo semanticParseInfo) {
String correctS2SQL = semanticParseInfo.getSqlInfo().getCorrectS2SQL();
DateBoundInfo dateBoundInfo = SqlDateSelectHelper.getDateBoundInfo(correctS2SQL);

View File

@@ -1,21 +1,24 @@
package com.tencent.supersonic.headless.core.chat.corrector;
package com.tencent.supersonic.chat.core.corrector;
import com.tencent.supersonic.headless.api.pojo.SchemaElement;
import com.tencent.supersonic.headless.api.pojo.SchemaValueMap;
import com.tencent.supersonic.chat.api.pojo.SemanticParseInfo;
import com.tencent.supersonic.chat.api.pojo.SemanticSchema;
import com.tencent.supersonic.chat.api.pojo.request.QueryFilters;
import com.tencent.supersonic.chat.core.utils.S2SqlDateHelper;
import com.tencent.supersonic.chat.core.pojo.QueryContext;
import com.tencent.supersonic.common.pojo.Constants;
import com.tencent.supersonic.common.pojo.enums.TimeDimensionEnum;
import com.tencent.supersonic.common.util.StringUtil;
import com.tencent.supersonic.common.util.jsqlparser.SqlAddHelper;
import com.tencent.supersonic.common.util.jsqlparser.SqlReplaceHelper;
import com.tencent.supersonic.headless.api.pojo.SchemaElement;
import com.tencent.supersonic.headless.api.pojo.SchemaValueMap;
import com.tencent.supersonic.headless.api.pojo.SemanticParseInfo;
import com.tencent.supersonic.headless.api.pojo.SemanticSchema;
import com.tencent.supersonic.headless.api.pojo.request.QueryFilters;
import com.tencent.supersonic.headless.core.pojo.QueryContext;
import com.tencent.supersonic.common.util.jsqlparser.SqlSelectHelper;
import lombok.extern.slf4j.Slf4j;
import net.sf.jsqlparser.JSQLParserException;
import net.sf.jsqlparser.expression.Expression;
import net.sf.jsqlparser.parser.CCJSqlParserUtil;
import org.apache.commons.lang3.StringUtils;
import org.apache.commons.lang3.tuple.Pair;
import org.apache.logging.log4j.util.Strings;
import org.springframework.util.CollectionUtils;
@@ -34,6 +37,8 @@ public class WhereCorrector extends BaseSemanticCorrector {
@Override
public void doCorrect(QueryContext queryContext, SemanticParseInfo semanticParseInfo) {
addDateIfNotExist(queryContext, semanticParseInfo);
addQueryFilter(queryContext, semanticParseInfo);
updateFieldValueByTechName(queryContext, semanticParseInfo);
@@ -57,6 +62,29 @@ public class WhereCorrector extends BaseSemanticCorrector {
}
}
private void addDateIfNotExist(QueryContext queryContext, SemanticParseInfo semanticParseInfo) {
String correctS2SQL = semanticParseInfo.getSqlInfo().getCorrectS2SQL();
List<String> whereFields = SqlSelectHelper.getWhereFields(correctS2SQL);
if (CollectionUtils.isEmpty(whereFields) || !TimeDimensionEnum.containsZhTimeDimension(whereFields)) {
Pair<String, String> startEndDate = S2SqlDateHelper.getStartEndDate(queryContext,
semanticParseInfo.getViewId(), semanticParseInfo.getQueryType());
if (StringUtils.isNotBlank(startEndDate.getLeft())
&& StringUtils.isNotBlank(startEndDate.getRight())) {
correctS2SQL = SqlAddHelper.addParenthesisToWhere(correctS2SQL);
String dateChName = TimeDimensionEnum.DAY.getChName();
String condExpr = String.format(" ( %s >= '%s' and %s <= '%s' )", dateChName,
startEndDate.getLeft(), dateChName, startEndDate.getRight());
try {
Expression expression = CCJSqlParserUtil.parseCondExpression(condExpr);
correctS2SQL = SqlAddHelper.addWhere(correctS2SQL, expression);
} catch (JSQLParserException e) {
log.error("parseCondExpression:{}", e);
}
}
}
semanticParseInfo.getSqlInfo().setCorrectS2SQL(correctS2SQL);
}
private String getQueryFilter(QueryFilters queryFilters) {
if (Objects.isNull(queryFilters) || CollectionUtils.isEmpty(queryFilters.getFilters())) {
return null;
@@ -73,8 +101,8 @@ public class WhereCorrector extends BaseSemanticCorrector {
private void updateFieldValueByTechName(QueryContext queryContext, SemanticParseInfo semanticParseInfo) {
SemanticSchema semanticSchema = queryContext.getSemanticSchema();
Long dataSetId = semanticParseInfo.getDataSetId();
List<SchemaElement> dimensions = semanticSchema.getDimensions(dataSetId);
Long viewId = semanticParseInfo.getViewId();
List<SchemaElement> dimensions = semanticSchema.getDimensions(viewId);
if (CollectionUtils.isEmpty(dimensions)) {
return;

View File

@@ -0,0 +1,99 @@
package com.tencent.supersonic.chat.core.mapper;
import com.tencent.supersonic.headless.api.pojo.SchemaElement;
import com.tencent.supersonic.chat.api.pojo.SchemaElementMatch;
import com.tencent.supersonic.headless.api.pojo.SchemaElementType;
import com.tencent.supersonic.chat.api.pojo.SchemaMapInfo;
import com.tencent.supersonic.chat.api.pojo.SemanticSchema;
import com.tencent.supersonic.chat.api.pojo.ViewSchema;
import com.tencent.supersonic.chat.core.pojo.QueryContext;
import lombok.extern.slf4j.Slf4j;
import org.apache.commons.lang3.StringUtils;
import org.springframework.beans.BeanUtils;
import java.util.ArrayList;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.stream.Collectors;
@Slf4j
public abstract class BaseMapper implements SchemaMapper {
@Override
public void map(QueryContext queryContext) {
String simpleName = this.getClass().getSimpleName();
long startTime = System.currentTimeMillis();
log.debug("before {},mapInfo:{}", simpleName, queryContext.getMapInfo().getViewElementMatches());
try {
doMap(queryContext);
} catch (Exception e) {
log.error("work error", e);
}
long cost = System.currentTimeMillis() - startTime;
log.debug("after {},cost:{},mapInfo:{}", simpleName, cost, queryContext.getMapInfo().getViewElementMatches());
}
public abstract void doMap(QueryContext queryContext);
public void addToSchemaMap(SchemaMapInfo schemaMap, Long modelId, SchemaElementMatch newElementMatch) {
Map<Long, List<SchemaElementMatch>> modelElementMatches = schemaMap.getViewElementMatches();
List<SchemaElementMatch> schemaElementMatches = modelElementMatches.putIfAbsent(modelId, new ArrayList<>());
if (schemaElementMatches == null) {
schemaElementMatches = modelElementMatches.get(modelId);
}
//remove duplication
AtomicBoolean needAddNew = new AtomicBoolean(true);
schemaElementMatches.removeIf(
existElementMatch -> {
SchemaElement existElement = existElementMatch.getElement();
SchemaElement newElement = newElementMatch.getElement();
if (existElement.equals(newElement)) {
if (newElementMatch.getSimilarity() > existElementMatch.getSimilarity()) {
return true;
} else {
needAddNew.set(false);
}
}
return false;
}
);
if (needAddNew.get()) {
schemaElementMatches.add(newElementMatch);
}
}
public SchemaElement getSchemaElement(Long viewId, SchemaElementType elementType, Long elementID,
SemanticSchema semanticSchema) {
SchemaElement element = new SchemaElement();
ViewSchema viewSchema = semanticSchema.getViewSchemaMap().get(viewId);
if (Objects.isNull(viewSchema)) {
return null;
}
SchemaElement elementDb = viewSchema.getElement(elementType, elementID);
if (Objects.isNull(elementDb)) {
log.info("element is null, elementType:{},elementID:{}", elementType, elementID);
return null;
}
BeanUtils.copyProperties(elementDb, element);
element.setAlias(getAlias(elementDb));
return element;
}
public List<String> getAlias(SchemaElement element) {
if (!SchemaElementType.VALUE.equals(element.getType())) {
return element.getAlias();
}
if (org.apache.commons.collections.CollectionUtils.isNotEmpty(element.getAlias()) && StringUtils.isNotEmpty(
element.getName())) {
return element.getAlias().stream()
.filter(aliasItem -> aliasItem.contains(element.getName()))
.collect(Collectors.toList());
}
return element.getAlias();
}
}

View File

@@ -1,17 +1,8 @@
package com.tencent.supersonic.headless.core.chat.mapper;
package com.tencent.supersonic.chat.core.mapper;
import com.tencent.supersonic.headless.api.pojo.enums.MapModeEnum;
import com.tencent.supersonic.chat.core.pojo.QueryContext;
import com.tencent.supersonic.headless.api.pojo.response.S2Term;
import com.tencent.supersonic.headless.core.config.MapperConfig;
import com.tencent.supersonic.headless.core.pojo.QueryContext;
import com.tencent.supersonic.headless.core.chat.knowledge.helper.NatureHelper;
import lombok.extern.slf4j.Slf4j;
import org.apache.commons.collections4.CollectionUtils;
import org.apache.commons.lang3.StringUtils;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.stereotype.Service;
import com.tencent.supersonic.headless.core.knowledge.helper.NatureHelper;
import java.util.ArrayList;
import java.util.Comparator;
import java.util.HashMap;
@@ -22,35 +13,37 @@ import java.util.Objects;
import java.util.Optional;
import java.util.Set;
import java.util.stream.Collectors;
import lombok.extern.slf4j.Slf4j;
import org.apache.commons.collections4.CollectionUtils;
import org.apache.commons.lang3.StringUtils;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.stereotype.Service;
@Service
@Slf4j
public abstract class BaseMatchStrategy<T> implements MatchStrategy<T> {
@Autowired
protected MapperHelper mapperHelper;
@Autowired
protected MapperConfig mapperConfig;
private MapperHelper mapperHelper;
@Override
public Map<MatchText, List<T>> match(QueryContext queryContext, List<S2Term> terms,
Set<Long> detectDataSetIds) {
Set<Long> detectViewIds) {
String text = queryContext.getQueryText();
if (Objects.isNull(terms) || StringUtils.isEmpty(text)) {
return null;
}
log.debug("terms:{},,detectDataSetIds:{}", terms, detectDataSetIds);
log.debug("terms:{},,detectViewIds:{}", terms, detectViewIds);
List<T> detects = detect(queryContext, terms, detectDataSetIds);
List<T> detects = detect(queryContext, terms, detectViewIds);
Map<MatchText, List<T>> result = new HashMap<>();
result.put(MatchText.builder().regText(text).detectSegment(text).build(), detects);
return result;
}
public List<T> detect(QueryContext queryContext, List<S2Term> terms, Set<Long> detectDataSetIds) {
public List<T> detect(QueryContext queryContext, List<S2Term> terms, Set<Long> detectViewIds) {
Map<Integer, Integer> regOffsetToLength = getRegOffsetToLength(terms);
String text = queryContext.getQueryText();
Set<T> results = new HashSet<>();
@@ -65,17 +58,18 @@ public abstract class BaseMatchStrategy<T> implements MatchStrategy<T> {
if (index <= text.length()) {
String detectSegment = text.substring(startIndex, index).trim();
detectSegments.add(detectSegment);
detectByStep(queryContext, results, detectDataSetIds, detectSegment, offset);
detectByStep(queryContext, results, detectViewIds, detectSegment, offset);
}
}
startIndex = mapperHelper.getStepIndex(regOffsetToLength, startIndex);
}
detectByBatch(queryContext, results, detectDataSetIds, detectSegments);
detectByBatch(queryContext, results, detectViewIds, detectSegments);
return new ArrayList<>(results);
}
protected void detectByBatch(QueryContext queryContext, Set<T> results, Set<Long> detectDataSetIds,
protected void detectByBatch(QueryContext queryContext, Set<T> results, Set<Long> detectViewIds,
Set<String> detectSegments) {
return;
}
public Map<Integer, Integer> getRegOffsetToLength(List<S2Term> terms) {
@@ -110,9 +104,9 @@ public abstract class BaseMatchStrategy<T> implements MatchStrategy<T> {
}
public List<T> getMatches(QueryContext queryContext, List<S2Term> terms) {
Set<Long> dataSetIds = queryContext.getDataSetIds();
terms = filterByDataSetId(terms, dataSetIds);
Map<MatchText, List<T>> matchResult = match(queryContext, terms, dataSetIds);
Set<Long> viewIds = mapperHelper.getViewIds(queryContext.getViewId(), queryContext.getAgent());
terms = filterByViewId(terms, viewIds);
Map<MatchText, List<T>> matchResult = match(queryContext, terms, viewIds);
List<T> matches = new ArrayList<>();
if (Objects.isNull(matchResult)) {
return matches;
@@ -127,17 +121,17 @@ public abstract class BaseMatchStrategy<T> implements MatchStrategy<T> {
return matches;
}
public List<S2Term> filterByDataSetId(List<S2Term> terms, Set<Long> dataSetIds) {
public List<S2Term> filterByViewId(List<S2Term> terms, Set<Long> viewIds) {
logTerms(terms);
if (CollectionUtils.isNotEmpty(dataSetIds)) {
if (CollectionUtils.isNotEmpty(viewIds)) {
terms = terms.stream().filter(term -> {
Long dataSetId = NatureHelper.getDataSetId(term.getNature().toString());
if (Objects.nonNull(dataSetId)) {
return dataSetIds.contains(dataSetId);
Long viewId = NatureHelper.getViewId(term.getNature().toString());
if (Objects.nonNull(viewId)) {
return viewIds.contains(viewId);
}
return false;
}).collect(Collectors.toList());
log.info("terms filter by dataSetId:{}", dataSetIds);
log.info("terms filter by viewId:{}", viewIds);
logTerms(terms);
}
return terms;
@@ -156,12 +150,7 @@ public abstract class BaseMatchStrategy<T> implements MatchStrategy<T> {
public abstract String getMapKey(T a);
public abstract void detectByStep(QueryContext queryContext, Set<T> existResults, Set<Long> detectDataSetIds,
public abstract void detectByStep(QueryContext queryContext, Set<T> existResults, Set<Long> detectViewIds,
String detectSegment, int offset);
public double getThreshold(Double threshold, Double minThreshold, MapModeEnum mapModeEnum) {
double decreaseAmount = (threshold - minThreshold) / 4;
double divideThreshold = threshold - mapModeEnum.threshold * decreaseAmount;
return divideThreshold >= minThreshold ? divideThreshold : minThreshold;
}
}

View File

@@ -1,14 +1,15 @@
package com.tencent.supersonic.headless.core.chat.mapper;
package com.tencent.supersonic.chat.core.mapper;
import com.tencent.supersonic.common.pojo.Constants;
import com.tencent.supersonic.headless.api.pojo.SchemaElement;
import com.tencent.supersonic.headless.api.pojo.SchemaElementMatch;
import com.tencent.supersonic.chat.api.pojo.SchemaElementMatch;
import com.tencent.supersonic.chat.core.config.OptimizationConfig;
import com.tencent.supersonic.headless.api.pojo.response.S2Term;
import com.tencent.supersonic.headless.core.pojo.QueryContext;
import com.tencent.supersonic.headless.core.chat.knowledge.DatabaseMapResult;
import com.tencent.supersonic.headless.core.knowledge.DatabaseMapResult;
import com.tencent.supersonic.chat.core.pojo.QueryContext;
import com.tencent.supersonic.common.pojo.Constants;
import lombok.extern.slf4j.Slf4j;
import org.apache.commons.lang3.StringUtils;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.stereotype.Service;
import org.springframework.util.CollectionUtils;
@@ -20,9 +21,6 @@ import java.util.Map.Entry;
import java.util.Set;
import java.util.stream.Collectors;
import static com.tencent.supersonic.headless.core.config.MapperConfig.MAPPER_NAME_THRESHOLD;
import static com.tencent.supersonic.headless.core.config.MapperConfig.MAPPER_NAME_THRESHOLD_MIN;
/**
* DatabaseMatchStrategy uses SQL LIKE operator to match schema elements.
* It currently supports fuzzy matching against names and aliases.
@@ -31,13 +29,17 @@ import static com.tencent.supersonic.headless.core.config.MapperConfig.MAPPER_NA
@Slf4j
public class DatabaseMatchStrategy extends BaseMatchStrategy<DatabaseMapResult> {
@Autowired
private OptimizationConfig optimizationConfig;
@Autowired
private MapperHelper mapperHelper;
private List<SchemaElement> allElements;
@Override
public Map<MatchText, List<DatabaseMapResult>> match(QueryContext queryContext, List<S2Term> terms,
Set<Long> detectDataSetIds) {
Set<Long> detectViewIds) {
this.allElements = getSchemaElements(queryContext);
return super.match(queryContext, terms, detectDataSetIds);
return super.match(queryContext, terms, detectViewIds);
}
@Override
@@ -52,7 +54,7 @@ public class DatabaseMatchStrategy extends BaseMatchStrategy<DatabaseMapResult>
+ Constants.UNDERLINE + a.getSchemaElement().getName();
}
public void detectByStep(QueryContext queryContext, Set<DatabaseMapResult> existResults, Set<Long> detectDataSetIds,
public void detectByStep(QueryContext queryContext, Set<DatabaseMapResult> existResults, Set<Long> detectViewIds,
String detectSegment, int offset) {
if (StringUtils.isBlank(detectSegment)) {
return;
@@ -68,9 +70,9 @@ public class DatabaseMatchStrategy extends BaseMatchStrategy<DatabaseMapResult>
continue;
}
Set<SchemaElement> schemaElements = entry.getValue();
if (!CollectionUtils.isEmpty(detectDataSetIds)) {
if (!CollectionUtils.isEmpty(detectViewIds)) {
schemaElements = schemaElements.stream()
.filter(schemaElement -> detectDataSetIds.contains(schemaElement.getDataSet()))
.filter(schemaElement -> detectViewIds.contains(schemaElement.getView()))
.collect(Collectors.toSet());
}
for (SchemaElement schemaElement : schemaElements) {
@@ -91,19 +93,22 @@ public class DatabaseMatchStrategy extends BaseMatchStrategy<DatabaseMapResult>
}
private Double getThreshold(QueryContext queryContext) {
Double threshold = Double.valueOf(mapperConfig.getParameterValue(MAPPER_NAME_THRESHOLD));
Double minThreshold = Double.valueOf(mapperConfig.getParameterValue(MAPPER_NAME_THRESHOLD_MIN));
Double metricDimensionThresholdConfig = optimizationConfig.getMetricDimensionThresholdConfig();
Double metricDimensionMinThresholdConfig = optimizationConfig.getMetricDimensionMinThresholdConfig();
Map<Long, List<SchemaElementMatch>> modelElementMatches = queryContext.getMapInfo().getDataSetElementMatches();
Map<Long, List<SchemaElementMatch>> modelElementMatches = queryContext.getMapInfo().getViewElementMatches();
boolean existElement = modelElementMatches.entrySet().stream().anyMatch(entry -> entry.getValue().size() >= 1);
if (!existElement) {
threshold = threshold / 2;
log.info("ModelElementMatches:{},not exist Element threshold reduce by half:{}",
modelElementMatches, threshold);
double halfThreshold = metricDimensionThresholdConfig / 2;
metricDimensionThresholdConfig = halfThreshold >= metricDimensionMinThresholdConfig ? halfThreshold
: metricDimensionMinThresholdConfig;
log.info("ModelElementMatches:{} , not exist Element metricDimensionThresholdConfig reduce by half:{}",
modelElementMatches, metricDimensionThresholdConfig);
}
return getThreshold(threshold, minThreshold, queryContext.getMapModeEnum());
return metricDimensionThresholdConfig;
}
private Map<String, Set<SchemaElement>> getNameToItems(List<SchemaElement> models) {

View File

@@ -1,18 +1,17 @@
package com.tencent.supersonic.headless.core.chat.mapper;
package com.tencent.supersonic.chat.core.mapper;
import com.tencent.supersonic.chat.api.pojo.SchemaElementMatch;
import com.tencent.supersonic.chat.core.pojo.QueryContext;
import com.tencent.supersonic.common.util.ContextUtils;
import com.tencent.supersonic.common.util.embedding.Retrieval;
import com.tencent.supersonic.headless.api.pojo.SchemaElement;
import com.tencent.supersonic.headless.api.pojo.SchemaElementMatch;
import com.tencent.supersonic.headless.api.pojo.SchemaElementType;
import com.tencent.supersonic.headless.api.pojo.response.S2Term;
import com.tencent.supersonic.headless.core.pojo.QueryContext;
import com.tencent.supersonic.headless.core.chat.knowledge.EmbeddingResult;
import com.tencent.supersonic.headless.core.chat.knowledge.builder.BaseWordBuilder;
import com.tencent.supersonic.headless.core.chat.knowledge.helper.HanlpHelper;
import com.tencent.supersonic.headless.core.knowledge.EmbeddingResult;
import com.tencent.supersonic.headless.core.knowledge.builder.BaseWordBuilder;
import com.tencent.supersonic.headless.core.knowledge.helper.HanlpHelper;
import com.tencent.supersonic.headless.server.service.KnowledgeService;
import lombok.extern.slf4j.Slf4j;
import java.util.List;
import java.util.Objects;
@@ -26,7 +25,8 @@ public class EmbeddingMapper extends BaseMapper {
public void doMap(QueryContext queryContext) {
//1. query from embedding by queryText
String queryText = queryContext.getQueryText();
List<S2Term> terms = HanlpHelper.getTerms(queryText, queryContext.getModelIdToDataSetIds());
KnowledgeService knowledgeService = ContextUtils.getBean(KnowledgeService.class);
List<S2Term> terms = knowledgeService.getTerms(queryText);
EmbeddingMatchStrategy matchStrategy = ContextUtils.getBean(EmbeddingMatchStrategy.class);
List<EmbeddingResult> matchResults = matchStrategy.getMatches(queryContext, terms);
@@ -36,12 +36,12 @@ public class EmbeddingMapper extends BaseMapper {
//2. build SchemaElementMatch by info
for (EmbeddingResult matchResult : matchResults) {
Long elementId = Retrieval.getLongId(matchResult.getId());
Long dataSetId = Retrieval.getLongId(matchResult.getMetadata().get("dataSetId"));
if (Objects.isNull(dataSetId)) {
Long viewId = Retrieval.getLongId(matchResult.getMetadata().get("viewId"));
if (Objects.isNull(viewId)) {
continue;
}
SchemaElementType elementType = SchemaElementType.valueOf(matchResult.getMetadata().get("type"));
SchemaElement schemaElement = getSchemaElement(dataSetId, elementType, elementId,
SchemaElement schemaElement = getSchemaElement(viewId, elementType, elementId,
queryContext.getSemanticSchema());
if (schemaElement == null) {
continue;
@@ -54,7 +54,7 @@ public class EmbeddingMapper extends BaseMapper {
.detectWord(matchResult.getDetectWord())
.build();
//3. add to mapInfo
addToSchemaMap(queryContext.getMapInfo(), dataSetId, schemaElementMatch);
addToSchemaMap(queryContext.getMapInfo(), viewId, schemaElementMatch);
}
}
}

View File

@@ -1,13 +1,20 @@
package com.tencent.supersonic.headless.core.chat.mapper;
package com.tencent.supersonic.chat.core.mapper;
import com.google.common.collect.Lists;
import com.tencent.supersonic.chat.core.config.OptimizationConfig;
import com.tencent.supersonic.chat.core.pojo.QueryContext;
import com.tencent.supersonic.common.pojo.Constants;
import com.tencent.supersonic.common.util.embedding.Retrieval;
import com.tencent.supersonic.common.util.embedding.RetrieveQuery;
import com.tencent.supersonic.common.util.embedding.RetrieveQueryResult;
import com.tencent.supersonic.headless.core.chat.knowledge.EmbeddingResult;
import com.tencent.supersonic.headless.core.chat.knowledge.MetaEmbeddingService;
import com.tencent.supersonic.headless.core.pojo.QueryContext;
import com.tencent.supersonic.headless.core.knowledge.EmbeddingResult;
import com.tencent.supersonic.headless.server.service.MetaEmbeddingService;
import java.util.ArrayList;
import java.util.Comparator;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.stream.Collectors;
import lombok.extern.slf4j.Slf4j;
import org.apache.commons.collections.CollectionUtils;
import org.apache.commons.lang3.StringUtils;
@@ -15,20 +22,6 @@ import org.springframework.beans.BeanUtils;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.stereotype.Service;
import java.util.Comparator;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.stream.Collectors;
import static com.tencent.supersonic.headless.core.config.MapperConfig.EMBEDDING_MAPPER_BATCH;
import static com.tencent.supersonic.headless.core.config.MapperConfig.EMBEDDING_MAPPER_MAX;
import static com.tencent.supersonic.headless.core.config.MapperConfig.EMBEDDING_MAPPER_MIN;
import static com.tencent.supersonic.headless.core.config.MapperConfig.EMBEDDING_MAPPER_NUMBER;
import static com.tencent.supersonic.headless.core.config.MapperConfig.EMBEDDING_MAPPER_ROUND_NUMBER;
import static com.tencent.supersonic.headless.core.config.MapperConfig.EMBEDDING_MAPPER_THRESHOLD;
import static com.tencent.supersonic.headless.core.config.MapperConfig.EMBEDDING_MAPPER_THRESHOLD_MIN;
/**
* EmbeddingMatchStrategy uses vector database to perform
* similarity search against the embeddings of schema elements.
@@ -37,6 +30,9 @@ import static com.tencent.supersonic.headless.core.config.MapperConfig.EMBEDDING
@Slf4j
public class EmbeddingMatchStrategy extends BaseMatchStrategy<EmbeddingResult> {
@Autowired
private OptimizationConfig optimizationConfig;
@Autowired
private MetaEmbeddingService metaEmbeddingService;
@@ -52,47 +48,39 @@ public class EmbeddingMatchStrategy extends BaseMatchStrategy<EmbeddingResult> {
}
@Override
public void detectByStep(QueryContext queryContext, Set<EmbeddingResult> existResults,
Set<Long> detectDataSetIds, String detectSegment, int offset) {
public void detectByStep(QueryContext queryContext, Set<EmbeddingResult> existResults, Set<Long> detectViewIds,
String detectSegment, int offset) {
}
@Override
protected void detectByBatch(QueryContext queryContext, Set<EmbeddingResult> results,
Set<Long> detectDataSetIds, Set<String> detectSegments) {
int embedddingMapperMin = Integer.valueOf(mapperConfig.getParameterValue(EMBEDDING_MAPPER_MIN));
int embedddingMapperMax = Integer.valueOf(mapperConfig.getParameterValue(EMBEDDING_MAPPER_MAX));
int embeddingMapperBatch = Integer.valueOf(mapperConfig.getParameterValue(EMBEDDING_MAPPER_BATCH));
protected void detectByBatch(QueryContext queryContext, Set<EmbeddingResult> results, Set<Long> detectViewIds,
Set<String> detectSegments) {
List<String> queryTextsList = detectSegments.stream()
.map(detectSegment -> detectSegment.trim())
.filter(detectSegment -> StringUtils.isNotBlank(detectSegment)
&& detectSegment.length() >= embedddingMapperMin
&& detectSegment.length() <= embedddingMapperMax)
&& detectSegment.length() >= optimizationConfig.getEmbeddingMapperWordMin()
&& detectSegment.length() <= optimizationConfig.getEmbeddingMapperWordMax())
.collect(Collectors.toList());
List<List<String>> queryTextsSubList = Lists.partition(queryTextsList,
embeddingMapperBatch);
optimizationConfig.getEmbeddingMapperBatch());
for (List<String> queryTextsSub : queryTextsSubList) {
detectByQueryTextsSub(results, detectDataSetIds, queryTextsSub, queryContext);
detectByQueryTextsSub(results, detectViewIds, queryTextsSub);
}
}
private void detectByQueryTextsSub(Set<EmbeddingResult> results, Set<Long> detectDataSetIds,
List<String> queryTextsSub, QueryContext queryContext) {
Map<Long, List<Long>> modelIdToDataSetIds = queryContext.getModelIdToDataSetIds();
double embeddingThreshold = Double.valueOf(mapperConfig.getParameterValue(EMBEDDING_MAPPER_THRESHOLD));
double embeddingThresholdMin = Double.valueOf(mapperConfig.getParameterValue(EMBEDDING_MAPPER_THRESHOLD_MIN));
double threshold = getThreshold(embeddingThreshold, embeddingThresholdMin, queryContext.getMapModeEnum());
private void detectByQueryTextsSub(Set<EmbeddingResult> results, Set<Long> detectViewIds,
List<String> queryTextsSub) {
int embeddingNumber = optimizationConfig.getEmbeddingMapperNumber();
Double distance = optimizationConfig.getEmbeddingMapperDistanceThreshold();
// step1. build query params
RetrieveQuery retrieveQuery = RetrieveQuery.builder().queryTextsList(queryTextsSub).build();
// step2. retrieveQuery by detectSegment
int embeddingNumber = Integer.valueOf(mapperConfig.getParameterValue(EMBEDDING_MAPPER_NUMBER));
List<RetrieveQueryResult> retrieveQueryResults = metaEmbeddingService.retrieveQuery(
retrieveQuery, embeddingNumber, modelIdToDataSetIds, detectDataSetIds);
new ArrayList<>(detectViewIds), retrieveQuery, embeddingNumber);
if (CollectionUtils.isEmpty(retrieveQueryResults)) {
return;
@@ -102,12 +90,7 @@ public class EmbeddingMatchStrategy extends BaseMatchStrategy<EmbeddingResult> {
.map(retrieveQueryResult -> {
List<Retrieval> retrievals = retrieveQueryResult.getRetrieval();
if (CollectionUtils.isNotEmpty(retrievals)) {
retrievals.removeIf(retrieval -> {
if (!retrieveQueryResult.getQuery().contains(retrieval.getQuery())) {
return retrieval.getDistance() > 1 - threshold;
}
return false;
});
retrievals.removeIf(retrieval -> retrieval.getDistance() > distance.doubleValue());
}
return retrieveQueryResult;
})
@@ -126,8 +109,7 @@ public class EmbeddingMatchStrategy extends BaseMatchStrategy<EmbeddingResult> {
.collect(Collectors.toList());
// step4. select mapResul in one round
int embeddingRoundNumber = Integer.valueOf(mapperConfig.getParameterValue(EMBEDDING_MAPPER_ROUND_NUMBER));
int roundNumber = embeddingRoundNumber * queryTextsSub.size();
int roundNumber = optimizationConfig.getEmbeddingMapperRoundNumber() * queryTextsSub.size();
List<EmbeddingResult> oneRoundResults = collect.stream()
.sorted(Comparator.comparingDouble(EmbeddingResult::getDistance))
.limit(roundNumber)

View File

@@ -1,12 +1,13 @@
package com.tencent.supersonic.headless.core.chat.mapper;
package com.tencent.supersonic.chat.core.mapper;
import com.tencent.supersonic.headless.api.pojo.DataSetSchema;
import com.tencent.supersonic.headless.api.pojo.SchemaElement;
import com.tencent.supersonic.headless.api.pojo.SchemaElementMatch;
import com.tencent.supersonic.chat.api.pojo.SchemaElementMatch;
import com.tencent.supersonic.headless.api.pojo.SchemaElementType;
import com.tencent.supersonic.headless.api.pojo.SchemaMapInfo;
import com.tencent.supersonic.headless.api.pojo.SemanticSchema;
import com.tencent.supersonic.headless.core.pojo.QueryContext;
import com.tencent.supersonic.chat.api.pojo.SchemaMapInfo;
import com.tencent.supersonic.chat.api.pojo.SemanticSchema;
import com.tencent.supersonic.chat.api.pojo.ViewSchema;
import com.tencent.supersonic.chat.core.pojo.QueryContext;
import lombok.extern.slf4j.Slf4j;
import org.springframework.beans.BeanUtils;
import org.springframework.util.CollectionUtils;
@@ -23,18 +24,18 @@ public class EntityMapper extends BaseMapper {
@Override
public void doMap(QueryContext queryContext) {
SchemaMapInfo schemaMapInfo = queryContext.getMapInfo();
for (Long dataSetId : schemaMapInfo.getMatchedDataSetInfos()) {
List<SchemaElementMatch> schemaElementMatchList = schemaMapInfo.getMatchedElements(dataSetId);
for (Long viewId : schemaMapInfo.getMatchedViewInfos()) {
List<SchemaElementMatch> schemaElementMatchList = schemaMapInfo.getMatchedElements(viewId);
if (CollectionUtils.isEmpty(schemaElementMatchList)) {
continue;
}
SchemaElement entity = getEntity(dataSetId, queryContext);
SchemaElement entity = getEntity(viewId, queryContext);
if (entity == null || entity.getId() == null) {
continue;
}
List<SchemaElementMatch> valueSchemaElements = schemaElementMatchList.stream()
.filter(schemaElementMatch -> SchemaElementType.VALUE.equals(
schemaElementMatch.getElement().getType()))
.filter(schemaElementMatch ->
SchemaElementType.VALUE.equals(schemaElementMatch.getElement().getType()))
.collect(Collectors.toList());
for (SchemaElementMatch schemaElementMatch : valueSchemaElements) {
if (!entity.getId().equals(schemaElementMatch.getElement().getId())) {
@@ -64,9 +65,9 @@ public class EntityMapper extends BaseMapper {
return false;
}
private SchemaElement getEntity(Long dataSetId, QueryContext queryContext) {
private SchemaElement getEntity(Long viewId, QueryContext queryContext) {
SemanticSchema semanticSchema = queryContext.getSemanticSchema();
DataSetSchema modelSchema = semanticSchema.getDataSetSchemaMap().get(dataSetId);
ViewSchema modelSchema = semanticSchema.getViewSchemaMap().get(viewId);
if (modelSchema != null && modelSchema.getEntity() != null) {
return modelSchema.getEntity();
}

View File

@@ -0,0 +1,119 @@
package com.tencent.supersonic.chat.core.mapper;
import com.tencent.supersonic.chat.core.config.OptimizationConfig;
import com.tencent.supersonic.chat.core.pojo.QueryContext;
import com.tencent.supersonic.common.pojo.Constants;
import com.tencent.supersonic.headless.api.pojo.response.S2Term;
import com.tencent.supersonic.headless.core.knowledge.HanlpMapResult;
import com.tencent.supersonic.headless.server.service.KnowledgeService;
import java.util.HashMap;
import java.util.LinkedHashSet;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.Set;
import java.util.stream.Collectors;
import lombok.extern.slf4j.Slf4j;
import org.apache.commons.collections.CollectionUtils;
import org.apache.commons.lang3.StringUtils;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.stereotype.Service;
/**
* HanlpDictMatchStrategy uses <a href="https://www.hanlp.com/">HanLP</a> to
* match schema elements. It currently supports prefix and suffix matching
* against names, values and aliases.
*/
@Service
@Slf4j
public class HanlpDictMatchStrategy extends BaseMatchStrategy<HanlpMapResult> {
@Autowired
private MapperHelper mapperHelper;
@Autowired
private OptimizationConfig optimizationConfig;
@Autowired
private KnowledgeService knowledgeService;
@Override
public Map<MatchText, List<HanlpMapResult>> match(QueryContext queryContext, List<S2Term> terms,
Set<Long> detectViewIds) {
String text = queryContext.getQueryText();
if (Objects.isNull(terms) || StringUtils.isEmpty(text)) {
return null;
}
log.debug("retryCount:{},terms:{},,detectModelIds:{}", terms, detectViewIds);
List<HanlpMapResult> detects = detect(queryContext, terms, detectViewIds);
Map<MatchText, List<HanlpMapResult>> result = new HashMap<>();
result.put(MatchText.builder().regText(text).detectSegment(text).build(), detects);
return result;
}
@Override
public boolean needDelete(HanlpMapResult oneRoundResult, HanlpMapResult existResult) {
return getMapKey(oneRoundResult).equals(getMapKey(existResult))
&& existResult.getDetectWord().length() < oneRoundResult.getDetectWord().length();
}
public void detectByStep(QueryContext queryContext, Set<HanlpMapResult> existResults, Set<Long> detectViewIds,
String detectSegment, int offset) {
// step1. pre search
Integer oneDetectionMaxSize = optimizationConfig.getOneDetectionMaxSize();
LinkedHashSet<HanlpMapResult> hanlpMapResults = knowledgeService.prefixSearch(detectSegment,
oneDetectionMaxSize, detectViewIds).stream().collect(Collectors.toCollection(LinkedHashSet::new));
// step2. suffix search
LinkedHashSet<HanlpMapResult> suffixHanlpMapResults = knowledgeService.suffixSearch(detectSegment,
oneDetectionMaxSize, detectViewIds).stream().collect(Collectors.toCollection(LinkedHashSet::new));
hanlpMapResults.addAll(suffixHanlpMapResults);
if (CollectionUtils.isEmpty(hanlpMapResults)) {
return;
}
// step3. merge pre/suffix result
hanlpMapResults = hanlpMapResults.stream().sorted((a, b) -> -(b.getName().length() - a.getName().length()))
.collect(Collectors.toCollection(LinkedHashSet::new));
// step4. filter by similarity
hanlpMapResults = hanlpMapResults.stream()
.filter(term -> mapperHelper.getSimilarity(detectSegment, term.getName())
>= mapperHelper.getThresholdMatch(term.getNatures()))
.filter(term -> CollectionUtils.isNotEmpty(term.getNatures()))
.collect(Collectors.toCollection(LinkedHashSet::new));
log.info("after isSimilarity parseResults:{}", hanlpMapResults);
hanlpMapResults = hanlpMapResults.stream().map(parseResult -> {
parseResult.setOffset(offset);
parseResult.setSimilarity(mapperHelper.getSimilarity(detectSegment, parseResult.getName()));
return parseResult;
}).collect(Collectors.toCollection(LinkedHashSet::new));
// step5. take only one dimension or 10 metric/dimension value per rond.
List<HanlpMapResult> dimensionMetrics = hanlpMapResults.stream()
.filter(entry -> mapperHelper.existDimensionValues(entry.getNatures()))
.collect(Collectors.toList())
.stream()
.limit(1)
.collect(Collectors.toList());
Integer oneDetectionSize = optimizationConfig.getOneDetectionSize();
List<HanlpMapResult> oneRoundResults = hanlpMapResults.stream().limit(oneDetectionSize)
.collect(Collectors.toList());
if (CollectionUtils.isNotEmpty(dimensionMetrics)) {
oneRoundResults = dimensionMetrics;
}
// step6. select mapResul in one round
selectResultInOneRound(existResults, oneRoundResults);
}
public String getMapKey(HanlpMapResult a) {
return a.getName() + Constants.UNDERLINE + String.join(Constants.UNDERLINE, a.getNatures());
}
}

View File

@@ -1,17 +1,17 @@
package com.tencent.supersonic.headless.core.chat.mapper;
package com.tencent.supersonic.chat.core.mapper;
import com.tencent.supersonic.common.util.ContextUtils;
import com.tencent.supersonic.headless.api.pojo.SchemaElement;
import com.tencent.supersonic.headless.api.pojo.SchemaElementMatch;
import com.tencent.supersonic.chat.api.pojo.SchemaElementMatch;
import com.tencent.supersonic.headless.api.pojo.SchemaElementType;
import com.tencent.supersonic.headless.api.pojo.SchemaMapInfo;
import com.tencent.supersonic.chat.api.pojo.SchemaMapInfo;
import com.tencent.supersonic.headless.api.pojo.response.S2Term;
import com.tencent.supersonic.headless.core.chat.knowledge.builder.BaseWordBuilder;
import com.tencent.supersonic.headless.core.pojo.QueryContext;
import com.tencent.supersonic.headless.core.chat.knowledge.DatabaseMapResult;
import com.tencent.supersonic.headless.core.chat.knowledge.HanlpMapResult;
import com.tencent.supersonic.headless.core.chat.knowledge.helper.HanlpHelper;
import com.tencent.supersonic.headless.core.chat.knowledge.helper.NatureHelper;
import com.tencent.supersonic.headless.core.knowledge.DatabaseMapResult;
import com.tencent.supersonic.headless.core.knowledge.HanlpMapResult;
import com.tencent.supersonic.chat.core.pojo.QueryContext;
import com.tencent.supersonic.headless.server.service.KnowledgeService;
import com.tencent.supersonic.headless.core.knowledge.helper.HanlpHelper;
import com.tencent.supersonic.headless.core.knowledge.helper.NatureHelper;
import com.tencent.supersonic.common.util.ContextUtils;
import lombok.extern.slf4j.Slf4j;
import org.springframework.util.CollectionUtils;
@@ -33,7 +33,8 @@ public class KeywordMapper extends BaseMapper {
public void doMap(QueryContext queryContext) {
String queryText = queryContext.getQueryText();
//1.hanlpDict Match
List<S2Term> terms = HanlpHelper.getTerms(queryText, queryContext.getModelIdToDataSetIds());
KnowledgeService knowledgeService = ContextUtils.getBean(KnowledgeService.class);
List<S2Term> terms = knowledgeService.getTerms(queryText);
HanlpDictMatchStrategy hanlpMatchStrategy = ContextUtils.getBean(HanlpDictMatchStrategy.class);
List<HanlpMapResult> hanlpMapResults = hanlpMatchStrategy.getMatches(queryContext, terms);
@@ -58,8 +59,8 @@ public class KeywordMapper extends BaseMapper {
for (HanlpMapResult hanlpMapResult : mapResults) {
for (String nature : hanlpMapResult.getNatures()) {
Long dataSetId = NatureHelper.getDataSetId(nature);
if (Objects.isNull(dataSetId)) {
Long viewId = NatureHelper.getViewId(nature);
if (Objects.isNull(viewId)) {
continue;
}
SchemaElementType elementType = NatureHelper.convertToElementType(nature);
@@ -67,11 +68,14 @@ public class KeywordMapper extends BaseMapper {
continue;
}
Long elementID = NatureHelper.getElementID(nature);
SchemaElement element = getSchemaElement(dataSetId, elementType,
SchemaElement element = getSchemaElement(viewId, elementType,
elementID, queryContext.getSemanticSchema());
if (element == null) {
continue;
}
if (element.getType().equals(SchemaElementType.VALUE)) {
element.setName(hanlpMapResult.getName());
}
Long frequency = wordNatureToFrequency.get(hanlpMapResult.getName() + nature);
SchemaElementMatch schemaElementMatch = SchemaElementMatch.builder()
.element(element)
@@ -81,7 +85,7 @@ public class KeywordMapper extends BaseMapper {
.detectWord(hanlpMapResult.getDetectWord())
.build();
addToSchemaMap(queryContext.getMapInfo(), dataSetId, schemaElementMatch);
addToSchemaMap(queryContext.getMapInfo(), viewId, schemaElementMatch);
}
}
}
@@ -98,16 +102,16 @@ public class KeywordMapper extends BaseMapper {
.element(schemaElement)
.word(schemaElement.getName())
.detectWord(match.getDetectWord())
.frequency(BaseWordBuilder.DEFAULT_FREQUENCY)
.frequency(10000L)
.similarity(mapperHelper.getSimilarity(match.getDetectWord(), schemaElement.getName()))
.build();
log.info("add to schema, elementMatch {}", schemaElementMatch);
addToSchemaMap(queryContext.getMapInfo(), schemaElement.getDataSet(), schemaElementMatch);
addToSchemaMap(queryContext.getMapInfo(), schemaElement.getView(), schemaElementMatch);
}
}
private Set<Long> getRegElementSet(SchemaMapInfo schemaMap, SchemaElement schemaElement) {
List<SchemaElementMatch> elements = schemaMap.getMatchedElements(schemaElement.getDataSet());
List<SchemaElementMatch> elements = schemaMap.getMatchedElements(schemaElement.getView());
if (CollectionUtils.isEmpty(elements)) {
return new HashSet<>();
}

View File

@@ -1,16 +1,21 @@
package com.tencent.supersonic.headless.core.chat.mapper;
package com.tencent.supersonic.chat.core.mapper;
import com.hankcs.hanlp.algorithm.EditDistance;
import com.tencent.supersonic.chat.core.agent.Agent;
import com.tencent.supersonic.chat.core.config.OptimizationConfig;
import com.tencent.supersonic.headless.api.pojo.response.S2Term;
import com.tencent.supersonic.headless.core.chat.knowledge.helper.NatureHelper;
import com.tencent.supersonic.headless.core.knowledge.helper.NatureHelper;
import lombok.Data;
import lombok.extern.slf4j.Slf4j;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.stereotype.Service;
import java.util.Comparator;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.Set;
import java.util.stream.Collectors;
@Data
@@ -18,6 +23,9 @@ import java.util.stream.Collectors;
@Slf4j
public class MapperHelper {
@Autowired
private OptimizationConfig optimizationConfig;
public Integer getStepIndex(Map<Integer, Integer> regOffsetToLength, Integer index) {
Integer subRegLength = regOffsetToLength.get(index);
if (Objects.nonNull(subRegLength)) {
@@ -40,6 +48,13 @@ public class MapperHelper {
return index;
}
public double getThresholdMatch(List<String> natures) {
if (existDimensionValues(natures)) {
return optimizationConfig.getDimensionValueThresholdConfig();
}
return optimizationConfig.getMetricDimensionThresholdConfig();
}
/***
* exist dimension values
* @param natures
@@ -47,16 +62,7 @@ public class MapperHelper {
*/
public boolean existDimensionValues(List<String> natures) {
for (String nature : natures) {
if (NatureHelper.isDimensionValueDataSetId(nature)) {
return true;
}
}
return false;
}
public boolean existTerms(List<String> natures) {
for (String nature : natures) {
if (NatureHelper.isTermNature(nature)) {
if (NatureHelper.isDimensionValueViewId(nature)) {
return true;
}
}
@@ -75,4 +81,34 @@ public class MapperHelper {
return 1 - (double) EditDistance.compute(detectSegmentLower, matchNameLower) / Math.max(matchName.length(),
detectSegment.length());
}
public Set<Long> getViewIds(Long viewId, Agent agent) {
Set<Long> detectViewIds = new HashSet<>();
if (Objects.nonNull(agent)) {
detectViewIds = agent.getViewIds(null);
}
//contains all
if (Agent.containsAllModel(detectViewIds)) {
if (Objects.nonNull(viewId) && viewId > 0) {
Set<Long> result = new HashSet<>();
result.add(viewId);
return result;
}
return new HashSet<>();
}
if (Objects.nonNull(detectViewIds)) {
detectViewIds = detectViewIds.stream().filter(entry -> entry > 0).collect(Collectors.toSet());
}
if (Objects.nonNull(viewId) && viewId > 0 && Objects.nonNull(detectViewIds)) {
if (detectViewIds.contains(viewId)) {
Set<Long> result = new HashSet<>();
result.add(viewId);
return result;
}
}
return detectViewIds;
}
}

View File

@@ -1,8 +1,7 @@
package com.tencent.supersonic.headless.core.chat.mapper;
package com.tencent.supersonic.chat.core.mapper;
import com.tencent.supersonic.chat.core.pojo.QueryContext;
import com.tencent.supersonic.headless.api.pojo.response.S2Term;
import com.tencent.supersonic.headless.core.pojo.QueryContext;
import java.util.List;
import java.util.Map;
@@ -14,6 +13,6 @@ import java.util.Set;
*/
public interface MatchStrategy<T> {
Map<MatchText, List<T>> match(QueryContext queryContext, List<S2Term> terms, Set<Long> detectDataSetIds);
Map<MatchText, List<T>> match(QueryContext queryContext, List<S2Term> terms, Set<Long> detectViewIds);
}

View File

@@ -1,11 +1,10 @@
package com.tencent.supersonic.headless.core.chat.mapper;
package com.tencent.supersonic.chat.core.mapper;
import java.util.Objects;
import lombok.Builder;
import lombok.Data;
import lombok.ToString;
import java.util.Objects;
@Data
@ToString
@Builder

View File

@@ -1,11 +1,10 @@
package com.tencent.supersonic.headless.core.chat.mapper;
package com.tencent.supersonic.chat.core.mapper;
import com.tencent.supersonic.headless.api.pojo.SchemaElementType;
import java.io.Serializable;
import lombok.Data;
import lombok.ToString;
import java.io.Serializable;
@Data
@ToString
public class ModelWithSemanticType implements Serializable {

View File

@@ -1,59 +1,56 @@
package com.tencent.supersonic.headless.core.chat.mapper;
package com.tencent.supersonic.chat.core.mapper;
import com.google.common.collect.Lists;
import com.tencent.supersonic.common.pojo.Constants;
import com.tencent.supersonic.headless.api.pojo.SchemaElement;
import com.tencent.supersonic.headless.api.pojo.SchemaElementMatch;
import com.tencent.supersonic.chat.api.pojo.SchemaElementMatch;
import com.tencent.supersonic.headless.api.pojo.SchemaElementType;
import com.tencent.supersonic.headless.api.pojo.SchemaMapInfo;
import com.tencent.supersonic.headless.api.pojo.request.QueryFilter;
import com.tencent.supersonic.headless.api.pojo.request.QueryFilters;
import com.tencent.supersonic.headless.core.pojo.QueryContext;
import com.tencent.supersonic.headless.core.chat.knowledge.builder.BaseWordBuilder;
import com.tencent.supersonic.chat.api.pojo.SchemaMapInfo;
import com.tencent.supersonic.chat.api.pojo.request.QueryFilter;
import com.tencent.supersonic.chat.api.pojo.request.QueryFilters;
import com.tencent.supersonic.headless.core.knowledge.builder.BaseWordBuilder;
import com.tencent.supersonic.chat.core.pojo.QueryContext;
import com.tencent.supersonic.common.pojo.Constants;
import lombok.extern.slf4j.Slf4j;
import org.springframework.util.CollectionUtils;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.stream.Collectors;
@Slf4j
public class QueryFilterMapper extends BaseMapper {
public class QueryFilterMapper implements SchemaMapper {
private double similarity = 1.0;
@Override
public void doMap(QueryContext queryContext) {
Set<Long> dataSetIds = queryContext.getDataSetIds();
if (CollectionUtils.isEmpty(dataSetIds)) {
public void map(QueryContext queryContext) {
Long viewId = queryContext.getViewId();
if (viewId == null || viewId <= 0) {
return;
}
SchemaMapInfo schemaMapInfo = queryContext.getMapInfo();
clearOtherSchemaElementMatch(dataSetIds, schemaMapInfo);
for (Long dataSetId : dataSetIds) {
List<SchemaElementMatch> schemaElementMatches = schemaMapInfo.getMatchedElements(dataSetId);
clearOtherSchemaElementMatch(viewId, schemaMapInfo);
List<SchemaElementMatch> schemaElementMatches = schemaMapInfo.getMatchedElements(viewId);
if (schemaElementMatches == null) {
schemaElementMatches = Lists.newArrayList();
schemaMapInfo.setMatchedElements(dataSetId, schemaElementMatches);
}
addValueSchemaElementMatch(dataSetId, queryContext, schemaElementMatches);
schemaMapInfo.setMatchedElements(viewId, schemaElementMatches);
}
addValueSchemaElementMatch(queryContext, schemaElementMatches);
}
private void clearOtherSchemaElementMatch(Set<Long> viewIds, SchemaMapInfo schemaMapInfo) {
for (Map.Entry<Long, List<SchemaElementMatch>> entry : schemaMapInfo.getDataSetElementMatches().entrySet()) {
if (!viewIds.contains(entry.getKey())) {
private void clearOtherSchemaElementMatch(Long modelId, SchemaMapInfo schemaMapInfo) {
for (Map.Entry<Long, List<SchemaElementMatch>> entry : schemaMapInfo.getViewElementMatches().entrySet()) {
if (!entry.getKey().equals(modelId)) {
entry.getValue().clear();
}
}
}
private void addValueSchemaElementMatch(Long dataSetId, QueryContext queryContext,
private List<SchemaElementMatch> addValueSchemaElementMatch(QueryContext queryContext,
List<SchemaElementMatch> candidateElementMatches) {
QueryFilters queryFilters = queryContext.getQueryFilters();
if (queryFilters == null || CollectionUtils.isEmpty(queryFilters.getFilters())) {
return;
return candidateElementMatches;
}
for (QueryFilter filter : queryFilters.getFilters()) {
if (checkExistSameValueSchemaElementMatch(filter, candidateElementMatches)) {
@@ -64,7 +61,7 @@ public class QueryFilterMapper extends BaseMapper {
.name(String.valueOf(filter.getValue()))
.type(SchemaElementType.VALUE)
.bizName(filter.getBizName())
.dataSet(dataSetId)
.view(queryContext.getViewId())
.build();
SchemaElementMatch schemaElementMatch = SchemaElementMatch.builder()
.element(element)
@@ -75,7 +72,7 @@ public class QueryFilterMapper extends BaseMapper {
.build();
candidateElementMatches.add(schemaElementMatch);
}
queryContext.getMapInfo().setMatchedElements(dataSetId, candidateElementMatches);
return candidateElementMatches;
}
private boolean checkExistSameValueSchemaElementMatch(QueryFilter queryFilter,

View File

@@ -1,7 +1,6 @@
package com.tencent.supersonic.headless.core.chat.mapper;
package com.tencent.supersonic.chat.core.mapper;
import com.tencent.supersonic.headless.core.pojo.QueryContext;
import com.tencent.supersonic.chat.core.pojo.QueryContext;
/**
* A schema mapper identifies references to schema elements(metrics/dimensions/entities/values)

View File

@@ -1,23 +1,22 @@
package com.tencent.supersonic.headless.core.chat.mapper;
package com.tencent.supersonic.chat.core.mapper;
import com.google.common.collect.Lists;
import com.tencent.supersonic.common.pojo.enums.DictWordType;
import com.tencent.supersonic.headless.api.pojo.response.S2Term;
import com.tencent.supersonic.headless.core.pojo.QueryContext;
import com.tencent.supersonic.headless.core.chat.knowledge.HanlpMapResult;
import com.tencent.supersonic.headless.core.chat.knowledge.KnowledgeBaseService;
import com.tencent.supersonic.headless.core.chat.knowledge.SearchService;
import org.apache.commons.collections.CollectionUtils;
import org.apache.commons.lang3.StringUtils;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.stereotype.Service;
import com.tencent.supersonic.headless.core.knowledge.HanlpMapResult;
import com.tencent.supersonic.headless.core.knowledge.SearchService;
import com.tencent.supersonic.chat.core.pojo.QueryContext;
import com.tencent.supersonic.common.pojo.enums.DictWordType;
import com.tencent.supersonic.headless.server.service.KnowledgeService;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.Set;
import java.util.concurrent.ConcurrentHashMap;
import java.util.stream.Collectors;
import org.apache.commons.collections.CollectionUtils;
import org.apache.commons.lang3.StringUtils;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.stereotype.Service;
/**
* SearchMatchStrategy encapsulates a concrete matching algorithm
@@ -29,11 +28,11 @@ public class SearchMatchStrategy extends BaseMatchStrategy<HanlpMapResult> {
private static final int SEARCH_SIZE = 3;
@Autowired
private KnowledgeBaseService knowledgeBaseService;
private KnowledgeService knowledgeService;
@Override
public Map<MatchText, List<HanlpMapResult>> match(QueryContext queryContext, List<S2Term> originals,
Set<Long> detectDataSetIds) {
Set<Long> detectViewIds) {
String text = queryContext.getQueryText();
Map<Integer, Integer> regOffsetToLength = getRegOffsetToLength(originals);
@@ -57,10 +56,10 @@ public class SearchMatchStrategy extends BaseMatchStrategy<HanlpMapResult> {
String detectSegment = text.substring(detectIndex);
if (StringUtils.isNotEmpty(detectSegment)) {
List<HanlpMapResult> hanlpMapResults = knowledgeBaseService.prefixSearch(detectSegment,
SearchService.SEARCH_SIZE, queryContext.getModelIdToDataSetIds(), detectDataSetIds);
List<HanlpMapResult> suffixHanlpMapResults = knowledgeBaseService.suffixSearch(
detectSegment, SEARCH_SIZE, queryContext.getModelIdToDataSetIds(), detectDataSetIds);
List<HanlpMapResult> hanlpMapResults = knowledgeService.prefixSearch(detectSegment,
SearchService.SEARCH_SIZE, detectViewIds);
List<HanlpMapResult> suffixHanlpMapResults = knowledgeService.suffixSearch(
detectSegment, SEARCH_SIZE, detectViewIds);
hanlpMapResults.addAll(suffixHanlpMapResults);
// remove entity name where search
hanlpMapResults = hanlpMapResults.stream().filter(entry -> {
@@ -94,7 +93,7 @@ public class SearchMatchStrategy extends BaseMatchStrategy<HanlpMapResult> {
}
@Override
public void detectByStep(QueryContext queryContext, Set<HanlpMapResult> existResults, Set<Long> detectDataSetIds,
public void detectByStep(QueryContext queryContext, Set<HanlpMapResult> existResults, Set<Long> detectViewIds,
String detectSegment, int offset) {
}

View File

@@ -0,0 +1,66 @@
package com.tencent.supersonic.chat.core.parser;
import com.tencent.supersonic.chat.core.parser.plugin.function.FunctionPromptGenerator;
import com.tencent.supersonic.chat.core.parser.plugin.function.FunctionReq;
import com.tencent.supersonic.chat.core.parser.plugin.function.FunctionResp;
import com.tencent.supersonic.chat.core.parser.sql.llm.OutputFormat;
import com.tencent.supersonic.chat.core.parser.sql.llm.SqlGeneration;
import com.tencent.supersonic.chat.core.parser.sql.llm.SqlGenerationFactory;
import com.tencent.supersonic.chat.core.pojo.QueryContext;
import com.tencent.supersonic.chat.core.query.llm.s2sql.LLMReq;
import com.tencent.supersonic.chat.core.query.llm.s2sql.LLMReq.SqlGenerationMode;
import com.tencent.supersonic.chat.core.query.llm.s2sql.LLMResp;
import com.tencent.supersonic.common.util.ContextUtils;
import dev.langchain4j.model.chat.ChatLanguageModel;
import lombok.extern.slf4j.Slf4j;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.stereotype.Component;
import java.util.Objects;
/**
* LLMProxy based on langchain4j Java version.
*/
@Slf4j
@Component
public class JavaLLMProxy implements LLMProxy {
private static final Logger keyPipelineLog = LoggerFactory.getLogger("keyPipeline");
@Override
public boolean isSkip(QueryContext queryContext) {
ChatLanguageModel chatLanguageModel = ContextUtils.getBean(ChatLanguageModel.class);
if (Objects.isNull(chatLanguageModel)) {
log.warn("chatLanguageModel is null, skip :{}", JavaLLMProxy.class.getName());
return true;
}
return false;
}
public LLMResp query2sql(LLMReq llmReq, Long viewId) {
SqlGeneration sqlGeneration = SqlGenerationFactory.get(
SqlGenerationMode.getMode(llmReq.getSqlGenerationMode()));
String modelName = llmReq.getSchema().getViewName();
LLMResp result = sqlGeneration.generation(llmReq, viewId);
result.setQuery(llmReq.getQueryText());
result.setModelName(modelName);
return result;
}
@Override
public FunctionResp requestFunction(FunctionReq functionReq) {
FunctionPromptGenerator promptGenerator = ContextUtils.getBean(FunctionPromptGenerator.class);
ChatLanguageModel chatLanguageModel = ContextUtils.getBean(ChatLanguageModel.class);
String functionCallPrompt = promptGenerator.generateFunctionCallPrompt(functionReq.getQueryText(),
functionReq.getPluginConfigs());
keyPipelineLog.info("functionCallPrompt:{}", functionCallPrompt);
String response = chatLanguageModel.generate(functionCallPrompt);
keyPipelineLog.info("functionCall response:{}", response);
return OutputFormat.functionCallParse(response);
}
}

View File

@@ -0,0 +1,22 @@
package com.tencent.supersonic.chat.core.parser;
import com.tencent.supersonic.chat.core.pojo.QueryContext;
import com.tencent.supersonic.chat.core.parser.plugin.function.FunctionReq;
import com.tencent.supersonic.chat.core.parser.plugin.function.FunctionResp;
import com.tencent.supersonic.chat.core.query.llm.s2sql.LLMReq;
import com.tencent.supersonic.chat.core.query.llm.s2sql.LLMResp;
/**
* LLMProxy encapsulates functions performed by LLMs so that multiple
* orchestration frameworks (e.g. LangChain in python, LangChain4j in java)
* could be used.
*/
public interface LLMProxy {
boolean isSkip(QueryContext queryContext);
LLMResp query2sql(LLMReq llmReq, Long viewId);
FunctionResp requestFunction(FunctionReq functionReq);
}

View File

@@ -0,0 +1,104 @@
package com.tencent.supersonic.chat.core.parser;
import com.alibaba.fastjson.JSON;
import com.tencent.supersonic.chat.core.config.LLMParserConfig;
import com.tencent.supersonic.chat.core.parser.plugin.function.FunctionCallConfig;
import com.tencent.supersonic.chat.core.parser.plugin.function.FunctionReq;
import com.tencent.supersonic.chat.core.parser.plugin.function.FunctionResp;
import com.tencent.supersonic.chat.core.parser.sql.llm.OutputFormat;
import com.tencent.supersonic.chat.core.pojo.QueryContext;
import com.tencent.supersonic.chat.core.query.llm.s2sql.LLMReq;
import com.tencent.supersonic.chat.core.query.llm.s2sql.LLMResp;
import com.tencent.supersonic.common.util.ContextUtils;
import com.tencent.supersonic.common.util.JsonUtil;
import lombok.extern.slf4j.Slf4j;
import org.apache.commons.collections4.MapUtils;
import org.apache.commons.lang3.StringUtils;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.http.HttpEntity;
import org.springframework.http.HttpHeaders;
import org.springframework.http.HttpMethod;
import org.springframework.http.MediaType;
import org.springframework.http.ResponseEntity;
import org.springframework.stereotype.Component;
import org.springframework.web.client.RestTemplate;
import org.springframework.web.util.UriComponentsBuilder;
import java.net.URI;
import java.net.URL;
import java.util.ArrayList;
/**
* PythonLLMProxy sends requests to LangChain-based python service.
*/
@Slf4j
@Component
public class PythonLLMProxy implements LLMProxy {
private static final Logger keyPipelineLog = LoggerFactory.getLogger("keyPipeline");
@Override
public boolean isSkip(QueryContext queryContext) {
LLMParserConfig llmParserConfig = ContextUtils.getBean(LLMParserConfig.class);
if (StringUtils.isEmpty(llmParserConfig.getUrl())) {
log.warn("llmParserUrl is empty, skip :{}", PythonLLMProxy.class.getName());
return true;
}
return false;
}
public LLMResp query2sql(LLMReq llmReq, Long viewId) {
long startTime = System.currentTimeMillis();
log.info("requestLLM request, viewId:{},llmReq:{}", viewId, llmReq);
keyPipelineLog.info("viewId:{},llmReq:{}", viewId, llmReq);
try {
LLMParserConfig llmParserConfig = ContextUtils.getBean(LLMParserConfig.class);
URL url = new URL(new URL(llmParserConfig.getUrl()), llmParserConfig.getQueryToSqlPath());
HttpHeaders headers = new HttpHeaders();
headers.setContentType(MediaType.APPLICATION_JSON);
HttpEntity<String> entity = new HttpEntity<>(JsonUtil.toString(llmReq), headers);
RestTemplate restTemplate = ContextUtils.getBean(RestTemplate.class);
ResponseEntity<LLMResp> responseEntity = restTemplate.exchange(url.toString(), HttpMethod.POST, entity,
LLMResp.class);
LLMResp llmResp = responseEntity.getBody();
log.info("requestLLM response,cost:{}, questUrl:{} \n entity:{} \n body:{}",
System.currentTimeMillis() - startTime, url, entity, llmResp);
keyPipelineLog.info("LLMResp:{}", llmResp);
if (MapUtils.isEmpty(llmResp.getSqlRespMap())) {
llmResp.setSqlRespMap(OutputFormat.buildSqlRespMap(new ArrayList<>(), llmResp.getSqlWeight()));
}
return llmResp;
} catch (Exception e) {
log.error("requestLLM error", e);
}
return null;
}
public FunctionResp requestFunction(FunctionReq functionReq) {
FunctionCallConfig functionCallInfoConfig = ContextUtils.getBean(FunctionCallConfig.class);
String url = functionCallInfoConfig.getUrl() + functionCallInfoConfig.getPluginSelectPath();
HttpHeaders headers = new HttpHeaders();
long startTime = System.currentTimeMillis();
headers.setContentType(MediaType.APPLICATION_JSON);
HttpEntity<String> entity = new HttpEntity<>(JSON.toJSONString(functionReq), headers);
URI requestUrl = UriComponentsBuilder.fromHttpUrl(url).build().encode().toUri();
RestTemplate restTemplate = ContextUtils.getBean(RestTemplate.class);
try {
log.info("requestFunction functionReq:{}", JsonUtil.toString(functionReq));
keyPipelineLog.info("requestFunction functionReq:{}", JsonUtil.toString(functionReq));
ResponseEntity<FunctionResp> responseEntity = restTemplate.exchange(requestUrl, HttpMethod.POST, entity,
FunctionResp.class);
log.info("requestFunction responseEntity:{},cost:{}", responseEntity,
System.currentTimeMillis() - startTime);
keyPipelineLog.info("response:{}", responseEntity.getBody());
return responseEntity.getBody();
} catch (Exception e) {
log.error("requestFunction error", e);
}
return null;
}
}

View File

@@ -1,18 +1,18 @@
package com.tencent.supersonic.headless.core.chat.parser;
package com.tencent.supersonic.chat.core.parser;
import com.tencent.supersonic.auth.api.authentication.pojo.User;
import com.tencent.supersonic.headless.api.pojo.SchemaElement;
import com.tencent.supersonic.chat.api.pojo.SemanticParseInfo;
import com.tencent.supersonic.chat.api.pojo.SemanticSchema;
import com.tencent.supersonic.chat.api.pojo.response.SqlInfo;
import com.tencent.supersonic.chat.core.pojo.ChatContext;
import com.tencent.supersonic.chat.core.pojo.QueryContext;
import com.tencent.supersonic.chat.core.query.SemanticQuery;
import com.tencent.supersonic.chat.core.query.llm.s2sql.LLMSqlQuery;
import com.tencent.supersonic.chat.core.query.rule.RuleSemanticQuery;
import com.tencent.supersonic.common.pojo.enums.QueryType;
import com.tencent.supersonic.common.pojo.enums.TimeDimensionEnum;
import com.tencent.supersonic.common.util.jsqlparser.SqlSelectHelper;
import com.tencent.supersonic.headless.api.pojo.SchemaElement;
import com.tencent.supersonic.headless.api.pojo.SemanticParseInfo;
import com.tencent.supersonic.headless.api.pojo.SemanticSchema;
import com.tencent.supersonic.headless.api.pojo.SqlInfo;
import com.tencent.supersonic.headless.core.chat.query.SemanticQuery;
import com.tencent.supersonic.headless.core.chat.query.llm.s2sql.LLMSqlQuery;
import com.tencent.supersonic.headless.core.chat.query.rule.RuleSemanticQuery;
import com.tencent.supersonic.headless.core.pojo.ChatContext;
import com.tencent.supersonic.headless.core.pojo.QueryContext;
import lombok.extern.slf4j.Slf4j;
import org.apache.commons.collections4.CollectionUtils;
import org.apache.commons.lang3.StringUtils;
@@ -50,54 +50,38 @@ public class QueryTypeParser implements SemanticParser {
return QueryType.ID;
}
//1. entity queryType
Long dataSetId = parseInfo.getDataSetId();
Long viewId = parseInfo.getViewId();
SemanticSchema semanticSchema = queryContext.getSemanticSchema();
if (semanticQuery instanceof RuleSemanticQuery || semanticQuery instanceof LLMSqlQuery) {
List<String> whereFields = SqlSelectHelper.getWhereFields(sqlInfo.getS2SQL());
List<String> whereFilterByTimeFields = filterByTimeFields(whereFields);
if (CollectionUtils.isNotEmpty(whereFilterByTimeFields)) {
Set<String> ids = semanticSchema.getEntities(dataSetId).stream().map(SchemaElement::getName)
//If all the fields in the SELECT statement are of tag type.
List<String> whereFields = SqlSelectHelper.getWhereFields(sqlInfo.getS2SQL())
.stream().filter(field -> !TimeDimensionEnum.containsTimeDimension(field))
.collect(Collectors.toList());
if (CollectionUtils.isNotEmpty(whereFields)) {
Set<String> ids = semanticSchema.getEntities(viewId).stream().map(SchemaElement::getName)
.collect(Collectors.toSet());
if (CollectionUtils.isNotEmpty(ids) && ids.stream()
.anyMatch(whereFilterByTimeFields::contains)) {
if (CollectionUtils.isNotEmpty(ids) && ids.stream().anyMatch(whereFields::contains)) {
return QueryType.ID;
}
}
List<String> selectFields = SqlSelectHelper.getSelectFields(sqlInfo.getS2SQL());
selectFields.addAll(whereFields);
List<String> selectWhereFilterByTimeFields = filterByTimeFields(selectFields);
if (CollectionUtils.isNotEmpty(selectWhereFilterByTimeFields)) {
Set<String> tags = semanticSchema.getTags(dataSetId).stream().map(SchemaElement::getName)
Set<String> tags = semanticSchema.getTags(viewId).stream().map(SchemaElement::getName)
.collect(Collectors.toSet());
//If all the fields in the SELECT/WHERE statement are of tag type.
if (CollectionUtils.isNotEmpty(tags)
&& tags.containsAll(selectWhereFilterByTimeFields)) {
return QueryType.DETAIL;
if (CollectionUtils.isNotEmpty(tags) && tags.containsAll(whereFields)) {
return QueryType.TAG;
}
}
}
//2. metric queryType
if (selectContainsMetric(sqlInfo, dataSetId, semanticSchema)) {
List<String> selectFields = SqlSelectHelper.getSelectFields(sqlInfo.getS2SQL());
List<SchemaElement> metrics = semanticSchema.getMetrics(viewId);
if (CollectionUtils.isNotEmpty(metrics)) {
Set<String> metricNameSet = metrics.stream().map(SchemaElement::getName).collect(Collectors.toSet());
boolean containMetric = selectFields.stream().anyMatch(metricNameSet::contains);
if (containMetric) {
return QueryType.METRIC;
}
}
return QueryType.ID;
}
private static List<String> filterByTimeFields(List<String> whereFields) {
List<String> selectAndWhereFilterByTimeFields = whereFields
.stream().filter(field -> !TimeDimensionEnum.containsTimeDimension(field))
.collect(Collectors.toList());
return selectAndWhereFilterByTimeFields;
}
private static boolean selectContainsMetric(SqlInfo sqlInfo, Long dataSetId, SemanticSchema semanticSchema) {
List<String> selectFields = SqlSelectHelper.getSelectFields(sqlInfo.getS2SQL());
List<SchemaElement> metrics = semanticSchema.getMetrics(dataSetId);
if (CollectionUtils.isNotEmpty(metrics)) {
Set<String> metricNameSet = metrics.stream().map(SchemaElement::getName).collect(Collectors.toSet());
return selectFields.stream().anyMatch(metricNameSet::contains);
}
return false;
}
}

View File

@@ -0,0 +1,49 @@
package com.tencent.supersonic.chat.core.parser;
import com.tencent.supersonic.chat.core.query.SemanticQuery;
import com.tencent.supersonic.chat.core.pojo.QueryContext;
import com.tencent.supersonic.chat.api.pojo.SemanticParseInfo;
import com.tencent.supersonic.chat.core.config.OptimizationConfig;
import com.tencent.supersonic.chat.core.query.llm.s2sql.LLMSqlQuery;
import com.tencent.supersonic.common.util.ContextUtils;
import lombok.extern.slf4j.Slf4j;
/**
* This checker can be used by semantic parsers to check if query intent
* has already been satisfied by current candidate queries. If so, current
* parser could be skipped.
*/
@Slf4j
public class SatisfactionChecker {
// check all the parse info in candidate
public static boolean isSkip(QueryContext queryContext) {
for (SemanticQuery query : queryContext.getCandidateQueries()) {
if (query.getQueryMode().equals(LLMSqlQuery.QUERY_MODE)) {
continue;
}
if (checkThreshold(queryContext.getQueryText(), query.getParseInfo())) {
return true;
}
}
return false;
}
private static boolean checkThreshold(String queryText, SemanticParseInfo semanticParseInfo) {
int queryTextLength = queryText.replaceAll(" ", "").length();
double degree = semanticParseInfo.getScore() / queryTextLength;
OptimizationConfig optimizationConfig = ContextUtils.getBean(OptimizationConfig.class);
if (queryTextLength > optimizationConfig.getQueryTextLengthThreshold()) {
if (degree < optimizationConfig.getLongTextThreshold()) {
return false;
}
} else if (degree < optimizationConfig.getShortTextThreshold()) {
return false;
}
log.info("queryMode:{}, degree:{}, parse info:{}",
semanticParseInfo.getQueryMode(), degree, semanticParseInfo);
return true;
}
}

View File

@@ -0,0 +1,15 @@
package com.tencent.supersonic.chat.core.parser;
import com.tencent.supersonic.chat.core.pojo.QueryContext;
import com.tencent.supersonic.chat.core.pojo.ChatContext;
/**
* A semantic parser understands user queries and extracts semantic information.
* It could leverage either rule-based or LLM-based approach to identify query intent
* and extract related semantic items from the query.
*/
public interface SemanticParser {
void parse(QueryContext queryContext, ChatContext chatContext);
}

View File

@@ -1,4 +1,4 @@
package com.tencent.supersonic.chat.server.plugin;
package com.tencent.supersonic.chat.core.parser.plugin;
public enum ParseMode {

View File

@@ -0,0 +1,123 @@
package com.tencent.supersonic.chat.core.parser.plugin;
import com.google.common.collect.Lists;
import com.google.common.collect.Sets;
import com.tencent.supersonic.chat.api.pojo.SchemaElementMatch;
import com.tencent.supersonic.headless.api.pojo.SchemaElementType;
import com.tencent.supersonic.chat.api.pojo.SemanticParseInfo;
import com.tencent.supersonic.chat.api.pojo.request.QueryFilter;
import com.tencent.supersonic.chat.api.pojo.request.QueryFilters;
import com.tencent.supersonic.chat.core.parser.SemanticParser;
import com.tencent.supersonic.chat.core.plugin.Plugin;
import com.tencent.supersonic.chat.core.plugin.PluginManager;
import com.tencent.supersonic.chat.core.plugin.PluginParseResult;
import com.tencent.supersonic.chat.core.plugin.PluginRecallResult;
import com.tencent.supersonic.chat.core.pojo.ChatContext;
import com.tencent.supersonic.chat.core.pojo.QueryContext;
import com.tencent.supersonic.chat.core.query.QueryManager;
import com.tencent.supersonic.chat.core.query.SemanticQuery;
import com.tencent.supersonic.chat.core.query.plugin.PluginSemanticQuery;
import com.tencent.supersonic.common.pojo.Constants;
import com.tencent.supersonic.common.pojo.enums.FilterOperatorEnum;
import org.springframework.util.CollectionUtils;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Set;
/**
* PluginParser defines the basic process and common methods for recalling plugins.
*/
public abstract class PluginParser implements SemanticParser {
@Override
public void parse(QueryContext queryContext, ChatContext chatContext) {
for (SemanticQuery semanticQuery : queryContext.getCandidateQueries()) {
if (queryContext.getQueryText().length() <= semanticQuery.getParseInfo().getScore()
&& (QueryManager.getPluginQueryModes().contains(semanticQuery.getQueryMode()))) {
return;
}
}
if (!checkPreCondition(queryContext)) {
return;
}
PluginRecallResult pluginRecallResult = recallPlugin(queryContext);
if (pluginRecallResult == null) {
return;
}
buildQuery(queryContext, pluginRecallResult);
}
public abstract boolean checkPreCondition(QueryContext queryContext);
public abstract PluginRecallResult recallPlugin(QueryContext queryContext);
public void buildQuery(QueryContext queryContext, PluginRecallResult pluginRecallResult) {
Plugin plugin = pluginRecallResult.getPlugin();
Set<Long> viewIds = pluginRecallResult.getViewIds();
if (plugin.isContainsAllModel()) {
viewIds = Sets.newHashSet(-1L);
}
for (Long viewId : viewIds) {
PluginSemanticQuery pluginQuery = QueryManager.createPluginQuery(plugin.getType());
SemanticParseInfo semanticParseInfo = buildSemanticParseInfo(viewId, plugin,
queryContext, pluginRecallResult.getDistance());
semanticParseInfo.setQueryMode(pluginQuery.getQueryMode());
semanticParseInfo.setScore(pluginRecallResult.getScore());
pluginQuery.setParseInfo(semanticParseInfo);
queryContext.getCandidateQueries().add(pluginQuery);
}
}
protected List<Plugin> getPluginList(QueryContext queryContext) {
return PluginManager.getPluginAgentCanSupport(queryContext);
}
protected SemanticParseInfo buildSemanticParseInfo(Long viewId, Plugin plugin,
QueryContext queryContext, double distance) {
List<SchemaElementMatch> schemaElementMatches = queryContext.getMapInfo().getMatchedElements(viewId);
QueryFilters queryFilters = queryContext.getQueryFilters();
if (viewId == null && !CollectionUtils.isEmpty(plugin.getViewList())) {
viewId = plugin.getViewList().get(0);
}
if (schemaElementMatches == null) {
schemaElementMatches = Lists.newArrayList();
}
SemanticParseInfo semanticParseInfo = new SemanticParseInfo();
semanticParseInfo.setElementMatches(schemaElementMatches);
semanticParseInfo.setView(queryContext.getSemanticSchema().getView(viewId));
Map<String, Object> properties = new HashMap<>();
PluginParseResult pluginParseResult = new PluginParseResult();
pluginParseResult.setPlugin(plugin);
pluginParseResult.setQueryFilters(queryFilters);
pluginParseResult.setDistance(distance);
pluginParseResult.setQueryText(queryContext.getQueryText());
properties.put(Constants.CONTEXT, pluginParseResult);
properties.put("type", "plugin");
properties.put("name", plugin.getName());
semanticParseInfo.setProperties(properties);
semanticParseInfo.setScore(distance);
fillSemanticParseInfo(semanticParseInfo);
return semanticParseInfo;
}
private void fillSemanticParseInfo(SemanticParseInfo semanticParseInfo) {
List<SchemaElementMatch> schemaElementMatches = semanticParseInfo.getElementMatches();
if (CollectionUtils.isEmpty(schemaElementMatches)) {
return;
}
schemaElementMatches.stream().filter(schemaElementMatch ->
SchemaElementType.VALUE.equals(schemaElementMatch.getElement().getType())
|| SchemaElementType.ID.equals(schemaElementMatch.getElement().getType()))
.forEach(schemaElementMatch -> {
QueryFilter queryFilter = new QueryFilter();
queryFilter.setValue(schemaElementMatch.getWord());
queryFilter.setElementID(schemaElementMatch.getElement().getId());
queryFilter.setName(schemaElementMatch.getElement().getName());
queryFilter.setOperator(FilterOperatorEnum.EQUALS);
queryFilter.setBizName(schemaElementMatch.getElement().getBizName());
semanticParseInfo.getDimensionFilters().add(queryFilter);
});
}
}

View File

@@ -1,18 +1,18 @@
package com.tencent.supersonic.chat.server.plugin.recognize.embedding;
package com.tencent.supersonic.chat.core.parser.plugin.embedding;
import com.google.common.collect.Lists;
import com.tencent.supersonic.chat.server.plugin.ParseMode;
import com.tencent.supersonic.chat.server.plugin.Plugin;
import com.tencent.supersonic.chat.server.plugin.PluginManager;
import com.tencent.supersonic.chat.server.plugin.PluginRecallResult;
import com.tencent.supersonic.chat.server.plugin.recognize.PluginRecognizer;
import com.tencent.supersonic.chat.server.pojo.ChatParseContext;
import com.tencent.supersonic.chat.core.pojo.QueryContext;
import com.tencent.supersonic.chat.core.plugin.Plugin;
import com.tencent.supersonic.chat.core.plugin.PluginManager;
import com.tencent.supersonic.chat.core.plugin.PluginRecallResult;
import com.tencent.supersonic.chat.core.utils.ComponentFactory;
import com.tencent.supersonic.chat.core.parser.PythonLLMProxy;
import com.tencent.supersonic.chat.core.parser.plugin.ParseMode;
import com.tencent.supersonic.chat.core.parser.plugin.PluginParser;
import com.tencent.supersonic.common.config.EmbeddingConfig;
import com.tencent.supersonic.common.util.ContextUtils;
import com.tencent.supersonic.common.util.embedding.Retrieval;
import com.tencent.supersonic.common.util.embedding.RetrieveQueryResult;
import com.tencent.supersonic.headless.core.chat.parser.llm.PythonLLMProxy;
import com.tencent.supersonic.headless.core.utils.ComponentFactory;
import lombok.extern.slf4j.Slf4j;
import org.apache.commons.lang3.StringUtils;
import org.apache.commons.lang3.tuple.Pair;
@@ -28,42 +28,44 @@ import java.util.stream.Collectors;
* EmbeddingRecallParser is an implementation of a recall plugin based on Embedding
*/
@Slf4j
public class EmbeddingRecallRecognizer extends PluginRecognizer {
public class EmbeddingRecallParser extends PluginParser {
public boolean checkPreCondition(ChatParseContext chatParseContext) {
@Override
public boolean checkPreCondition(QueryContext queryContext) {
EmbeddingConfig embeddingConfig = ContextUtils.getBean(EmbeddingConfig.class);
if (StringUtils.isBlank(embeddingConfig.getUrl()) && ComponentFactory.getLLMProxy() instanceof PythonLLMProxy) {
return false;
}
List<Plugin> plugins = getPluginList(chatParseContext);
List<Plugin> plugins = getPluginList(queryContext);
return !CollectionUtils.isEmpty(plugins);
}
public PluginRecallResult recallPlugin(ChatParseContext chatParseContext) {
String text = chatParseContext.getQueryText();
@Override
public PluginRecallResult recallPlugin(QueryContext queryContext) {
String text = queryContext.getQueryText();
List<Retrieval> embeddingRetrievals = embeddingRecall(text);
if (CollectionUtils.isEmpty(embeddingRetrievals)) {
return null;
}
List<Plugin> plugins = getPluginList(chatParseContext);
List<Plugin> plugins = getPluginList(queryContext);
Map<Long, Plugin> pluginMap = plugins.stream().collect(Collectors.toMap(Plugin::getId, p -> p));
for (Retrieval embeddingRetrieval : embeddingRetrievals) {
Plugin plugin = pluginMap.get(Long.parseLong(embeddingRetrieval.getId()));
if (plugin == null) {
continue;
}
Pair<Boolean, Set<Long>> pair = PluginManager.resolve(plugin, chatParseContext);
Pair<Boolean, Set<Long>> pair = PluginManager.resolve(plugin, queryContext);
log.info("embedding plugin resolve: {}", pair);
if (pair.getLeft()) {
Set<Long> dataSetList = pair.getRight();
if (CollectionUtils.isEmpty(dataSetList)) {
Set<Long> viewList = pair.getRight();
if (CollectionUtils.isEmpty(viewList)) {
continue;
}
plugin.setParseMode(ParseMode.EMBEDDING_RECALL);
double distance = embeddingRetrieval.getDistance();
double score = chatParseContext.getQueryText().length() * (1 - distance);
double score = queryContext.getQueryText().length() * (1 - distance);
return PluginRecallResult.builder()
.plugin(plugin).dataSetIds(dataSetList).score(score).distance(distance).build();
.plugin(plugin).viewIds(viewList).score(score).distance(distance).build();
}
}
return null;

View File

@@ -1,4 +1,4 @@
package com.tencent.supersonic.chat.server.plugin.recognize.embedding;
package com.tencent.supersonic.chat.core.parser.plugin.embedding;
import lombok.Data;

View File

@@ -1,7 +1,8 @@
package com.tencent.supersonic.chat.server.plugin.recognize.embedding;
package com.tencent.supersonic.chat.core.parser.plugin.embedding;
import lombok.Data;
import java.util.List;
@Data

View File

@@ -0,0 +1,15 @@
package com.tencent.supersonic.chat.core.parser.plugin.function;
import lombok.Data;
import org.springframework.beans.factory.annotation.Value;
import org.springframework.context.annotation.Configuration;
@Configuration
@Data
public class FunctionCallConfig {
@Value("${functionCall.url:}")
private String url;
@Value("${funtionCall.plugin.select.path:/plugin_selection}")
private String pluginSelectPath;
}

Some files were not shown because too many files have changed in this diff Show More