mirror of
https://github.com/tencentmusic/supersonic.git
synced 2025-12-11 20:25:12 +00:00
Compare commits
176 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
0f766f84b9 | ||
|
|
5229fdc8b5 | ||
|
|
bef652892b | ||
|
|
d2306464a6 | ||
|
|
371e2f1e05 | ||
|
|
59c50176c5 | ||
|
|
be9a8bbc27 | ||
|
|
afa82bf98d | ||
|
|
7e013ca36a | ||
|
|
ca098b576c | ||
|
|
af270580bf | ||
|
|
4632549603 | ||
|
|
bd2eaef3f6 | ||
|
|
ba55ecb31e | ||
|
|
10a5e485cb | ||
|
|
2801b27ade | ||
|
|
07e0ba24bc | ||
|
|
115cf19078 | ||
|
|
898c7100ba | ||
|
|
7150f19def | ||
|
|
6aff51d394 | ||
|
|
8d29e89317 | ||
|
|
93fedd787f | ||
|
|
ebea58098c | ||
|
|
95be7f3ce1 | ||
|
|
d32d791238 | ||
|
|
4d65b8b93a | ||
|
|
4d02bb7068 | ||
|
|
c82c2d0a95 | ||
|
|
0c70df12ca | ||
|
|
1ff4a71a41 | ||
|
|
8b01dac8d4 | ||
|
|
2f2f493d17 | ||
|
|
b13b38c645 | ||
|
|
68952fdb55 | ||
|
|
ba9e6afa51 | ||
|
|
da0ac7b26c | ||
|
|
db698ecb75 | ||
|
|
cc3fa0078a | ||
|
|
e586d887ed | ||
|
|
ecc651e12d | ||
|
|
24c63c93bb | ||
|
|
e88e381302 | ||
|
|
f06cd0b296 | ||
|
|
1608317ab3 | ||
|
|
cdb67650c5 | ||
|
|
3ca51145e5 | ||
|
|
794a448619 | ||
|
|
9dbc8657e2 | ||
|
|
b8aeff9a6a | ||
|
|
82f86a8635 | ||
|
|
3d1ca6ac1d | ||
|
|
208686de46 | ||
|
|
c8fe6d2d04 | ||
|
|
45fb83356f | ||
|
|
5ad8ac69ab | ||
|
|
3621766a0d | ||
|
|
89b028b594 | ||
|
|
0a4272c25e | ||
|
|
e2e45a40ab | ||
|
|
97bf8049d7 | ||
|
|
a9232fa1c7 | ||
|
|
ac6b28ebb7 | ||
|
|
53a9f7c451 | ||
|
|
e26263d229 | ||
|
|
cabcbf16ff | ||
|
|
5a18ad5229 | ||
|
|
ac96f72b07 | ||
|
|
ddd44f343a | ||
|
|
55abc883ab | ||
|
|
27a70de1be | ||
|
|
4a5bb9e457 | ||
|
|
12a504585f | ||
|
|
b5fa54a754 | ||
|
|
52a584b9a4 | ||
|
|
b45f3d0663 | ||
|
|
e571ca1f55 | ||
|
|
9a1fac5d4c | ||
|
|
23af977972 | ||
|
|
2472ce2461 | ||
|
|
9e4513f7ca | ||
|
|
9a14728152 | ||
|
|
26f682cc45 | ||
|
|
ccd79e4830 | ||
|
|
e5504473a4 | ||
|
|
ebbb519c07 | ||
|
|
8f620480c6 | ||
|
|
85088abd7b | ||
|
|
f38a84bc8c | ||
|
|
8307c813d2 | ||
|
|
cd8f38c334 | ||
|
|
ae34c15c95 | ||
|
|
c8df102402 | ||
|
|
335902bd1f | ||
|
|
0f5b49f7c5 | ||
|
|
f0bdb14818 | ||
|
|
c39460ee02 | ||
|
|
865788b71b | ||
|
|
8f55fa0c1e | ||
|
|
f03e5a0d38 | ||
|
|
73d9cbdbc1 | ||
|
|
d64ed02df9 | ||
|
|
3797cc2ce8 | ||
|
|
7d64aa893c | ||
|
|
b5768b27aa | ||
|
|
b319f4682a | ||
|
|
ef44954e1e | ||
|
|
17f965eabd | ||
|
|
2425067091 | ||
|
|
2eac301076 | ||
|
|
f30c74c18f | ||
|
|
2cec2e61bc | ||
|
|
1b37925bf3 | ||
|
|
ff38a6d250 | ||
|
|
782b768950 | ||
|
|
1c0b8f8161 | ||
|
|
35892f2344 | ||
|
|
13ae312e51 | ||
|
|
4d1360b924 | ||
|
|
529251097b | ||
|
|
d5c78d87e7 | ||
|
|
4eb6193699 | ||
|
|
407c8d4702 | ||
|
|
baff30550e | ||
|
|
e365a36749 | ||
|
|
5bf4a4160d | ||
|
|
37da1ac2ae | ||
|
|
41ad1ada6c | ||
|
|
f9d6ea11c5 | ||
|
|
d6c5702b5a | ||
|
|
e0647dd990 | ||
|
|
ee258d4c8f | ||
|
|
d7c1b2cbaa | ||
|
|
e041fcb37e | ||
|
|
9bb95ca4be | ||
|
|
76fb3cb4a2 | ||
|
|
78a91ad8c2 | ||
|
|
03f5678732 | ||
|
|
62e70f5cb7 | ||
|
|
8ef5ce8b76 | ||
|
|
d3f3fc5de3 | ||
|
|
c3b3b7e769 | ||
|
|
ea4aa3eacf | ||
|
|
f0b4eb46cf | ||
|
|
7a376bd9a3 | ||
|
|
c9c049a20f | ||
|
|
efd617b2e5 | ||
|
|
9911e6772c | ||
|
|
08ae27ab43 | ||
|
|
64786cb0ef | ||
|
|
4d7bfe07aa | ||
|
|
3f460429e6 | ||
|
|
d849acf971 | ||
|
|
e0e77a3b64 | ||
|
|
609146d12c | ||
|
|
6db6aaf98d | ||
|
|
d39db734c4 | ||
|
|
a1ab7ac1c1 | ||
|
|
16c3ff0c30 | ||
|
|
7c86e2b3db | ||
|
|
2ddf0ad41a | ||
|
|
66e5ee06e6 | ||
|
|
72465cd88c | ||
|
|
097f2f4fe7 | ||
|
|
71954e42a8 | ||
|
|
14b9086d83 | ||
|
|
93ea7a618c | ||
|
|
bb4cd880b0 | ||
|
|
fa5abc58a5 | ||
|
|
d200086779 | ||
|
|
5d5ca438a6 | ||
|
|
c5478ad8a2 | ||
|
|
20aa0bc0a9 | ||
|
|
9cd352f146 | ||
|
|
78e023e955 | ||
|
|
ccf41fa6db |
24
.github/workflows/centos-ci.yml
vendored
24
.github/workflows/centos-ci.yml
vendored
@@ -1,5 +1,4 @@
|
|||||||
name: supersonic RHEL/CentOS CI
|
name: supersonic CentOS CI
|
||||||
|
|
||||||
on:
|
on:
|
||||||
push:
|
push:
|
||||||
branches:
|
branches:
|
||||||
@@ -14,31 +13,52 @@ jobs:
|
|||||||
container:
|
container:
|
||||||
image: quay.io/centos/centos:stream8 # 使用 CentOS Stream 8 容器
|
image: quay.io/centos/centos:stream8 # 使用 CentOS Stream 8 容器
|
||||||
|
|
||||||
|
strategy:
|
||||||
|
matrix:
|
||||||
|
java-version: [8, 11, 21] # 定义要测试的JDK版本
|
||||||
|
|
||||||
steps:
|
steps:
|
||||||
- uses: actions/checkout@v2
|
- uses: actions/checkout@v2
|
||||||
|
|
||||||
- name: Reset DNF repositories
|
- name: Reset DNF repositories
|
||||||
run: |
|
run: |
|
||||||
cd /etc/yum.repos.d/
|
cd /etc/yum.repos.d/
|
||||||
sed -i 's/mirrorlist/#mirrorlist/g' /etc/yum.repos.d/CentOS-*
|
sed -i 's/mirrorlist/#mirrorlist/g' /etc/yum.repos.d/CentOS-*
|
||||||
sed -i 's|#baseurl=http://mirror.centos.org|baseurl=http://vault.centos.org|g' /etc/yum.repos.d/CentOS-*
|
sed -i 's|#baseurl=http://mirror.centos.org|baseurl=http://vault.centos.org|g' /etc/yum.repos.d/CentOS-*
|
||||||
|
|
||||||
- name: Update DNF package index
|
- name: Update DNF package index
|
||||||
run: dnf makecache
|
run: dnf makecache
|
||||||
|
|
||||||
- name: Install Java and Maven with retry
|
- name: Install Java and Maven with retry
|
||||||
run: |
|
run: |
|
||||||
|
if [ ${{ matrix.java-version }} -eq 8 ]; then
|
||||||
for i in {1..5}; do
|
for i in {1..5}; do
|
||||||
dnf install -y java-1.8.0-openjdk-devel maven && break || sleep 15
|
dnf install -y java-1.8.0-openjdk-devel maven && break || sleep 15
|
||||||
done
|
done
|
||||||
|
elif [ ${{ matrix.java-version }} -eq 11 ]; then
|
||||||
|
for i in {1..5}; do
|
||||||
|
dnf install -y java-11-openjdk-devel maven && break || sleep 15
|
||||||
|
done
|
||||||
|
elif [ ${{ matrix.java-version }} -eq 21 ]; then
|
||||||
|
for i in {1..5}; do
|
||||||
|
dnf install -y java-21-openjdk-devel maven && break || sleep 15
|
||||||
|
done
|
||||||
|
fi
|
||||||
|
|
||||||
- name: Verify Java and Maven installation
|
- name: Verify Java and Maven installation
|
||||||
run: |
|
run: |
|
||||||
java -version
|
java -version
|
||||||
mvn -version
|
mvn -version
|
||||||
|
|
||||||
- name: Cache Maven packages
|
- name: Cache Maven packages
|
||||||
uses: actions/cache@v2
|
uses: actions/cache@v2
|
||||||
with:
|
with:
|
||||||
path: ~/.m2
|
path: ~/.m2
|
||||||
key: ${{ runner.os }}-m2-${{ hashFiles('**/pom.xml') }}
|
key: ${{ runner.os }}-m2-${{ hashFiles('**/pom.xml') }}
|
||||||
restore-keys: ${{ runner.os }}-m2
|
restore-keys: ${{ runner.os }}-m2
|
||||||
|
|
||||||
- name: Build with Maven
|
- name: Build with Maven
|
||||||
run: mvn -B package --file pom.xml
|
run: mvn -B package --file pom.xml
|
||||||
|
|
||||||
- name: Test with Maven
|
- name: Test with Maven
|
||||||
run: mvn test
|
run: mvn test
|
||||||
8
.github/workflows/mac-ci.yml
vendored
8
.github/workflows/mac-ci.yml
vendored
@@ -12,13 +12,17 @@ jobs:
|
|||||||
build:
|
build:
|
||||||
runs-on: macos-latest # Specify a macOS runner
|
runs-on: macos-latest # Specify a macOS runner
|
||||||
|
|
||||||
|
strategy:
|
||||||
|
matrix:
|
||||||
|
java-version: [8, 11, 21] # Define the JDK versions to test
|
||||||
|
|
||||||
steps:
|
steps:
|
||||||
- uses: actions/checkout@v2
|
- uses: actions/checkout@v2
|
||||||
|
|
||||||
- name: Set up JDK 8
|
- name: Set up JDK ${{ matrix.java-version }}
|
||||||
uses: actions/setup-java@v2
|
uses: actions/setup-java@v2
|
||||||
with:
|
with:
|
||||||
java-version: '8'
|
java-version: ${{ matrix.java-version }}
|
||||||
distribution: 'adopt'
|
distribution: 'adopt'
|
||||||
|
|
||||||
- name: Cache Maven packages
|
- name: Cache Maven packages
|
||||||
|
|||||||
14
.github/workflows/ubuntu-ci.yml
vendored
14
.github/workflows/ubuntu-ci.yml
vendored
@@ -7,25 +7,33 @@ on:
|
|||||||
pull_request:
|
pull_request:
|
||||||
branches:
|
branches:
|
||||||
- master
|
- master
|
||||||
|
|
||||||
jobs:
|
jobs:
|
||||||
build:
|
build:
|
||||||
|
|
||||||
runs-on: ubuntu-latest
|
runs-on: ubuntu-latest
|
||||||
|
|
||||||
|
strategy:
|
||||||
|
matrix:
|
||||||
|
java-version: [8, 11, 21] # 定义要测试的JDK版本
|
||||||
|
|
||||||
steps:
|
steps:
|
||||||
- uses: actions/checkout@v2
|
- uses: actions/checkout@v2
|
||||||
- name: Set up JDK 8
|
|
||||||
|
- name: Set up JDK ${{ matrix.java-version }}
|
||||||
uses: actions/setup-java@v2
|
uses: actions/setup-java@v2
|
||||||
with:
|
with:
|
||||||
java-version: '8'
|
java-version: ${{ matrix.java-version }}
|
||||||
distribution: 'adopt'
|
distribution: 'adopt'
|
||||||
|
|
||||||
- name: Cache Maven packages
|
- name: Cache Maven packages
|
||||||
uses: actions/cache@v2
|
uses: actions/cache@v2
|
||||||
with:
|
with:
|
||||||
path: ~/.m2
|
path: ~/.m2
|
||||||
key: ${{ runner.os }}-m2-${{ hashFiles('**/pom.xml') }}
|
key: ${{ runner.os }}-m2-${{ hashFiles('**/pom.xml') }}
|
||||||
restore-keys: ${{ runner.os }}-m2
|
restore-keys: ${{ runner.os }}-m2
|
||||||
|
|
||||||
- name: Build with Maven
|
- name: Build with Maven
|
||||||
run: mvn -B package --file pom.xml
|
run: mvn -B package --file pom.xml
|
||||||
|
|
||||||
- name: Test with Maven
|
- name: Test with Maven
|
||||||
run: mvn test
|
run: mvn test
|
||||||
10
.github/workflows/windows-ci.yml
vendored
10
.github/workflows/windows-ci.yml
vendored
@@ -12,14 +12,18 @@ jobs:
|
|||||||
build:
|
build:
|
||||||
runs-on: windows-latest # Specify a Windows runner
|
runs-on: windows-latest # Specify a Windows runner
|
||||||
|
|
||||||
|
strategy:
|
||||||
|
matrix:
|
||||||
|
java-version: [8, 11, 21] # Add JDK 21 to the matrix
|
||||||
|
|
||||||
steps:
|
steps:
|
||||||
- uses: actions/checkout@v2
|
- uses: actions/checkout@v2
|
||||||
|
|
||||||
- name: Set up JDK 8
|
- name: Set up JDK ${{ matrix.java-version }}
|
||||||
uses: actions/setup-java@v2
|
uses: actions/setup-java@v2
|
||||||
with:
|
with:
|
||||||
java-version: '8'
|
java-version: ${{ matrix.java-version }}
|
||||||
distribution: 'adopt'
|
distribution: 'adopt' # You might need to change this if 'adopt' doesn't support JDK 21
|
||||||
|
|
||||||
- name: Cache Maven packages
|
- name: Cache Maven packages
|
||||||
uses: actions/cache@v2
|
uses: actions/cache@v2
|
||||||
|
|||||||
10
README.md
10
README.md
@@ -60,6 +60,12 @@ The high-level architecture and main process flow is as follows:
|
|||||||
### Online playground
|
### Online playground
|
||||||
Visit http://117.72.46.148:9080 to register and experience as a new user. Please do not modify system configurations. We will restart to reset configurations regularly every weekend.
|
Visit http://117.72.46.148:9080 to register and experience as a new user. Please do not modify system configurations. We will restart to reset configurations regularly every weekend.
|
||||||
|
|
||||||
|
### Docker Deployment
|
||||||
|
- Install Docker and docker-compose.
|
||||||
|
- Download the docker-compose.yml file; Execute: wget https://raw.githubusercontent.com/tencentmusic/supersonic/master/docker/docker-compose.yml.
|
||||||
|
- Execute "docker-compose up -d".
|
||||||
|
- Open a browser and visit http://localhost:9080 to start exploring.
|
||||||
|
|
||||||
### Local build
|
### Local build
|
||||||
SuperSonic comes with sample semantic models as well as chat conversations that can be used as a starting point. Please follow the steps:
|
SuperSonic comes with sample semantic models as well as chat conversations that can be used as a starting point. Please follow the steps:
|
||||||
|
|
||||||
@@ -76,7 +82,3 @@ Please refer to project [Docs](https://supersonicbi.github.io/docs/%E7%B3%BB%E7%
|
|||||||
Please follow SuperSonic wechat official account:
|
Please follow SuperSonic wechat official account:
|
||||||
|
|
||||||
<img src="https://github.com/supersonicbi/supersonic-website/blob/main/static/img/supersonic_wechat_oa.png" height="50%" width="50%" />
|
<img src="https://github.com/supersonicbi/supersonic-website/blob/main/static/img/supersonic_wechat_oa.png" height="50%" width="50%" />
|
||||||
|
|
||||||
Welcome to join the WeChat community:
|
|
||||||
|
|
||||||
<img src="https://github.com/supersonicbi/supersonic-website/blob/main/static/img/supersonic_wechat.png" height="50%" width="50%" />
|
|
||||||
|
|||||||
10
README_CN.md
10
README_CN.md
@@ -59,6 +59,12 @@ SuperSonic的整体架构和主流程如下图所示:
|
|||||||
### 线上环境体验
|
### 线上环境体验
|
||||||
访问http://117.72.46.148:9080 注册新用户体验. 请勿修改系统配置。我们每周末定期重启重置配置。
|
访问http://117.72.46.148:9080 注册新用户体验. 请勿修改系统配置。我们每周末定期重启重置配置。
|
||||||
|
|
||||||
|
### Docker部署
|
||||||
|
- 安装好Docker以及docker-compose
|
||||||
|
- 下载docker-compose.yml;执行命令:wget https://raw.githubusercontent.com/tencentmusic/supersonic/master/docker/docker-compose.yml
|
||||||
|
- 执行:"docker-compose up -d"
|
||||||
|
- 在浏览器访问http://localhost:9080 开启探索
|
||||||
|
|
||||||
### 本地构建
|
### 本地构建
|
||||||
|
|
||||||
SuperSonic自带样例的语义模型和问答对话,只需以下三步即可快速体验:
|
SuperSonic自带样例的语义模型和问答对话,只需以下三步即可快速体验:
|
||||||
@@ -76,7 +82,3 @@ SuperSonic自带样例的语义模型和问答对话,只需以下三步即可
|
|||||||
欢迎关注微信公众号:
|
欢迎关注微信公众号:
|
||||||
|
|
||||||
<img src="https://github.com/supersonicbi/supersonic-website/blob/main/static/img/supersonic_wechat_oa.png" height="50%" width="50%" />
|
<img src="https://github.com/supersonicbi/supersonic-website/blob/main/static/img/supersonic_wechat_oa.png" height="50%" width="50%" />
|
||||||
|
|
||||||
欢迎加入微信社群:
|
|
||||||
|
|
||||||
<img src="https://github.com/supersonicbi/supersonic-website/blob/main/static/img/supersonic_wechat.png" height="50%" width="50%" />
|
|
||||||
|
|||||||
10
README_JP.md
10
README_JP.md
@@ -56,6 +56,12 @@ ChatGPTのような大規模言語モデル(LLM)の出現は、情報検索
|
|||||||
### オンラインプレイグラウンド
|
### オンラインプレイグラウンド
|
||||||
http://117.72.46.148:9080 にアクセスして、新規ユーザーとして登録して体験してください。システム設定を変更しないでください。毎週末に定期的に再起動して設定をリセットします。
|
http://117.72.46.148:9080 にアクセスして、新規ユーザーとして登録して体験してください。システム設定を変更しないでください。毎週末に定期的に再起動して設定をリセットします。
|
||||||
|
|
||||||
|
### Dockerのデプロイメント
|
||||||
|
- Dockerおよびdocker-composeをインストールします。
|
||||||
|
- docker-compose.ymlファイルをダウンロードします。コマンドを実行します:wget https://raw.githubusercontent.com/tencentmusic/supersonic/master/docker/docker-compose.yml。
|
||||||
|
- docker-compose up -dを実行します。
|
||||||
|
- ブラウザを開いてhttp://localhost:9080にアクセスし、探索を開始します。
|
||||||
|
|
||||||
### ローカルビルド
|
### ローカルビルド
|
||||||
SuperSonicには、サンプルのセマンティックモデルとチャット会話が付属しており、以下の手順で簡単に体験できます:
|
SuperSonicには、サンプルのセマンティックモデルとチャット会話が付属しており、以下の手順で簡単に体験できます:
|
||||||
|
|
||||||
@@ -72,7 +78,3 @@ SuperSonicには、サンプルのセマンティックモデルとチャット
|
|||||||
SuperSonicの公式WeChatアカウントをフォローしてください:
|
SuperSonicの公式WeChatアカウントをフォローしてください:
|
||||||
|
|
||||||

|

|
||||||
|
|
||||||
WeChatコミュニティに参加することを歓迎します:
|
|
||||||
|
|
||||||

|
|
||||||
|
|||||||
@@ -54,7 +54,7 @@ if "%command%"=="restart" (
|
|||||||
set "webDir=%baseDir%\webapp"
|
set "webDir=%baseDir%\webapp"
|
||||||
set "logDir=%baseDir%\logs"
|
set "logDir=%baseDir%\logs"
|
||||||
set "classpath=%baseDir%;%webDir%;%libDir%\*;%confDir%"
|
set "classpath=%baseDir%;%webDir%;%libDir%\*;%confDir%"
|
||||||
set "java-command=-Dfile.encoding=UTF-8 -Duser.language=Zh -Duser.region=CN -Duser.timezone=GMT+08 -Dspring.profiles.active=%profile% -Xms1024m -Xmx2048m -cp %CLASSPATH% %MAIN_CLASS%"
|
set "java-command=-Dfile.encoding=UTF-8 -Duser.language=Zh -Duser.region=CN -Duser.timezone=GMT+08 -Dspring.profiles.active=%profile% -Xms1024m -Xmx1024m -cp %CLASSPATH% %MAIN_CLASS%"
|
||||||
if not exist %logDir% mkdir %logDir%
|
if not exist %logDir% mkdir %logDir%
|
||||||
start /B java %java-command% >nul 2>&1
|
start /B java %java-command% >nul 2>&1
|
||||||
timeout /t 10 >nul
|
timeout /t 10 >nul
|
||||||
|
|||||||
@@ -59,7 +59,7 @@ function runJavaService {
|
|||||||
JAVA_HOME=$(ls /usr/jdk64/jdk* -d 2>/dev/null | xargs | awk '{print "'$local_app_name'"}')
|
JAVA_HOME=$(ls /usr/jdk64/jdk* -d 2>/dev/null | xargs | awk '{print "'$local_app_name'"}')
|
||||||
fi
|
fi
|
||||||
export PATH=$JAVA_HOME/bin:$PATH
|
export PATH=$JAVA_HOME/bin:$PATH
|
||||||
command="-Dfile.encoding=UTF-8 -Duser.language=Zh -Duser.region=CN -Duser.timezone=GMT+08 -Dapp_name=${local_app_name} -Xms1024m -Xmx2048m $main_class"
|
command="-Dfile.encoding=UTF-8 -Duser.language=Zh -Duser.region=CN -Duser.timezone=GMT+08 -Dapp_name=${local_app_name} -Xms1024m -Xmx1024m $main_class"
|
||||||
|
|
||||||
mkdir -p $javaRunDir/logs
|
mkdir -p $javaRunDir/logs
|
||||||
java -Dspring.profiles.active="$profile" $command >/dev/null 2>$javaRunDir/logs/error.log &
|
java -Dspring.profiles.active="$profile" $command >/dev/null 2>$javaRunDir/logs/error.log &
|
||||||
|
|||||||
@@ -1,11 +0,0 @@
|
|||||||
package com.tencent.supersonic.auth.api.authorization.pojo;
|
|
||||||
|
|
||||||
import java.util.ArrayList;
|
|
||||||
import java.util.List;
|
|
||||||
import lombok.Data;
|
|
||||||
|
|
||||||
@Data
|
|
||||||
public class AuthResGrp {
|
|
||||||
|
|
||||||
private List<AuthRes> group = new ArrayList<>();
|
|
||||||
}
|
|
||||||
@@ -1,7 +1,6 @@
|
|||||||
package com.tencent.supersonic.auth.api.authorization.request;
|
package com.tencent.supersonic.auth.api.authorization.request;
|
||||||
|
|
||||||
import com.google.common.collect.Lists;
|
import com.google.common.collect.Lists;
|
||||||
import com.tencent.supersonic.auth.api.authorization.pojo.AuthRes;
|
|
||||||
import lombok.Data;
|
import lombok.Data;
|
||||||
import lombok.ToString;
|
import lombok.ToString;
|
||||||
import org.springframework.util.CollectionUtils;
|
import org.springframework.util.CollectionUtils;
|
||||||
@@ -15,8 +14,6 @@ public class QueryAuthResReq {
|
|||||||
|
|
||||||
private List<String> departmentIds = new ArrayList<>();
|
private List<String> departmentIds = new ArrayList<>();
|
||||||
|
|
||||||
private List<AuthRes> resources;
|
|
||||||
|
|
||||||
private Long modelId;
|
private Long modelId;
|
||||||
|
|
||||||
private List<Long> modelIds;
|
private List<Long> modelIds;
|
||||||
|
|||||||
@@ -1,16 +1,16 @@
|
|||||||
package com.tencent.supersonic.auth.api.authorization.response;
|
package com.tencent.supersonic.auth.api.authorization.response;
|
||||||
|
|
||||||
import com.tencent.supersonic.auth.api.authorization.pojo.AuthResGrp;
|
import com.tencent.supersonic.auth.api.authorization.pojo.AuthRes;
|
||||||
import com.tencent.supersonic.auth.api.authorization.pojo.DimensionFilter;
|
import com.tencent.supersonic.auth.api.authorization.pojo.DimensionFilter;
|
||||||
|
import lombok.Data;
|
||||||
|
|
||||||
import java.util.ArrayList;
|
import java.util.ArrayList;
|
||||||
import java.util.List;
|
import java.util.List;
|
||||||
|
|
||||||
import lombok.Data;
|
|
||||||
|
|
||||||
@Data
|
@Data
|
||||||
public class AuthorizedResourceResp {
|
public class AuthorizedResourceResp {
|
||||||
|
|
||||||
private List<AuthResGrp> resources = new ArrayList<>();
|
private List<AuthRes> authResList = new ArrayList<>();
|
||||||
|
|
||||||
private List<DimensionFilter> filters = new ArrayList<>();
|
private List<DimensionFilter> filters = new ArrayList<>();
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,19 +1,16 @@
|
|||||||
package com.tencent.supersonic.auth.authorization.service;
|
package com.tencent.supersonic.auth.authorization.service;
|
||||||
|
|
||||||
import com.google.common.collect.Lists;
|
|
||||||
import com.google.gson.Gson;
|
import com.google.gson.Gson;
|
||||||
import com.tencent.supersonic.auth.api.authentication.pojo.User;
|
import com.tencent.supersonic.auth.api.authentication.pojo.User;
|
||||||
import com.tencent.supersonic.auth.api.authentication.service.UserService;
|
import com.tencent.supersonic.auth.api.authentication.service.UserService;
|
||||||
import com.tencent.supersonic.auth.api.authorization.pojo.AuthGroup;
|
import com.tencent.supersonic.auth.api.authorization.pojo.AuthGroup;
|
||||||
import com.tencent.supersonic.auth.api.authorization.pojo.AuthRes;
|
import com.tencent.supersonic.auth.api.authorization.pojo.AuthRes;
|
||||||
import com.tencent.supersonic.auth.api.authorization.pojo.AuthResGrp;
|
|
||||||
import com.tencent.supersonic.auth.api.authorization.pojo.AuthRule;
|
import com.tencent.supersonic.auth.api.authorization.pojo.AuthRule;
|
||||||
import com.tencent.supersonic.auth.api.authorization.pojo.DimensionFilter;
|
import com.tencent.supersonic.auth.api.authorization.pojo.DimensionFilter;
|
||||||
import com.tencent.supersonic.auth.api.authorization.request.QueryAuthResReq;
|
import com.tencent.supersonic.auth.api.authorization.request.QueryAuthResReq;
|
||||||
import com.tencent.supersonic.auth.api.authorization.response.AuthorizedResourceResp;
|
import com.tencent.supersonic.auth.api.authorization.response.AuthorizedResourceResp;
|
||||||
import com.tencent.supersonic.auth.api.authorization.service.AuthService;
|
import com.tencent.supersonic.auth.api.authorization.service.AuthService;
|
||||||
import lombok.extern.slf4j.Slf4j;
|
import lombok.extern.slf4j.Slf4j;
|
||||||
import org.apache.commons.lang3.StringUtils;
|
|
||||||
import org.springframework.jdbc.core.JdbcTemplate;
|
import org.springframework.jdbc.core.JdbcTemplate;
|
||||||
import org.springframework.stereotype.Service;
|
import org.springframework.stereotype.Service;
|
||||||
import org.springframework.util.CollectionUtils;
|
import org.springframework.util.CollectionUtils;
|
||||||
@@ -79,66 +76,48 @@ public class AuthServiceImpl implements AuthService {
|
|||||||
|
|
||||||
@Override
|
@Override
|
||||||
public AuthorizedResourceResp queryAuthorizedResources(QueryAuthResReq req, User user) {
|
public AuthorizedResourceResp queryAuthorizedResources(QueryAuthResReq req, User user) {
|
||||||
|
if (CollectionUtils.isEmpty(req.getModelIds())) {
|
||||||
|
return new AuthorizedResourceResp();
|
||||||
|
}
|
||||||
Set<String> userOrgIds = userService.getUserAllOrgId(user.getName());
|
Set<String> userOrgIds = userService.getUserAllOrgId(user.getName());
|
||||||
List<AuthGroup> groups = getAuthGroups(req.getModelIds(), user.getName(), new ArrayList<>(userOrgIds));
|
List<AuthGroup> groups = getAuthGroups(req.getModelIds(), user.getName(), new ArrayList<>(userOrgIds));
|
||||||
AuthorizedResourceResp resource = new AuthorizedResourceResp();
|
AuthorizedResourceResp resource = new AuthorizedResourceResp();
|
||||||
Map<Long, List<AuthGroup>> authGroupsByModelId = groups.stream()
|
Map<Long, List<AuthGroup>> authGroupsByModelId = groups.stream()
|
||||||
.collect(Collectors.groupingBy(AuthGroup::getModelId));
|
.collect(Collectors.groupingBy(AuthGroup::getModelId));
|
||||||
Map<Long, List<AuthRes>> reqAuthRes = req.getResources().stream()
|
for (Long modelId : req.getModelIds()) {
|
||||||
.collect(Collectors.groupingBy(AuthRes::getModelId));
|
|
||||||
|
|
||||||
for (Long modelId : reqAuthRes.keySet()) {
|
|
||||||
List<AuthRes> reqResourcesList = reqAuthRes.get(modelId);
|
|
||||||
AuthResGrp rg = new AuthResGrp();
|
|
||||||
if (authGroupsByModelId.containsKey(modelId)) {
|
if (authGroupsByModelId.containsKey(modelId)) {
|
||||||
List<AuthGroup> authGroups = authGroupsByModelId.get(modelId);
|
List<AuthGroup> authGroups = authGroupsByModelId.get(modelId);
|
||||||
for (AuthRes reqRes : reqResourcesList) {
|
|
||||||
for (AuthGroup authRuleGroup : authGroups) {
|
for (AuthGroup authRuleGroup : authGroups) {
|
||||||
List<AuthRule> authRules = authRuleGroup.getAuthRules();
|
List<AuthRule> authRules = authRuleGroup.getAuthRules();
|
||||||
List<String> allAuthItems = new ArrayList<>();
|
for (AuthRule authRule : authRules) {
|
||||||
authRules.forEach(authRule -> allAuthItems.addAll(authRule.resourceNames()));
|
for (String resBizName : authRule.resourceNames()) {
|
||||||
|
resource.getAuthResList().add(new AuthRes(modelId, resBizName));
|
||||||
if (allAuthItems.contains(reqRes.getName())) {
|
|
||||||
rg.getGroup().add(reqRes);
|
|
||||||
}
|
|
||||||
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
if (!CollectionUtils.isEmpty(rg.getGroup())) {
|
|
||||||
resource.getResources().add(rg);
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
Set<Map.Entry<Long, List<AuthGroup>>> entries = authGroupsByModelId.entrySet();
|
||||||
if (!CollectionUtils.isEmpty(req.getModelIds())) {
|
for (Map.Entry<Long, List<AuthGroup>> entry : entries) {
|
||||||
List<AuthGroup> authGroups = Lists.newArrayList();
|
List<AuthGroup> authGroups = entry.getValue();
|
||||||
for (Long modelId : authGroupsByModelId.keySet()) {
|
for (AuthGroup authGroup : authGroups) {
|
||||||
authGroups.addAll(authGroupsByModelId.getOrDefault(modelId, Lists.newArrayList()));
|
|
||||||
}
|
|
||||||
if (!CollectionUtils.isEmpty(authGroups)) {
|
|
||||||
for (AuthGroup group : authGroups) {
|
|
||||||
if (group.getDimensionFilters() != null
|
|
||||||
&& group.getDimensionFilters().stream().anyMatch(expr ->
|
|
||||||
!StringUtils.isEmpty(expr))) {
|
|
||||||
DimensionFilter df = new DimensionFilter();
|
DimensionFilter df = new DimensionFilter();
|
||||||
df.setDescription(group.getDimensionFilterDescription());
|
df.setDescription(authGroup.getDimensionFilterDescription());
|
||||||
df.setExpressions(group.getDimensionFilters());
|
df.setExpressions(authGroup.getDimensionFilters());
|
||||||
resource.getFilters().add(df);
|
resource.getFilters().add(df);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
|
||||||
}
|
|
||||||
return resource;
|
return resource;
|
||||||
}
|
}
|
||||||
|
|
||||||
private List<AuthGroup> getAuthGroups(List<Long> modelIds, String userName, List<String> departmentIds) {
|
private List<AuthGroup> getAuthGroups(List<Long> modelIds, String userName, List<String> departmentIds) {
|
||||||
List<AuthGroup> groups = load().stream()
|
List<AuthGroup> groups = load().stream()
|
||||||
.filter(group -> {
|
.filter(group -> {
|
||||||
if (CollectionUtils.isEmpty(modelIds) || !modelIds.contains(group.getModelId())) {
|
if (!modelIds.contains(group.getModelId())) {
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
if (!CollectionUtils.isEmpty(group.getAuthorizedUsers()) && group.getAuthorizedUsers()
|
if (!CollectionUtils.isEmpty(group.getAuthorizedUsers())
|
||||||
.contains(userName)) {
|
&& group.getAuthorizedUsers().contains(userName)) {
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
for (String departmentId : departmentIds) {
|
for (String departmentId : departmentIds) {
|
||||||
|
|||||||
@@ -2,7 +2,6 @@ package com.tencent.supersonic.chat.api.pojo.response;
|
|||||||
|
|
||||||
import com.tencent.supersonic.headless.api.pojo.SemanticParseInfo;
|
import com.tencent.supersonic.headless.api.pojo.SemanticParseInfo;
|
||||||
import com.tencent.supersonic.headless.api.pojo.response.ParseTimeCostResp;
|
import com.tencent.supersonic.headless.api.pojo.response.ParseTimeCostResp;
|
||||||
import com.tencent.supersonic.headless.api.pojo.response.QueryResult;
|
|
||||||
import lombok.Data;
|
import lombok.Data;
|
||||||
import java.util.Date;
|
import java.util.Date;
|
||||||
import java.util.List;
|
import java.util.List;
|
||||||
|
|||||||
@@ -1,4 +1,4 @@
|
|||||||
package com.tencent.supersonic.headless.api.pojo.response;
|
package com.tencent.supersonic.chat.api.pojo.response;
|
||||||
|
|
||||||
import com.tencent.supersonic.common.pojo.QueryAuthorization;
|
import com.tencent.supersonic.common.pojo.QueryAuthorization;
|
||||||
import com.tencent.supersonic.common.pojo.QueryColumn;
|
import com.tencent.supersonic.common.pojo.QueryColumn;
|
||||||
@@ -6,6 +6,7 @@ import com.tencent.supersonic.headless.api.pojo.AggregateInfo;
|
|||||||
import com.tencent.supersonic.headless.api.pojo.EntityInfo;
|
import com.tencent.supersonic.headless.api.pojo.EntityInfo;
|
||||||
import com.tencent.supersonic.headless.api.pojo.SchemaElement;
|
import com.tencent.supersonic.headless.api.pojo.SchemaElement;
|
||||||
import com.tencent.supersonic.headless.api.pojo.SemanticParseInfo;
|
import com.tencent.supersonic.headless.api.pojo.SemanticParseInfo;
|
||||||
|
import com.tencent.supersonic.headless.api.pojo.response.QueryState;
|
||||||
import lombok.Data;
|
import lombok.Data;
|
||||||
|
|
||||||
import java.util.List;
|
import java.util.List;
|
||||||
@@ -4,8 +4,9 @@ package com.tencent.supersonic.chat.server.agent;
|
|||||||
import com.alibaba.fastjson.JSONObject;
|
import com.alibaba.fastjson.JSONObject;
|
||||||
import com.google.common.collect.Lists;
|
import com.google.common.collect.Lists;
|
||||||
import com.google.common.collect.Sets;
|
import com.google.common.collect.Sets;
|
||||||
import com.tencent.supersonic.common.config.LLMConfig;
|
import com.tencent.supersonic.common.config.PromptConfig;
|
||||||
import com.tencent.supersonic.common.config.VisualConfig;
|
import com.tencent.supersonic.common.config.VisualConfig;
|
||||||
|
import com.tencent.supersonic.common.pojo.ChatModelConfig;
|
||||||
import com.tencent.supersonic.common.pojo.RecordInfo;
|
import com.tencent.supersonic.common.pojo.RecordInfo;
|
||||||
import lombok.Data;
|
import lombok.Data;
|
||||||
import org.springframework.util.CollectionUtils;
|
import org.springframework.util.CollectionUtils;
|
||||||
@@ -23,6 +24,7 @@ public class Agent extends RecordInfo {
|
|||||||
|
|
||||||
private Integer id;
|
private Integer id;
|
||||||
private Integer enableSearch;
|
private Integer enableSearch;
|
||||||
|
private Integer enableMemoryReview;
|
||||||
private String name;
|
private String name;
|
||||||
private String description;
|
private String description;
|
||||||
|
|
||||||
@@ -32,7 +34,8 @@ public class Agent extends RecordInfo {
|
|||||||
private Integer status;
|
private Integer status;
|
||||||
private List<String> examples;
|
private List<String> examples;
|
||||||
private String agentConfig;
|
private String agentConfig;
|
||||||
private LLMConfig llmConfig;
|
private ChatModelConfig modelConfig;
|
||||||
|
private PromptConfig promptConfig;
|
||||||
private MultiTurnConfig multiTurnConfig;
|
private MultiTurnConfig multiTurnConfig;
|
||||||
private VisualConfig visualConfig;
|
private VisualConfig visualConfig;
|
||||||
|
|
||||||
@@ -58,6 +61,10 @@ public class Agent extends RecordInfo {
|
|||||||
return enableSearch != null && enableSearch == 1;
|
return enableSearch != null && enableSearch == 1;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
public boolean enableMemoryReview() {
|
||||||
|
return enableMemoryReview != null && enableMemoryReview == 1;
|
||||||
|
}
|
||||||
|
|
||||||
public static boolean containsAllModel(Set<Long> detectViewIds) {
|
public static boolean containsAllModel(Set<Long> detectViewIds) {
|
||||||
return !CollectionUtils.isEmpty(detectViewIds) && detectViewIds.contains(-1L);
|
return !CollectionUtils.isEmpty(detectViewIds) && detectViewIds.contains(-1L);
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,10 +0,0 @@
|
|||||||
package com.tencent.supersonic.chat.server.executor;
|
|
||||||
|
|
||||||
import com.tencent.supersonic.chat.server.pojo.ChatExecuteContext;
|
|
||||||
import com.tencent.supersonic.headless.api.pojo.response.QueryResult;
|
|
||||||
|
|
||||||
public interface ChatExecutor {
|
|
||||||
|
|
||||||
QueryResult execute(ChatExecuteContext chatExecuteContext);
|
|
||||||
|
|
||||||
}
|
|
||||||
@@ -0,0 +1,10 @@
|
|||||||
|
package com.tencent.supersonic.chat.server.executor;
|
||||||
|
|
||||||
|
import com.tencent.supersonic.chat.server.pojo.ExecuteContext;
|
||||||
|
import com.tencent.supersonic.chat.api.pojo.response.QueryResult;
|
||||||
|
|
||||||
|
public interface ChatQueryExecutor {
|
||||||
|
|
||||||
|
QueryResult execute(ExecuteContext executeContext);
|
||||||
|
|
||||||
|
}
|
||||||
@@ -1,29 +1,30 @@
|
|||||||
package com.tencent.supersonic.chat.server.executor;
|
package com.tencent.supersonic.chat.server.executor;
|
||||||
|
|
||||||
|
import com.tencent.supersonic.chat.api.pojo.response.QueryResp;
|
||||||
import com.tencent.supersonic.chat.server.agent.Agent;
|
import com.tencent.supersonic.chat.server.agent.Agent;
|
||||||
import com.tencent.supersonic.chat.server.agent.MultiTurnConfig;
|
import com.tencent.supersonic.chat.server.agent.MultiTurnConfig;
|
||||||
import com.tencent.supersonic.chat.server.parser.ParserConfig;
|
import com.tencent.supersonic.chat.server.parser.ParserConfig;
|
||||||
import com.tencent.supersonic.chat.server.persistence.repository.ChatQueryRepository;
|
import com.tencent.supersonic.chat.server.pojo.ExecuteContext;
|
||||||
import com.tencent.supersonic.chat.server.pojo.ChatExecuteContext;
|
|
||||||
import com.tencent.supersonic.chat.server.service.AgentService;
|
import com.tencent.supersonic.chat.server.service.AgentService;
|
||||||
|
import com.tencent.supersonic.chat.server.service.ChatManageService;
|
||||||
import com.tencent.supersonic.common.util.ContextUtils;
|
import com.tencent.supersonic.common.util.ContextUtils;
|
||||||
import com.tencent.supersonic.headless.api.pojo.response.ParseResp;
|
import com.tencent.supersonic.chat.api.pojo.response.QueryResult;
|
||||||
import com.tencent.supersonic.headless.api.pojo.response.QueryResult;
|
|
||||||
import com.tencent.supersonic.headless.api.pojo.response.QueryState;
|
import com.tencent.supersonic.headless.api.pojo.response.QueryState;
|
||||||
import dev.langchain4j.data.message.AiMessage;
|
import dev.langchain4j.data.message.AiMessage;
|
||||||
import dev.langchain4j.model.chat.ChatLanguageModel;
|
import dev.langchain4j.model.chat.ChatLanguageModel;
|
||||||
import dev.langchain4j.model.input.Prompt;
|
import dev.langchain4j.model.input.Prompt;
|
||||||
import dev.langchain4j.model.input.PromptTemplate;
|
import dev.langchain4j.model.input.PromptTemplate;
|
||||||
import dev.langchain4j.model.output.Response;
|
import dev.langchain4j.model.output.Response;
|
||||||
import dev.langchain4j.model.provider.ChatLanguageModelProvider;
|
import dev.langchain4j.provider.ModelProvider;
|
||||||
|
|
||||||
import java.util.Collections;
|
import java.util.Collections;
|
||||||
import java.util.List;
|
import java.util.List;
|
||||||
|
import java.util.Objects;
|
||||||
import java.util.stream.Collectors;
|
import java.util.stream.Collectors;
|
||||||
|
|
||||||
import static com.tencent.supersonic.chat.server.parser.ParserConfig.PARSER_MULTI_TURN_ENABLE;
|
import static com.tencent.supersonic.chat.server.parser.ParserConfig.PARSER_MULTI_TURN_ENABLE;
|
||||||
|
|
||||||
public class PlainTextExecutor implements ChatExecutor {
|
public class PlainTextExecutor implements ChatQueryExecutor {
|
||||||
|
|
||||||
private static final String INSTRUCTION = ""
|
private static final String INSTRUCTION = ""
|
||||||
+ "#Role: You are a nice person to talk to.\n"
|
+ "#Role: You are a nice person to talk to.\n"
|
||||||
@@ -34,34 +35,34 @@ public class PlainTextExecutor implements ChatExecutor {
|
|||||||
+ "#Your response: ";
|
+ "#Your response: ";
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public QueryResult execute(ChatExecuteContext chatExecuteContext) {
|
public QueryResult execute(ExecuteContext executeContext) {
|
||||||
if (!"PLAIN_TEXT".equals(chatExecuteContext.getParseInfo().getQueryMode())) {
|
if (!"PLAIN_TEXT".equals(executeContext.getParseInfo().getQueryMode())) {
|
||||||
return null;
|
return null;
|
||||||
}
|
}
|
||||||
|
|
||||||
String promptStr = String.format(INSTRUCTION, getHistoryInputs(chatExecuteContext),
|
String promptStr = String.format(INSTRUCTION, getHistoryInputs(executeContext),
|
||||||
chatExecuteContext.getQueryText());
|
executeContext.getQueryText());
|
||||||
Prompt prompt = PromptTemplate.from(promptStr).apply(Collections.EMPTY_MAP);
|
Prompt prompt = PromptTemplate.from(promptStr).apply(Collections.EMPTY_MAP);
|
||||||
|
|
||||||
AgentService agentService = ContextUtils.getBean(AgentService.class);
|
AgentService agentService = ContextUtils.getBean(AgentService.class);
|
||||||
Agent chatAgent = agentService.getAgent(chatExecuteContext.getAgentId());
|
Agent chatAgent = agentService.getAgent(executeContext.getAgent().getId());
|
||||||
|
|
||||||
ChatLanguageModel chatLanguageModel = ChatLanguageModelProvider.provide(chatAgent.getLlmConfig());
|
ChatLanguageModel chatLanguageModel = ModelProvider.getChatModel(chatAgent.getModelConfig());
|
||||||
Response<AiMessage> response = chatLanguageModel.generate(prompt.toUserMessage());
|
Response<AiMessage> response = chatLanguageModel.generate(prompt.toUserMessage());
|
||||||
|
|
||||||
QueryResult result = new QueryResult();
|
QueryResult result = new QueryResult();
|
||||||
result.setQueryState(QueryState.SUCCESS);
|
result.setQueryState(QueryState.SUCCESS);
|
||||||
result.setQueryMode(chatExecuteContext.getParseInfo().getQueryMode());
|
result.setQueryMode(executeContext.getParseInfo().getQueryMode());
|
||||||
result.setTextResult(response.content().text());
|
result.setTextResult(response.content().text());
|
||||||
|
|
||||||
return result;
|
return result;
|
||||||
}
|
}
|
||||||
|
|
||||||
private String getHistoryInputs(ChatExecuteContext chatExecuteContext) {
|
private String getHistoryInputs(ExecuteContext executeContext) {
|
||||||
StringBuilder historyInput = new StringBuilder();
|
StringBuilder historyInput = new StringBuilder();
|
||||||
|
|
||||||
AgentService agentService = ContextUtils.getBean(AgentService.class);
|
AgentService agentService = ContextUtils.getBean(AgentService.class);
|
||||||
Agent chatAgent = agentService.getAgent(chatExecuteContext.getAgentId());
|
Agent chatAgent = agentService.getAgent(executeContext.getAgent().getId());
|
||||||
|
|
||||||
ParserConfig parserConfig = ContextUtils.getBean(ParserConfig.class);
|
ParserConfig parserConfig = ContextUtils.getBean(ParserConfig.class);
|
||||||
MultiTurnConfig agentMultiTurnConfig = chatAgent.getMultiTurnConfig();
|
MultiTurnConfig agentMultiTurnConfig = chatAgent.getMultiTurnConfig();
|
||||||
@@ -70,8 +71,8 @@ public class PlainTextExecutor implements ChatExecutor {
|
|||||||
? agentMultiTurnConfig.isEnableMultiTurn() : globalMultiTurnConfig;
|
? agentMultiTurnConfig.isEnableMultiTurn() : globalMultiTurnConfig;
|
||||||
|
|
||||||
if (Boolean.TRUE.equals(multiTurnConfig)) {
|
if (Boolean.TRUE.equals(multiTurnConfig)) {
|
||||||
List<ParseResp> parseResps = getHistoryParseResult(chatExecuteContext.getChatId(), 5);
|
List<QueryResp> queryResps = getHistoryQueries(executeContext.getChatId(), 5);
|
||||||
parseResps.stream().forEach(p -> {
|
queryResps.stream().forEach(p -> {
|
||||||
historyInput.append(p.getQueryText());
|
historyInput.append(p.getQueryText());
|
||||||
historyInput.append(";");
|
historyInput.append(";");
|
||||||
});
|
});
|
||||||
@@ -80,12 +81,15 @@ public class PlainTextExecutor implements ChatExecutor {
|
|||||||
return historyInput.toString();
|
return historyInput.toString();
|
||||||
}
|
}
|
||||||
|
|
||||||
private List<ParseResp> getHistoryParseResult(int chatId, int multiNum) {
|
private List<QueryResp> getHistoryQueries(int chatId, int multiNum) {
|
||||||
ChatQueryRepository chatQueryRepository = ContextUtils.getBean(ChatQueryRepository.class);
|
ChatManageService chatManageService = ContextUtils.getBean(ChatManageService.class);
|
||||||
List<ParseResp> contextualParseInfoList = chatQueryRepository.getContextualParseInfo(chatId)
|
List<QueryResp> contextualParseInfoList = chatManageService.getChatQueries(chatId)
|
||||||
.stream().filter(p -> p.getState() != ParseResp.ParseState.FAILED).collect(Collectors.toList());
|
.stream()
|
||||||
|
.filter(q -> Objects.nonNull(q.getQueryResult())
|
||||||
|
&& q.getQueryResult().getQueryState() == QueryState.SUCCESS)
|
||||||
|
.collect(Collectors.toList());
|
||||||
|
|
||||||
List<ParseResp> contextualList = contextualParseInfoList.subList(0,
|
List<QueryResp> contextualList = contextualParseInfoList.subList(0,
|
||||||
Math.min(multiNum, contextualParseInfoList.size()));
|
Math.min(multiNum, contextualParseInfoList.size()));
|
||||||
Collections.reverse(contextualList);
|
Collections.reverse(contextualList);
|
||||||
|
|
||||||
|
|||||||
@@ -2,15 +2,15 @@ package com.tencent.supersonic.chat.server.executor;
|
|||||||
|
|
||||||
import com.tencent.supersonic.chat.server.plugin.PluginQueryManager;
|
import com.tencent.supersonic.chat.server.plugin.PluginQueryManager;
|
||||||
import com.tencent.supersonic.chat.server.plugin.build.PluginSemanticQuery;
|
import com.tencent.supersonic.chat.server.plugin.build.PluginSemanticQuery;
|
||||||
import com.tencent.supersonic.chat.server.pojo.ChatExecuteContext;
|
import com.tencent.supersonic.chat.server.pojo.ExecuteContext;
|
||||||
import com.tencent.supersonic.headless.api.pojo.SemanticParseInfo;
|
import com.tencent.supersonic.headless.api.pojo.SemanticParseInfo;
|
||||||
import com.tencent.supersonic.headless.api.pojo.response.QueryResult;
|
import com.tencent.supersonic.chat.api.pojo.response.QueryResult;
|
||||||
|
|
||||||
public class PluginExecutor implements ChatExecutor {
|
public class PluginExecutor implements ChatQueryExecutor {
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public QueryResult execute(ChatExecuteContext chatExecuteContext) {
|
public QueryResult execute(ExecuteContext executeContext) {
|
||||||
SemanticParseInfo parseInfo = chatExecuteContext.getParseInfo();
|
SemanticParseInfo parseInfo = executeContext.getParseInfo();
|
||||||
if (!PluginQueryManager.isPluginQuery(parseInfo.getQueryMode())) {
|
if (!PluginQueryManager.isPluginQuery(parseInfo.getQueryMode())) {
|
||||||
return null;
|
return null;
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -2,28 +2,34 @@ package com.tencent.supersonic.chat.server.executor;
|
|||||||
|
|
||||||
import com.tencent.supersonic.chat.api.pojo.enums.MemoryStatus;
|
import com.tencent.supersonic.chat.api.pojo.enums.MemoryStatus;
|
||||||
import com.tencent.supersonic.chat.server.persistence.dataobject.ChatMemoryDO;
|
import com.tencent.supersonic.chat.server.persistence.dataobject.ChatMemoryDO;
|
||||||
import com.tencent.supersonic.chat.server.pojo.ChatExecuteContext;
|
import com.tencent.supersonic.chat.server.pojo.ExecuteContext;
|
||||||
import com.tencent.supersonic.chat.server.service.MemoryService;
|
import com.tencent.supersonic.chat.server.service.MemoryService;
|
||||||
import com.tencent.supersonic.chat.server.util.ResultFormatter;
|
import com.tencent.supersonic.chat.server.util.ResultFormatter;
|
||||||
|
import com.tencent.supersonic.common.pojo.Text2SQLExemplar;
|
||||||
import com.tencent.supersonic.common.util.ContextUtils;
|
import com.tencent.supersonic.common.util.ContextUtils;
|
||||||
|
import com.tencent.supersonic.common.util.JsonUtil;
|
||||||
import com.tencent.supersonic.headless.api.pojo.SemanticParseInfo;
|
import com.tencent.supersonic.headless.api.pojo.SemanticParseInfo;
|
||||||
import com.tencent.supersonic.headless.api.pojo.request.ExecuteQueryReq;
|
import com.tencent.supersonic.headless.api.pojo.request.QuerySqlReq;
|
||||||
import com.tencent.supersonic.headless.api.pojo.response.QueryResult;
|
import com.tencent.supersonic.chat.api.pojo.response.QueryResult;
|
||||||
import com.tencent.supersonic.headless.api.pojo.response.QueryState;
|
import com.tencent.supersonic.headless.api.pojo.response.QueryState;
|
||||||
|
import com.tencent.supersonic.headless.api.pojo.response.SemanticQueryResp;
|
||||||
|
import com.tencent.supersonic.chat.server.pojo.ChatContext;
|
||||||
import com.tencent.supersonic.headless.chat.query.llm.s2sql.LLMSqlQuery;
|
import com.tencent.supersonic.headless.chat.query.llm.s2sql.LLMSqlQuery;
|
||||||
import com.tencent.supersonic.headless.server.facade.service.ChatQueryService;
|
import com.tencent.supersonic.headless.server.facade.service.SemanticLayerService;
|
||||||
|
import com.tencent.supersonic.chat.server.service.ChatContextService;
|
||||||
import lombok.SneakyThrows;
|
import lombok.SneakyThrows;
|
||||||
import org.apache.commons.lang3.StringUtils;
|
import org.apache.commons.lang3.StringUtils;
|
||||||
import java.util.Date;
|
|
||||||
|
|
||||||
public class SqlExecutor implements ChatExecutor {
|
import java.util.Date;
|
||||||
|
import java.util.Objects;
|
||||||
|
|
||||||
|
public class SqlExecutor implements ChatQueryExecutor {
|
||||||
|
|
||||||
@SneakyThrows
|
@SneakyThrows
|
||||||
@Override
|
@Override
|
||||||
public QueryResult execute(ChatExecuteContext chatExecuteContext) {
|
public QueryResult execute(ExecuteContext executeContext) {
|
||||||
ExecuteQueryReq executeQueryReq = buildExecuteReq(chatExecuteContext);
|
QueryResult queryResult = doExecute(executeContext);
|
||||||
ChatQueryService chatQueryService = ContextUtils.getBean(ChatQueryService.class);
|
|
||||||
QueryResult queryResult = chatQueryService.performExecution(executeQueryReq);
|
|
||||||
if (queryResult != null) {
|
if (queryResult != null) {
|
||||||
String textResult = ResultFormatter.transform2TextNew(queryResult.getQueryColumns(),
|
String textResult = ResultFormatter.transform2TextNew(queryResult.getQueryColumns(),
|
||||||
queryResult.getQueryResults());
|
queryResult.getQueryResults());
|
||||||
@@ -31,14 +37,20 @@ public class SqlExecutor implements ChatExecutor {
|
|||||||
|
|
||||||
if (queryResult.getQueryState().equals(QueryState.SUCCESS)
|
if (queryResult.getQueryState().equals(QueryState.SUCCESS)
|
||||||
&& queryResult.getQueryMode().equals(LLMSqlQuery.QUERY_MODE)) {
|
&& queryResult.getQueryMode().equals(LLMSqlQuery.QUERY_MODE)) {
|
||||||
|
Text2SQLExemplar exemplar = JsonUtil.toObject(JsonUtil.toString(
|
||||||
|
executeContext.getParseInfo().getProperties()
|
||||||
|
.get(Text2SQLExemplar.PROPERTY_KEY)), Text2SQLExemplar.class);
|
||||||
|
|
||||||
MemoryService memoryService = ContextUtils.getBean(MemoryService.class);
|
MemoryService memoryService = ContextUtils.getBean(MemoryService.class);
|
||||||
memoryService.createMemory(ChatMemoryDO.builder()
|
memoryService.createMemory(ChatMemoryDO.builder()
|
||||||
.agentId(chatExecuteContext.getAgentId())
|
.agentId(executeContext.getAgent().getId())
|
||||||
.status(MemoryStatus.PENDING)
|
.status(MemoryStatus.PENDING)
|
||||||
.question(chatExecuteContext.getQueryText())
|
.question(exemplar.getQuestion())
|
||||||
.s2sql(chatExecuteContext.getParseInfo().getSqlInfo().getS2SQL())
|
.sideInfo(exemplar.getSideInfo())
|
||||||
.dbSchema(buildSchemaStr(chatExecuteContext.getParseInfo()))
|
.dbSchema(exemplar.getDbSchema())
|
||||||
.createdBy(chatExecuteContext.getUser().getName())
|
.s2sql(exemplar.getSql())
|
||||||
|
.createdBy(executeContext.getUser().getName())
|
||||||
|
.updatedBy(executeContext.getUser().getName())
|
||||||
.createdAt(new Date())
|
.createdAt(new Date())
|
||||||
.build());
|
.build());
|
||||||
}
|
}
|
||||||
@@ -47,48 +59,43 @@ public class SqlExecutor implements ChatExecutor {
|
|||||||
return queryResult;
|
return queryResult;
|
||||||
}
|
}
|
||||||
|
|
||||||
private ExecuteQueryReq buildExecuteReq(ChatExecuteContext chatExecuteContext) {
|
@SneakyThrows
|
||||||
SemanticParseInfo parseInfo = chatExecuteContext.getParseInfo();
|
private QueryResult doExecute(ExecuteContext executeContext) {
|
||||||
return ExecuteQueryReq.builder()
|
SemanticLayerService semanticLayer = ContextUtils.getBean(SemanticLayerService.class);
|
||||||
.queryId(chatExecuteContext.getQueryId())
|
ChatContextService chatContextService = ContextUtils.getBean(ChatContextService.class);
|
||||||
.chatId(chatExecuteContext.getChatId())
|
|
||||||
.queryText(chatExecuteContext.getQueryText())
|
ChatContext chatCtx = chatContextService.getOrCreateContext(executeContext.getChatId());
|
||||||
.parseInfo(parseInfo)
|
SemanticParseInfo parseInfo = executeContext.getParseInfo();
|
||||||
.saveAnswer(chatExecuteContext.isSaveAnswer())
|
if (Objects.isNull(parseInfo.getSqlInfo())
|
||||||
.user(chatExecuteContext.getUser())
|
|| StringUtils.isBlank(parseInfo.getSqlInfo().getCorrectedS2SQL())) {
|
||||||
|
return null;
|
||||||
|
}
|
||||||
|
|
||||||
|
QuerySqlReq sqlReq = QuerySqlReq.builder()
|
||||||
|
.sql(parseInfo.getSqlInfo().getCorrectedS2SQL())
|
||||||
.build();
|
.build();
|
||||||
}
|
sqlReq.setSqlInfo(parseInfo.getSqlInfo());
|
||||||
|
sqlReq.setDataSetId(parseInfo.getDataSetId());
|
||||||
|
|
||||||
public String buildSchemaStr(SemanticParseInfo parseInfo) {
|
long startTime = System.currentTimeMillis();
|
||||||
String tableStr = parseInfo.getDataSet().getName();
|
SemanticQueryResp queryResp = semanticLayer.queryByReq(sqlReq, executeContext.getUser());
|
||||||
StringBuilder metricStr = new StringBuilder();
|
QueryResult queryResult = new QueryResult();
|
||||||
StringBuilder dimensionStr = new StringBuilder();
|
queryResult.setChatContext(parseInfo);
|
||||||
|
queryResult.setQueryMode(parseInfo.getQueryMode());
|
||||||
|
queryResult.setQueryTimeCost(System.currentTimeMillis() - startTime);
|
||||||
|
if (queryResp != null) {
|
||||||
|
queryResult.setQueryAuthorization(queryResp.getQueryAuthorization());
|
||||||
|
queryResult.setQuerySql(queryResp.getSql());
|
||||||
|
queryResult.setQueryResults(queryResp.getResultList());
|
||||||
|
queryResult.setQueryColumns(queryResp.getColumns());
|
||||||
|
queryResult.setQueryState(QueryState.SUCCESS);
|
||||||
|
|
||||||
parseInfo.getMetrics().stream().forEach(
|
chatCtx.setParseInfo(parseInfo);
|
||||||
metric -> {
|
chatContextService.updateContext(chatCtx);
|
||||||
metricStr.append(metric.getName());
|
} else {
|
||||||
if (StringUtils.isNotEmpty(metric.getDescription())) {
|
queryResult.setQueryState(QueryState.INVALID);
|
||||||
metricStr.append(" COMMENT '" + metric.getDescription() + "'");
|
|
||||||
}
|
}
|
||||||
if (StringUtils.isNotEmpty(metric.getDefaultAgg())) {
|
return queryResult;
|
||||||
metricStr.append(" AGGREGATE '" + metric.getDefaultAgg().toUpperCase() + "'");
|
|
||||||
}
|
|
||||||
metricStr.append(",");
|
|
||||||
}
|
|
||||||
);
|
|
||||||
|
|
||||||
parseInfo.getDimensions().stream().forEach(
|
|
||||||
dimension -> {
|
|
||||||
dimensionStr.append(dimension.getName());
|
|
||||||
if (StringUtils.isNotEmpty(dimension.getDescription())) {
|
|
||||||
dimensionStr.append(" COMMENT '" + dimension.getDescription() + "'");
|
|
||||||
}
|
|
||||||
dimensionStr.append(",");
|
|
||||||
}
|
|
||||||
);
|
|
||||||
|
|
||||||
String template = "Table: %s, Metrics: [%s], Dimensions: [%s]";
|
|
||||||
return String.format(template, tableStr, metricStr, dimensionStr);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -7,7 +7,7 @@ import com.tencent.supersonic.chat.server.service.MemoryService;
|
|||||||
import dev.langchain4j.model.chat.ChatLanguageModel;
|
import dev.langchain4j.model.chat.ChatLanguageModel;
|
||||||
import dev.langchain4j.model.input.Prompt;
|
import dev.langchain4j.model.input.Prompt;
|
||||||
import dev.langchain4j.model.input.PromptTemplate;
|
import dev.langchain4j.model.input.PromptTemplate;
|
||||||
import dev.langchain4j.model.provider.ChatLanguageModelProvider;
|
import dev.langchain4j.provider.ModelProvider;
|
||||||
import lombok.extern.slf4j.Slf4j;
|
import lombok.extern.slf4j.Slf4j;
|
||||||
import org.slf4j.Logger;
|
import org.slf4j.Logger;
|
||||||
import org.slf4j.LoggerFactory;
|
import org.slf4j.LoggerFactory;
|
||||||
@@ -32,9 +32,11 @@ public class MemoryReviewTask {
|
|||||||
+ "please take a review and give your opinion.\n"
|
+ "please take a review and give your opinion.\n"
|
||||||
+ "#Rules: "
|
+ "#Rules: "
|
||||||
+ "1.ALWAYS follow the output format: `opinion=(POSITIVE|NEGATIVE),comment=(your comment)`."
|
+ "1.ALWAYS follow the output format: `opinion=(POSITIVE|NEGATIVE),comment=(your comment)`."
|
||||||
+ "2.DO NOT check the usage of `数据日期` field and `datediff()` function.\n"
|
+ "2.ALWAYS recognize `数据日期` as the date field."
|
||||||
|
+ "3.IGNORE `数据日期` if not expressed in the `Question`."
|
||||||
+ "#Question: %s\n"
|
+ "#Question: %s\n"
|
||||||
+ "#Schema: %s\n"
|
+ "#Schema: %s\n"
|
||||||
|
+ "#SideInfo: %s\n"
|
||||||
+ "#SQL: %s\n"
|
+ "#SQL: %s\n"
|
||||||
+ "#Response: ";
|
+ "#Response: ";
|
||||||
|
|
||||||
@@ -51,28 +53,33 @@ public class MemoryReviewTask {
|
|||||||
memoryService.getMemoriesForLlmReview().stream()
|
memoryService.getMemoriesForLlmReview().stream()
|
||||||
.forEach(m -> {
|
.forEach(m -> {
|
||||||
Agent chatAgent = agentService.getAgent(m.getAgentId());
|
Agent chatAgent = agentService.getAgent(m.getAgentId());
|
||||||
if (Objects.nonNull(chatAgent)) {
|
if (Objects.nonNull(chatAgent) && chatAgent.enableMemoryReview()) {
|
||||||
String promptStr = String.format(INSTRUCTION, m.getQuestion(), m.getDbSchema(), m.getS2sql());
|
String promptStr = String.format(INSTRUCTION, m.getQuestion(), m.getDbSchema(),
|
||||||
|
m.getSideInfo(), m.getS2sql());
|
||||||
Prompt prompt = PromptTemplate.from(promptStr).apply(Collections.EMPTY_MAP);
|
Prompt prompt = PromptTemplate.from(promptStr).apply(Collections.EMPTY_MAP);
|
||||||
|
|
||||||
keyPipelineLog.info("MemoryReviewTask reqPrompt:{}", promptStr);
|
keyPipelineLog.info("MemoryReviewTask reqPrompt:\n{}", promptStr);
|
||||||
ChatLanguageModel chatLanguageModel = ChatLanguageModelProvider.provide(
|
ChatLanguageModel chatLanguageModel = ModelProvider.getChatModel(
|
||||||
chatAgent.getLlmConfig());
|
chatAgent.getModelConfig());
|
||||||
if (Objects.nonNull(chatLanguageModel)) {
|
if (Objects.nonNull(chatLanguageModel)) {
|
||||||
String response = chatLanguageModel.generate(prompt.toUserMessage()).content().text();
|
String response = chatLanguageModel.generate(prompt.toUserMessage()).content().text();
|
||||||
keyPipelineLog.info("MemoryReviewTask modelResp:{}", response);
|
keyPipelineLog.info("MemoryReviewTask modelResp:\n{}", response);
|
||||||
|
|
||||||
Matcher matcher = OUTPUT_PATTERN.matcher(response);
|
Matcher matcher = OUTPUT_PATTERN.matcher(response);
|
||||||
if (matcher.find()) {
|
if (matcher.find()) {
|
||||||
m.setLlmReviewRet(MemoryReviewResult.valueOf(matcher.group(1)));
|
m.setLlmReviewRet(MemoryReviewResult.valueOf(matcher.group(1)));
|
||||||
m.setLlmReviewCmt(matcher.group(2));
|
m.setLlmReviewCmt(matcher.group(2));
|
||||||
|
// directly enable memory if the LLM determines it positive
|
||||||
|
if (MemoryReviewResult.POSITIVE.equals(m.getLlmReviewRet())) {
|
||||||
|
memoryService.enableMemory(m);
|
||||||
|
}
|
||||||
memoryService.updateMemory(m);
|
memoryService.updateMemory(m);
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
log.debug("ChatLanguageModel not found for agent:{}", chatAgent.getId());
|
log.debug("ChatLanguageModel not found for agent:{}", chatAgent.getId());
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
log.debug("Agent not found for memory:{}", m.getAgentId());
|
log.debug("Agent id {} not found or memory review disabled", m.getAgentId());
|
||||||
}
|
}
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,10 +0,0 @@
|
|||||||
package com.tencent.supersonic.chat.server.parser;
|
|
||||||
|
|
||||||
import com.tencent.supersonic.chat.server.pojo.ChatParseContext;
|
|
||||||
import com.tencent.supersonic.headless.api.pojo.response.ParseResp;
|
|
||||||
|
|
||||||
public interface ChatParser {
|
|
||||||
|
|
||||||
void parse(ChatParseContext chatParseContext, ParseResp parseResp);
|
|
||||||
|
|
||||||
}
|
|
||||||
@@ -0,0 +1,10 @@
|
|||||||
|
package com.tencent.supersonic.chat.server.parser;
|
||||||
|
|
||||||
|
import com.tencent.supersonic.chat.server.pojo.ParseContext;
|
||||||
|
import com.tencent.supersonic.headless.api.pojo.response.ParseResp;
|
||||||
|
|
||||||
|
public interface ChatQueryParser {
|
||||||
|
|
||||||
|
void parse(ParseContext parseContext, ParseResp parseResp);
|
||||||
|
|
||||||
|
}
|
||||||
@@ -1,7 +1,7 @@
|
|||||||
package com.tencent.supersonic.chat.server.parser;
|
package com.tencent.supersonic.chat.server.parser;
|
||||||
|
|
||||||
import com.tencent.supersonic.chat.server.plugin.recognize.PluginRecognizer;
|
import com.tencent.supersonic.chat.server.plugin.recognize.PluginRecognizer;
|
||||||
import com.tencent.supersonic.chat.server.pojo.ChatParseContext;
|
import com.tencent.supersonic.chat.server.pojo.ParseContext;
|
||||||
import com.tencent.supersonic.chat.server.util.ComponentFactory;
|
import com.tencent.supersonic.chat.server.util.ComponentFactory;
|
||||||
import com.tencent.supersonic.common.util.JsonUtil;
|
import com.tencent.supersonic.common.util.JsonUtil;
|
||||||
import com.tencent.supersonic.headless.api.pojo.response.ParseResp;
|
import com.tencent.supersonic.headless.api.pojo.response.ParseResp;
|
||||||
@@ -9,18 +9,18 @@ import lombok.extern.slf4j.Slf4j;
|
|||||||
import java.util.List;
|
import java.util.List;
|
||||||
|
|
||||||
@Slf4j
|
@Slf4j
|
||||||
public class NL2PluginParser implements ChatParser {
|
public class NL2PluginParser implements ChatQueryParser {
|
||||||
|
|
||||||
private final List<PluginRecognizer> pluginRecognizers = ComponentFactory.getPluginRecognizers();
|
private final List<PluginRecognizer> pluginRecognizers = ComponentFactory.getPluginRecognizers();
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public void parse(ChatParseContext chatParseContext, ParseResp parseResp) {
|
public void parse(ParseContext parseContext, ParseResp parseResp) {
|
||||||
if (!chatParseContext.getAgent().containsPluginTool()) {
|
if (!parseContext.getAgent().containsPluginTool()) {
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
pluginRecognizers.forEach(pluginRecognizer -> {
|
pluginRecognizers.forEach(pluginRecognizer -> {
|
||||||
pluginRecognizer.recognize(chatParseContext, parseResp);
|
pluginRecognizer.recognize(parseContext, parseResp);
|
||||||
log.info("{} recallResult:{}", pluginRecognizer.getClass().getSimpleName(),
|
log.info("{} recallResult:{}", pluginRecognizer.getClass().getSimpleName(),
|
||||||
JsonUtil.toString(parseResp));
|
JsonUtil.toString(parseResp));
|
||||||
});
|
});
|
||||||
|
|||||||
@@ -1,15 +1,14 @@
|
|||||||
package com.tencent.supersonic.chat.server.parser;
|
package com.tencent.supersonic.chat.server.parser;
|
||||||
|
|
||||||
import static com.tencent.supersonic.chat.server.parser.ParserConfig.PARSER_MULTI_TURN_ENABLE;
|
import com.tencent.supersonic.chat.api.pojo.response.QueryResp;
|
||||||
|
|
||||||
import com.tencent.supersonic.chat.server.agent.MultiTurnConfig;
|
import com.tencent.supersonic.chat.server.agent.MultiTurnConfig;
|
||||||
import com.tencent.supersonic.chat.server.persistence.repository.ChatQueryRepository;
|
|
||||||
import com.tencent.supersonic.chat.server.plugin.PluginQueryManager;
|
import com.tencent.supersonic.chat.server.plugin.PluginQueryManager;
|
||||||
import com.tencent.supersonic.chat.server.pojo.ChatParseContext;
|
import com.tencent.supersonic.chat.server.pojo.ParseContext;
|
||||||
|
import com.tencent.supersonic.chat.server.service.ChatContextService;
|
||||||
|
import com.tencent.supersonic.chat.server.service.ChatManageService;
|
||||||
import com.tencent.supersonic.chat.server.util.QueryReqConverter;
|
import com.tencent.supersonic.chat.server.util.QueryReqConverter;
|
||||||
import com.tencent.supersonic.common.config.EmbeddingConfig;
|
import com.tencent.supersonic.common.config.EmbeddingConfig;
|
||||||
import com.tencent.supersonic.common.config.LLMConfig;
|
import com.tencent.supersonic.common.pojo.Text2SQLExemplar;
|
||||||
import com.tencent.supersonic.common.pojo.SqlExemplar;
|
|
||||||
import com.tencent.supersonic.common.service.impl.ExemplarServiceImpl;
|
import com.tencent.supersonic.common.service.impl.ExemplarServiceImpl;
|
||||||
import com.tencent.supersonic.common.util.ContextUtils;
|
import com.tencent.supersonic.common.util.ContextUtils;
|
||||||
import com.tencent.supersonic.headless.api.pojo.SchemaElement;
|
import com.tencent.supersonic.headless.api.pojo.SchemaElement;
|
||||||
@@ -17,65 +16,93 @@ import com.tencent.supersonic.headless.api.pojo.SchemaElementMatch;
|
|||||||
import com.tencent.supersonic.headless.api.pojo.SchemaElementType;
|
import com.tencent.supersonic.headless.api.pojo.SchemaElementType;
|
||||||
import com.tencent.supersonic.headless.api.pojo.SemanticParseInfo;
|
import com.tencent.supersonic.headless.api.pojo.SemanticParseInfo;
|
||||||
import com.tencent.supersonic.headless.api.pojo.request.QueryFilter;
|
import com.tencent.supersonic.headless.api.pojo.request.QueryFilter;
|
||||||
import com.tencent.supersonic.headless.api.pojo.request.QueryReq;
|
import com.tencent.supersonic.headless.api.pojo.request.QueryNLReq;
|
||||||
import com.tencent.supersonic.headless.api.pojo.response.MapResp;
|
import com.tencent.supersonic.headless.api.pojo.response.MapResp;
|
||||||
import com.tencent.supersonic.headless.api.pojo.response.ParseResp;
|
import com.tencent.supersonic.headless.api.pojo.response.ParseResp;
|
||||||
import com.tencent.supersonic.headless.server.facade.service.ChatQueryService;
|
import com.tencent.supersonic.chat.server.pojo.ChatContext;
|
||||||
|
import com.tencent.supersonic.headless.api.pojo.response.QueryState;
|
||||||
|
import com.tencent.supersonic.headless.server.facade.service.ChatLayerService;
|
||||||
import dev.langchain4j.data.message.AiMessage;
|
import dev.langchain4j.data.message.AiMessage;
|
||||||
import dev.langchain4j.model.chat.ChatLanguageModel;
|
import dev.langchain4j.model.chat.ChatLanguageModel;
|
||||||
import dev.langchain4j.model.input.Prompt;
|
import dev.langchain4j.model.input.Prompt;
|
||||||
import dev.langchain4j.model.input.PromptTemplate;
|
import dev.langchain4j.model.input.PromptTemplate;
|
||||||
import dev.langchain4j.model.output.Response;
|
import dev.langchain4j.model.output.Response;
|
||||||
import dev.langchain4j.model.provider.ChatLanguageModelProvider;
|
import dev.langchain4j.provider.ModelProvider;
|
||||||
import java.util.ArrayList;
|
|
||||||
import java.util.Collections;
|
|
||||||
import java.util.List;
|
|
||||||
import java.util.Objects;
|
|
||||||
import java.util.Optional;
|
|
||||||
import java.util.Set;
|
|
||||||
import java.util.stream.Collectors;
|
|
||||||
import lombok.Builder;
|
|
||||||
import lombok.Data;
|
|
||||||
import lombok.extern.slf4j.Slf4j;
|
import lombok.extern.slf4j.Slf4j;
|
||||||
import org.slf4j.Logger;
|
import org.slf4j.Logger;
|
||||||
import org.slf4j.LoggerFactory;
|
import org.slf4j.LoggerFactory;
|
||||||
import org.springframework.util.CollectionUtils;
|
import org.springframework.util.CollectionUtils;
|
||||||
|
|
||||||
|
import java.util.ArrayList;
|
||||||
|
import java.util.Collections;
|
||||||
|
import java.util.HashMap;
|
||||||
|
import java.util.List;
|
||||||
|
import java.util.Map;
|
||||||
|
import java.util.Objects;
|
||||||
|
import java.util.Optional;
|
||||||
|
import java.util.Set;
|
||||||
|
import java.util.stream.Collectors;
|
||||||
|
|
||||||
|
import static com.tencent.supersonic.chat.server.parser.ParserConfig.PARSER_MULTI_TURN_ENABLE;
|
||||||
|
|
||||||
@Slf4j
|
@Slf4j
|
||||||
public class NL2SQLParser implements ChatParser {
|
public class NL2SQLParser implements ChatQueryParser {
|
||||||
|
|
||||||
private static final Logger keyPipelineLog = LoggerFactory.getLogger("keyPipeline");
|
private static final Logger keyPipelineLog = LoggerFactory.getLogger("keyPipeline");
|
||||||
|
|
||||||
private static final String REWRITE_INSTRUCTION = ""
|
private static final String REWRITE_USER_QUESTION_INSTRUCTION = ""
|
||||||
+ "#Role: You are a data product manager experienced in data requirements.\n"
|
+ "#Role: You are a data product manager experienced in data requirements."
|
||||||
+ "#Task: Your will be provided with current and history questions asked by a user,"
|
+ "#Task: Your will be provided with current and history questions asked by a user,"
|
||||||
+ "along with their mapped schema elements(metric, dimension and value),"
|
+ "along with their mapped schema elements(metric, dimension and value),"
|
||||||
+ "please try understanding the semantics and rewrite a question.\n"
|
+ "please try understanding the semantics and rewrite a question."
|
||||||
+ "#Rules: "
|
+ "#Rules: "
|
||||||
+ "1.ALWAYS keep relevant entities, metrics, dimensions, values and date ranges. "
|
+ "1.ALWAYS keep relevant entities, metrics, dimensions, values and date ranges."
|
||||||
+ "2.ONLY respond with the rewritten question.\n"
|
+ "2.ONLY respond with the rewritten question."
|
||||||
+ "#Current Question: %s\n"
|
+ "#Current Question: {{current_question}}"
|
||||||
+ "#Current Mapped Schema: %s\n"
|
+ "#Current Mapped Schema: {{current_schema}}"
|
||||||
+ "#History Question: %s\n"
|
+ "#History Question: {{history_question}}"
|
||||||
+ "#History Mapped Schema: %s\n"
|
+ "#History Mapped Schema: {{history_schema}}"
|
||||||
+ "#History SQL: %s\n"
|
+ "#History SQL: {{history_sql}}"
|
||||||
+ "#Rewritten Question: ";
|
+ "#Rewritten Question: ";
|
||||||
|
|
||||||
|
private static final String REWRITE_ERROR_MESSAGE_INSTRUCTION = ""
|
||||||
|
+ "#Role: You are a data business partner who closely interacts with business people.\n"
|
||||||
|
+ "#Task: Your will be provided with user input, system output and some examples, "
|
||||||
|
+ "please respond shortly to teach user how to ask the right question, "
|
||||||
|
+ "by using `Examples` as references."
|
||||||
|
+ "#Rules: ALWAYS respond with the same language as the `Input`.\n"
|
||||||
|
+ "#Input: {{user_question}}\n"
|
||||||
|
+ "#Output: {{system_message}}\n"
|
||||||
|
+ "#Examples: {{examples}}\n"
|
||||||
|
+ "#Response: ";
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public void parse(ChatParseContext chatParseContext, ParseResp parseResp) {
|
public void parse(ParseContext parseContext, ParseResp parseResp) {
|
||||||
if (!chatParseContext.enableNL2SQL() || checkSkip(parseResp)) {
|
if (!parseContext.enableNL2SQL() || checkSkip(parseResp)) {
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
processMultiTurn(chatParseContext);
|
ChatContextService chatContextService = ContextUtils.getBean(ChatContextService.class);
|
||||||
|
ChatContext chatCtx = chatContextService.getOrCreateContext(parseContext.getChatId());
|
||||||
|
|
||||||
QueryReq queryReq = QueryReqConverter.buildText2SqlQueryReq(chatParseContext);
|
ChatLanguageModel chatLanguageModel = ModelProvider.getChatModel(
|
||||||
addExemplars(chatParseContext.getAgent().getId(), queryReq);
|
parseContext.getAgent().getModelConfig());
|
||||||
|
|
||||||
ChatQueryService chatQueryService = ContextUtils.getBean(ChatQueryService.class);
|
processMultiTurn(chatLanguageModel, parseContext);
|
||||||
ParseResp text2SqlParseResp = chatQueryService.performParsing(queryReq);
|
QueryNLReq queryNLReq = QueryReqConverter.buildText2SqlQueryReq(parseContext, chatCtx);
|
||||||
if (!ParseResp.ParseState.FAILED.equals(text2SqlParseResp.getState())) {
|
addDynamicExemplars(parseContext.getAgent().getId(), queryNLReq);
|
||||||
|
|
||||||
|
ChatLayerService chatLayerService = ContextUtils.getBean(ChatLayerService.class);
|
||||||
|
ParseResp text2SqlParseResp = chatLayerService.performParsing(queryNLReq);
|
||||||
|
if (ParseResp.ParseState.COMPLETED.equals(text2SqlParseResp.getState())) {
|
||||||
parseResp.getSelectedParses().addAll(text2SqlParseResp.getSelectedParses());
|
parseResp.getSelectedParses().addAll(text2SqlParseResp.getSelectedParses());
|
||||||
|
} else {
|
||||||
|
parseResp.setErrorMsg(rewriteErrorMessage(chatLanguageModel,
|
||||||
|
parseContext.getQueryText(),
|
||||||
|
text2SqlParseResp.getErrorMsg(),
|
||||||
|
queryNLReq.getDynamicExemplars(),
|
||||||
|
parseContext.getAgent().getExamples()));
|
||||||
}
|
}
|
||||||
|
parseResp.setState(text2SqlParseResp.getState());
|
||||||
parseResp.getParseTimeCost().setSqlTime(text2SqlParseResp.getParseTimeCost().getSqlTime());
|
parseResp.getParseTimeCost().setSqlTime(text2SqlParseResp.getParseTimeCost().getSqlTime());
|
||||||
formatParseResult(parseResp);
|
formatParseResult(parseResp);
|
||||||
}
|
}
|
||||||
@@ -135,9 +162,9 @@ public class NL2SQLParser implements ChatParser {
|
|||||||
parseInfo.setTextInfo(textBuilder.toString());
|
parseInfo.setTextInfo(textBuilder.toString());
|
||||||
}
|
}
|
||||||
|
|
||||||
private void processMultiTurn(ChatParseContext chatParseContext) {
|
private void processMultiTurn(ChatLanguageModel chatLanguageModel, ParseContext parseContext) {
|
||||||
ParserConfig parserConfig = ContextUtils.getBean(ParserConfig.class);
|
ParserConfig parserConfig = ContextUtils.getBean(ParserConfig.class);
|
||||||
MultiTurnConfig agentMultiTurnConfig = chatParseContext.getAgent().getMultiTurnConfig();
|
MultiTurnConfig agentMultiTurnConfig = parseContext.getAgent().getMultiTurnConfig();
|
||||||
Boolean globalMultiTurnConfig = Boolean.valueOf(parserConfig.getParameterValue(PARSER_MULTI_TURN_ENABLE));
|
Boolean globalMultiTurnConfig = Boolean.valueOf(parserConfig.getParameterValue(PARSER_MULTI_TURN_ENABLE));
|
||||||
|
|
||||||
Boolean multiTurnConfig = agentMultiTurnConfig != null
|
Boolean multiTurnConfig = agentMultiTurnConfig != null
|
||||||
@@ -147,45 +174,63 @@ public class NL2SQLParser implements ChatParser {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// derive mapping result of current question and parsing result of last question.
|
// derive mapping result of current question and parsing result of last question.
|
||||||
ChatQueryService chatQueryService = ContextUtils.getBean(ChatQueryService.class);
|
ChatLayerService chatLayerService = ContextUtils.getBean(ChatLayerService.class);
|
||||||
QueryReq queryReq = QueryReqConverter.buildText2SqlQueryReq(chatParseContext);
|
QueryNLReq queryNLReq = QueryReqConverter.buildText2SqlQueryReq(parseContext);
|
||||||
MapResp currentMapResult = chatQueryService.performMapping(queryReq);
|
MapResp currentMapResult = chatLayerService.performMapping(queryNLReq);
|
||||||
|
|
||||||
List<ParseResp> historyParseResults = getHistoryParseResult(chatParseContext.getChatId(), 1);
|
List<QueryResp> historyQueries = getHistoryQueries(parseContext.getChatId(), 1);
|
||||||
if (historyParseResults.size() == 0) {
|
if (historyQueries.size() == 0) {
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
ParseResp lastParseResult = historyParseResults.get(0);
|
QueryResp lastQuery = historyQueries.get(0);
|
||||||
Long dataId = lastParseResult.getSelectedParses().get(0).getDataSetId();
|
SemanticParseInfo lastParseInfo = lastQuery.getParseInfos().get(0);
|
||||||
|
Long dataId = lastParseInfo.getDataSetId();
|
||||||
|
|
||||||
String curtMapStr = generateSchemaPrompt(currentMapResult.getMapInfo().getMatchedElements(dataId));
|
String curtMapStr = generateSchemaPrompt(currentMapResult.getMapInfo().getMatchedElements(dataId));
|
||||||
String histMapStr = generateSchemaPrompt(lastParseResult.getSelectedParses().get(0).getElementMatches());
|
String histMapStr = generateSchemaPrompt(lastParseInfo.getElementMatches());
|
||||||
String histSQL = lastParseResult.getSelectedParses().get(0).getSqlInfo().getCorrectS2SQL();
|
String histSQL = lastParseInfo.getSqlInfo().getCorrectedS2SQL();
|
||||||
String rewrittenQuery = rewriteQuery(RewriteContext.builder()
|
|
||||||
.curtQuestion(currentMapResult.getQueryText())
|
Map<String, Object> variables = new HashMap<>();
|
||||||
.histQuestion(lastParseResult.getQueryText())
|
variables.put("current_question", currentMapResult.getQueryText());
|
||||||
.curtSchema(curtMapStr)
|
variables.put("current_schema", curtMapStr);
|
||||||
.histSchema(histMapStr)
|
variables.put("history_question", lastQuery.getQueryText());
|
||||||
.histSQL(histSQL)
|
variables.put("history_schema", histMapStr);
|
||||||
.llmConfig(queryReq.getLlmConfig())
|
variables.put("history_sql", histSQL);
|
||||||
.build());
|
|
||||||
chatParseContext.setQueryText(rewrittenQuery);
|
Prompt prompt = PromptTemplate.from(REWRITE_USER_QUESTION_INSTRUCTION).apply(variables);
|
||||||
|
keyPipelineLog.info("NL2SQLParser reqPrompt:{}", prompt.text());
|
||||||
|
|
||||||
|
Response<AiMessage> response = chatLanguageModel.generate(prompt.toUserMessage());
|
||||||
|
String rewrittenQuery = response.content().text();
|
||||||
|
keyPipelineLog.info("NL2SQLParser modelResp:{}", rewrittenQuery);
|
||||||
|
|
||||||
|
parseContext.setQueryText(rewrittenQuery);
|
||||||
log.info("Last Query: {} Current Query: {}, Rewritten Query: {}",
|
log.info("Last Query: {} Current Query: {}, Rewritten Query: {}",
|
||||||
lastParseResult.getQueryText(), currentMapResult.getQueryText(), rewrittenQuery);
|
lastQuery.getQueryText(), currentMapResult.getQueryText(), rewrittenQuery);
|
||||||
}
|
}
|
||||||
|
|
||||||
private String rewriteQuery(RewriteContext context) {
|
private String rewriteErrorMessage(ChatLanguageModel chatLanguageModel, String userQuestion,
|
||||||
String promptStr = String.format(REWRITE_INSTRUCTION, context.getCurtQuestion(), context.getCurtSchema(),
|
String errMsg, List<Text2SQLExemplar> similarExemplars,
|
||||||
context.getHistQuestion(), context.getHistSchema(), context.getHistSQL());
|
List<String> agentExamples) {
|
||||||
Prompt prompt = PromptTemplate.from(promptStr).apply(Collections.EMPTY_MAP);
|
Map<String, Object> variables = new HashMap<>();
|
||||||
keyPipelineLog.info("NL2SQLParser reqPrompt:{}", promptStr);
|
variables.put("user_question", userQuestion);
|
||||||
|
variables.put("system_message", errMsg);
|
||||||
|
|
||||||
ChatLanguageModel chatLanguageModel = ChatLanguageModelProvider.provide(context.getLlmConfig());
|
StringBuilder exampleStr = new StringBuilder();
|
||||||
|
similarExemplars.forEach(e ->
|
||||||
|
exampleStr.append(String.format("<Question:{%s},Schema:{%s}> ", e.getQuestion(), e.getDbSchema())));
|
||||||
|
agentExamples.forEach(e ->
|
||||||
|
exampleStr.append(String.format("<Question:{%s}> ", e)));
|
||||||
|
variables.put("examples", exampleStr);
|
||||||
|
|
||||||
|
Prompt prompt = PromptTemplate.from(REWRITE_ERROR_MESSAGE_INSTRUCTION).apply(variables);
|
||||||
|
keyPipelineLog.info("NL2SQLParser reqPrompt:{}", prompt.text());
|
||||||
Response<AiMessage> response = chatLanguageModel.generate(prompt.toUserMessage());
|
Response<AiMessage> response = chatLanguageModel.generate(prompt.toUserMessage());
|
||||||
|
|
||||||
String result = response.content().text();
|
String rewrittenMsg = response.content().text();
|
||||||
keyPipelineLog.info("NL2SQLParser modelResp:{}", result);
|
keyPipelineLog.info("NL2SQLParser modelResp:{}", rewrittenMsg);
|
||||||
return response.content().text();
|
|
||||||
|
return rewrittenMsg;
|
||||||
}
|
}
|
||||||
|
|
||||||
private String generateSchemaPrompt(List<SchemaElementMatch> elementMatches) {
|
private String generateSchemaPrompt(List<SchemaElementMatch> elementMatches) {
|
||||||
@@ -213,36 +258,27 @@ public class NL2SQLParser implements ChatParser {
|
|||||||
return prompt.toString();
|
return prompt.toString();
|
||||||
}
|
}
|
||||||
|
|
||||||
private List<ParseResp> getHistoryParseResult(int chatId, int multiNum) {
|
private List<QueryResp> getHistoryQueries(int chatId, int multiNum) {
|
||||||
ChatQueryRepository chatQueryRepository = ContextUtils.getBean(ChatQueryRepository.class);
|
ChatManageService chatManageService = ContextUtils.getBean(ChatManageService.class);
|
||||||
List<ParseResp> contextualParseInfoList = chatQueryRepository.getContextualParseInfo(chatId)
|
List<QueryResp> contextualParseInfoList = chatManageService.getChatQueries(chatId)
|
||||||
.stream().filter(p -> p.getState() != ParseResp.ParseState.FAILED).collect(Collectors.toList());
|
.stream()
|
||||||
|
.filter(q -> Objects.nonNull(q.getQueryResult())
|
||||||
|
&& q.getQueryResult().getQueryState() == QueryState.SUCCESS)
|
||||||
|
.collect(Collectors.toList());
|
||||||
|
|
||||||
List<ParseResp> contextualList = contextualParseInfoList.subList(0,
|
List<QueryResp> contextualList = contextualParseInfoList.subList(0,
|
||||||
Math.min(multiNum, contextualParseInfoList.size()));
|
Math.min(multiNum, contextualParseInfoList.size()));
|
||||||
Collections.reverse(contextualList);
|
Collections.reverse(contextualList);
|
||||||
return contextualList;
|
return contextualList;
|
||||||
}
|
}
|
||||||
|
|
||||||
private void addExemplars(Integer agentId, QueryReq queryReq) {
|
private void addDynamicExemplars(Integer agentId, QueryNLReq queryNLReq) {
|
||||||
ExemplarServiceImpl exemplarManager = ContextUtils.getBean(ExemplarServiceImpl.class);
|
ExemplarServiceImpl exemplarManager = ContextUtils.getBean(ExemplarServiceImpl.class);
|
||||||
EmbeddingConfig embeddingConfig = ContextUtils.getBean(EmbeddingConfig.class);
|
EmbeddingConfig embeddingConfig = ContextUtils.getBean(EmbeddingConfig.class);
|
||||||
String memoryCollectionName = embeddingConfig.getMemoryCollectionName(agentId);
|
String memoryCollectionName = embeddingConfig.getMemoryCollectionName(agentId);
|
||||||
List<SqlExemplar> exemplars = exemplarManager.recallExemplars(memoryCollectionName,
|
List<Text2SQLExemplar> exemplars = exemplarManager.recallExemplars(memoryCollectionName,
|
||||||
queryReq.getQueryText(), 5);
|
queryNLReq.getQueryText(), 5);
|
||||||
queryReq.getExemplars().addAll(exemplars);
|
queryNLReq.getDynamicExemplars().addAll(exemplars);
|
||||||
}
|
|
||||||
|
|
||||||
@Builder
|
|
||||||
@Data
|
|
||||||
public static class RewriteContext {
|
|
||||||
|
|
||||||
private String curtQuestion;
|
|
||||||
private String histQuestion;
|
|
||||||
private String curtSchema;
|
|
||||||
private String histSchema;
|
|
||||||
private String histSQL;
|
|
||||||
private LLMConfig llmConfig;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,14 +1,11 @@
|
|||||||
package com.tencent.supersonic.chat.server.parser;
|
package com.tencent.supersonic.chat.server.parser;
|
||||||
|
|
||||||
import com.google.common.collect.Lists;
|
|
||||||
import com.tencent.supersonic.common.config.ParameterConfig;
|
import com.tencent.supersonic.common.config.ParameterConfig;
|
||||||
import com.tencent.supersonic.common.pojo.Parameter;
|
import com.tencent.supersonic.common.pojo.Parameter;
|
||||||
import lombok.extern.slf4j.Slf4j;
|
import lombok.extern.slf4j.Slf4j;
|
||||||
import org.springframework.stereotype.Service;
|
import org.springframework.stereotype.Service;
|
||||||
|
|
||||||
import java.util.List;
|
@Service("ChatQueryParserConfig")
|
||||||
|
|
||||||
@Service("ChatParserConfig")
|
|
||||||
@Slf4j
|
@Slf4j
|
||||||
public class ParserConfig extends ParameterConfig {
|
public class ParserConfig extends ParameterConfig {
|
||||||
|
|
||||||
@@ -17,11 +14,4 @@ public class ParserConfig extends ParameterConfig {
|
|||||||
"是否开启多轮对话", "开启多轮对话将消耗更多token",
|
"是否开启多轮对话", "开启多轮对话将消耗更多token",
|
||||||
"bool", "Parser相关配置");
|
"bool", "Parser相关配置");
|
||||||
|
|
||||||
@Override
|
|
||||||
public List<Parameter> getSysParameters() {
|
|
||||||
return Lists.newArrayList(
|
|
||||||
PARSER_MULTI_TURN_ENABLE
|
|
||||||
);
|
|
||||||
}
|
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,21 +1,22 @@
|
|||||||
package com.tencent.supersonic.chat.server.parser;
|
package com.tencent.supersonic.chat.server.parser;
|
||||||
|
|
||||||
import com.tencent.supersonic.chat.server.pojo.ChatParseContext;
|
import com.tencent.supersonic.chat.server.pojo.ParseContext;
|
||||||
import com.tencent.supersonic.headless.api.pojo.SemanticParseInfo;
|
import com.tencent.supersonic.headless.api.pojo.SemanticParseInfo;
|
||||||
import com.tencent.supersonic.headless.api.pojo.response.ParseResp;
|
import com.tencent.supersonic.headless.api.pojo.response.ParseResp;
|
||||||
|
|
||||||
|
|
||||||
public class PlainTextParser implements ChatParser {
|
public class PlainTextParser implements ChatQueryParser {
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public void parse(ChatParseContext chatParseContext, ParseResp parseResp) {
|
public void parse(ParseContext parseContext, ParseResp parseResp) {
|
||||||
if (chatParseContext.getAgent().containsAnyTool()) {
|
if (parseContext.getAgent().containsAnyTool()) {
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
SemanticParseInfo parseInfo = new SemanticParseInfo();
|
SemanticParseInfo parseInfo = new SemanticParseInfo();
|
||||||
parseInfo.setQueryMode("PLAIN_TEXT");
|
parseInfo.setQueryMode("PLAIN_TEXT");
|
||||||
parseResp.getSelectedParses().add(parseInfo);
|
parseResp.getSelectedParses().add(parseInfo);
|
||||||
|
parseResp.setState(ParseResp.ParseState.COMPLETED);
|
||||||
}
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -3,7 +3,6 @@ package com.tencent.supersonic.chat.server.persistence.dataobject;
|
|||||||
import com.baomidou.mybatisplus.annotation.IdType;
|
import com.baomidou.mybatisplus.annotation.IdType;
|
||||||
import com.baomidou.mybatisplus.annotation.TableId;
|
import com.baomidou.mybatisplus.annotation.TableId;
|
||||||
import com.baomidou.mybatisplus.annotation.TableName;
|
import com.baomidou.mybatisplus.annotation.TableName;
|
||||||
import com.tencent.supersonic.common.config.VisualConfig;
|
|
||||||
import lombok.Data;
|
import lombok.Data;
|
||||||
|
|
||||||
import java.util.Date;
|
import java.util.Date;
|
||||||
@@ -12,15 +11,18 @@ import java.util.Date;
|
|||||||
@TableName("s2_agent")
|
@TableName("s2_agent")
|
||||||
public class AgentDO {
|
public class AgentDO {
|
||||||
/**
|
/**
|
||||||
|
*
|
||||||
*/
|
*/
|
||||||
@TableId(type = IdType.AUTO)
|
@TableId(type = IdType.AUTO)
|
||||||
private Integer id;
|
private Integer id;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
*
|
||||||
*/
|
*/
|
||||||
private String name;
|
private String name;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
*
|
||||||
*/
|
*/
|
||||||
private String description;
|
private String description;
|
||||||
|
|
||||||
@@ -30,37 +32,45 @@ public class AgentDO {
|
|||||||
private Integer status;
|
private Integer status;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
*
|
||||||
*/
|
*/
|
||||||
private String examples;
|
private String examples;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
*
|
||||||
*/
|
*/
|
||||||
private String config;
|
private String config;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
*
|
||||||
*/
|
*/
|
||||||
private String createdBy;
|
private String createdBy;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
*
|
||||||
*/
|
*/
|
||||||
private Date createdAt;
|
private Date createdAt;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
*
|
||||||
*/
|
*/
|
||||||
private String updatedBy;
|
private String updatedBy;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
*
|
||||||
*/
|
*/
|
||||||
private Date updatedAt;
|
private Date updatedAt;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
*
|
||||||
*/
|
*/
|
||||||
private Integer enableSearch;
|
private Integer enableSearch;
|
||||||
|
private Integer enableMemoryReview;
|
||||||
private String llmConfig;
|
private String modelConfig;
|
||||||
|
|
||||||
private String multiTurnConfig;
|
private String multiTurnConfig;
|
||||||
|
|
||||||
private String visualConfig;
|
private String visualConfig;
|
||||||
|
|
||||||
|
private String promptConfig;
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,4 +1,4 @@
|
|||||||
package com.tencent.supersonic.headless.server.persistence.dataobject;
|
package com.tencent.supersonic.chat.server.persistence.dataobject;
|
||||||
|
|
||||||
import lombok.Data;
|
import lombok.Data;
|
||||||
|
|
||||||
@@ -20,11 +20,14 @@ public class ChatMemoryDO {
|
|||||||
@TableId(type = IdType.AUTO)
|
@TableId(type = IdType.AUTO)
|
||||||
private Long id;
|
private Long id;
|
||||||
|
|
||||||
|
@TableField("agent_id")
|
||||||
|
private Integer agentId;
|
||||||
|
|
||||||
@TableField("question")
|
@TableField("question")
|
||||||
private String question;
|
private String question;
|
||||||
|
|
||||||
@TableField("agent_id")
|
@TableField("side_info")
|
||||||
private Integer agentId;
|
private String sideInfo;
|
||||||
|
|
||||||
@TableField("db_schema")
|
@TableField("db_schema")
|
||||||
private String dbSchema;
|
private String dbSchema;
|
||||||
|
|||||||
@@ -1,142 +1,25 @@
|
|||||||
package com.tencent.supersonic.chat.server.persistence.dataobject;
|
package com.tencent.supersonic.chat.server.persistence.dataobject;
|
||||||
|
|
||||||
|
import lombok.Data;
|
||||||
|
|
||||||
import java.util.Date;
|
import java.util.Date;
|
||||||
|
|
||||||
|
@Data
|
||||||
public class ChatParseDO {
|
public class ChatParseDO {
|
||||||
|
|
||||||
/**
|
|
||||||
* questionId
|
|
||||||
*/
|
|
||||||
private Long questionId;
|
private Long questionId;
|
||||||
|
|
||||||
/**
|
private Integer chatId;
|
||||||
* chatId
|
|
||||||
*/
|
|
||||||
private Long chatId;
|
|
||||||
|
|
||||||
/**
|
|
||||||
* parseId
|
|
||||||
*/
|
|
||||||
private Integer parseId;
|
private Integer parseId;
|
||||||
|
|
||||||
/**
|
|
||||||
* createTime
|
|
||||||
*/
|
|
||||||
private Date createTime;
|
private Date createTime;
|
||||||
|
|
||||||
/**
|
|
||||||
* queryText
|
|
||||||
*/
|
|
||||||
private String queryText;
|
private String queryText;
|
||||||
|
|
||||||
/**
|
|
||||||
* userName
|
|
||||||
*/
|
|
||||||
private String userName;
|
private String userName;
|
||||||
|
|
||||||
|
|
||||||
/**
|
|
||||||
* parseInfo
|
|
||||||
*/
|
|
||||||
private String parseInfo;
|
private String parseInfo;
|
||||||
|
|
||||||
/**
|
|
||||||
* isCandidate
|
|
||||||
*/
|
|
||||||
private Integer isCandidate;
|
private Integer isCandidate;
|
||||||
|
|
||||||
/**
|
|
||||||
* return question_id
|
|
||||||
*/
|
|
||||||
public Long getQuestionId() {
|
|
||||||
return questionId;
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* questionId
|
|
||||||
*/
|
|
||||||
public void setQuestionId(Long questionId) {
|
|
||||||
this.questionId = questionId;
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* return create_time
|
|
||||||
*/
|
|
||||||
public Date getCreateTime() {
|
|
||||||
return createTime;
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* createTime
|
|
||||||
*/
|
|
||||||
public void setCreateTime(Date createTime) {
|
|
||||||
this.createTime = createTime;
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* return user_name
|
|
||||||
*/
|
|
||||||
public String getUserName() {
|
|
||||||
return userName;
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* userName
|
|
||||||
*/
|
|
||||||
public void setUserName(String userName) {
|
|
||||||
this.userName = userName == null ? null : userName.trim();
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* return chat_id
|
|
||||||
*/
|
|
||||||
public Long getChatId() {
|
|
||||||
return chatId;
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* chatId
|
|
||||||
*/
|
|
||||||
public void setChatId(Long chatId) {
|
|
||||||
this.chatId = chatId;
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* return query_text
|
|
||||||
*/
|
|
||||||
public String getQueryText() {
|
|
||||||
return queryText;
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* queryText
|
|
||||||
*/
|
|
||||||
public void setQueryText(String queryText) {
|
|
||||||
this.queryText = queryText == null ? null : queryText.trim();
|
|
||||||
}
|
|
||||||
|
|
||||||
public Integer getIsCandidate() {
|
|
||||||
return isCandidate;
|
|
||||||
}
|
|
||||||
|
|
||||||
public Integer getParseId() {
|
|
||||||
return parseId;
|
|
||||||
}
|
|
||||||
|
|
||||||
public String getParseInfo() {
|
|
||||||
return parseInfo;
|
|
||||||
}
|
|
||||||
|
|
||||||
public void setParseId(Integer parseId) {
|
|
||||||
this.parseId = parseId;
|
|
||||||
}
|
|
||||||
|
|
||||||
public void setIsCandidate(Integer isCandidate) {
|
|
||||||
this.isCandidate = isCandidate;
|
|
||||||
}
|
|
||||||
|
|
||||||
public void setParseInfo(String parseInfo) {
|
|
||||||
this.parseInfo = parseInfo;
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,4 +1,4 @@
|
|||||||
package com.tencent.supersonic.headless.server.persistence.dataobject;
|
package com.tencent.supersonic.chat.server.persistence.dataobject;
|
||||||
|
|
||||||
import lombok.AllArgsConstructor;
|
import lombok.AllArgsConstructor;
|
||||||
import lombok.Data;
|
import lombok.Data;
|
||||||
@@ -1,6 +1,6 @@
|
|||||||
package com.tencent.supersonic.headless.server.persistence.mapper;
|
package com.tencent.supersonic.chat.server.persistence.mapper;
|
||||||
|
|
||||||
import com.tencent.supersonic.headless.server.persistence.dataobject.ChatContextDO;
|
import com.tencent.supersonic.chat.server.persistence.dataobject.ChatContextDO;
|
||||||
import org.apache.ibatis.annotations.Mapper;
|
import org.apache.ibatis.annotations.Mapper;
|
||||||
|
|
||||||
@Mapper
|
@Mapper
|
||||||
@@ -1,6 +1,6 @@
|
|||||||
package com.tencent.supersonic.headless.server.persistence.mapper;
|
package com.tencent.supersonic.chat.server.persistence.mapper;
|
||||||
|
|
||||||
import com.tencent.supersonic.headless.server.persistence.dataobject.StatisticsDO;
|
import com.tencent.supersonic.chat.server.persistence.dataobject.StatisticsDO;
|
||||||
import org.apache.ibatis.annotations.Mapper;
|
import org.apache.ibatis.annotations.Mapper;
|
||||||
import org.apache.ibatis.annotations.Param;
|
import org.apache.ibatis.annotations.Param;
|
||||||
|
|
||||||
@@ -1,6 +1,6 @@
|
|||||||
package com.tencent.supersonic.headless.server.persistence.repository;
|
package com.tencent.supersonic.chat.server.persistence.repository;
|
||||||
|
|
||||||
import com.tencent.supersonic.headless.chat.ChatContext;
|
import com.tencent.supersonic.chat.server.pojo.ChatContext;
|
||||||
|
|
||||||
public interface ChatContextRepository {
|
public interface ChatContextRepository {
|
||||||
|
|
||||||
@@ -18,6 +18,8 @@ public interface ChatQueryRepository {
|
|||||||
|
|
||||||
QueryResp getChatQuery(Long queryId);
|
QueryResp getChatQuery(Long queryId);
|
||||||
|
|
||||||
|
List<QueryResp> getChatQueries(Integer chatId);
|
||||||
|
|
||||||
ChatQueryDO getChatQueryDO(Long queryId);
|
ChatQueryDO getChatQueryDO(Long queryId);
|
||||||
|
|
||||||
List<QueryResp> queryShowCase(PageQueryInfoReq pageQueryInfoCommend, int agentId);
|
List<QueryResp> queryShowCase(PageQueryInfoReq pageQueryInfoCommend, int agentId);
|
||||||
@@ -35,6 +37,4 @@ public interface ChatQueryRepository {
|
|||||||
|
|
||||||
List<ChatParseDO> getParseInfoList(List<Long> questionIds);
|
List<ChatParseDO> getParseInfoList(List<Long> questionIds);
|
||||||
|
|
||||||
List<ParseResp> getContextualParseInfo(Integer chatId);
|
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,12 +1,12 @@
|
|||||||
package com.tencent.supersonic.headless.server.persistence.repository.impl;
|
package com.tencent.supersonic.chat.server.persistence.repository.impl;
|
||||||
|
|
||||||
import com.google.gson.Gson;
|
import com.google.gson.Gson;
|
||||||
|
import com.tencent.supersonic.chat.server.persistence.repository.ChatContextRepository;
|
||||||
|
import com.tencent.supersonic.chat.server.pojo.ChatContext;
|
||||||
import com.tencent.supersonic.common.util.JsonUtil;
|
import com.tencent.supersonic.common.util.JsonUtil;
|
||||||
import com.tencent.supersonic.headless.api.pojo.SemanticParseInfo;
|
import com.tencent.supersonic.headless.api.pojo.SemanticParseInfo;
|
||||||
import com.tencent.supersonic.headless.chat.ChatContext;
|
import com.tencent.supersonic.chat.server.persistence.dataobject.ChatContextDO;
|
||||||
import com.tencent.supersonic.headless.server.persistence.dataobject.ChatContextDO;
|
import com.tencent.supersonic.chat.server.persistence.mapper.ChatContextMapper;
|
||||||
import com.tencent.supersonic.headless.server.persistence.mapper.ChatContextMapper;
|
|
||||||
import com.tencent.supersonic.headless.server.persistence.repository.ChatContextRepository;
|
|
||||||
import lombok.extern.slf4j.Slf4j;
|
import lombok.extern.slf4j.Slf4j;
|
||||||
import org.springframework.context.annotation.Primary;
|
import org.springframework.context.annotation.Primary;
|
||||||
import org.springframework.stereotype.Repository;
|
import org.springframework.stereotype.Repository;
|
||||||
@@ -20,7 +20,7 @@ import com.tencent.supersonic.headless.api.pojo.SemanticParseInfo;
|
|||||||
import com.tencent.supersonic.headless.api.pojo.response.ParseResp;
|
import com.tencent.supersonic.headless.api.pojo.response.ParseResp;
|
||||||
import com.tencent.supersonic.chat.api.pojo.response.QueryResp;
|
import com.tencent.supersonic.chat.api.pojo.response.QueryResp;
|
||||||
import com.tencent.supersonic.headless.api.pojo.response.ParseTimeCostResp;
|
import com.tencent.supersonic.headless.api.pojo.response.ParseTimeCostResp;
|
||||||
import com.tencent.supersonic.headless.api.pojo.response.QueryResult;
|
import com.tencent.supersonic.chat.api.pojo.response.QueryResult;
|
||||||
import lombok.extern.slf4j.Slf4j;
|
import lombok.extern.slf4j.Slf4j;
|
||||||
import org.apache.commons.lang3.StringUtils;
|
import org.apache.commons.lang3.StringUtils;
|
||||||
import org.springframework.beans.BeanUtils;
|
import org.springframework.beans.BeanUtils;
|
||||||
@@ -61,7 +61,8 @@ public class ChatQueryRepositoryImpl implements ChatQueryRepository {
|
|||||||
if (!CollectionUtils.isEmpty(pageQueryInfoReq.getIds())) {
|
if (!CollectionUtils.isEmpty(pageQueryInfoReq.getIds())) {
|
||||||
queryWrapper.lambda().in(ChatQueryDO::getQuestionId, pageQueryInfoReq.getIds());
|
queryWrapper.lambda().in(ChatQueryDO::getQuestionId, pageQueryInfoReq.getIds());
|
||||||
}
|
}
|
||||||
|
queryWrapper.lambda().isNotNull(ChatQueryDO::getQueryResult);
|
||||||
|
queryWrapper.lambda().ne(ChatQueryDO::getQueryResult, "");
|
||||||
queryWrapper.lambda().orderByDesc(ChatQueryDO::getQuestionId);
|
queryWrapper.lambda().orderByDesc(ChatQueryDO::getQuestionId);
|
||||||
|
|
||||||
PageInfo<ChatQueryDO> pageInfo = PageHelper.startPage(pageQueryInfoReq.getCurrent(),
|
PageInfo<ChatQueryDO> pageInfo = PageHelper.startPage(pageQueryInfoReq.getCurrent(),
|
||||||
@@ -70,8 +71,9 @@ public class ChatQueryRepositoryImpl implements ChatQueryRepository {
|
|||||||
|
|
||||||
PageInfo<QueryResp> chatQueryVOPageInfo = PageUtils.pageInfo2PageInfoVo(pageInfo);
|
PageInfo<QueryResp> chatQueryVOPageInfo = PageUtils.pageInfo2PageInfoVo(pageInfo);
|
||||||
chatQueryVOPageInfo.setList(
|
chatQueryVOPageInfo.setList(
|
||||||
pageInfo.getList().stream().filter(o -> !StringUtils.isEmpty(o.getQueryResult())).map(this::convertTo)
|
pageInfo.getList().stream()
|
||||||
.sorted(Comparator.comparingInt(o -> o.getQuestionId().intValue()))
|
.sorted(Comparator.comparingInt(o -> o.getQuestionId().intValue()))
|
||||||
|
.map(this::convertTo)
|
||||||
.collect(Collectors.toList()));
|
.collect(Collectors.toList()));
|
||||||
return chatQueryVOPageInfo;
|
return chatQueryVOPageInfo;
|
||||||
}
|
}
|
||||||
@@ -90,6 +92,16 @@ public class ChatQueryRepositoryImpl implements ChatQueryRepository {
|
|||||||
return chatQueryDOMapper.selectById(queryId);
|
return chatQueryDOMapper.selectById(queryId);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public List<QueryResp> getChatQueries(Integer chatId) {
|
||||||
|
QueryWrapper<ChatQueryDO> queryWrapper = new QueryWrapper<>();
|
||||||
|
queryWrapper.lambda().eq(ChatQueryDO::getChatId, chatId);
|
||||||
|
queryWrapper.lambda().orderByDesc(ChatQueryDO::getQuestionId);
|
||||||
|
return chatQueryDOMapper.selectList(queryWrapper).stream()
|
||||||
|
.map(q -> convertTo(q))
|
||||||
|
.collect(Collectors.toList());
|
||||||
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public List<QueryResp> queryShowCase(PageQueryInfoReq pageQueryInfoReq, int agentId) {
|
public List<QueryResp> queryShowCase(PageQueryInfoReq pageQueryInfoReq, int agentId) {
|
||||||
return showCaseCustomMapper.queryShowCase(pageQueryInfoReq.getLimitStart(),
|
return showCaseCustomMapper.queryShowCase(pageQueryInfoReq.getLimitStart(),
|
||||||
@@ -145,7 +157,7 @@ public class ChatQueryRepositoryImpl implements ChatQueryRepository {
|
|||||||
List<SemanticParseInfo> parses, List<ChatParseDO> chatParseDOList) {
|
List<SemanticParseInfo> parses, List<ChatParseDO> chatParseDOList) {
|
||||||
for (int i = 0; i < parses.size(); i++) {
|
for (int i = 0; i < parses.size(); i++) {
|
||||||
ChatParseDO chatParseDO = new ChatParseDO();
|
ChatParseDO chatParseDO = new ChatParseDO();
|
||||||
chatParseDO.setChatId(Long.valueOf(chatParseReq.getChatId()));
|
chatParseDO.setChatId(chatParseReq.getChatId());
|
||||||
chatParseDO.setQuestionId(queryId);
|
chatParseDO.setQuestionId(queryId);
|
||||||
chatParseDO.setQueryText(chatParseReq.getQueryText());
|
chatParseDO.setQueryText(chatParseReq.getQueryText());
|
||||||
chatParseDO.setParseInfo(JsonUtil.toString(parses.get(i)));
|
chatParseDO.setParseInfo(JsonUtil.toString(parses.get(i)));
|
||||||
@@ -179,17 +191,4 @@ public class ChatQueryRepositoryImpl implements ChatQueryRepository {
|
|||||||
return chatParseMapper.getParseInfoList(questionIds);
|
return chatParseMapper.getParseInfoList(questionIds);
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
|
||||||
public List<ParseResp> getContextualParseInfo(Integer chatId) {
|
|
||||||
List<ChatParseDO> chatParseDOList = chatParseMapper.getContextualParseInfo(chatId);
|
|
||||||
List<ParseResp> semanticParseInfoList = chatParseDOList.stream().map(parseInfo -> {
|
|
||||||
ParseResp parseResp = new ParseResp(chatId, parseInfo.getQueryText());
|
|
||||||
List<SemanticParseInfo> selectedParses = new ArrayList<>();
|
|
||||||
selectedParses.add(JSONObject.parseObject(parseInfo.getParseInfo(), SemanticParseInfo.class));
|
|
||||||
parseResp.setSelectedParses(selectedParses);
|
|
||||||
return parseResp;
|
|
||||||
}).collect(Collectors.toList());
|
|
||||||
return semanticParseInfoList;
|
|
||||||
}
|
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -11,7 +11,7 @@ import com.tencent.supersonic.chat.server.plugin.build.WebBase;
|
|||||||
import com.tencent.supersonic.chat.server.plugin.event.PluginAddEvent;
|
import com.tencent.supersonic.chat.server.plugin.event.PluginAddEvent;
|
||||||
import com.tencent.supersonic.chat.server.plugin.event.PluginDelEvent;
|
import com.tencent.supersonic.chat.server.plugin.event.PluginDelEvent;
|
||||||
import com.tencent.supersonic.chat.server.plugin.event.PluginUpdateEvent;
|
import com.tencent.supersonic.chat.server.plugin.event.PluginUpdateEvent;
|
||||||
import com.tencent.supersonic.chat.server.pojo.ChatParseContext;
|
import com.tencent.supersonic.chat.server.pojo.ParseContext;
|
||||||
import com.tencent.supersonic.chat.server.service.PluginService;
|
import com.tencent.supersonic.chat.server.service.PluginService;
|
||||||
import com.tencent.supersonic.common.config.EmbeddingConfig;
|
import com.tencent.supersonic.common.config.EmbeddingConfig;
|
||||||
import com.tencent.supersonic.common.service.EmbeddingService;
|
import com.tencent.supersonic.common.service.EmbeddingService;
|
||||||
@@ -52,9 +52,9 @@ public class PluginManager {
|
|||||||
@Autowired
|
@Autowired
|
||||||
private EmbeddingService embeddingService;
|
private EmbeddingService embeddingService;
|
||||||
|
|
||||||
public static List<ChatPlugin> getPluginAgentCanSupport(ChatParseContext chatParseContext) {
|
public static List<ChatPlugin> getPluginAgentCanSupport(ParseContext parseContext) {
|
||||||
PluginService pluginService = ContextUtils.getBean(PluginService.class);
|
PluginService pluginService = ContextUtils.getBean(PluginService.class);
|
||||||
Agent agent = chatParseContext.getAgent();
|
Agent agent = parseContext.getAgent();
|
||||||
List<ChatPlugin> plugins = pluginService.getPluginList();
|
List<ChatPlugin> plugins = pluginService.getPluginList();
|
||||||
if (Objects.isNull(agent)) {
|
if (Objects.isNull(agent)) {
|
||||||
return plugins;
|
return plugins;
|
||||||
@@ -191,9 +191,9 @@ public class PluginManager {
|
|||||||
return String.valueOf(Integer.parseInt(id) / 1000);
|
return String.valueOf(Integer.parseInt(id) / 1000);
|
||||||
}
|
}
|
||||||
|
|
||||||
public static Pair<Boolean, Set<Long>> resolve(ChatPlugin plugin, ChatParseContext chatParseContext) {
|
public static Pair<Boolean, Set<Long>> resolve(ChatPlugin plugin, ParseContext parseContext) {
|
||||||
SchemaMapInfo schemaMapInfo = chatParseContext.getMapInfo();
|
SchemaMapInfo schemaMapInfo = parseContext.getMapInfo();
|
||||||
Set<Long> pluginMatchedDataSet = getPluginMatchedDataSet(plugin, chatParseContext);
|
Set<Long> pluginMatchedDataSet = getPluginMatchedDataSet(plugin, parseContext);
|
||||||
if (CollectionUtils.isEmpty(pluginMatchedDataSet) && !plugin.isContainsAllDataSet()) {
|
if (CollectionUtils.isEmpty(pluginMatchedDataSet) && !plugin.isContainsAllDataSet()) {
|
||||||
return Pair.of(false, Sets.newHashSet());
|
return Pair.of(false, Sets.newHashSet());
|
||||||
}
|
}
|
||||||
@@ -259,8 +259,8 @@ public class PluginManager {
|
|||||||
.collect(Collectors.toList());
|
.collect(Collectors.toList());
|
||||||
}
|
}
|
||||||
|
|
||||||
private static Set<Long> getPluginMatchedDataSet(ChatPlugin plugin, ChatParseContext chatParseContext) {
|
private static Set<Long> getPluginMatchedDataSet(ChatPlugin plugin, ParseContext parseContext) {
|
||||||
Set<Long> matchedDataSets = chatParseContext.getMapInfo().getMatchedDataSetInfos();
|
Set<Long> matchedDataSets = parseContext.getMapInfo().getMatchedDataSetInfos();
|
||||||
if (plugin.isContainsAllDataSet()) {
|
if (plugin.isContainsAllDataSet()) {
|
||||||
return Sets.newHashSet(plugin.getDefaultMode());
|
return Sets.newHashSet(plugin.getDefaultMode());
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -7,7 +7,7 @@ import com.tencent.supersonic.headless.api.pojo.SchemaElementType;
|
|||||||
import com.tencent.supersonic.headless.api.pojo.SemanticParseInfo;
|
import com.tencent.supersonic.headless.api.pojo.SemanticParseInfo;
|
||||||
import com.tencent.supersonic.headless.api.pojo.request.QueryFilter;
|
import com.tencent.supersonic.headless.api.pojo.request.QueryFilter;
|
||||||
import com.tencent.supersonic.headless.api.pojo.request.QueryFilters;
|
import com.tencent.supersonic.headless.api.pojo.request.QueryFilters;
|
||||||
import com.tencent.supersonic.headless.api.pojo.response.QueryResult;
|
import com.tencent.supersonic.chat.api.pojo.response.QueryResult;
|
||||||
import lombok.extern.slf4j.Slf4j;
|
import lombok.extern.slf4j.Slf4j;
|
||||||
import org.springframework.util.CollectionUtils;
|
import org.springframework.util.CollectionUtils;
|
||||||
|
|
||||||
|
|||||||
@@ -7,7 +7,7 @@ import com.tencent.supersonic.chat.server.plugin.build.PluginSemanticQuery;
|
|||||||
import com.tencent.supersonic.chat.server.plugin.build.WebBase;
|
import com.tencent.supersonic.chat.server.plugin.build.WebBase;
|
||||||
import com.tencent.supersonic.common.pojo.Constants;
|
import com.tencent.supersonic.common.pojo.Constants;
|
||||||
import com.tencent.supersonic.common.util.JsonUtil;
|
import com.tencent.supersonic.common.util.JsonUtil;
|
||||||
import com.tencent.supersonic.headless.api.pojo.response.QueryResult;
|
import com.tencent.supersonic.chat.api.pojo.response.QueryResult;
|
||||||
import com.tencent.supersonic.headless.api.pojo.response.QueryState;
|
import com.tencent.supersonic.headless.api.pojo.response.QueryState;
|
||||||
import lombok.extern.slf4j.Slf4j;
|
import lombok.extern.slf4j.Slf4j;
|
||||||
import org.springframework.stereotype.Component;
|
import org.springframework.stereotype.Component;
|
||||||
|
|||||||
@@ -11,7 +11,7 @@ import com.tencent.supersonic.common.pojo.Constants;
|
|||||||
import com.tencent.supersonic.common.pojo.QueryColumn;
|
import com.tencent.supersonic.common.pojo.QueryColumn;
|
||||||
import com.tencent.supersonic.common.util.ContextUtils;
|
import com.tencent.supersonic.common.util.ContextUtils;
|
||||||
import com.tencent.supersonic.common.util.JsonUtil;
|
import com.tencent.supersonic.common.util.JsonUtil;
|
||||||
import com.tencent.supersonic.headless.api.pojo.response.QueryResult;
|
import com.tencent.supersonic.chat.api.pojo.response.QueryResult;
|
||||||
import com.tencent.supersonic.headless.api.pojo.response.QueryState;
|
import com.tencent.supersonic.headless.api.pojo.response.QueryState;
|
||||||
import lombok.extern.slf4j.Slf4j;
|
import lombok.extern.slf4j.Slf4j;
|
||||||
import org.springframework.http.HttpEntity;
|
import org.springframework.http.HttpEntity;
|
||||||
|
|||||||
@@ -6,7 +6,7 @@ import com.tencent.supersonic.chat.server.plugin.ChatPlugin;
|
|||||||
import com.tencent.supersonic.chat.server.plugin.PluginManager;
|
import com.tencent.supersonic.chat.server.plugin.PluginManager;
|
||||||
import com.tencent.supersonic.chat.server.plugin.PluginParseResult;
|
import com.tencent.supersonic.chat.server.plugin.PluginParseResult;
|
||||||
import com.tencent.supersonic.chat.server.plugin.PluginRecallResult;
|
import com.tencent.supersonic.chat.server.plugin.PluginRecallResult;
|
||||||
import com.tencent.supersonic.chat.server.pojo.ChatParseContext;
|
import com.tencent.supersonic.chat.server.pojo.ParseContext;
|
||||||
import com.tencent.supersonic.common.pojo.Constants;
|
import com.tencent.supersonic.common.pojo.Constants;
|
||||||
import com.tencent.supersonic.common.pojo.enums.FilterOperatorEnum;
|
import com.tencent.supersonic.common.pojo.enums.FilterOperatorEnum;
|
||||||
import com.tencent.supersonic.headless.api.pojo.SchemaElement;
|
import com.tencent.supersonic.headless.api.pojo.SchemaElement;
|
||||||
@@ -28,22 +28,22 @@ import java.util.Set;
|
|||||||
*/
|
*/
|
||||||
public abstract class PluginRecognizer {
|
public abstract class PluginRecognizer {
|
||||||
|
|
||||||
public void recognize(ChatParseContext chatParseContext, ParseResp parseResp) {
|
public void recognize(ParseContext parseContext, ParseResp parseResp) {
|
||||||
if (!checkPreCondition(chatParseContext)) {
|
if (!checkPreCondition(parseContext)) {
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
PluginRecallResult pluginRecallResult = recallPlugin(chatParseContext);
|
PluginRecallResult pluginRecallResult = recallPlugin(parseContext);
|
||||||
if (pluginRecallResult == null) {
|
if (pluginRecallResult == null) {
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
buildQuery(chatParseContext, parseResp, pluginRecallResult);
|
buildQuery(parseContext, parseResp, pluginRecallResult);
|
||||||
}
|
}
|
||||||
|
|
||||||
public abstract boolean checkPreCondition(ChatParseContext chatParseContext);
|
public abstract boolean checkPreCondition(ParseContext parseContext);
|
||||||
|
|
||||||
public abstract PluginRecallResult recallPlugin(ChatParseContext chatParseContext);
|
public abstract PluginRecallResult recallPlugin(ParseContext parseContext);
|
||||||
|
|
||||||
public void buildQuery(ChatParseContext chatParseContext, ParseResp parseResp,
|
public void buildQuery(ParseContext parseContext, ParseResp parseResp,
|
||||||
PluginRecallResult pluginRecallResult) {
|
PluginRecallResult pluginRecallResult) {
|
||||||
ChatPlugin plugin = pluginRecallResult.getPlugin();
|
ChatPlugin plugin = pluginRecallResult.getPlugin();
|
||||||
Set<Long> dataSetIds = pluginRecallResult.getDataSetIds();
|
Set<Long> dataSetIds = pluginRecallResult.getDataSetIds();
|
||||||
@@ -52,35 +52,35 @@ public abstract class PluginRecognizer {
|
|||||||
}
|
}
|
||||||
for (Long dataSetId : dataSetIds) {
|
for (Long dataSetId : dataSetIds) {
|
||||||
SemanticParseInfo semanticParseInfo = buildSemanticParseInfo(dataSetId, plugin,
|
SemanticParseInfo semanticParseInfo = buildSemanticParseInfo(dataSetId, plugin,
|
||||||
chatParseContext, pluginRecallResult.getDistance());
|
parseContext, pluginRecallResult.getDistance());
|
||||||
semanticParseInfo.setQueryMode(plugin.getType());
|
semanticParseInfo.setQueryMode(plugin.getType());
|
||||||
semanticParseInfo.setScore(pluginRecallResult.getScore());
|
semanticParseInfo.setScore(pluginRecallResult.getScore());
|
||||||
parseResp.getSelectedParses().add(semanticParseInfo);
|
parseResp.getSelectedParses().add(semanticParseInfo);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
protected List<ChatPlugin> getPluginList(ChatParseContext chatParseContext) {
|
protected List<ChatPlugin> getPluginList(ParseContext parseContext) {
|
||||||
return PluginManager.getPluginAgentCanSupport(chatParseContext);
|
return PluginManager.getPluginAgentCanSupport(parseContext);
|
||||||
}
|
}
|
||||||
|
|
||||||
protected SemanticParseInfo buildSemanticParseInfo(Long dataSetId, ChatPlugin plugin,
|
protected SemanticParseInfo buildSemanticParseInfo(Long dataSetId, ChatPlugin plugin,
|
||||||
ChatParseContext chatParseContext, double distance) {
|
ParseContext parseContext, double distance) {
|
||||||
List<SchemaElementMatch> schemaElementMatches = chatParseContext.getMapInfo().getMatchedElements(dataSetId);
|
List<SchemaElementMatch> schemaElementMatches = parseContext.getMapInfo().getMatchedElements(dataSetId);
|
||||||
QueryFilters queryFilters = chatParseContext.getQueryFilters();
|
QueryFilters queryFilters = parseContext.getQueryFilters();
|
||||||
if (schemaElementMatches == null) {
|
if (schemaElementMatches == null) {
|
||||||
schemaElementMatches = Lists.newArrayList();
|
schemaElementMatches = Lists.newArrayList();
|
||||||
}
|
}
|
||||||
SemanticParseInfo semanticParseInfo = new SemanticParseInfo();
|
SemanticParseInfo semanticParseInfo = new SemanticParseInfo();
|
||||||
semanticParseInfo.setElementMatches(schemaElementMatches);
|
semanticParseInfo.setElementMatches(schemaElementMatches);
|
||||||
SchemaElement schemaElement = new SchemaElement();
|
SchemaElement schemaElement = new SchemaElement();
|
||||||
schemaElement.setDataSet(dataSetId);
|
schemaElement.setDataSetId(dataSetId);
|
||||||
semanticParseInfo.setDataSet(schemaElement);
|
semanticParseInfo.setDataSet(schemaElement);
|
||||||
Map<String, Object> properties = new HashMap<>();
|
Map<String, Object> properties = new HashMap<>();
|
||||||
PluginParseResult pluginParseResult = new PluginParseResult();
|
PluginParseResult pluginParseResult = new PluginParseResult();
|
||||||
pluginParseResult.setPlugin(plugin);
|
pluginParseResult.setPlugin(plugin);
|
||||||
pluginParseResult.setQueryFilters(queryFilters);
|
pluginParseResult.setQueryFilters(queryFilters);
|
||||||
pluginParseResult.setDistance(distance);
|
pluginParseResult.setDistance(distance);
|
||||||
pluginParseResult.setQueryText(chatParseContext.getQueryText());
|
pluginParseResult.setQueryText(parseContext.getQueryText());
|
||||||
properties.put(Constants.CONTEXT, pluginParseResult);
|
properties.put(Constants.CONTEXT, pluginParseResult);
|
||||||
properties.put("type", "plugin");
|
properties.put("type", "plugin");
|
||||||
properties.put("name", plugin.getName());
|
properties.put("name", plugin.getName());
|
||||||
|
|||||||
@@ -6,7 +6,7 @@ import com.tencent.supersonic.chat.server.plugin.ChatPlugin;
|
|||||||
import com.tencent.supersonic.chat.server.plugin.PluginManager;
|
import com.tencent.supersonic.chat.server.plugin.PluginManager;
|
||||||
import com.tencent.supersonic.chat.server.plugin.PluginRecallResult;
|
import com.tencent.supersonic.chat.server.plugin.PluginRecallResult;
|
||||||
import com.tencent.supersonic.chat.server.plugin.recognize.PluginRecognizer;
|
import com.tencent.supersonic.chat.server.plugin.recognize.PluginRecognizer;
|
||||||
import com.tencent.supersonic.chat.server.pojo.ChatParseContext;
|
import com.tencent.supersonic.chat.server.pojo.ParseContext;
|
||||||
import com.tencent.supersonic.common.util.ContextUtils;
|
import com.tencent.supersonic.common.util.ContextUtils;
|
||||||
import dev.langchain4j.store.embedding.Retrieval;
|
import dev.langchain4j.store.embedding.Retrieval;
|
||||||
import dev.langchain4j.store.embedding.RetrieveQueryResult;
|
import dev.langchain4j.store.embedding.RetrieveQueryResult;
|
||||||
@@ -26,25 +26,25 @@ import java.util.stream.Collectors;
|
|||||||
@Slf4j
|
@Slf4j
|
||||||
public class EmbeddingRecallRecognizer extends PluginRecognizer {
|
public class EmbeddingRecallRecognizer extends PluginRecognizer {
|
||||||
|
|
||||||
public boolean checkPreCondition(ChatParseContext chatParseContext) {
|
public boolean checkPreCondition(ParseContext parseContext) {
|
||||||
List<ChatPlugin> plugins = getPluginList(chatParseContext);
|
List<ChatPlugin> plugins = getPluginList(parseContext);
|
||||||
return !CollectionUtils.isEmpty(plugins);
|
return !CollectionUtils.isEmpty(plugins);
|
||||||
}
|
}
|
||||||
|
|
||||||
public PluginRecallResult recallPlugin(ChatParseContext chatParseContext) {
|
public PluginRecallResult recallPlugin(ParseContext parseContext) {
|
||||||
String text = chatParseContext.getQueryText();
|
String text = parseContext.getQueryText();
|
||||||
List<Retrieval> embeddingRetrievals = embeddingRecall(text);
|
List<Retrieval> embeddingRetrievals = embeddingRecall(text);
|
||||||
if (CollectionUtils.isEmpty(embeddingRetrievals)) {
|
if (CollectionUtils.isEmpty(embeddingRetrievals)) {
|
||||||
return null;
|
return null;
|
||||||
}
|
}
|
||||||
List<ChatPlugin> plugins = getPluginList(chatParseContext);
|
List<ChatPlugin> plugins = getPluginList(parseContext);
|
||||||
Map<Long, ChatPlugin> pluginMap = plugins.stream().collect(Collectors.toMap(ChatPlugin::getId, p -> p));
|
Map<Long, ChatPlugin> pluginMap = plugins.stream().collect(Collectors.toMap(ChatPlugin::getId, p -> p));
|
||||||
for (Retrieval embeddingRetrieval : embeddingRetrievals) {
|
for (Retrieval embeddingRetrieval : embeddingRetrievals) {
|
||||||
ChatPlugin plugin = pluginMap.get(Long.parseLong(embeddingRetrieval.getId()));
|
ChatPlugin plugin = pluginMap.get(Long.parseLong(embeddingRetrieval.getId()));
|
||||||
if (plugin == null) {
|
if (plugin == null) {
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
Pair<Boolean, Set<Long>> pair = PluginManager.resolve(plugin, chatParseContext);
|
Pair<Boolean, Set<Long>> pair = PluginManager.resolve(plugin, parseContext);
|
||||||
log.info("embedding plugin resolve: {}", pair);
|
log.info("embedding plugin resolve: {}", pair);
|
||||||
if (pair.getLeft()) {
|
if (pair.getLeft()) {
|
||||||
Set<Long> dataSetList = pair.getRight();
|
Set<Long> dataSetList = pair.getRight();
|
||||||
@@ -53,7 +53,7 @@ public class EmbeddingRecallRecognizer extends PluginRecognizer {
|
|||||||
}
|
}
|
||||||
plugin.setParseMode(ParseMode.EMBEDDING_RECALL);
|
plugin.setParseMode(ParseMode.EMBEDDING_RECALL);
|
||||||
double distance = embeddingRetrieval.getDistance();
|
double distance = embeddingRetrieval.getDistance();
|
||||||
double score = chatParseContext.getQueryText().length() * (1 - distance);
|
double score = parseContext.getQueryText().length() * (1 - distance);
|
||||||
return PluginRecallResult.builder()
|
return PluginRecallResult.builder()
|
||||||
.plugin(plugin).dataSetIds(dataSetList).score(score).distance(distance).build();
|
.plugin(plugin).dataSetIds(dataSetList).score(score).distance(distance).build();
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,4 +1,4 @@
|
|||||||
package com.tencent.supersonic.headless.chat;
|
package com.tencent.supersonic.chat.server.pojo;
|
||||||
|
|
||||||
|
|
||||||
import com.tencent.supersonic.headless.api.pojo.SemanticParseInfo;
|
import com.tencent.supersonic.headless.api.pojo.SemanticParseInfo;
|
||||||
@@ -6,7 +6,6 @@ import lombok.Data;
|
|||||||
|
|
||||||
@Data
|
@Data
|
||||||
public class ChatContext {
|
public class ChatContext {
|
||||||
|
|
||||||
private Integer chatId;
|
private Integer chatId;
|
||||||
private String queryText;
|
private String queryText;
|
||||||
private SemanticParseInfo parseInfo = new SemanticParseInfo();
|
private SemanticParseInfo parseInfo = new SemanticParseInfo();
|
||||||
@@ -1,17 +1,17 @@
|
|||||||
package com.tencent.supersonic.chat.server.pojo;
|
package com.tencent.supersonic.chat.server.pojo;
|
||||||
|
|
||||||
import com.tencent.supersonic.auth.api.authentication.pojo.User;
|
import com.tencent.supersonic.auth.api.authentication.pojo.User;
|
||||||
|
import com.tencent.supersonic.chat.server.agent.Agent;
|
||||||
import com.tencent.supersonic.headless.api.pojo.SemanticParseInfo;
|
import com.tencent.supersonic.headless.api.pojo.SemanticParseInfo;
|
||||||
import lombok.Data;
|
import lombok.Data;
|
||||||
|
|
||||||
@Data
|
@Data
|
||||||
public class ChatExecuteContext {
|
public class ExecuteContext {
|
||||||
private User user;
|
private User user;
|
||||||
private Integer agentId;
|
|
||||||
private Long queryId;
|
|
||||||
private Integer chatId;
|
|
||||||
private int parseId;
|
|
||||||
private String queryText;
|
private String queryText;
|
||||||
|
private Agent agent;
|
||||||
|
private Integer chatId;
|
||||||
|
private Long queryId;
|
||||||
private boolean saveAnswer;
|
private boolean saveAnswer;
|
||||||
private SemanticParseInfo parseInfo;
|
private SemanticParseInfo parseInfo;
|
||||||
}
|
}
|
||||||
@@ -7,14 +7,14 @@ import com.tencent.supersonic.headless.api.pojo.request.QueryFilters;
|
|||||||
import lombok.Data;
|
import lombok.Data;
|
||||||
|
|
||||||
@Data
|
@Data
|
||||||
public class ChatParseContext {
|
public class ParseContext {
|
||||||
private String queryText;
|
|
||||||
private Integer chatId;
|
|
||||||
private Agent agent;
|
|
||||||
private User user;
|
private User user;
|
||||||
|
private String queryText;
|
||||||
|
private Agent agent;
|
||||||
|
private Integer chatId;
|
||||||
private QueryFilters queryFilters;
|
private QueryFilters queryFilters;
|
||||||
private boolean saveAnswer = true;
|
private boolean saveAnswer = true;
|
||||||
private SchemaMapInfo mapInfo = new SchemaMapInfo();
|
private SchemaMapInfo mapInfo;
|
||||||
|
|
||||||
public boolean enableNL2SQL() {
|
public boolean enableNL2SQL() {
|
||||||
if (agent == null) {
|
if (agent == null) {
|
||||||
@@ -5,5 +5,4 @@ package com.tencent.supersonic.chat.server.processor;
|
|||||||
*/
|
*/
|
||||||
public interface ResultProcessor {
|
public interface ResultProcessor {
|
||||||
|
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,14 +1,14 @@
|
|||||||
package com.tencent.supersonic.chat.server.processor.execute;
|
package com.tencent.supersonic.chat.server.processor.execute;
|
||||||
|
|
||||||
import com.google.common.collect.Lists;
|
import com.google.common.collect.Lists;
|
||||||
import com.tencent.supersonic.chat.server.pojo.ChatExecuteContext;
|
import com.tencent.supersonic.chat.server.pojo.ExecuteContext;
|
||||||
import com.tencent.supersonic.common.pojo.enums.QueryType;
|
import com.tencent.supersonic.common.pojo.enums.QueryType;
|
||||||
import com.tencent.supersonic.common.util.ContextUtils;
|
import com.tencent.supersonic.common.util.ContextUtils;
|
||||||
import com.tencent.supersonic.headless.api.pojo.DataSetSchema;
|
import com.tencent.supersonic.headless.api.pojo.DataSetSchema;
|
||||||
import com.tencent.supersonic.headless.api.pojo.RelatedSchemaElement;
|
import com.tencent.supersonic.headless.api.pojo.RelatedSchemaElement;
|
||||||
import com.tencent.supersonic.headless.api.pojo.SchemaElement;
|
import com.tencent.supersonic.headless.api.pojo.SchemaElement;
|
||||||
import com.tencent.supersonic.headless.api.pojo.SemanticParseInfo;
|
import com.tencent.supersonic.headless.api.pojo.SemanticParseInfo;
|
||||||
import com.tencent.supersonic.headless.api.pojo.response.QueryResult;
|
import com.tencent.supersonic.chat.api.pojo.response.QueryResult;
|
||||||
import com.tencent.supersonic.headless.server.facade.service.SemanticLayerService;
|
import com.tencent.supersonic.headless.server.facade.service.SemanticLayerService;
|
||||||
import org.springframework.util.CollectionUtils;
|
import org.springframework.util.CollectionUtils;
|
||||||
|
|
||||||
@@ -28,14 +28,14 @@ public class DimensionRecommendProcessor implements ExecuteResultProcessor {
|
|||||||
private static final int recommend_dimension_size = 5;
|
private static final int recommend_dimension_size = 5;
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public void process(ChatExecuteContext chatExecuteContext, QueryResult queryResult) {
|
public void process(ExecuteContext executeContext, QueryResult queryResult) {
|
||||||
SemanticParseInfo semanticParseInfo = chatExecuteContext.getParseInfo();
|
SemanticParseInfo semanticParseInfo = executeContext.getParseInfo();
|
||||||
if (!QueryType.METRIC.equals(semanticParseInfo.getQueryType())
|
if (!QueryType.METRIC.equals(semanticParseInfo.getQueryType())
|
||||||
|| CollectionUtils.isEmpty(semanticParseInfo.getMetrics())) {
|
|| CollectionUtils.isEmpty(semanticParseInfo.getMetrics())) {
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
SchemaElement element = semanticParseInfo.getMetrics().iterator().next();
|
SchemaElement element = semanticParseInfo.getMetrics().iterator().next();
|
||||||
List<SchemaElement> dimensionRecommended = getDimensions(element.getId(), element.getDataSet());
|
List<SchemaElement> dimensionRecommended = getDimensions(element.getId(), element.getDataSetId());
|
||||||
queryResult.setRecommendedDimensions(dimensionRecommended);
|
queryResult.setRecommendedDimensions(dimensionRecommended);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -1,14 +1,15 @@
|
|||||||
package com.tencent.supersonic.chat.server.processor.execute;
|
package com.tencent.supersonic.chat.server.processor.execute;
|
||||||
|
|
||||||
import com.tencent.supersonic.chat.server.pojo.ChatExecuteContext;
|
import com.tencent.supersonic.chat.server.pojo.ExecuteContext;
|
||||||
import com.tencent.supersonic.chat.server.processor.ResultProcessor;
|
import com.tencent.supersonic.chat.server.processor.ResultProcessor;
|
||||||
import com.tencent.supersonic.headless.api.pojo.response.QueryResult;
|
import com.tencent.supersonic.chat.api.pojo.response.QueryResult;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* A ExecuteResultProcessor wraps things up before returning results to users in execute stage.
|
* A ExecuteResultProcessor wraps things up before returning
|
||||||
|
* execution results to the users.
|
||||||
*/
|
*/
|
||||||
public interface ExecuteResultProcessor extends ResultProcessor {
|
public interface ExecuteResultProcessor extends ResultProcessor {
|
||||||
|
|
||||||
void process(ChatExecuteContext chatExecuteContext, QueryResult queryResult);
|
void process(ExecuteContext executeContext, QueryResult queryResult);
|
||||||
|
|
||||||
}
|
}
|
||||||
@@ -11,7 +11,7 @@ import static com.tencent.supersonic.common.pojo.Constants.TIME_FORMAT;
|
|||||||
import static com.tencent.supersonic.common.pojo.Constants.WEEK;
|
import static com.tencent.supersonic.common.pojo.Constants.WEEK;
|
||||||
|
|
||||||
import com.tencent.supersonic.auth.api.authentication.pojo.User;
|
import com.tencent.supersonic.auth.api.authentication.pojo.User;
|
||||||
import com.tencent.supersonic.chat.server.pojo.ChatExecuteContext;
|
import com.tencent.supersonic.chat.server.pojo.ExecuteContext;
|
||||||
import com.tencent.supersonic.common.pojo.DateConf;
|
import com.tencent.supersonic.common.pojo.DateConf;
|
||||||
import com.tencent.supersonic.common.pojo.DateConf.DateMode;
|
import com.tencent.supersonic.common.pojo.DateConf.DateMode;
|
||||||
import com.tencent.supersonic.common.pojo.QueryColumn;
|
import com.tencent.supersonic.common.pojo.QueryColumn;
|
||||||
@@ -26,7 +26,7 @@ import com.tencent.supersonic.headless.api.pojo.MetricInfo;
|
|||||||
import com.tencent.supersonic.headless.api.pojo.SchemaElement;
|
import com.tencent.supersonic.headless.api.pojo.SchemaElement;
|
||||||
import com.tencent.supersonic.headless.api.pojo.SemanticParseInfo;
|
import com.tencent.supersonic.headless.api.pojo.SemanticParseInfo;
|
||||||
import com.tencent.supersonic.headless.api.pojo.request.QueryStructReq;
|
import com.tencent.supersonic.headless.api.pojo.request.QueryStructReq;
|
||||||
import com.tencent.supersonic.headless.api.pojo.response.QueryResult;
|
import com.tencent.supersonic.chat.api.pojo.response.QueryResult;
|
||||||
import com.tencent.supersonic.headless.api.pojo.response.SemanticQueryResp;
|
import com.tencent.supersonic.headless.api.pojo.response.SemanticQueryResp;
|
||||||
import com.tencent.supersonic.headless.core.config.AggregatorConfig;
|
import com.tencent.supersonic.headless.core.config.AggregatorConfig;
|
||||||
import com.tencent.supersonic.headless.chat.utils.QueryReqBuilder;
|
import com.tencent.supersonic.headless.chat.utils.QueryReqBuilder;
|
||||||
@@ -60,15 +60,15 @@ import org.springframework.util.CollectionUtils;
|
|||||||
public class MetricRatioProcessor implements ExecuteResultProcessor {
|
public class MetricRatioProcessor implements ExecuteResultProcessor {
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public void process(ChatExecuteContext chatExecuteContext, QueryResult queryResult) {
|
public void process(ExecuteContext executeContext, QueryResult queryResult) {
|
||||||
SemanticParseInfo semanticParseInfo = chatExecuteContext.getParseInfo();
|
SemanticParseInfo semanticParseInfo = executeContext.getParseInfo();
|
||||||
AggregatorConfig aggregatorConfig = ContextUtils.getBean(AggregatorConfig.class);
|
AggregatorConfig aggregatorConfig = ContextUtils.getBean(AggregatorConfig.class);
|
||||||
if (CollectionUtils.isEmpty(semanticParseInfo.getMetrics())
|
if (CollectionUtils.isEmpty(semanticParseInfo.getMetrics())
|
||||||
|| !aggregatorConfig.getEnableRatio()
|
|| !aggregatorConfig.getEnableRatio()
|
||||||
|| !QueryType.METRIC.equals(semanticParseInfo.getQueryType())) {
|
|| !QueryType.METRIC.equals(semanticParseInfo.getQueryType())) {
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
AggregateInfo aggregateInfo = getAggregateInfo(chatExecuteContext.getUser(),
|
AggregateInfo aggregateInfo = getAggregateInfo(executeContext.getUser(),
|
||||||
semanticParseInfo, queryResult);
|
semanticParseInfo, queryResult);
|
||||||
queryResult.setAggregateInfo(aggregateInfo);
|
queryResult.setAggregateInfo(aggregateInfo);
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,20 +1,19 @@
|
|||||||
package com.tencent.supersonic.chat.server.processor.execute;
|
package com.tencent.supersonic.chat.server.processor.execute;
|
||||||
|
|
||||||
import com.alibaba.fastjson.JSONObject;
|
import com.alibaba.fastjson.JSONObject;
|
||||||
import com.tencent.supersonic.chat.server.pojo.ChatExecuteContext;
|
import com.tencent.supersonic.chat.api.pojo.response.QueryResult;
|
||||||
|
import com.tencent.supersonic.chat.server.pojo.ExecuteContext;
|
||||||
import com.tencent.supersonic.common.pojo.Constants;
|
import com.tencent.supersonic.common.pojo.Constants;
|
||||||
import com.tencent.supersonic.common.pojo.enums.DictWordType;
|
import com.tencent.supersonic.common.pojo.enums.DictWordType;
|
||||||
import com.tencent.supersonic.common.pojo.enums.QueryType;
|
import com.tencent.supersonic.common.pojo.enums.QueryType;
|
||||||
import com.tencent.supersonic.common.util.ContextUtils;
|
import com.tencent.supersonic.common.util.ContextUtils;
|
||||||
import dev.langchain4j.store.embedding.Retrieval;
|
|
||||||
import dev.langchain4j.store.embedding.RetrieveQuery;
|
|
||||||
import dev.langchain4j.store.embedding.RetrieveQueryResult;
|
|
||||||
import com.tencent.supersonic.headless.api.pojo.SchemaElement;
|
import com.tencent.supersonic.headless.api.pojo.SchemaElement;
|
||||||
import com.tencent.supersonic.headless.api.pojo.SchemaElementType;
|
import com.tencent.supersonic.headless.api.pojo.SchemaElementType;
|
||||||
import com.tencent.supersonic.headless.api.pojo.SemanticParseInfo;
|
import com.tencent.supersonic.headless.api.pojo.SemanticParseInfo;
|
||||||
import com.tencent.supersonic.headless.api.pojo.response.QueryResult;
|
|
||||||
import com.tencent.supersonic.headless.chat.knowledge.MetaEmbeddingService;
|
import com.tencent.supersonic.headless.chat.knowledge.MetaEmbeddingService;
|
||||||
import java.util.Objects;
|
import dev.langchain4j.store.embedding.Retrieval;
|
||||||
|
import dev.langchain4j.store.embedding.RetrieveQuery;
|
||||||
|
import dev.langchain4j.store.embedding.RetrieveQueryResult;
|
||||||
import org.springframework.util.CollectionUtils;
|
import org.springframework.util.CollectionUtils;
|
||||||
|
|
||||||
import java.util.Collections;
|
import java.util.Collections;
|
||||||
@@ -23,6 +22,7 @@ import java.util.HashMap;
|
|||||||
import java.util.HashSet;
|
import java.util.HashSet;
|
||||||
import java.util.List;
|
import java.util.List;
|
||||||
import java.util.Map;
|
import java.util.Map;
|
||||||
|
import java.util.Objects;
|
||||||
import java.util.Set;
|
import java.util.Set;
|
||||||
import java.util.stream.Collectors;
|
import java.util.stream.Collectors;
|
||||||
|
|
||||||
@@ -34,8 +34,8 @@ public class MetricRecommendProcessor implements ExecuteResultProcessor {
|
|||||||
private static final int METRIC_RECOMMEND_SIZE = 5;
|
private static final int METRIC_RECOMMEND_SIZE = 5;
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public void process(ChatExecuteContext chatExecuteContext, QueryResult queryResult) {
|
public void process(ExecuteContext executeContext, QueryResult queryResult) {
|
||||||
fillSimilarMetric(chatExecuteContext.getParseInfo());
|
fillSimilarMetric(executeContext.getParseInfo());
|
||||||
}
|
}
|
||||||
|
|
||||||
private void fillSimilarMetric(SemanticParseInfo parseInfo) {
|
private void fillSimilarMetric(SemanticParseInfo parseInfo) {
|
||||||
@@ -45,8 +45,8 @@ public class MetricRecommendProcessor implements ExecuteResultProcessor {
|
|||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
List<String> metricNames = Collections.singletonList(parseInfo.getMetrics().iterator().next().getName());
|
List<String> metricNames = Collections.singletonList(parseInfo.getMetrics().iterator().next().getName());
|
||||||
Map<String, String> filterCondition = new HashMap<>();
|
Map<String, Object> filterCondition = new HashMap<>();
|
||||||
filterCondition.put("modelId", parseInfo.getMetrics().iterator().next().getDataSet().toString());
|
filterCondition.put("modelId", parseInfo.getMetrics().iterator().next().getDataSetId().toString());
|
||||||
filterCondition.put("type", SchemaElementType.METRIC.name());
|
filterCondition.put("type", SchemaElementType.METRIC.name());
|
||||||
RetrieveQuery retrieveQuery = RetrieveQuery.builder().queryTextsList(metricNames)
|
RetrieveQuery retrieveQuery = RetrieveQuery.builder().queryTextsList(metricNames)
|
||||||
.filterCondition(filterCondition).queryEmbeddings(null).build();
|
.filterCondition(filterCondition).queryEmbeddings(null).build();
|
||||||
@@ -78,7 +78,7 @@ public class MetricRecommendProcessor implements ExecuteResultProcessor {
|
|||||||
if (retrieval.getMetadata().containsKey("dataSetId")) {
|
if (retrieval.getMetadata().containsKey("dataSetId")) {
|
||||||
String dataSetId = retrieval.getMetadata().get("dataSetId").toString()
|
String dataSetId = retrieval.getMetadata().get("dataSetId").toString()
|
||||||
.replace(Constants.UNDERLINE, "");
|
.replace(Constants.UNDERLINE, "");
|
||||||
schemaElement.setDataSet(Long.parseLong(dataSetId));
|
schemaElement.setDataSetId(Long.parseLong(dataSetId));
|
||||||
}
|
}
|
||||||
schemaElement.setOrder(++metricOrder);
|
schemaElement.setOrder(++metricOrder);
|
||||||
parseInfo.getMetrics().add(schemaElement);
|
parseInfo.getMetrics().add(schemaElement);
|
||||||
|
|||||||
@@ -1,43 +0,0 @@
|
|||||||
package com.tencent.supersonic.chat.server.processor.parse;
|
|
||||||
|
|
||||||
import com.tencent.supersonic.chat.server.pojo.ChatParseContext;
|
|
||||||
import com.tencent.supersonic.common.util.ContextUtils;
|
|
||||||
import com.tencent.supersonic.headless.api.pojo.DataSetSchema;
|
|
||||||
import com.tencent.supersonic.headless.api.pojo.EntityInfo;
|
|
||||||
import com.tencent.supersonic.headless.api.pojo.SemanticParseInfo;
|
|
||||||
import com.tencent.supersonic.headless.api.pojo.response.ParseResp;
|
|
||||||
import com.tencent.supersonic.headless.chat.query.QueryManager;
|
|
||||||
import com.tencent.supersonic.headless.server.facade.service.SemanticLayerService;
|
|
||||||
import org.springframework.util.CollectionUtils;
|
|
||||||
|
|
||||||
import java.util.List;
|
|
||||||
|
|
||||||
/**
|
|
||||||
* EntityInfoProcessor fills core attributes of an entity so that
|
|
||||||
* users get to know which entity is parsed out.
|
|
||||||
*/
|
|
||||||
public class EntityInfoProcessor implements ParseResultProcessor {
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public void process(ChatParseContext chatParseContext, ParseResp parseResp) {
|
|
||||||
List<SemanticParseInfo> selectedParses = parseResp.getSelectedParses();
|
|
||||||
if (CollectionUtils.isEmpty(selectedParses)) {
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
selectedParses.forEach(parseInfo -> {
|
|
||||||
String queryMode = parseInfo.getQueryMode();
|
|
||||||
if (QueryManager.containsRuleQuery(queryMode) || "PLAIN".equals(queryMode)) {
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
|
|
||||||
//1. set entity info
|
|
||||||
SemanticLayerService semanticService = ContextUtils.getBean(SemanticLayerService.class);
|
|
||||||
DataSetSchema dataSetSchema = semanticService.getDataSetSchema(parseInfo.getDataSetId());
|
|
||||||
EntityInfo entityInfo = semanticService.getEntityInfo(parseInfo, dataSetSchema, chatParseContext.getUser());
|
|
||||||
if (QueryManager.isTagQuery(queryMode)
|
|
||||||
|| QueryManager.isMetricQuery(queryMode)) {
|
|
||||||
parseInfo.setEntityInfo(entityInfo);
|
|
||||||
}
|
|
||||||
});
|
|
||||||
}
|
|
||||||
}
|
|
||||||
@@ -1,10 +1,15 @@
|
|||||||
package com.tencent.supersonic.chat.server.processor.parse;
|
package com.tencent.supersonic.chat.server.processor.parse;
|
||||||
|
|
||||||
import com.tencent.supersonic.chat.server.pojo.ChatParseContext;
|
import com.tencent.supersonic.chat.server.pojo.ParseContext;
|
||||||
|
import com.tencent.supersonic.chat.server.processor.ResultProcessor;
|
||||||
import com.tencent.supersonic.headless.api.pojo.response.ParseResp;
|
import com.tencent.supersonic.headless.api.pojo.response.ParseResp;
|
||||||
|
|
||||||
public interface ParseResultProcessor {
|
/**
|
||||||
|
* A ParseResultProcessor wraps things up before returning
|
||||||
|
* parsing results to the users.
|
||||||
|
*/
|
||||||
|
public interface ParseResultProcessor extends ResultProcessor {
|
||||||
|
|
||||||
void process(ChatParseContext chatParseContext, ParseResp parseResp);
|
void process(ParseContext parseContext, ParseResp parseResp);
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -5,9 +5,9 @@ import com.baomidou.mybatisplus.core.conditions.update.UpdateWrapper;
|
|||||||
import com.tencent.supersonic.chat.api.pojo.response.SimilarQueryRecallResp;
|
import com.tencent.supersonic.chat.api.pojo.response.SimilarQueryRecallResp;
|
||||||
import com.tencent.supersonic.chat.server.persistence.dataobject.ChatQueryDO;
|
import com.tencent.supersonic.chat.server.persistence.dataobject.ChatQueryDO;
|
||||||
import com.tencent.supersonic.chat.server.persistence.repository.ChatQueryRepository;
|
import com.tencent.supersonic.chat.server.persistence.repository.ChatQueryRepository;
|
||||||
import com.tencent.supersonic.chat.server.pojo.ChatParseContext;
|
import com.tencent.supersonic.chat.server.pojo.ParseContext;
|
||||||
import com.tencent.supersonic.common.config.EmbeddingConfig;
|
import com.tencent.supersonic.common.config.EmbeddingConfig;
|
||||||
import com.tencent.supersonic.common.pojo.SqlExemplar;
|
import com.tencent.supersonic.common.pojo.Text2SQLExemplar;
|
||||||
import com.tencent.supersonic.common.service.ExemplarService;
|
import com.tencent.supersonic.common.service.ExemplarService;
|
||||||
import com.tencent.supersonic.common.util.ContextUtils;
|
import com.tencent.supersonic.common.util.ContextUtils;
|
||||||
import com.tencent.supersonic.headless.api.pojo.response.ParseResp;
|
import com.tencent.supersonic.headless.api.pojo.response.ParseResp;
|
||||||
@@ -25,15 +25,15 @@ import java.util.stream.Collectors;
|
|||||||
public class QueryRecommendProcessor implements ParseResultProcessor {
|
public class QueryRecommendProcessor implements ParseResultProcessor {
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public void process(ChatParseContext chatParseContext, ParseResp parseResp) {
|
public void process(ParseContext parseContext, ParseResp parseResp) {
|
||||||
CompletableFuture.runAsync(() -> doProcess(parseResp, chatParseContext));
|
CompletableFuture.runAsync(() -> doProcess(parseResp, parseContext));
|
||||||
}
|
}
|
||||||
|
|
||||||
@SneakyThrows
|
@SneakyThrows
|
||||||
private void doProcess(ParseResp parseResp, ChatParseContext chatParseContext) {
|
private void doProcess(ParseResp parseResp, ParseContext parseContext) {
|
||||||
Long queryId = parseResp.getQueryId();
|
Long queryId = parseResp.getQueryId();
|
||||||
List<SimilarQueryRecallResp> solvedQueries = getSimilarQueries(chatParseContext.getQueryText(),
|
List<SimilarQueryRecallResp> solvedQueries = getSimilarQueries(parseContext.getQueryText(),
|
||||||
chatParseContext.getAgent().getId());
|
parseContext.getAgent().getId());
|
||||||
ChatQueryDO chatQueryDO = getChatQuery(queryId);
|
ChatQueryDO chatQueryDO = getChatQuery(queryId);
|
||||||
chatQueryDO.setSimilarQueries(JSONObject.toJSONString(solvedQueries));
|
chatQueryDO.setSimilarQueries(JSONObject.toJSONString(solvedQueries));
|
||||||
updateChatQuery(chatQueryDO);
|
updateChatQuery(chatQueryDO);
|
||||||
@@ -43,7 +43,7 @@ public class QueryRecommendProcessor implements ParseResultProcessor {
|
|||||||
ExemplarService exemplarService = ContextUtils.getBean(ExemplarService.class);
|
ExemplarService exemplarService = ContextUtils.getBean(ExemplarService.class);
|
||||||
EmbeddingConfig embeddingConfig = ContextUtils.getBean(EmbeddingConfig.class);
|
EmbeddingConfig embeddingConfig = ContextUtils.getBean(EmbeddingConfig.class);
|
||||||
String memoryCollectionName = embeddingConfig.getMemoryCollectionName(agentId);
|
String memoryCollectionName = embeddingConfig.getMemoryCollectionName(agentId);
|
||||||
List<SqlExemplar> exemplars = exemplarService.recallExemplars(memoryCollectionName, queryText, 5);
|
List<Text2SQLExemplar> exemplars = exemplarService.recallExemplars(memoryCollectionName, queryText, 5);
|
||||||
return exemplars.stream().map(sqlExemplar ->
|
return exemplars.stream().map(sqlExemplar ->
|
||||||
SimilarQueryRecallResp.builder().queryText(sqlExemplar.getQuestion()).build())
|
SimilarQueryRecallResp.builder().queryText(sqlExemplar.getQuestion()).build())
|
||||||
.collect(Collectors.toList());
|
.collect(Collectors.toList());
|
||||||
|
|||||||
@@ -1,7 +1,7 @@
|
|||||||
package com.tencent.supersonic.chat.server.processor.parse;
|
package com.tencent.supersonic.chat.server.processor.parse;
|
||||||
|
|
||||||
|
|
||||||
import com.tencent.supersonic.chat.server.pojo.ChatParseContext;
|
import com.tencent.supersonic.chat.server.pojo.ParseContext;
|
||||||
import com.tencent.supersonic.headless.api.pojo.response.ParseResp;
|
import com.tencent.supersonic.headless.api.pojo.response.ParseResp;
|
||||||
import lombok.extern.slf4j.Slf4j;
|
import lombok.extern.slf4j.Slf4j;
|
||||||
|
|
||||||
@@ -12,7 +12,7 @@ import lombok.extern.slf4j.Slf4j;
|
|||||||
public class TimeCostProcessor implements ParseResultProcessor {
|
public class TimeCostProcessor implements ParseResultProcessor {
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public void process(ChatParseContext chatParseContext, ParseResp parseResp) {
|
public void process(ParseContext parseContext, ParseResp parseResp) {
|
||||||
long parseStartTime = parseResp.getParseTimeCost().getParseStartTime();
|
long parseStartTime = parseResp.getParseTimeCost().getParseStartTime();
|
||||||
parseResp.getParseTimeCost().setParseTime(
|
parseResp.getParseTimeCost().setParseTime(
|
||||||
System.currentTimeMillis() - parseStartTime - parseResp.getParseTimeCost().getSqlTime());
|
System.currentTimeMillis() - parseStartTime - parseResp.getParseTimeCost().getSqlTime());
|
||||||
|
|||||||
@@ -6,7 +6,7 @@ import com.tencent.supersonic.chat.server.agent.Agent;
|
|||||||
import com.tencent.supersonic.chat.server.agent.AgentToolType;
|
import com.tencent.supersonic.chat.server.agent.AgentToolType;
|
||||||
import com.tencent.supersonic.chat.server.service.AgentService;
|
import com.tencent.supersonic.chat.server.service.AgentService;
|
||||||
import com.tencent.supersonic.chat.server.util.LLMConnHelper;
|
import com.tencent.supersonic.chat.server.util.LLMConnHelper;
|
||||||
import com.tencent.supersonic.common.config.LLMConfig;
|
import com.tencent.supersonic.common.pojo.ChatModelConfig;
|
||||||
import org.springframework.beans.factory.annotation.Autowired;
|
import org.springframework.beans.factory.annotation.Autowired;
|
||||||
import org.springframework.web.bind.annotation.DeleteMapping;
|
import org.springframework.web.bind.annotation.DeleteMapping;
|
||||||
import org.springframework.web.bind.annotation.PathVariable;
|
import org.springframework.web.bind.annotation.PathVariable;
|
||||||
@@ -15,6 +15,7 @@ import org.springframework.web.bind.annotation.PutMapping;
|
|||||||
import org.springframework.web.bind.annotation.RequestBody;
|
import org.springframework.web.bind.annotation.RequestBody;
|
||||||
import org.springframework.web.bind.annotation.RequestMapping;
|
import org.springframework.web.bind.annotation.RequestMapping;
|
||||||
import org.springframework.web.bind.annotation.RestController;
|
import org.springframework.web.bind.annotation.RestController;
|
||||||
|
|
||||||
import javax.servlet.http.HttpServletRequest;
|
import javax.servlet.http.HttpServletRequest;
|
||||||
import javax.servlet.http.HttpServletResponse;
|
import javax.servlet.http.HttpServletResponse;
|
||||||
import java.util.List;
|
import java.util.List;
|
||||||
@@ -50,8 +51,8 @@ public class AgentController {
|
|||||||
}
|
}
|
||||||
|
|
||||||
@PostMapping("/testLLMConn")
|
@PostMapping("/testLLMConn")
|
||||||
public boolean testLLMConn(@RequestBody LLMConfig llmConfig) {
|
public boolean testLLMConn(@RequestBody ChatModelConfig modelConfig) {
|
||||||
return LLMConnHelper.testConnection(llmConfig);
|
return LLMConnHelper.testConnection(modelConfig);
|
||||||
}
|
}
|
||||||
|
|
||||||
@RequestMapping("/getAgentList")
|
@RequestMapping("/getAgentList")
|
||||||
|
|||||||
@@ -6,11 +6,10 @@ import com.tencent.supersonic.auth.api.authentication.utils.UserHolder;
|
|||||||
import com.tencent.supersonic.chat.api.pojo.request.ChatExecuteReq;
|
import com.tencent.supersonic.chat.api.pojo.request.ChatExecuteReq;
|
||||||
import com.tencent.supersonic.chat.api.pojo.request.ChatParseReq;
|
import com.tencent.supersonic.chat.api.pojo.request.ChatParseReq;
|
||||||
import com.tencent.supersonic.chat.api.pojo.request.ChatQueryDataReq;
|
import com.tencent.supersonic.chat.api.pojo.request.ChatQueryDataReq;
|
||||||
import com.tencent.supersonic.chat.server.service.ChatService;
|
import com.tencent.supersonic.chat.server.service.ChatQueryService;
|
||||||
import com.tencent.supersonic.common.pojo.exception.InvalidArgumentException;
|
import com.tencent.supersonic.common.pojo.exception.InvalidArgumentException;
|
||||||
import com.tencent.supersonic.headless.api.pojo.SemanticParseInfo;
|
import com.tencent.supersonic.headless.api.pojo.SemanticParseInfo;
|
||||||
import com.tencent.supersonic.headless.api.pojo.request.DimensionValueReq;
|
import com.tencent.supersonic.headless.api.pojo.request.DimensionValueReq;
|
||||||
import com.tencent.supersonic.headless.api.pojo.request.QueryReq;
|
|
||||||
import com.tencent.supersonic.headless.api.pojo.response.ParseResp;
|
import com.tencent.supersonic.headless.api.pojo.response.ParseResp;
|
||||||
import org.apache.commons.collections.CollectionUtils;
|
import org.apache.commons.collections.CollectionUtils;
|
||||||
import org.springframework.beans.BeanUtils;
|
import org.springframework.beans.BeanUtils;
|
||||||
@@ -32,20 +31,20 @@ import javax.validation.Valid;
|
|||||||
public class ChatQueryController {
|
public class ChatQueryController {
|
||||||
|
|
||||||
@Autowired
|
@Autowired
|
||||||
private ChatService chatService;
|
private ChatQueryService chatQueryService;
|
||||||
|
|
||||||
@PostMapping("search")
|
@PostMapping("search")
|
||||||
public Object search(@RequestBody ChatParseReq chatParseReq, HttpServletRequest request,
|
public Object search(@RequestBody ChatParseReq chatParseReq, HttpServletRequest request,
|
||||||
HttpServletResponse response) {
|
HttpServletResponse response) {
|
||||||
chatParseReq.setUser(UserHolder.findUser(request, response));
|
chatParseReq.setUser(UserHolder.findUser(request, response));
|
||||||
return chatService.search(chatParseReq);
|
return chatQueryService.search(chatParseReq);
|
||||||
}
|
}
|
||||||
|
|
||||||
@PostMapping("parse")
|
@PostMapping("parse")
|
||||||
public Object parse(@RequestBody ChatParseReq chatParseReq,
|
public Object parse(@RequestBody ChatParseReq chatParseReq,
|
||||||
HttpServletRequest request, HttpServletResponse response) throws Exception {
|
HttpServletRequest request, HttpServletResponse response) throws Exception {
|
||||||
chatParseReq.setUser(UserHolder.findUser(request, response));
|
chatParseReq.setUser(UserHolder.findUser(request, response));
|
||||||
return chatService.performParsing(chatParseReq);
|
return chatQueryService.performParsing(chatParseReq);
|
||||||
}
|
}
|
||||||
|
|
||||||
@PostMapping("execute")
|
@PostMapping("execute")
|
||||||
@@ -53,7 +52,7 @@ public class ChatQueryController {
|
|||||||
HttpServletRequest request, HttpServletResponse response)
|
HttpServletRequest request, HttpServletResponse response)
|
||||||
throws Exception {
|
throws Exception {
|
||||||
chatExecuteReq.setUser(UserHolder.findUser(request, response));
|
chatExecuteReq.setUser(UserHolder.findUser(request, response));
|
||||||
return chatService.performExecution(chatExecuteReq);
|
return chatQueryService.performExecution(chatExecuteReq);
|
||||||
}
|
}
|
||||||
|
|
||||||
@PostMapping("/")
|
@PostMapping("/")
|
||||||
@@ -62,7 +61,7 @@ public class ChatQueryController {
|
|||||||
throws Exception {
|
throws Exception {
|
||||||
User user = UserHolder.findUser(request, response);
|
User user = UserHolder.findUser(request, response);
|
||||||
chatParseReq.setUser(user);
|
chatParseReq.setUser(user);
|
||||||
ParseResp parseResp = chatService.performParsing(chatParseReq);
|
ParseResp parseResp = chatQueryService.performParsing(chatParseReq);
|
||||||
|
|
||||||
if (CollectionUtils.isEmpty(parseResp.getSelectedParses())) {
|
if (CollectionUtils.isEmpty(parseResp.getSelectedParses())) {
|
||||||
throw new InvalidArgumentException("parser error,no selectedParses");
|
throw new InvalidArgumentException("parser error,no selectedParses");
|
||||||
@@ -72,27 +71,20 @@ public class ChatQueryController {
|
|||||||
BeanUtils.copyProperties(chatParseReq, chatExecuteReq);
|
BeanUtils.copyProperties(chatParseReq, chatExecuteReq);
|
||||||
chatExecuteReq.setQueryId(parseResp.getQueryId());
|
chatExecuteReq.setQueryId(parseResp.getQueryId());
|
||||||
chatExecuteReq.setParseId(semanticParseInfo.getId());
|
chatExecuteReq.setParseId(semanticParseInfo.getId());
|
||||||
return chatService.performExecution(chatExecuteReq);
|
return chatQueryService.performExecution(chatExecuteReq);
|
||||||
}
|
|
||||||
|
|
||||||
@PostMapping("queryContext")
|
|
||||||
public Object queryContext(@RequestBody QueryReq queryCtx,
|
|
||||||
HttpServletRequest request, HttpServletResponse response) {
|
|
||||||
queryCtx.setUser(UserHolder.findUser(request, response));
|
|
||||||
return chatService.queryContext(queryCtx.getChatId());
|
|
||||||
}
|
}
|
||||||
|
|
||||||
@PostMapping("queryData")
|
@PostMapping("queryData")
|
||||||
public Object queryData(@RequestBody ChatQueryDataReq chatQueryDataReq,
|
public Object queryData(@RequestBody ChatQueryDataReq chatQueryDataReq,
|
||||||
HttpServletRequest request, HttpServletResponse response) throws Exception {
|
HttpServletRequest request, HttpServletResponse response) throws Exception {
|
||||||
chatQueryDataReq.setUser(UserHolder.findUser(request, response));
|
chatQueryDataReq.setUser(UserHolder.findUser(request, response));
|
||||||
return chatService.queryData(chatQueryDataReq, UserHolder.findUser(request, response));
|
return chatQueryService.queryData(chatQueryDataReq, UserHolder.findUser(request, response));
|
||||||
}
|
}
|
||||||
|
|
||||||
@PostMapping("queryDimensionValue")
|
@PostMapping("queryDimensionValue")
|
||||||
public Object queryDimensionValue(@RequestBody @Valid DimensionValueReq dimensionValueReq,
|
public Object queryDimensionValue(@RequestBody @Valid DimensionValueReq dimensionValueReq,
|
||||||
HttpServletRequest request, HttpServletResponse response) throws Exception {
|
HttpServletRequest request, HttpServletResponse response) throws Exception {
|
||||||
return chatService.queryDimensionValue(dimensionValueReq, UserHolder.findUser(request, response));
|
return chatQueryService.queryDimensionValue(dimensionValueReq, UserHolder.findUser(request, response));
|
||||||
}
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -0,0 +1,11 @@
|
|||||||
|
package com.tencent.supersonic.chat.server.service;
|
||||||
|
|
||||||
|
import com.tencent.supersonic.chat.server.pojo.ChatContext;
|
||||||
|
|
||||||
|
public interface ChatContextService {
|
||||||
|
|
||||||
|
ChatContext getOrCreateContext(Integer chatId);
|
||||||
|
|
||||||
|
void updateContext(ChatContext chatCtx);
|
||||||
|
|
||||||
|
}
|
||||||
@@ -12,7 +12,7 @@ import com.tencent.supersonic.chat.server.persistence.dataobject.ChatParseDO;
|
|||||||
import com.tencent.supersonic.chat.server.persistence.dataobject.ChatQueryDO;
|
import com.tencent.supersonic.chat.server.persistence.dataobject.ChatQueryDO;
|
||||||
import com.tencent.supersonic.headless.api.pojo.SemanticParseInfo;
|
import com.tencent.supersonic.headless.api.pojo.SemanticParseInfo;
|
||||||
import com.tencent.supersonic.headless.api.pojo.response.ParseResp;
|
import com.tencent.supersonic.headless.api.pojo.response.ParseResp;
|
||||||
import com.tencent.supersonic.headless.api.pojo.response.QueryResult;
|
import com.tencent.supersonic.chat.api.pojo.response.QueryResult;
|
||||||
|
|
||||||
import java.util.List;
|
import java.util.List;
|
||||||
|
|
||||||
@@ -35,6 +35,8 @@ public interface ChatManageService {
|
|||||||
|
|
||||||
QueryResp getChatQuery(Long queryId);
|
QueryResp getChatQuery(Long queryId);
|
||||||
|
|
||||||
|
List<QueryResp> getChatQueries(Integer chatId);
|
||||||
|
|
||||||
ShowCaseResp queryShowCase(PageQueryInfoReq pageQueryInfoReq, int agentId);
|
ShowCaseResp queryShowCase(PageQueryInfoReq pageQueryInfoReq, int agentId);
|
||||||
|
|
||||||
ChatQueryDO saveQueryResult(ChatExecuteReq chatExecuteReq, QueryResult queryResult);
|
ChatQueryDO saveQueryResult(ChatExecuteReq chatExecuteReq, QueryResult queryResult);
|
||||||
|
|||||||
@@ -4,15 +4,14 @@ import com.tencent.supersonic.auth.api.authentication.pojo.User;
|
|||||||
import com.tencent.supersonic.chat.api.pojo.request.ChatExecuteReq;
|
import com.tencent.supersonic.chat.api.pojo.request.ChatExecuteReq;
|
||||||
import com.tencent.supersonic.chat.api.pojo.request.ChatParseReq;
|
import com.tencent.supersonic.chat.api.pojo.request.ChatParseReq;
|
||||||
import com.tencent.supersonic.chat.api.pojo.request.ChatQueryDataReq;
|
import com.tencent.supersonic.chat.api.pojo.request.ChatQueryDataReq;
|
||||||
import com.tencent.supersonic.headless.api.pojo.SemanticParseInfo;
|
|
||||||
import com.tencent.supersonic.headless.api.pojo.request.DimensionValueReq;
|
import com.tencent.supersonic.headless.api.pojo.request.DimensionValueReq;
|
||||||
import com.tencent.supersonic.headless.api.pojo.response.ParseResp;
|
import com.tencent.supersonic.headless.api.pojo.response.ParseResp;
|
||||||
import com.tencent.supersonic.headless.api.pojo.response.QueryResult;
|
import com.tencent.supersonic.chat.api.pojo.response.QueryResult;
|
||||||
import com.tencent.supersonic.headless.api.pojo.response.SearchResult;
|
import com.tencent.supersonic.headless.api.pojo.response.SearchResult;
|
||||||
|
|
||||||
import java.util.List;
|
import java.util.List;
|
||||||
|
|
||||||
public interface ChatService {
|
public interface ChatQueryService {
|
||||||
|
|
||||||
List<SearchResult> search(ChatParseReq chatParseReq);
|
List<SearchResult> search(ChatParseReq chatParseReq);
|
||||||
|
|
||||||
@@ -24,8 +23,6 @@ public interface ChatService {
|
|||||||
|
|
||||||
Object queryData(ChatQueryDataReq chatQueryDataReq, User user) throws Exception;
|
Object queryData(ChatQueryDataReq chatQueryDataReq, User user) throws Exception;
|
||||||
|
|
||||||
SemanticParseInfo queryContext(Integer chatId);
|
|
||||||
|
|
||||||
Object queryDimensionValue(DimensionValueReq dimensionValueReq, User user) throws Exception;
|
Object queryDimensionValue(DimensionValueReq dimensionValueReq, User user) throws Exception;
|
||||||
|
|
||||||
}
|
}
|
||||||
@@ -16,6 +16,10 @@ public interface MemoryService {
|
|||||||
|
|
||||||
void updateMemory(ChatMemoryDO memory);
|
void updateMemory(ChatMemoryDO memory);
|
||||||
|
|
||||||
|
void enableMemory(ChatMemoryDO memory);
|
||||||
|
|
||||||
|
void disableMemory(ChatMemoryDO memory);
|
||||||
|
|
||||||
PageInfo<ChatMemoryDO> pageMemories(PageMemoryReq pageMemoryReq);
|
PageInfo<ChatMemoryDO> pageMemories(PageMemoryReq pageMemoryReq);
|
||||||
|
|
||||||
List<ChatMemoryDO> getMemories(ChatMemoryFilter chatMemoryFilter);
|
List<ChatMemoryDO> getMemories(ChatMemoryFilter chatMemoryFilter);
|
||||||
|
|||||||
@@ -1,6 +1,6 @@
|
|||||||
package com.tencent.supersonic.chat.server.service;
|
package com.tencent.supersonic.chat.server.service;
|
||||||
|
|
||||||
import com.tencent.supersonic.headless.server.persistence.dataobject.StatisticsDO;
|
import com.tencent.supersonic.chat.server.persistence.dataobject.StatisticsDO;
|
||||||
|
|
||||||
import java.util.List;
|
import java.util.List;
|
||||||
|
|
||||||
|
|||||||
@@ -9,16 +9,18 @@ import com.tencent.supersonic.chat.server.persistence.dataobject.AgentDO;
|
|||||||
import com.tencent.supersonic.chat.server.persistence.dataobject.ChatMemoryDO;
|
import com.tencent.supersonic.chat.server.persistence.dataobject.ChatMemoryDO;
|
||||||
import com.tencent.supersonic.chat.server.persistence.mapper.AgentDOMapper;
|
import com.tencent.supersonic.chat.server.persistence.mapper.AgentDOMapper;
|
||||||
import com.tencent.supersonic.chat.server.service.AgentService;
|
import com.tencent.supersonic.chat.server.service.AgentService;
|
||||||
import com.tencent.supersonic.chat.server.service.ChatService;
|
import com.tencent.supersonic.chat.server.service.ChatQueryService;
|
||||||
import com.tencent.supersonic.chat.server.service.MemoryService;
|
import com.tencent.supersonic.chat.server.service.MemoryService;
|
||||||
import com.tencent.supersonic.chat.server.util.LLMConnHelper;
|
import com.tencent.supersonic.chat.server.util.LLMConnHelper;
|
||||||
import com.tencent.supersonic.common.config.LLMConfig;
|
import com.tencent.supersonic.common.config.PromptConfig;
|
||||||
import com.tencent.supersonic.common.config.VisualConfig;
|
import com.tencent.supersonic.common.config.VisualConfig;
|
||||||
|
import com.tencent.supersonic.common.pojo.ChatModelConfig;
|
||||||
import com.tencent.supersonic.common.util.JsonUtil;
|
import com.tencent.supersonic.common.util.JsonUtil;
|
||||||
import lombok.extern.slf4j.Slf4j;
|
import lombok.extern.slf4j.Slf4j;
|
||||||
import org.springframework.beans.BeanUtils;
|
import org.springframework.beans.BeanUtils;
|
||||||
import org.springframework.beans.factory.annotation.Autowired;
|
import org.springframework.beans.factory.annotation.Autowired;
|
||||||
import org.springframework.stereotype.Service;
|
import org.springframework.stereotype.Service;
|
||||||
|
import org.springframework.util.CollectionUtils;
|
||||||
|
|
||||||
import java.util.List;
|
import java.util.List;
|
||||||
import java.util.concurrent.ExecutorService;
|
import java.util.concurrent.ExecutorService;
|
||||||
@@ -34,7 +36,7 @@ public class AgentServiceImpl extends ServiceImpl<AgentDOMapper, AgentDO>
|
|||||||
private MemoryService memoryService;
|
private MemoryService memoryService;
|
||||||
|
|
||||||
@Autowired
|
@Autowired
|
||||||
private ChatService chatService;
|
private ChatQueryService chatQueryService;
|
||||||
|
|
||||||
private ExecutorService executorService = Executors.newFixedThreadPool(1);
|
private ExecutorService executorService = Executors.newFixedThreadPool(1);
|
||||||
|
|
||||||
@@ -78,6 +80,7 @@ public class AgentServiceImpl extends ServiceImpl<AgentDOMapper, AgentDO>
|
|||||||
/**
|
/**
|
||||||
* the example in the agent will be executed by default,
|
* the example in the agent will be executed by default,
|
||||||
* if the result is correct, it will be put into memory as a reference for LLM
|
* if the result is correct, it will be put into memory as a reference for LLM
|
||||||
|
*
|
||||||
* @param agent
|
* @param agent
|
||||||
*/
|
*/
|
||||||
private void executeAgentExamplesAsync(Agent agent) {
|
private void executeAgentExamplesAsync(Agent agent) {
|
||||||
@@ -85,9 +88,11 @@ public class AgentServiceImpl extends ServiceImpl<AgentDOMapper, AgentDO>
|
|||||||
}
|
}
|
||||||
|
|
||||||
private synchronized void doExecuteAgentExamples(Agent agent) {
|
private synchronized void doExecuteAgentExamples(Agent agent) {
|
||||||
if (!agent.containsLLMParserTool() || !LLMConnHelper.testConnection(agent.getLlmConfig())) {
|
if (!agent.containsLLMParserTool() || !LLMConnHelper.testConnection(agent.getModelConfig())
|
||||||
|
|| CollectionUtils.isEmpty(agent.getExamples())) {
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
List<String> examples = agent.getExamples();
|
List<String> examples = agent.getExamples();
|
||||||
ChatMemoryFilter chatMemoryFilter = ChatMemoryFilter.builder().agentId(agent.getId())
|
ChatMemoryFilter chatMemoryFilter = ChatMemoryFilter.builder().agentId(agent.getId())
|
||||||
.questions(examples).build();
|
.questions(examples).build();
|
||||||
@@ -98,7 +103,7 @@ public class AgentServiceImpl extends ServiceImpl<AgentDOMapper, AgentDO>
|
|||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
try {
|
try {
|
||||||
chatService.parseAndExecute(-1, agent.getId(), example);
|
chatQueryService.parseAndExecute(-1, agent.getId(), example);
|
||||||
} catch (Exception e) {
|
} catch (Exception e) {
|
||||||
log.warn("agent:{} example execute failed:{}", agent.getName(), example);
|
log.warn("agent:{} example execute failed:{}", agent.getName(), example);
|
||||||
}
|
}
|
||||||
@@ -117,7 +122,8 @@ public class AgentServiceImpl extends ServiceImpl<AgentDOMapper, AgentDO>
|
|||||||
BeanUtils.copyProperties(agentDO, agent);
|
BeanUtils.copyProperties(agentDO, agent);
|
||||||
agent.setAgentConfig(agentDO.getConfig());
|
agent.setAgentConfig(agentDO.getConfig());
|
||||||
agent.setExamples(JsonUtil.toList(agentDO.getExamples(), String.class));
|
agent.setExamples(JsonUtil.toList(agentDO.getExamples(), String.class));
|
||||||
agent.setLlmConfig(JsonUtil.toObject(agentDO.getLlmConfig(), LLMConfig.class));
|
agent.setModelConfig(JsonUtil.toObject(agentDO.getModelConfig(), ChatModelConfig.class));
|
||||||
|
agent.setPromptConfig(JsonUtil.toObject(agentDO.getPromptConfig(), PromptConfig.class));
|
||||||
agent.setMultiTurnConfig(JsonUtil.toObject(agentDO.getMultiTurnConfig(), MultiTurnConfig.class));
|
agent.setMultiTurnConfig(JsonUtil.toObject(agentDO.getMultiTurnConfig(), MultiTurnConfig.class));
|
||||||
agent.setVisualConfig(JsonUtil.toObject(agentDO.getVisualConfig(), VisualConfig.class));
|
agent.setVisualConfig(JsonUtil.toObject(agentDO.getVisualConfig(), VisualConfig.class));
|
||||||
return agent;
|
return agent;
|
||||||
@@ -128,9 +134,10 @@ public class AgentServiceImpl extends ServiceImpl<AgentDOMapper, AgentDO>
|
|||||||
BeanUtils.copyProperties(agent, agentDO);
|
BeanUtils.copyProperties(agent, agentDO);
|
||||||
agentDO.setConfig(agent.getAgentConfig());
|
agentDO.setConfig(agent.getAgentConfig());
|
||||||
agentDO.setExamples(JsonUtil.toString(agent.getExamples()));
|
agentDO.setExamples(JsonUtil.toString(agent.getExamples()));
|
||||||
agentDO.setLlmConfig(JsonUtil.toString(agent.getLlmConfig()));
|
agentDO.setModelConfig(JsonUtil.toString(agent.getModelConfig()));
|
||||||
agentDO.setMultiTurnConfig(JsonUtil.toString(agent.getMultiTurnConfig()));
|
agentDO.setMultiTurnConfig(JsonUtil.toString(agent.getMultiTurnConfig()));
|
||||||
agentDO.setVisualConfig(JsonUtil.toString(agent.getVisualConfig()));
|
agentDO.setVisualConfig(JsonUtil.toString(agent.getVisualConfig()));
|
||||||
|
agentDO.setPromptConfig(JsonUtil.toString(agent.getPromptConfig()));
|
||||||
if (agentDO.getStatus() == null) {
|
if (agentDO.getStatus() == null) {
|
||||||
agentDO.setStatus(1);
|
agentDO.setStatus(1);
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -0,0 +1,30 @@
|
|||||||
|
package com.tencent.supersonic.chat.server.service.impl;
|
||||||
|
|
||||||
|
import com.tencent.supersonic.chat.server.service.ChatContextService;
|
||||||
|
import com.tencent.supersonic.chat.server.pojo.ChatContext;
|
||||||
|
import com.tencent.supersonic.chat.server.persistence.repository.ChatContextRepository;
|
||||||
|
import lombok.extern.slf4j.Slf4j;
|
||||||
|
import org.springframework.stereotype.Service;
|
||||||
|
|
||||||
|
@Slf4j
|
||||||
|
@Service
|
||||||
|
public class ChatContextServiceImpl implements ChatContextService {
|
||||||
|
|
||||||
|
private ChatContextRepository chatContextRepository;
|
||||||
|
|
||||||
|
public ChatContextServiceImpl(ChatContextRepository chatContextRepository) {
|
||||||
|
this.chatContextRepository = chatContextRepository;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public ChatContext getOrCreateContext(Integer chatId) {
|
||||||
|
return chatContextRepository.getOrCreateContext(chatId);
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public void updateContext(ChatContext chatCtx) {
|
||||||
|
log.debug("save ChatContext {}", chatCtx);
|
||||||
|
chatContextRepository.updateContext(chatCtx);
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
||||||
@@ -18,7 +18,7 @@ import com.tencent.supersonic.chat.server.service.ChatManageService;
|
|||||||
import com.tencent.supersonic.common.util.JsonUtil;
|
import com.tencent.supersonic.common.util.JsonUtil;
|
||||||
import com.tencent.supersonic.headless.api.pojo.SemanticParseInfo;
|
import com.tencent.supersonic.headless.api.pojo.SemanticParseInfo;
|
||||||
import com.tencent.supersonic.headless.api.pojo.response.ParseResp;
|
import com.tencent.supersonic.headless.api.pojo.response.ParseResp;
|
||||||
import com.tencent.supersonic.headless.api.pojo.response.QueryResult;
|
import com.tencent.supersonic.chat.api.pojo.response.QueryResult;
|
||||||
import lombok.extern.slf4j.Slf4j;
|
import lombok.extern.slf4j.Slf4j;
|
||||||
import org.springframework.beans.factory.annotation.Autowired;
|
import org.springframework.beans.factory.annotation.Autowired;
|
||||||
import org.springframework.stereotype.Service;
|
import org.springframework.stereotype.Service;
|
||||||
@@ -107,6 +107,13 @@ public class ChatManageServiceImpl implements ChatManageService {
|
|||||||
return chatQueryRepository.getChatQuery(queryId);
|
return chatQueryRepository.getChatQuery(queryId);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public List<QueryResp> getChatQueries(Integer chatId) {
|
||||||
|
List<QueryResp> queries = chatQueryRepository.getChatQueries(chatId);
|
||||||
|
fillParseInfo(queries);
|
||||||
|
return queries;
|
||||||
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public ShowCaseResp queryShowCase(PageQueryInfoReq pageQueryInfoReq, int agentId) {
|
public ShowCaseResp queryShowCase(PageQueryInfoReq pageQueryInfoReq, int agentId) {
|
||||||
ShowCaseResp showCaseResp = new ShowCaseResp();
|
ShowCaseResp showCaseResp = new ShowCaseResp();
|
||||||
|
|||||||
@@ -0,0 +1,548 @@
|
|||||||
|
package com.tencent.supersonic.chat.server.service.impl;
|
||||||
|
|
||||||
|
import com.google.common.collect.Lists;
|
||||||
|
import com.tencent.supersonic.auth.api.authentication.pojo.User;
|
||||||
|
import com.tencent.supersonic.chat.api.pojo.request.ChatExecuteReq;
|
||||||
|
import com.tencent.supersonic.chat.api.pojo.request.ChatParseReq;
|
||||||
|
import com.tencent.supersonic.chat.api.pojo.request.ChatQueryDataReq;
|
||||||
|
import com.tencent.supersonic.chat.api.pojo.response.QueryResult;
|
||||||
|
import com.tencent.supersonic.chat.server.agent.Agent;
|
||||||
|
import com.tencent.supersonic.chat.server.executor.ChatQueryExecutor;
|
||||||
|
import com.tencent.supersonic.chat.server.parser.ChatQueryParser;
|
||||||
|
import com.tencent.supersonic.chat.server.pojo.ExecuteContext;
|
||||||
|
import com.tencent.supersonic.chat.server.pojo.ParseContext;
|
||||||
|
import com.tencent.supersonic.chat.server.processor.execute.ExecuteResultProcessor;
|
||||||
|
import com.tencent.supersonic.chat.server.processor.parse.ParseResultProcessor;
|
||||||
|
import com.tencent.supersonic.chat.server.service.AgentService;
|
||||||
|
import com.tencent.supersonic.chat.server.service.ChatManageService;
|
||||||
|
import com.tencent.supersonic.chat.server.service.ChatQueryService;
|
||||||
|
import com.tencent.supersonic.chat.server.util.ComponentFactory;
|
||||||
|
import com.tencent.supersonic.chat.server.util.QueryReqConverter;
|
||||||
|
import com.tencent.supersonic.common.jsqlparser.FieldExpression;
|
||||||
|
import com.tencent.supersonic.common.jsqlparser.SqlAddHelper;
|
||||||
|
import com.tencent.supersonic.common.jsqlparser.SqlRemoveHelper;
|
||||||
|
import com.tencent.supersonic.common.jsqlparser.SqlReplaceHelper;
|
||||||
|
import com.tencent.supersonic.common.jsqlparser.SqlSelectHelper;
|
||||||
|
import com.tencent.supersonic.common.pojo.enums.FilterOperatorEnum;
|
||||||
|
import com.tencent.supersonic.common.pojo.enums.TimeDimensionEnum;
|
||||||
|
import com.tencent.supersonic.common.util.BeanMapper;
|
||||||
|
import com.tencent.supersonic.common.util.ContextUtils;
|
||||||
|
import com.tencent.supersonic.common.util.DateUtils;
|
||||||
|
import com.tencent.supersonic.common.util.JsonUtil;
|
||||||
|
import com.tencent.supersonic.headless.api.pojo.DataSetSchema;
|
||||||
|
import com.tencent.supersonic.headless.api.pojo.EntityInfo;
|
||||||
|
import com.tencent.supersonic.headless.api.pojo.SchemaElement;
|
||||||
|
import com.tencent.supersonic.headless.api.pojo.SemanticParseInfo;
|
||||||
|
import com.tencent.supersonic.headless.api.pojo.SqlInfo;
|
||||||
|
import com.tencent.supersonic.headless.api.pojo.request.DimensionValueReq;
|
||||||
|
import com.tencent.supersonic.headless.api.pojo.request.QueryFilter;
|
||||||
|
import com.tencent.supersonic.headless.api.pojo.request.QueryNLReq;
|
||||||
|
import com.tencent.supersonic.headless.api.pojo.request.SemanticQueryReq;
|
||||||
|
import com.tencent.supersonic.headless.api.pojo.response.MapResp;
|
||||||
|
import com.tencent.supersonic.headless.api.pojo.response.ParseResp;
|
||||||
|
import com.tencent.supersonic.headless.api.pojo.response.QueryState;
|
||||||
|
import com.tencent.supersonic.headless.api.pojo.response.SearchResult;
|
||||||
|
import com.tencent.supersonic.headless.api.pojo.response.SemanticQueryResp;
|
||||||
|
import com.tencent.supersonic.headless.api.pojo.response.SemanticTranslateResp;
|
||||||
|
import com.tencent.supersonic.headless.chat.query.QueryManager;
|
||||||
|
import com.tencent.supersonic.headless.chat.query.SemanticQuery;
|
||||||
|
import com.tencent.supersonic.headless.chat.query.llm.s2sql.LLMSqlQuery;
|
||||||
|
import com.tencent.supersonic.headless.server.facade.service.ChatLayerService;
|
||||||
|
import com.tencent.supersonic.headless.server.facade.service.SemanticLayerService;
|
||||||
|
import lombok.extern.slf4j.Slf4j;
|
||||||
|
import net.sf.jsqlparser.expression.Expression;
|
||||||
|
import net.sf.jsqlparser.expression.LongValue;
|
||||||
|
import net.sf.jsqlparser.expression.StringValue;
|
||||||
|
import net.sf.jsqlparser.expression.operators.relational.ComparisonOperator;
|
||||||
|
import net.sf.jsqlparser.expression.operators.relational.GreaterThanEquals;
|
||||||
|
import net.sf.jsqlparser.expression.operators.relational.InExpression;
|
||||||
|
import net.sf.jsqlparser.expression.operators.relational.MinorThanEquals;
|
||||||
|
import net.sf.jsqlparser.expression.operators.relational.ParenthesedExpressionList;
|
||||||
|
import net.sf.jsqlparser.schema.Column;
|
||||||
|
import org.apache.commons.lang3.StringUtils;
|
||||||
|
import org.apache.commons.lang3.tuple.Pair;
|
||||||
|
import org.springframework.beans.factory.annotation.Autowired;
|
||||||
|
import org.springframework.stereotype.Service;
|
||||||
|
import org.springframework.util.CollectionUtils;
|
||||||
|
|
||||||
|
import java.util.ArrayList;
|
||||||
|
import java.util.HashMap;
|
||||||
|
import java.util.HashSet;
|
||||||
|
import java.util.Iterator;
|
||||||
|
import java.util.List;
|
||||||
|
import java.util.Map;
|
||||||
|
import java.util.Objects;
|
||||||
|
import java.util.Set;
|
||||||
|
import java.util.stream.Collectors;
|
||||||
|
|
||||||
|
|
||||||
|
@Slf4j
|
||||||
|
@Service
|
||||||
|
public class ChatQueryServiceImpl implements ChatQueryService {
|
||||||
|
|
||||||
|
@Autowired
|
||||||
|
private ChatManageService chatManageService;
|
||||||
|
@Autowired
|
||||||
|
private ChatLayerService chatLayerService;
|
||||||
|
@Autowired
|
||||||
|
private SemanticLayerService semanticLayerService;
|
||||||
|
@Autowired
|
||||||
|
private AgentService agentService;
|
||||||
|
|
||||||
|
private List<ChatQueryParser> chatQueryParsers = ComponentFactory.getChatParsers();
|
||||||
|
private List<ChatQueryExecutor> chatQueryExecutors = ComponentFactory.getChatExecutors();
|
||||||
|
private List<ParseResultProcessor> parseResultProcessors = ComponentFactory.getParseProcessors();
|
||||||
|
private List<ExecuteResultProcessor> executeResultProcessors = ComponentFactory.getExecuteProcessors();
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public List<SearchResult> search(ChatParseReq chatParseReq) {
|
||||||
|
ParseContext parseContext = buildParseContext(chatParseReq);
|
||||||
|
Agent agent = parseContext.getAgent();
|
||||||
|
if (!agent.enableSearch()) {
|
||||||
|
return Lists.newArrayList();
|
||||||
|
}
|
||||||
|
QueryNLReq queryNLReq = QueryReqConverter.buildText2SqlQueryReq(parseContext);
|
||||||
|
return chatLayerService.retrieve(queryNLReq);
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public ParseResp performParsing(ChatParseReq chatParseReq) {
|
||||||
|
ParseResp parseResp = new ParseResp(chatParseReq.getQueryText());
|
||||||
|
chatManageService.createChatQuery(chatParseReq, parseResp);
|
||||||
|
ParseContext parseContext = buildParseContext(chatParseReq);
|
||||||
|
supplyMapInfo(parseContext);
|
||||||
|
for (ChatQueryParser chatQueryParser : chatQueryParsers) {
|
||||||
|
chatQueryParser.parse(parseContext, parseResp);
|
||||||
|
}
|
||||||
|
for (ParseResultProcessor processor : parseResultProcessors) {
|
||||||
|
processor.process(parseContext, parseResp);
|
||||||
|
}
|
||||||
|
chatParseReq.setQueryText(parseContext.getQueryText());
|
||||||
|
chatManageService.batchAddParse(chatParseReq, parseResp);
|
||||||
|
chatManageService.updateParseCostTime(parseResp);
|
||||||
|
return parseResp;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public QueryResult performExecution(ChatExecuteReq chatExecuteReq) {
|
||||||
|
QueryResult queryResult = new QueryResult();
|
||||||
|
ExecuteContext executeContext = buildExecuteContext(chatExecuteReq);
|
||||||
|
for (ChatQueryExecutor chatQueryExecutor : chatQueryExecutors) {
|
||||||
|
queryResult = chatQueryExecutor.execute(executeContext);
|
||||||
|
if (queryResult != null) {
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if (queryResult != null) {
|
||||||
|
for (ExecuteResultProcessor processor : executeResultProcessors) {
|
||||||
|
processor.process(executeContext, queryResult);
|
||||||
|
}
|
||||||
|
saveQueryResult(chatExecuteReq, queryResult);
|
||||||
|
}
|
||||||
|
|
||||||
|
return queryResult;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public QueryResult parseAndExecute(int chatId, int agentId, String queryText) {
|
||||||
|
ChatParseReq chatParseReq = new ChatParseReq();
|
||||||
|
chatParseReq.setQueryText(queryText);
|
||||||
|
chatParseReq.setChatId(chatId);
|
||||||
|
chatParseReq.setAgentId(agentId);
|
||||||
|
chatParseReq.setUser(User.getFakeUser());
|
||||||
|
ParseResp parseResp = performParsing(chatParseReq);
|
||||||
|
if (CollectionUtils.isEmpty(parseResp.getSelectedParses())) {
|
||||||
|
log.debug("chatId:{}, agentId:{}, queryText:{}, parseResp.getSelectedParses() is empty",
|
||||||
|
chatId, agentId, queryText);
|
||||||
|
return null;
|
||||||
|
}
|
||||||
|
ChatExecuteReq executeReq = new ChatExecuteReq();
|
||||||
|
executeReq.setQueryId(parseResp.getQueryId());
|
||||||
|
executeReq.setParseId(parseResp.getSelectedParses().get(0).getId());
|
||||||
|
executeReq.setQueryText(queryText);
|
||||||
|
executeReq.setChatId(chatId);
|
||||||
|
executeReq.setUser(User.getFakeUser());
|
||||||
|
executeReq.setAgentId(agentId);
|
||||||
|
executeReq.setSaveAnswer(true);
|
||||||
|
return performExecution(executeReq);
|
||||||
|
}
|
||||||
|
|
||||||
|
private ParseContext buildParseContext(ChatParseReq chatParseReq) {
|
||||||
|
ParseContext parseContext = new ParseContext();
|
||||||
|
BeanMapper.mapper(chatParseReq, parseContext);
|
||||||
|
Agent agent = agentService.getAgent(chatParseReq.getAgentId());
|
||||||
|
parseContext.setAgent(agent);
|
||||||
|
return parseContext;
|
||||||
|
}
|
||||||
|
|
||||||
|
private void supplyMapInfo(ParseContext parseContext) {
|
||||||
|
QueryNLReq queryNLReq = QueryReqConverter.buildText2SqlQueryReq(parseContext);
|
||||||
|
MapResp mapResp = chatLayerService.performMapping(queryNLReq);
|
||||||
|
parseContext.setMapInfo(mapResp.getMapInfo());
|
||||||
|
}
|
||||||
|
|
||||||
|
private ExecuteContext buildExecuteContext(ChatExecuteReq chatExecuteReq) {
|
||||||
|
ExecuteContext executeContext = new ExecuteContext();
|
||||||
|
BeanMapper.mapper(chatExecuteReq, executeContext);
|
||||||
|
SemanticParseInfo parseInfo = chatManageService.getParseInfo(
|
||||||
|
chatExecuteReq.getQueryId(), chatExecuteReq.getParseId());
|
||||||
|
Agent agent = agentService.getAgent(chatExecuteReq.getAgentId());
|
||||||
|
executeContext.setAgent(agent);
|
||||||
|
executeContext.setParseInfo(parseInfo);
|
||||||
|
return executeContext;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public Object queryData(ChatQueryDataReq chatQueryDataReq, User user) throws Exception {
|
||||||
|
Integer parseId = chatQueryDataReq.getParseId();
|
||||||
|
SemanticParseInfo parseInfo = chatManageService.getParseInfo(chatQueryDataReq.getQueryId(), parseId);
|
||||||
|
parseInfo = mergeParseInfo(parseInfo, chatQueryDataReq);
|
||||||
|
DataSetSchema dataSetSchema = semanticLayerService.getDataSetSchema(parseInfo.getDataSetId());
|
||||||
|
|
||||||
|
SemanticQuery semanticQuery = QueryManager.createQuery(parseInfo.getQueryMode());
|
||||||
|
semanticQuery.setParseInfo(parseInfo);
|
||||||
|
|
||||||
|
if (LLMSqlQuery.QUERY_MODE.equalsIgnoreCase(parseInfo.getQueryMode())) {
|
||||||
|
handleLLMQueryMode(chatQueryDataReq, semanticQuery, user);
|
||||||
|
} else {
|
||||||
|
handleRuleQueryMode(semanticQuery, dataSetSchema, user);
|
||||||
|
}
|
||||||
|
|
||||||
|
return executeQuery(semanticQuery, user, dataSetSchema);
|
||||||
|
}
|
||||||
|
|
||||||
|
private List<String> getFieldsFromSql(SemanticParseInfo parseInfo) {
|
||||||
|
SqlInfo sqlInfo = parseInfo.getSqlInfo();
|
||||||
|
if (Objects.isNull(sqlInfo) || StringUtils.isNotBlank(sqlInfo.getCorrectedS2SQL())) {
|
||||||
|
return new ArrayList<>();
|
||||||
|
}
|
||||||
|
return SqlSelectHelper.getAllSelectFields(sqlInfo.getCorrectedS2SQL());
|
||||||
|
}
|
||||||
|
|
||||||
|
private void handleLLMQueryMode(ChatQueryDataReq chatQueryDataReq,
|
||||||
|
SemanticQuery semanticQuery,
|
||||||
|
User user) throws Exception {
|
||||||
|
SemanticParseInfo parseInfo = semanticQuery.getParseInfo();
|
||||||
|
List<String> fields = getFieldsFromSql(parseInfo);
|
||||||
|
if (checkMetricReplace(fields, chatQueryDataReq.getMetrics())) {
|
||||||
|
log.info("llm begin replace metrics!");
|
||||||
|
SchemaElement metricToReplace = chatQueryDataReq.getMetrics().iterator().next();
|
||||||
|
replaceMetrics(parseInfo, metricToReplace);
|
||||||
|
} else {
|
||||||
|
log.info("llm begin revise filters!");
|
||||||
|
String correctorSql = reviseCorrectS2SQL(chatQueryDataReq, parseInfo);
|
||||||
|
parseInfo.getSqlInfo().setCorrectedS2SQL(correctorSql);
|
||||||
|
semanticQuery.setParseInfo(parseInfo);
|
||||||
|
SemanticQueryReq semanticQueryReq = semanticQuery.buildSemanticQueryReq();
|
||||||
|
SemanticTranslateResp explain = semanticLayerService.translate(semanticQueryReq, user);
|
||||||
|
parseInfo.getSqlInfo().setQuerySQL(explain.getQuerySQL());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
private void handleRuleQueryMode(SemanticQuery semanticQuery,
|
||||||
|
DataSetSchema dataSetSchema,
|
||||||
|
User user) {
|
||||||
|
log.info("rule begin replace metrics and revise filters!");
|
||||||
|
validFilter(semanticQuery.getParseInfo().getDimensionFilters());
|
||||||
|
validFilter(semanticQuery.getParseInfo().getMetricFilters());
|
||||||
|
semanticQuery.initS2Sql(dataSetSchema, user);
|
||||||
|
}
|
||||||
|
|
||||||
|
private QueryResult executeQuery(SemanticQuery semanticQuery,
|
||||||
|
User user,
|
||||||
|
DataSetSchema dataSetSchema) throws Exception {
|
||||||
|
SemanticQueryReq semanticQueryReq = semanticQuery.buildSemanticQueryReq();
|
||||||
|
SemanticParseInfo parseInfo = semanticQuery.getParseInfo();
|
||||||
|
QueryResult queryResult = doExecution(semanticQueryReq, parseInfo.getQueryMode(), user);
|
||||||
|
queryResult.setChatContext(semanticQuery.getParseInfo());
|
||||||
|
SemanticLayerService semanticService = ContextUtils.getBean(SemanticLayerService.class);
|
||||||
|
EntityInfo entityInfo = semanticService.getEntityInfo(parseInfo, dataSetSchema, user);
|
||||||
|
queryResult.setEntityInfo(entityInfo);
|
||||||
|
return queryResult;
|
||||||
|
}
|
||||||
|
|
||||||
|
private boolean checkMetricReplace(List<String> oriFields, Set<SchemaElement> metrics) {
|
||||||
|
if (CollectionUtils.isEmpty(oriFields) || CollectionUtils.isEmpty(metrics)) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
List<String> metricNames = metrics.stream().map(SchemaElement::getName).collect(Collectors.toList());
|
||||||
|
return !oriFields.containsAll(metricNames);
|
||||||
|
}
|
||||||
|
|
||||||
|
private String reviseCorrectS2SQL(ChatQueryDataReq queryData, SemanticParseInfo parseInfo) {
|
||||||
|
String correctorSql = parseInfo.getSqlInfo().getCorrectedS2SQL();
|
||||||
|
log.info("correctorSql before replacing:{}", correctorSql);
|
||||||
|
// get where filter and having filter
|
||||||
|
List<FieldExpression> whereExpressionList = SqlSelectHelper.getWhereExpressions(correctorSql);
|
||||||
|
|
||||||
|
// replace where filter
|
||||||
|
List<Expression> addWhereConditions = new ArrayList<>();
|
||||||
|
Set<String> removeWhereFieldNames = updateFilters(whereExpressionList, queryData.getDimensionFilters(),
|
||||||
|
parseInfo.getDimensionFilters(), addWhereConditions);
|
||||||
|
|
||||||
|
Map<String, Map<String, String>> filedNameToValueMap = new HashMap<>();
|
||||||
|
Set<String> removeDataFieldNames = updateDateInfo(queryData, parseInfo, filedNameToValueMap,
|
||||||
|
whereExpressionList, addWhereConditions);
|
||||||
|
removeWhereFieldNames.addAll(removeDataFieldNames);
|
||||||
|
|
||||||
|
correctorSql = SqlReplaceHelper.replaceValue(correctorSql, filedNameToValueMap);
|
||||||
|
correctorSql = SqlRemoveHelper.removeWhereCondition(correctorSql, removeWhereFieldNames);
|
||||||
|
|
||||||
|
// replace having filter
|
||||||
|
List<FieldExpression> havingExpressionList = SqlSelectHelper.getHavingExpressions(correctorSql);
|
||||||
|
List<Expression> addHavingConditions = new ArrayList<>();
|
||||||
|
Set<String> removeHavingFieldNames = updateFilters(havingExpressionList,
|
||||||
|
queryData.getDimensionFilters(), parseInfo.getDimensionFilters(), addHavingConditions);
|
||||||
|
correctorSql = SqlReplaceHelper.replaceHavingValue(correctorSql, new HashMap<>());
|
||||||
|
correctorSql = SqlRemoveHelper.removeHavingCondition(correctorSql, removeHavingFieldNames);
|
||||||
|
|
||||||
|
correctorSql = SqlAddHelper.addWhere(correctorSql, addWhereConditions);
|
||||||
|
correctorSql = SqlAddHelper.addHaving(correctorSql, addHavingConditions);
|
||||||
|
log.info("correctorSql after replacing:{}", correctorSql);
|
||||||
|
return correctorSql;
|
||||||
|
}
|
||||||
|
|
||||||
|
private void replaceMetrics(SemanticParseInfo parseInfo, SchemaElement metric) {
|
||||||
|
List<String> oriMetrics = parseInfo.getMetrics().stream()
|
||||||
|
.map(SchemaElement::getName).collect(Collectors.toList());
|
||||||
|
String correctorSql = parseInfo.getSqlInfo().getCorrectedS2SQL();
|
||||||
|
log.info("before replaceMetrics:{}", correctorSql);
|
||||||
|
log.info("filteredMetrics:{},metrics:{}", oriMetrics, metric);
|
||||||
|
Map<String, Pair<String, String>> fieldMap = new HashMap<>();
|
||||||
|
if (!CollectionUtils.isEmpty(oriMetrics) && !oriMetrics.contains(metric.getName())) {
|
||||||
|
fieldMap.put(oriMetrics.get(0), Pair.of(metric.getName(), metric.getDefaultAgg()));
|
||||||
|
correctorSql = SqlReplaceHelper.replaceAggFields(correctorSql, fieldMap);
|
||||||
|
}
|
||||||
|
log.info("after replaceMetrics:{}", correctorSql);
|
||||||
|
parseInfo.getSqlInfo().setCorrectedS2SQL(correctorSql);
|
||||||
|
}
|
||||||
|
|
||||||
|
private QueryResult doExecution(SemanticQueryReq semanticQueryReq, String queryMode, User user) throws Exception {
|
||||||
|
SemanticQueryResp queryResp = semanticLayerService.queryByReq(semanticQueryReq, user);
|
||||||
|
QueryResult queryResult = new QueryResult();
|
||||||
|
|
||||||
|
if (queryResp != null) {
|
||||||
|
queryResult.setQueryAuthorization(queryResp.getQueryAuthorization());
|
||||||
|
queryResult.setQuerySql(queryResp.getSql());
|
||||||
|
queryResult.setQueryResults(queryResp.getResultList());
|
||||||
|
queryResult.setQueryColumns(queryResp.getColumns());
|
||||||
|
} else {
|
||||||
|
queryResult.setQueryResults(new ArrayList<>());
|
||||||
|
queryResult.setQueryColumns(new ArrayList<>());
|
||||||
|
}
|
||||||
|
|
||||||
|
queryResult.setQueryMode(queryMode);
|
||||||
|
queryResult.setQueryState(QueryState.SUCCESS);
|
||||||
|
return queryResult;
|
||||||
|
}
|
||||||
|
|
||||||
|
private Set<String> updateDateInfo(ChatQueryDataReq queryData, SemanticParseInfo parseInfo,
|
||||||
|
Map<String, Map<String, String>> filedNameToValueMap,
|
||||||
|
List<FieldExpression> fieldExpressionList,
|
||||||
|
List<Expression> addConditions) {
|
||||||
|
Set<String> removeFieldNames = new HashSet<>();
|
||||||
|
if (Objects.isNull(queryData.getDateInfo())) {
|
||||||
|
return removeFieldNames;
|
||||||
|
}
|
||||||
|
if (queryData.getDateInfo().getUnit() > 1) {
|
||||||
|
queryData.getDateInfo().setStartDate(DateUtils.getBeforeDate(queryData.getDateInfo().getUnit() + 1));
|
||||||
|
queryData.getDateInfo().setEndDate(DateUtils.getBeforeDate(1));
|
||||||
|
}
|
||||||
|
// startDate equals to endDate
|
||||||
|
for (FieldExpression fieldExpression : fieldExpressionList) {
|
||||||
|
if (TimeDimensionEnum.DAY.getChName().equals(fieldExpression.getFieldName())) {
|
||||||
|
// first remove,then add
|
||||||
|
removeFieldNames.add(TimeDimensionEnum.DAY.getChName());
|
||||||
|
GreaterThanEquals greaterThanEquals = new GreaterThanEquals();
|
||||||
|
addTimeFilters(queryData.getDateInfo().getStartDate(), greaterThanEquals, addConditions);
|
||||||
|
MinorThanEquals minorThanEquals = new MinorThanEquals();
|
||||||
|
addTimeFilters(queryData.getDateInfo().getEndDate(), minorThanEquals, addConditions);
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
for (FieldExpression fieldExpression : fieldExpressionList) {
|
||||||
|
for (QueryFilter queryFilter : queryData.getDimensionFilters()) {
|
||||||
|
if (queryFilter.getOperator().equals(FilterOperatorEnum.LIKE)
|
||||||
|
&& FilterOperatorEnum.LIKE.getValue().equalsIgnoreCase(
|
||||||
|
fieldExpression.getOperator())) {
|
||||||
|
Map<String, String> replaceMap = new HashMap<>();
|
||||||
|
String preValue = fieldExpression.getFieldValue().toString();
|
||||||
|
String curValue = queryFilter.getValue().toString();
|
||||||
|
if (preValue.startsWith("%")) {
|
||||||
|
curValue = "%" + curValue;
|
||||||
|
}
|
||||||
|
if (preValue.endsWith("%")) {
|
||||||
|
curValue = curValue + "%";
|
||||||
|
}
|
||||||
|
replaceMap.put(preValue, curValue);
|
||||||
|
filedNameToValueMap.put(fieldExpression.getFieldName(), replaceMap);
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
parseInfo.setDateInfo(queryData.getDateInfo());
|
||||||
|
return removeFieldNames;
|
||||||
|
}
|
||||||
|
|
||||||
|
private <T extends ComparisonOperator> void addTimeFilters(String date,
|
||||||
|
T comparisonExpression,
|
||||||
|
List<Expression> addConditions) {
|
||||||
|
Column column = new Column(TimeDimensionEnum.DAY.getChName());
|
||||||
|
StringValue stringValue = new StringValue(date);
|
||||||
|
comparisonExpression.setLeftExpression(column);
|
||||||
|
comparisonExpression.setRightExpression(stringValue);
|
||||||
|
addConditions.add(comparisonExpression);
|
||||||
|
}
|
||||||
|
|
||||||
|
private Set<String> updateFilters(List<FieldExpression> fieldExpressionList,
|
||||||
|
Set<QueryFilter> metricFilters,
|
||||||
|
Set<QueryFilter> contextMetricFilters,
|
||||||
|
List<Expression> addConditions) {
|
||||||
|
Set<String> removeFieldNames = new HashSet<>();
|
||||||
|
if (CollectionUtils.isEmpty(metricFilters)) {
|
||||||
|
return removeFieldNames;
|
||||||
|
}
|
||||||
|
|
||||||
|
for (QueryFilter dslQueryFilter : metricFilters) {
|
||||||
|
for (FieldExpression fieldExpression : fieldExpressionList) {
|
||||||
|
if (fieldExpression.getFieldName() != null
|
||||||
|
&& fieldExpression.getFieldName().contains(dslQueryFilter.getName())) {
|
||||||
|
removeFieldNames.add(dslQueryFilter.getName());
|
||||||
|
handleFilter(dslQueryFilter, contextMetricFilters, addConditions);
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return removeFieldNames;
|
||||||
|
}
|
||||||
|
|
||||||
|
private void handleFilter(QueryFilter dslQueryFilter,
|
||||||
|
Set<QueryFilter> contextMetricFilters,
|
||||||
|
List<Expression> addConditions) {
|
||||||
|
FilterOperatorEnum operator = dslQueryFilter.getOperator();
|
||||||
|
|
||||||
|
if (operator == FilterOperatorEnum.IN) {
|
||||||
|
addWhereInFilters(dslQueryFilter, new InExpression(), contextMetricFilters, addConditions);
|
||||||
|
} else {
|
||||||
|
ComparisonOperator expression = FilterOperatorEnum.createExpression(operator);
|
||||||
|
if (Objects.nonNull(expression)) {
|
||||||
|
addWhereFilters(dslQueryFilter, expression, contextMetricFilters, addConditions);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// add in condition to sql where condition
|
||||||
|
private void addWhereInFilters(QueryFilter dslQueryFilter,
|
||||||
|
InExpression inExpression,
|
||||||
|
Set<QueryFilter> contextMetricFilters,
|
||||||
|
List<Expression> addConditions) {
|
||||||
|
Column column = new Column(dslQueryFilter.getName());
|
||||||
|
ParenthesedExpressionList parenthesedExpressionList = new ParenthesedExpressionList<>();
|
||||||
|
List<String> valueList = JsonUtil.toList(
|
||||||
|
JsonUtil.toString(dslQueryFilter.getValue()), String.class);
|
||||||
|
if (CollectionUtils.isEmpty(valueList)) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
valueList.stream().forEach(o -> {
|
||||||
|
StringValue stringValue = new StringValue(o);
|
||||||
|
parenthesedExpressionList.add(stringValue);
|
||||||
|
});
|
||||||
|
inExpression.setLeftExpression(column);
|
||||||
|
inExpression.setRightExpression(parenthesedExpressionList);
|
||||||
|
addConditions.add(inExpression);
|
||||||
|
contextMetricFilters.stream().forEach(o -> {
|
||||||
|
if (o.getName().equals(dslQueryFilter.getName())) {
|
||||||
|
o.setValue(dslQueryFilter.getValue());
|
||||||
|
o.setOperator(dslQueryFilter.getOperator());
|
||||||
|
}
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
// add where filter
|
||||||
|
private void addWhereFilters(QueryFilter dslQueryFilter,
|
||||||
|
ComparisonOperator comparisonExpression,
|
||||||
|
Set<QueryFilter> contextMetricFilters,
|
||||||
|
List<Expression> addConditions) {
|
||||||
|
String columnName = dslQueryFilter.getName();
|
||||||
|
if (StringUtils.isNotBlank(dslQueryFilter.getFunction())) {
|
||||||
|
columnName = dslQueryFilter.getFunction() + "(" + dslQueryFilter.getName() + ")";
|
||||||
|
}
|
||||||
|
if (Objects.isNull(dslQueryFilter.getValue())) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
Column column = new Column(columnName);
|
||||||
|
comparisonExpression.setLeftExpression(column);
|
||||||
|
if (StringUtils.isNumeric(dslQueryFilter.getValue().toString())) {
|
||||||
|
LongValue longValue = new LongValue(Long.parseLong(dslQueryFilter.getValue().toString()));
|
||||||
|
comparisonExpression.setRightExpression(longValue);
|
||||||
|
} else {
|
||||||
|
StringValue stringValue = new StringValue(dslQueryFilter.getValue().toString());
|
||||||
|
comparisonExpression.setRightExpression(stringValue);
|
||||||
|
}
|
||||||
|
addConditions.add(comparisonExpression);
|
||||||
|
contextMetricFilters.stream().forEach(o -> {
|
||||||
|
if (o.getName().equals(dslQueryFilter.getName())) {
|
||||||
|
o.setValue(dslQueryFilter.getValue());
|
||||||
|
o.setOperator(dslQueryFilter.getOperator());
|
||||||
|
}
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
private SemanticParseInfo mergeParseInfo(SemanticParseInfo parseInfo,
|
||||||
|
ChatQueryDataReq queryData) {
|
||||||
|
if (LLMSqlQuery.QUERY_MODE.equals(parseInfo.getQueryMode())) {
|
||||||
|
return parseInfo;
|
||||||
|
}
|
||||||
|
if (!CollectionUtils.isEmpty(queryData.getDimensions())) {
|
||||||
|
parseInfo.setDimensions(queryData.getDimensions());
|
||||||
|
}
|
||||||
|
if (!CollectionUtils.isEmpty(queryData.getMetrics())) {
|
||||||
|
parseInfo.setMetrics(queryData.getMetrics());
|
||||||
|
}
|
||||||
|
if (!CollectionUtils.isEmpty(queryData.getDimensionFilters())) {
|
||||||
|
parseInfo.setDimensionFilters(queryData.getDimensionFilters());
|
||||||
|
}
|
||||||
|
if (!CollectionUtils.isEmpty(queryData.getMetricFilters())) {
|
||||||
|
parseInfo.setMetricFilters(queryData.getMetricFilters());
|
||||||
|
}
|
||||||
|
if (Objects.nonNull(queryData.getDateInfo())) {
|
||||||
|
parseInfo.setDateInfo(queryData.getDateInfo());
|
||||||
|
}
|
||||||
|
return parseInfo;
|
||||||
|
}
|
||||||
|
|
||||||
|
private void validFilter(Set<QueryFilter> filters) {
|
||||||
|
Iterator<QueryFilter> iterator = filters.iterator();
|
||||||
|
while (iterator.hasNext()) {
|
||||||
|
QueryFilter queryFilter = iterator.next();
|
||||||
|
Object queryFilterValue = queryFilter.getValue();
|
||||||
|
if (Objects.isNull(queryFilterValue)) {
|
||||||
|
iterator.remove();
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
List<String> collection = JsonUtil.toList(JsonUtil.toString(queryFilterValue), String.class);
|
||||||
|
if (FilterOperatorEnum.IN.equals(queryFilter.getOperator())
|
||||||
|
&& CollectionUtils.isEmpty(collection)) {
|
||||||
|
iterator.remove();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public Object queryDimensionValue(DimensionValueReq dimensionValueReq, User user) {
|
||||||
|
Integer agentId = dimensionValueReq.getAgentId();
|
||||||
|
Agent agent = agentService.getAgent(agentId);
|
||||||
|
dimensionValueReq.setDataSetIds(agent.getDataSetIds());
|
||||||
|
return semanticLayerService.queryDimensionValue(dimensionValueReq, user);
|
||||||
|
}
|
||||||
|
|
||||||
|
public void saveQueryResult(ChatExecuteReq chatExecuteReq, QueryResult queryResult) {
|
||||||
|
//The history record only retains the query result of the first parse
|
||||||
|
if (chatExecuteReq.getParseId() > 1) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
chatManageService.saveQueryResult(chatExecuteReq, queryResult);
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
||||||
@@ -1,186 +0,0 @@
|
|||||||
package com.tencent.supersonic.chat.server.service.impl;
|
|
||||||
|
|
||||||
import com.google.common.collect.Lists;
|
|
||||||
import com.tencent.supersonic.auth.api.authentication.pojo.User;
|
|
||||||
import com.tencent.supersonic.chat.api.pojo.request.ChatExecuteReq;
|
|
||||||
import com.tencent.supersonic.chat.api.pojo.request.ChatParseReq;
|
|
||||||
import com.tencent.supersonic.chat.api.pojo.request.ChatQueryDataReq;
|
|
||||||
import com.tencent.supersonic.chat.server.agent.Agent;
|
|
||||||
import com.tencent.supersonic.chat.server.executor.ChatExecutor;
|
|
||||||
import com.tencent.supersonic.chat.server.parser.ChatParser;
|
|
||||||
import com.tencent.supersonic.chat.server.pojo.ChatExecuteContext;
|
|
||||||
import com.tencent.supersonic.chat.server.pojo.ChatParseContext;
|
|
||||||
import com.tencent.supersonic.chat.server.processor.execute.ExecuteResultProcessor;
|
|
||||||
import com.tencent.supersonic.chat.server.processor.parse.ParseResultProcessor;
|
|
||||||
import com.tencent.supersonic.chat.server.service.AgentService;
|
|
||||||
import com.tencent.supersonic.chat.server.service.ChatManageService;
|
|
||||||
import com.tencent.supersonic.chat.server.service.ChatService;
|
|
||||||
import com.tencent.supersonic.chat.server.util.ComponentFactory;
|
|
||||||
import com.tencent.supersonic.chat.server.util.QueryReqConverter;
|
|
||||||
import com.tencent.supersonic.common.util.BeanMapper;
|
|
||||||
import com.tencent.supersonic.headless.api.pojo.SemanticParseInfo;
|
|
||||||
import com.tencent.supersonic.headless.api.pojo.request.DimensionValueReq;
|
|
||||||
import com.tencent.supersonic.headless.api.pojo.request.QueryDataReq;
|
|
||||||
import com.tencent.supersonic.headless.api.pojo.request.QueryReq;
|
|
||||||
import com.tencent.supersonic.headless.api.pojo.response.MapResp;
|
|
||||||
import com.tencent.supersonic.headless.api.pojo.response.ParseResp;
|
|
||||||
import com.tencent.supersonic.headless.api.pojo.response.QueryResult;
|
|
||||||
import com.tencent.supersonic.headless.api.pojo.response.SearchResult;
|
|
||||||
import com.tencent.supersonic.headless.server.facade.service.ChatQueryService;
|
|
||||||
import com.tencent.supersonic.headless.server.facade.service.RetrieveService;
|
|
||||||
import lombok.extern.slf4j.Slf4j;
|
|
||||||
import org.springframework.beans.factory.annotation.Autowired;
|
|
||||||
import org.springframework.stereotype.Service;
|
|
||||||
import org.springframework.util.CollectionUtils;
|
|
||||||
|
|
||||||
import java.util.List;
|
|
||||||
|
|
||||||
|
|
||||||
@Slf4j
|
|
||||||
@Service
|
|
||||||
public class ChatServiceImpl implements ChatService {
|
|
||||||
|
|
||||||
@Autowired
|
|
||||||
private ChatManageService chatManageService;
|
|
||||||
@Autowired
|
|
||||||
private ChatQueryService chatQueryService;
|
|
||||||
@Autowired
|
|
||||||
private RetrieveService retrieveService;
|
|
||||||
@Autowired
|
|
||||||
private AgentService agentService;
|
|
||||||
private List<ChatParser> chatParsers = ComponentFactory.getChatParsers();
|
|
||||||
private List<ChatExecutor> chatExecutors = ComponentFactory.getChatExecutors();
|
|
||||||
private List<ParseResultProcessor> parseResultProcessors = ComponentFactory.getParseProcessors();
|
|
||||||
private List<ExecuteResultProcessor> executeResultProcessors = ComponentFactory.getExecuteProcessors();
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public List<SearchResult> search(ChatParseReq chatParseReq) {
|
|
||||||
ChatParseContext chatParseContext = buildParseContext(chatParseReq);
|
|
||||||
Agent agent = chatParseContext.getAgent();
|
|
||||||
if (!agent.enableSearch()) {
|
|
||||||
return Lists.newArrayList();
|
|
||||||
}
|
|
||||||
QueryReq queryReq = QueryReqConverter.buildText2SqlQueryReq(chatParseContext);
|
|
||||||
return retrieveService.retrieve(queryReq);
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public ParseResp performParsing(ChatParseReq chatParseReq) {
|
|
||||||
ParseResp parseResp = new ParseResp(chatParseReq.getChatId(), chatParseReq.getQueryText());
|
|
||||||
chatManageService.createChatQuery(chatParseReq, parseResp);
|
|
||||||
ChatParseContext chatParseContext = buildParseContext(chatParseReq);
|
|
||||||
supplyMapInfo(chatParseContext);
|
|
||||||
for (ChatParser chatParser : chatParsers) {
|
|
||||||
chatParser.parse(chatParseContext, parseResp);
|
|
||||||
}
|
|
||||||
for (ParseResultProcessor processor : parseResultProcessors) {
|
|
||||||
processor.process(chatParseContext, parseResp);
|
|
||||||
}
|
|
||||||
chatParseReq.setQueryText(chatParseContext.getQueryText());
|
|
||||||
parseResp.setQueryText(chatParseContext.getQueryText());
|
|
||||||
chatManageService.batchAddParse(chatParseReq, parseResp);
|
|
||||||
chatManageService.updateParseCostTime(parseResp);
|
|
||||||
return parseResp;
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public QueryResult performExecution(ChatExecuteReq chatExecuteReq) {
|
|
||||||
QueryResult queryResult = new QueryResult();
|
|
||||||
ChatExecuteContext chatExecuteContext = buildExecuteContext(chatExecuteReq);
|
|
||||||
for (ChatExecutor chatExecutor : chatExecutors) {
|
|
||||||
queryResult = chatExecutor.execute(chatExecuteContext);
|
|
||||||
if (queryResult != null) {
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if (queryResult != null) {
|
|
||||||
for (ExecuteResultProcessor processor : executeResultProcessors) {
|
|
||||||
processor.process(chatExecuteContext, queryResult);
|
|
||||||
}
|
|
||||||
saveQueryResult(chatExecuteReq, queryResult);
|
|
||||||
}
|
|
||||||
|
|
||||||
return queryResult;
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public QueryResult parseAndExecute(int chatId, int agentId, String queryText) {
|
|
||||||
ChatParseReq chatParseReq = new ChatParseReq();
|
|
||||||
chatParseReq.setQueryText(queryText);
|
|
||||||
chatParseReq.setChatId(chatId);
|
|
||||||
chatParseReq.setAgentId(agentId);
|
|
||||||
chatParseReq.setUser(User.getFakeUser());
|
|
||||||
ParseResp parseResp = performParsing(chatParseReq);
|
|
||||||
if (CollectionUtils.isEmpty(parseResp.getSelectedParses())) {
|
|
||||||
log.debug("chatId:{}, agentId:{}, queryText:{}, parseResp.getSelectedParses() is empty",
|
|
||||||
chatId, agentId, queryText);
|
|
||||||
return null;
|
|
||||||
}
|
|
||||||
ChatExecuteReq executeReq = new ChatExecuteReq();
|
|
||||||
executeReq.setQueryId(parseResp.getQueryId());
|
|
||||||
executeReq.setParseId(parseResp.getSelectedParses().get(0).getId());
|
|
||||||
executeReq.setQueryText(queryText);
|
|
||||||
executeReq.setChatId(parseResp.getChatId());
|
|
||||||
executeReq.setUser(User.getFakeUser());
|
|
||||||
executeReq.setAgentId(agentId);
|
|
||||||
executeReq.setSaveAnswer(true);
|
|
||||||
return performExecution(executeReq);
|
|
||||||
}
|
|
||||||
|
|
||||||
private ChatParseContext buildParseContext(ChatParseReq chatParseReq) {
|
|
||||||
ChatParseContext chatParseContext = new ChatParseContext();
|
|
||||||
BeanMapper.mapper(chatParseReq, chatParseContext);
|
|
||||||
Agent agent = agentService.getAgent(chatParseReq.getAgentId());
|
|
||||||
chatParseContext.setAgent(agent);
|
|
||||||
return chatParseContext;
|
|
||||||
}
|
|
||||||
|
|
||||||
private void supplyMapInfo(ChatParseContext chatParseContext) {
|
|
||||||
QueryReq queryReq = QueryReqConverter.buildText2SqlQueryReq(chatParseContext);
|
|
||||||
MapResp mapResp = chatQueryService.performMapping(queryReq);
|
|
||||||
chatParseContext.setMapInfo(mapResp.getMapInfo());
|
|
||||||
}
|
|
||||||
|
|
||||||
private ChatExecuteContext buildExecuteContext(ChatExecuteReq chatExecuteReq) {
|
|
||||||
ChatExecuteContext chatExecuteContext = new ChatExecuteContext();
|
|
||||||
BeanMapper.mapper(chatExecuteReq, chatExecuteContext);
|
|
||||||
SemanticParseInfo parseInfo = chatManageService.getParseInfo(
|
|
||||||
chatExecuteReq.getQueryId(), chatExecuteReq.getParseId());
|
|
||||||
chatExecuteContext.setParseInfo(parseInfo);
|
|
||||||
return chatExecuteContext;
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public Object queryData(ChatQueryDataReq chatQueryDataReq, User user) throws Exception {
|
|
||||||
Integer parseId = chatQueryDataReq.getParseId();
|
|
||||||
SemanticParseInfo parseInfo = chatManageService.getParseInfo(
|
|
||||||
chatQueryDataReq.getQueryId(), parseId);
|
|
||||||
QueryDataReq queryData = new QueryDataReq();
|
|
||||||
BeanMapper.mapper(chatQueryDataReq, queryData);
|
|
||||||
queryData.setParseInfo(parseInfo);
|
|
||||||
return chatQueryService.executeDirectQuery(queryData, user);
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public SemanticParseInfo queryContext(Integer chatId) {
|
|
||||||
return chatQueryService.queryContext(chatId);
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public Object queryDimensionValue(DimensionValueReq dimensionValueReq, User user) throws Exception {
|
|
||||||
Integer agentId = dimensionValueReq.getAgentId();
|
|
||||||
Agent agent = agentService.getAgent(agentId);
|
|
||||||
dimensionValueReq.setDataSetIds(agent.getDataSetIds());
|
|
||||||
return chatQueryService.queryDimensionValue(dimensionValueReq, user);
|
|
||||||
}
|
|
||||||
|
|
||||||
public void saveQueryResult(ChatExecuteReq chatExecuteReq, QueryResult queryResult) {
|
|
||||||
//The history record only retains the query result of the first parse
|
|
||||||
if (chatExecuteReq.getParseId() > 1) {
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
chatManageService.saveQueryResult(chatExecuteReq, queryResult);
|
|
||||||
}
|
|
||||||
|
|
||||||
}
|
|
||||||
@@ -30,9 +30,7 @@ import com.tencent.supersonic.headless.api.pojo.SchemaElement;
|
|||||||
import com.tencent.supersonic.headless.api.pojo.SchemaItem;
|
import com.tencent.supersonic.headless.api.pojo.SchemaItem;
|
||||||
import com.tencent.supersonic.headless.api.pojo.response.DimensionResp;
|
import com.tencent.supersonic.headless.api.pojo.response.DimensionResp;
|
||||||
import com.tencent.supersonic.headless.api.pojo.response.MetricResp;
|
import com.tencent.supersonic.headless.api.pojo.response.MetricResp;
|
||||||
import com.tencent.supersonic.headless.server.pojo.MetaFilter;
|
import com.tencent.supersonic.headless.api.pojo.MetaFilter;
|
||||||
import com.tencent.supersonic.headless.server.web.service.DimensionService;
|
|
||||||
import com.tencent.supersonic.headless.server.web.service.MetricService;
|
|
||||||
import com.tencent.supersonic.headless.server.facade.service.SemanticLayerService;
|
import com.tencent.supersonic.headless.server.facade.service.SemanticLayerService;
|
||||||
import lombok.extern.slf4j.Slf4j;
|
import lombok.extern.slf4j.Slf4j;
|
||||||
import org.springframework.beans.BeanUtils;
|
import org.springframework.beans.BeanUtils;
|
||||||
@@ -53,18 +51,13 @@ public class ConfigServiceImpl implements ConfigService {
|
|||||||
|
|
||||||
private final ChatConfigRepository chatConfigRepository;
|
private final ChatConfigRepository chatConfigRepository;
|
||||||
private final ChatConfigHelper chatConfigHelper;
|
private final ChatConfigHelper chatConfigHelper;
|
||||||
private final DimensionService dimensionService;
|
|
||||||
private final MetricService metricService;
|
|
||||||
private final SemanticLayerService semanticLayerService;
|
private final SemanticLayerService semanticLayerService;
|
||||||
|
|
||||||
|
|
||||||
public ConfigServiceImpl(ChatConfigRepository chatConfigRepository,
|
public ConfigServiceImpl(ChatConfigRepository chatConfigRepository,
|
||||||
ChatConfigHelper chatConfigHelper, DimensionService dimensionService,
|
ChatConfigHelper chatConfigHelper, SemanticLayerService semanticLayerService) {
|
||||||
MetricService metricService, SemanticLayerService semanticLayerService) {
|
|
||||||
this.chatConfigRepository = chatConfigRepository;
|
this.chatConfigRepository = chatConfigRepository;
|
||||||
this.chatConfigHelper = chatConfigHelper;
|
this.chatConfigHelper = chatConfigHelper;
|
||||||
this.dimensionService = dimensionService;
|
|
||||||
this.metricService = metricService;
|
|
||||||
this.semanticLayerService = semanticLayerService;
|
this.semanticLayerService = semanticLayerService;
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -136,14 +129,14 @@ public class ConfigServiceImpl implements ConfigService {
|
|||||||
MetaFilter metaFilter = new MetaFilter();
|
MetaFilter metaFilter = new MetaFilter();
|
||||||
metaFilter.setModelIds(Lists.newArrayList(modelId));
|
metaFilter.setModelIds(Lists.newArrayList(modelId));
|
||||||
if (!CollectionUtils.isEmpty(blackDimIdList)) {
|
if (!CollectionUtils.isEmpty(blackDimIdList)) {
|
||||||
List<DimensionResp> dimensionRespList = dimensionService.getDimensions(metaFilter);
|
List<DimensionResp> dimensionRespList = semanticLayerService.getDimensions(metaFilter);
|
||||||
List<String> blackDimNameList = dimensionRespList.stream().filter(o -> filterDimIdList.contains(o.getId()))
|
List<String> blackDimNameList = dimensionRespList.stream().filter(o -> filterDimIdList.contains(o.getId()))
|
||||||
.map(SchemaItem::getName).collect(Collectors.toList());
|
.map(SchemaItem::getName).collect(Collectors.toList());
|
||||||
itemNameVisibility.setBlackDimNameList(blackDimNameList);
|
itemNameVisibility.setBlackDimNameList(blackDimNameList);
|
||||||
}
|
}
|
||||||
if (!CollectionUtils.isEmpty(blackMetricIdList)) {
|
if (!CollectionUtils.isEmpty(blackMetricIdList)) {
|
||||||
|
|
||||||
List<MetricResp> metricRespList = metricService.getMetrics(metaFilter);
|
List<MetricResp> metricRespList = semanticLayerService.getMetrics(metaFilter);
|
||||||
List<String> blackMetricList = metricRespList.stream().filter(o -> filterMetricIdList.contains(o.getId()))
|
List<String> blackMetricList = metricRespList.stream().filter(o -> filterMetricIdList.contains(o.getId()))
|
||||||
.map(SchemaItem::getName).collect(Collectors.toList());
|
.map(SchemaItem::getName).collect(Collectors.toList());
|
||||||
itemNameVisibility.setBlackMetricNameList(blackMetricList);
|
itemNameVisibility.setBlackMetricNameList(blackMetricList);
|
||||||
|
|||||||
@@ -12,7 +12,7 @@ import com.tencent.supersonic.chat.server.persistence.dataobject.ChatMemoryDO;
|
|||||||
import com.tencent.supersonic.chat.server.persistence.repository.ChatMemoryRepository;
|
import com.tencent.supersonic.chat.server.persistence.repository.ChatMemoryRepository;
|
||||||
import com.tencent.supersonic.chat.server.service.MemoryService;
|
import com.tencent.supersonic.chat.server.service.MemoryService;
|
||||||
import com.tencent.supersonic.common.config.EmbeddingConfig;
|
import com.tencent.supersonic.common.config.EmbeddingConfig;
|
||||||
import com.tencent.supersonic.common.pojo.SqlExemplar;
|
import com.tencent.supersonic.common.pojo.Text2SQLExemplar;
|
||||||
import com.tencent.supersonic.common.service.ExemplarService;
|
import com.tencent.supersonic.common.service.ExemplarService;
|
||||||
import com.tencent.supersonic.common.util.BeanMapper;
|
import com.tencent.supersonic.common.util.BeanMapper;
|
||||||
import java.util.List;
|
import java.util.List;
|
||||||
@@ -96,19 +96,25 @@ public class MemoryServiceImpl implements MemoryService {
|
|||||||
return chatMemoryRepository.getMemories(queryWrapper);
|
return chatMemoryRepository.getMemories(queryWrapper);
|
||||||
}
|
}
|
||||||
|
|
||||||
private void enableMemory(ChatMemoryDO memory) {
|
@Override
|
||||||
|
public void enableMemory(ChatMemoryDO memory) {
|
||||||
|
memory.setStatus(MemoryStatus.ENABLED);
|
||||||
exemplarService.storeExemplar(embeddingConfig.getMemoryCollectionName(memory.getAgentId()),
|
exemplarService.storeExemplar(embeddingConfig.getMemoryCollectionName(memory.getAgentId()),
|
||||||
SqlExemplar.builder()
|
Text2SQLExemplar.builder()
|
||||||
.question(memory.getQuestion())
|
.question(memory.getQuestion())
|
||||||
|
.sideInfo(memory.getSideInfo())
|
||||||
.dbSchema(memory.getDbSchema())
|
.dbSchema(memory.getDbSchema())
|
||||||
.sql(memory.getS2sql())
|
.sql(memory.getS2sql())
|
||||||
.build());
|
.build());
|
||||||
}
|
}
|
||||||
|
|
||||||
private void disableMemory(ChatMemoryDO memory) {
|
@Override
|
||||||
|
public void disableMemory(ChatMemoryDO memory) {
|
||||||
|
memory.setStatus(MemoryStatus.DISABLED);
|
||||||
exemplarService.removeExemplar(embeddingConfig.getMemoryCollectionName(memory.getAgentId()),
|
exemplarService.removeExemplar(embeddingConfig.getMemoryCollectionName(memory.getAgentId()),
|
||||||
SqlExemplar.builder()
|
Text2SQLExemplar.builder()
|
||||||
.question(memory.getQuestion())
|
.question(memory.getQuestion())
|
||||||
|
.sideInfo(memory.getSideInfo())
|
||||||
.dbSchema(memory.getDbSchema())
|
.dbSchema(memory.getDbSchema())
|
||||||
.sql(memory.getS2sql())
|
.sql(memory.getS2sql())
|
||||||
.build());
|
.build());
|
||||||
|
|||||||
@@ -1,8 +1,8 @@
|
|||||||
package com.tencent.supersonic.chat.server.service.impl;
|
package com.tencent.supersonic.chat.server.service.impl;
|
||||||
|
|
||||||
import com.tencent.supersonic.headless.server.persistence.dataobject.StatisticsDO;
|
import com.tencent.supersonic.chat.server.persistence.dataobject.StatisticsDO;
|
||||||
import com.tencent.supersonic.chat.server.service.StatisticsService;
|
import com.tencent.supersonic.chat.server.service.StatisticsService;
|
||||||
import com.tencent.supersonic.headless.server.persistence.mapper.StatisticsMapper;
|
import com.tencent.supersonic.chat.server.persistence.mapper.StatisticsMapper;
|
||||||
import lombok.extern.slf4j.Slf4j;
|
import lombok.extern.slf4j.Slf4j;
|
||||||
import org.springframework.beans.factory.annotation.Autowired;
|
import org.springframework.beans.factory.annotation.Autowired;
|
||||||
import org.springframework.scheduling.annotation.Async;
|
import org.springframework.scheduling.annotation.Async;
|
||||||
|
|||||||
@@ -1,7 +1,7 @@
|
|||||||
package com.tencent.supersonic.chat.server.util;
|
package com.tencent.supersonic.chat.server.util;
|
||||||
|
|
||||||
import com.tencent.supersonic.chat.server.executor.ChatExecutor;
|
import com.tencent.supersonic.chat.server.executor.ChatQueryExecutor;
|
||||||
import com.tencent.supersonic.chat.server.parser.ChatParser;
|
import com.tencent.supersonic.chat.server.parser.ChatQueryParser;
|
||||||
import com.tencent.supersonic.chat.server.plugin.recognize.PluginRecognizer;
|
import com.tencent.supersonic.chat.server.plugin.recognize.PluginRecognizer;
|
||||||
import com.tencent.supersonic.chat.server.processor.execute.ExecuteResultProcessor;
|
import com.tencent.supersonic.chat.server.processor.execute.ExecuteResultProcessor;
|
||||||
import com.tencent.supersonic.chat.server.processor.parse.ParseResultProcessor;
|
import com.tencent.supersonic.chat.server.processor.parse.ParseResultProcessor;
|
||||||
@@ -16,8 +16,8 @@ import java.util.List;
|
|||||||
public class ComponentFactory {
|
public class ComponentFactory {
|
||||||
private static List<ParseResultProcessor> parseProcessors = new ArrayList<>();
|
private static List<ParseResultProcessor> parseProcessors = new ArrayList<>();
|
||||||
private static List<ExecuteResultProcessor> executeProcessors = new ArrayList<>();
|
private static List<ExecuteResultProcessor> executeProcessors = new ArrayList<>();
|
||||||
private static List<ChatParser> chatParsers = new ArrayList<>();
|
private static List<ChatQueryParser> chatQueryParsers = new ArrayList<>();
|
||||||
private static List<ChatExecutor> chatExecutors = new ArrayList<>();
|
private static List<ChatQueryExecutor> chatQueryExecutors = new ArrayList<>();
|
||||||
private static List<PluginRecognizer> pluginRecognizers = new ArrayList<>();
|
private static List<PluginRecognizer> pluginRecognizers = new ArrayList<>();
|
||||||
|
|
||||||
public static List<ParseResultProcessor> getParseProcessors() {
|
public static List<ParseResultProcessor> getParseProcessors() {
|
||||||
@@ -30,14 +30,14 @@ public class ComponentFactory {
|
|||||||
? init(ExecuteResultProcessor.class, executeProcessors) : executeProcessors;
|
? init(ExecuteResultProcessor.class, executeProcessors) : executeProcessors;
|
||||||
}
|
}
|
||||||
|
|
||||||
public static List<ChatParser> getChatParsers() {
|
public static List<ChatQueryParser> getChatParsers() {
|
||||||
return CollectionUtils.isEmpty(chatParsers)
|
return CollectionUtils.isEmpty(chatQueryParsers)
|
||||||
? init(ChatParser.class, chatParsers) : chatParsers;
|
? init(ChatQueryParser.class, chatQueryParsers) : chatQueryParsers;
|
||||||
}
|
}
|
||||||
|
|
||||||
public static List<ChatExecutor> getChatExecutors() {
|
public static List<ChatQueryExecutor> getChatExecutors() {
|
||||||
return CollectionUtils.isEmpty(chatExecutors)
|
return CollectionUtils.isEmpty(chatQueryExecutors)
|
||||||
? init(ChatExecutor.class, chatExecutors) : chatExecutors;
|
? init(ChatQueryExecutor.class, chatQueryExecutors) : chatQueryExecutors;
|
||||||
}
|
}
|
||||||
|
|
||||||
public static List<PluginRecognizer> getPluginRecognizers() {
|
public static List<PluginRecognizer> getPluginRecognizers() {
|
||||||
|
|||||||
@@ -1,20 +1,20 @@
|
|||||||
package com.tencent.supersonic.chat.server.util;
|
package com.tencent.supersonic.chat.server.util;
|
||||||
|
|
||||||
import com.tencent.supersonic.common.config.LLMConfig;
|
import com.tencent.supersonic.common.pojo.ChatModelConfig;
|
||||||
import com.tencent.supersonic.common.pojo.exception.InvalidArgumentException;
|
import com.tencent.supersonic.common.pojo.exception.InvalidArgumentException;
|
||||||
import dev.langchain4j.model.chat.ChatLanguageModel;
|
import dev.langchain4j.model.chat.ChatLanguageModel;
|
||||||
import dev.langchain4j.model.provider.ChatLanguageModelProvider;
|
import dev.langchain4j.provider.ModelProvider;
|
||||||
import lombok.extern.slf4j.Slf4j;
|
import lombok.extern.slf4j.Slf4j;
|
||||||
import org.apache.commons.lang3.StringUtils;
|
import org.apache.commons.lang3.StringUtils;
|
||||||
|
|
||||||
@Slf4j
|
@Slf4j
|
||||||
public class LLMConnHelper {
|
public class LLMConnHelper {
|
||||||
public static boolean testConnection(LLMConfig llmConfig) {
|
public static boolean testConnection(ChatModelConfig modelConfig) {
|
||||||
try {
|
try {
|
||||||
if (llmConfig == null || StringUtils.isBlank(llmConfig.getBaseUrl())) {
|
if (modelConfig == null || StringUtils.isBlank(modelConfig.getBaseUrl())) {
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
ChatLanguageModel chatLanguageModel = ChatLanguageModelProvider.provide(llmConfig);
|
ChatLanguageModel chatLanguageModel = ModelProvider.getChatModel(modelConfig);
|
||||||
String response = chatLanguageModel.generate("Hi there");
|
String response = chatLanguageModel.generate("Hi there");
|
||||||
return StringUtils.isNotEmpty(response) ? true : false;
|
return StringUtils.isNotEmpty(response) ? true : false;
|
||||||
} catch (Exception e) {
|
} catch (Exception e) {
|
||||||
|
|||||||
@@ -1,37 +1,53 @@
|
|||||||
package com.tencent.supersonic.chat.server.util;
|
package com.tencent.supersonic.chat.server.util;
|
||||||
|
|
||||||
import com.tencent.supersonic.chat.server.agent.Agent;
|
import com.tencent.supersonic.chat.server.agent.Agent;
|
||||||
import com.tencent.supersonic.chat.server.pojo.ChatParseContext;
|
import com.tencent.supersonic.chat.server.pojo.ParseContext;
|
||||||
import com.tencent.supersonic.common.pojo.enums.Text2SQLType;
|
import com.tencent.supersonic.common.pojo.enums.Text2SQLType;
|
||||||
import com.tencent.supersonic.common.util.BeanMapper;
|
import com.tencent.supersonic.common.util.BeanMapper;
|
||||||
import com.tencent.supersonic.headless.api.pojo.request.QueryReq;
|
import com.tencent.supersonic.headless.api.pojo.request.QueryNLReq;
|
||||||
|
import com.tencent.supersonic.chat.server.pojo.ChatContext;
|
||||||
import org.apache.commons.collections.MapUtils;
|
import org.apache.commons.collections.MapUtils;
|
||||||
|
|
||||||
import java.util.Objects;
|
import java.util.Objects;
|
||||||
|
|
||||||
public class QueryReqConverter {
|
public class QueryReqConverter {
|
||||||
|
|
||||||
public static QueryReq buildText2SqlQueryReq(ChatParseContext chatParseContext) {
|
public static QueryNLReq buildText2SqlQueryReq(ParseContext parseContext) {
|
||||||
QueryReq queryReq = new QueryReq();
|
return buildText2SqlQueryReq(parseContext, null);
|
||||||
BeanMapper.mapper(chatParseContext, queryReq);
|
}
|
||||||
Agent agent = chatParseContext.getAgent();
|
|
||||||
|
public static QueryNLReq buildText2SqlQueryReq(ParseContext parseContext, ChatContext chatCtx) {
|
||||||
|
QueryNLReq queryNLReq = new QueryNLReq();
|
||||||
|
BeanMapper.mapper(parseContext, queryNLReq);
|
||||||
|
Agent agent = parseContext.getAgent();
|
||||||
if (agent == null) {
|
if (agent == null) {
|
||||||
return queryReq;
|
return queryNLReq;
|
||||||
}
|
}
|
||||||
if (agent.containsLLMParserTool() && agent.containsRuleTool()) {
|
|
||||||
queryReq.setText2SQLType(Text2SQLType.RULE_AND_LLM);
|
boolean hasLLMTool = agent.containsLLMParserTool();
|
||||||
} else if (agent.containsLLMParserTool()) {
|
boolean hasRuleTool = agent.containsRuleTool();
|
||||||
queryReq.setText2SQLType(Text2SQLType.ONLY_LLM);
|
boolean hasLLMConfig = Objects.nonNull(agent.getModelConfig());
|
||||||
} else if (agent.containsRuleTool()) {
|
|
||||||
queryReq.setText2SQLType(Text2SQLType.ONLY_RULE);
|
if (hasLLMTool && hasLLMConfig) {
|
||||||
|
queryNLReq.setText2SQLType(Text2SQLType.ONLY_LLM);
|
||||||
|
} else if (hasLLMTool && hasRuleTool) {
|
||||||
|
queryNLReq.setText2SQLType(Text2SQLType.RULE_AND_LLM);
|
||||||
|
} else if (hasLLMTool) {
|
||||||
|
queryNLReq.setText2SQLType(Text2SQLType.ONLY_LLM);
|
||||||
|
} else if (hasRuleTool) {
|
||||||
|
queryNLReq.setText2SQLType(Text2SQLType.ONLY_RULE);
|
||||||
}
|
}
|
||||||
queryReq.setDataSetIds(agent.getDataSetIds());
|
queryNLReq.setDataSetIds(agent.getDataSetIds());
|
||||||
if (Objects.nonNull(queryReq.getMapInfo())
|
if (Objects.nonNull(queryNLReq.getMapInfo())
|
||||||
&& MapUtils.isNotEmpty(queryReq.getMapInfo().getDataSetElementMatches())) {
|
&& MapUtils.isNotEmpty(queryNLReq.getMapInfo().getDataSetElementMatches())) {
|
||||||
queryReq.setMapInfo(queryReq.getMapInfo());
|
queryNLReq.setMapInfo(queryNLReq.getMapInfo());
|
||||||
}
|
}
|
||||||
queryReq.setLlmConfig(agent.getLlmConfig());
|
queryNLReq.setModelConfig(agent.getModelConfig());
|
||||||
return queryReq;
|
queryNLReq.setPromptConfig(agent.getPromptConfig());
|
||||||
|
if (chatCtx != null) {
|
||||||
|
queryNLReq.setContextParseInfo(chatCtx.getParseInfo());
|
||||||
|
}
|
||||||
|
return queryNLReq;
|
||||||
}
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -3,10 +3,10 @@
|
|||||||
"http://mybatis.org/dtd/mybatis-3-mapper.dtd">
|
"http://mybatis.org/dtd/mybatis-3-mapper.dtd">
|
||||||
|
|
||||||
|
|
||||||
<mapper namespace="com.tencent.supersonic.headless.server.persistence.mapper.ChatContextMapper">
|
<mapper namespace="com.tencent.supersonic.chat.server.persistence.mapper.ChatContextMapper">
|
||||||
|
|
||||||
<resultMap id="ChatContextDO"
|
<resultMap id="ChatContextDO"
|
||||||
type="com.tencent.supersonic.headless.server.persistence.dataobject.ChatContextDO">
|
type="com.tencent.supersonic.chat.server.persistence.dataobject.ChatContextDO">
|
||||||
<id column="chat_id" property="chatId"/>
|
<id column="chat_id" property="chatId"/>
|
||||||
<result column="modified_at" property="modifiedAt"/>
|
<result column="modified_at" property="modifiedAt"/>
|
||||||
<result column="user" property="user"/>
|
<result column="user" property="user"/>
|
||||||
@@ -20,7 +20,7 @@
|
|||||||
from s2_chat_context where chat_id=#{chatId} limit 1
|
from s2_chat_context where chat_id=#{chatId} limit 1
|
||||||
</select>
|
</select>
|
||||||
|
|
||||||
<insert id="addContext" parameterType="com.tencent.supersonic.headless.server.persistence.dataobject.ChatContextDO" >
|
<insert id="addContext" parameterType="com.tencent.supersonic.chat.server.persistence.dataobject.ChatContextDO" >
|
||||||
insert into s2_chat_context (chat_id,user,query_text,semantic_parse) values (#{chatId}, #{user},#{queryText}, #{semanticParse})
|
insert into s2_chat_context (chat_id,user,query_text,semantic_parse) values (#{chatId}, #{user},#{queryText}, #{semanticParse})
|
||||||
</insert>
|
</insert>
|
||||||
<update id="updateContext">
|
<update id="updateContext">
|
||||||
@@ -3,9 +3,9 @@
|
|||||||
"http://mybatis.org/dtd/mybatis-3-mapper.dtd">
|
"http://mybatis.org/dtd/mybatis-3-mapper.dtd">
|
||||||
|
|
||||||
|
|
||||||
<mapper namespace="com.tencent.supersonic.headless.server.persistence.mapper.StatisticsMapper">
|
<mapper namespace="com.tencent.supersonic.chat.server.persistence.mapper.StatisticsMapper">
|
||||||
|
|
||||||
<resultMap id="Statistics" type="com.tencent.supersonic.headless.server.persistence.dataobject.StatisticsDO">
|
<resultMap id="Statistics" type="com.tencent.supersonic.chat.server.persistence.dataobject.StatisticsDO">
|
||||||
<id column="question_id" property="questionId"/>
|
<id column="question_id" property="questionId"/>
|
||||||
<result column="chat_id" property="chatId"/>
|
<result column="chat_id" property="chatId"/>
|
||||||
<result column="user_name" property="userName"/>
|
<result column="user_name" property="userName"/>
|
||||||
@@ -16,7 +16,7 @@
|
|||||||
<result column="create_time" property="createTime"/>
|
<result column="create_time" property="createTime"/>
|
||||||
</resultMap>
|
</resultMap>
|
||||||
|
|
||||||
<insert id="batchSaveStatistics" parameterType="com.tencent.supersonic.headless.server.persistence.dataobject.StatisticsDO">
|
<insert id="batchSaveStatistics" parameterType="com.tencent.supersonic.chat.server.persistence.dataobject.StatisticsDO">
|
||||||
insert into s2_chat_statistics
|
insert into s2_chat_statistics
|
||||||
(question_id,chat_id, user_name, query_text, interface_name,cost,type ,create_time)
|
(question_id,chat_id, user_name, query_text, interface_name,cost,type ,create_time)
|
||||||
values
|
values
|
||||||
@@ -46,11 +46,6 @@
|
|||||||
</exclusion>
|
</exclusion>
|
||||||
</exclusions>
|
</exclusions>
|
||||||
</dependency>
|
</dependency>
|
||||||
<dependency>
|
|
||||||
<groupId>commons-lang</groupId>
|
|
||||||
<artifactId>commons-lang</artifactId>
|
|
||||||
<version>${commons.lang.version}</version>
|
|
||||||
</dependency>
|
|
||||||
|
|
||||||
<dependency>
|
<dependency>
|
||||||
<groupId>com.alibaba</groupId>
|
<groupId>com.alibaba</groupId>
|
||||||
@@ -67,6 +62,11 @@
|
|||||||
<groupId>org.apache.commons</groupId>
|
<groupId>org.apache.commons</groupId>
|
||||||
<artifactId>commons-lang3</artifactId>
|
<artifactId>commons-lang3</artifactId>
|
||||||
</dependency>
|
</dependency>
|
||||||
|
<dependency>
|
||||||
|
<groupId>org.apache.commons</groupId>
|
||||||
|
<artifactId>commons-compress</artifactId>
|
||||||
|
<version>${commons.compress.version}</version>
|
||||||
|
</dependency>
|
||||||
|
|
||||||
<dependency>
|
<dependency>
|
||||||
<groupId>org.junit.jupiter</groupId>
|
<groupId>org.junit.jupiter</groupId>
|
||||||
|
|||||||
@@ -0,0 +1,179 @@
|
|||||||
|
package com.tencent.supersonic.common.config;
|
||||||
|
|
||||||
|
import com.google.common.collect.ImmutableMap;
|
||||||
|
import com.google.common.collect.Lists;
|
||||||
|
import com.tencent.supersonic.common.pojo.ChatModelConfig;
|
||||||
|
import com.tencent.supersonic.common.pojo.Parameter;
|
||||||
|
import dev.langchain4j.provider.AzureModelFactory;
|
||||||
|
import dev.langchain4j.provider.DashscopeModelFactory;
|
||||||
|
import dev.langchain4j.provider.LocalAiModelFactory;
|
||||||
|
import dev.langchain4j.provider.OllamaModelFactory;
|
||||||
|
import dev.langchain4j.provider.OpenAiModelFactory;
|
||||||
|
import dev.langchain4j.provider.QianfanModelFactory;
|
||||||
|
import dev.langchain4j.provider.ZhipuModelFactory;
|
||||||
|
import lombok.extern.slf4j.Slf4j;
|
||||||
|
import org.springframework.stereotype.Service;
|
||||||
|
|
||||||
|
import java.util.List;
|
||||||
|
|
||||||
|
@Service("ChatModelParameterConfig")
|
||||||
|
@Slf4j
|
||||||
|
public class ChatModelParameterConfig extends ParameterConfig {
|
||||||
|
|
||||||
|
public static final Parameter CHAT_MODEL_PROVIDER =
|
||||||
|
new Parameter("s2.chat.model.provider", OpenAiModelFactory.PROVIDER,
|
||||||
|
"接口协议", "", "list",
|
||||||
|
"对话模型配置", getCandidateValues());
|
||||||
|
|
||||||
|
public static final Parameter CHAT_MODEL_BASE_URL =
|
||||||
|
new Parameter("s2.chat.model.base.url", OpenAiModelFactory.DEFAULT_BASE_URL,
|
||||||
|
"BaseUrl", "", "string",
|
||||||
|
"对话模型配置", null, getBaseUrlDependency());
|
||||||
|
public static final Parameter CHAT_MODEL_ENDPOINT =
|
||||||
|
new Parameter("s2.chat.model.endpoint", "llama_2_70b",
|
||||||
|
"Endpoint", "", "string",
|
||||||
|
"对话模型配置", null, getEndpointDependency());
|
||||||
|
public static final Parameter CHAT_MODEL_API_KEY =
|
||||||
|
new Parameter("s2.chat.model.api.key", DEMO,
|
||||||
|
"ApiKey", "", "password",
|
||||||
|
"对话模型配置", null, getApiKeyDependency()
|
||||||
|
);
|
||||||
|
public static final Parameter CHAT_MODEL_SECRET_KEY =
|
||||||
|
new Parameter("s2.chat.model.secretKey", "demo",
|
||||||
|
"SecretKey", "", "password",
|
||||||
|
"对话模型配置", null, getSecretKeyDependency());
|
||||||
|
|
||||||
|
public static final Parameter CHAT_MODEL_NAME =
|
||||||
|
new Parameter("s2.chat.model.name", "gpt-3.5-turbo",
|
||||||
|
"ModelName", "", "string",
|
||||||
|
"对话模型配置", null, getModelNameDependency());
|
||||||
|
|
||||||
|
public static final Parameter CHAT_MODEL_ENABLE_SEARCH =
|
||||||
|
new Parameter("s2.chat.model.enableSearch", "false",
|
||||||
|
"是否启用搜索增强功能,设为false表示不启用", "", "bool",
|
||||||
|
"对话模型配置", null, getEnableSearchDependency());
|
||||||
|
|
||||||
|
public static final Parameter CHAT_MODEL_TEMPERATURE =
|
||||||
|
new Parameter("s2.chat.model.temperature", "0.0",
|
||||||
|
"Temperature", "",
|
||||||
|
"slider", "对话模型配置");
|
||||||
|
|
||||||
|
public static final Parameter CHAT_MODEL_TIMEOUT =
|
||||||
|
new Parameter("s2.chat.model.timeout", "60",
|
||||||
|
"超时时间(秒)", "",
|
||||||
|
"number", "对话模型配置");
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public List<Parameter> getSysParameters() {
|
||||||
|
return Lists.newArrayList(
|
||||||
|
CHAT_MODEL_PROVIDER, CHAT_MODEL_BASE_URL, CHAT_MODEL_ENDPOINT,
|
||||||
|
CHAT_MODEL_API_KEY, CHAT_MODEL_SECRET_KEY, CHAT_MODEL_NAME,
|
||||||
|
CHAT_MODEL_ENABLE_SEARCH, CHAT_MODEL_TEMPERATURE, CHAT_MODEL_TIMEOUT
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
public ChatModelConfig convert() {
|
||||||
|
String chatModelProvider = getParameterValue(CHAT_MODEL_PROVIDER);
|
||||||
|
String chatModelBaseUrl = getParameterValue(CHAT_MODEL_BASE_URL);
|
||||||
|
String chatModelApiKey = getParameterValue(CHAT_MODEL_API_KEY);
|
||||||
|
String chatModelName = getParameterValue(CHAT_MODEL_NAME);
|
||||||
|
String chatModelTemperature = getParameterValue(CHAT_MODEL_TEMPERATURE);
|
||||||
|
String chatModelTimeout = getParameterValue(CHAT_MODEL_TIMEOUT);
|
||||||
|
String endpoint = getParameterValue(CHAT_MODEL_ENDPOINT);
|
||||||
|
String secretKey = getParameterValue(CHAT_MODEL_SECRET_KEY);
|
||||||
|
String enableSearch = getParameterValue(CHAT_MODEL_ENABLE_SEARCH);
|
||||||
|
|
||||||
|
return ChatModelConfig.builder()
|
||||||
|
.provider(chatModelProvider)
|
||||||
|
.baseUrl(chatModelBaseUrl)
|
||||||
|
.apiKey(chatModelApiKey)
|
||||||
|
.modelName(chatModelName)
|
||||||
|
.enableSearch(Boolean.valueOf(enableSearch))
|
||||||
|
.temperature(Double.valueOf(chatModelTemperature))
|
||||||
|
.timeOut(Long.valueOf(chatModelTimeout))
|
||||||
|
.endpoint(endpoint)
|
||||||
|
.secretKey(secretKey)
|
||||||
|
.build();
|
||||||
|
}
|
||||||
|
|
||||||
|
private static List<String> getCandidateValues() {
|
||||||
|
return Lists.newArrayList(
|
||||||
|
OpenAiModelFactory.PROVIDER,
|
||||||
|
AzureModelFactory.PROVIDER,
|
||||||
|
OllamaModelFactory.PROVIDER,
|
||||||
|
QianfanModelFactory.PROVIDER,
|
||||||
|
ZhipuModelFactory.PROVIDER,
|
||||||
|
LocalAiModelFactory.PROVIDER,
|
||||||
|
DashscopeModelFactory.PROVIDER);
|
||||||
|
}
|
||||||
|
|
||||||
|
private static List<Parameter.Dependency> getBaseUrlDependency() {
|
||||||
|
return getDependency(CHAT_MODEL_PROVIDER.getName(),
|
||||||
|
getCandidateValues(),
|
||||||
|
ImmutableMap.of(
|
||||||
|
OpenAiModelFactory.PROVIDER, OpenAiModelFactory.DEFAULT_BASE_URL,
|
||||||
|
AzureModelFactory.PROVIDER, AzureModelFactory.DEFAULT_BASE_URL,
|
||||||
|
OllamaModelFactory.PROVIDER, OllamaModelFactory.DEFAULT_BASE_URL,
|
||||||
|
QianfanModelFactory.PROVIDER, QianfanModelFactory.DEFAULT_BASE_URL,
|
||||||
|
ZhipuModelFactory.PROVIDER, ZhipuModelFactory.DEFAULT_BASE_URL,
|
||||||
|
LocalAiModelFactory.PROVIDER, LocalAiModelFactory.DEFAULT_BASE_URL,
|
||||||
|
DashscopeModelFactory.PROVIDER, DashscopeModelFactory.DEFAULT_BASE_URL)
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
private static List<Parameter.Dependency> getApiKeyDependency() {
|
||||||
|
return getDependency(CHAT_MODEL_PROVIDER.getName(),
|
||||||
|
Lists.newArrayList(
|
||||||
|
OpenAiModelFactory.PROVIDER,
|
||||||
|
QianfanModelFactory.PROVIDER,
|
||||||
|
ZhipuModelFactory.PROVIDER,
|
||||||
|
LocalAiModelFactory.PROVIDER,
|
||||||
|
AzureModelFactory.PROVIDER,
|
||||||
|
DashscopeModelFactory.PROVIDER
|
||||||
|
),
|
||||||
|
ImmutableMap.of(
|
||||||
|
OpenAiModelFactory.PROVIDER, DEMO,
|
||||||
|
QianfanModelFactory.PROVIDER, DEMO,
|
||||||
|
ZhipuModelFactory.PROVIDER, DEMO,
|
||||||
|
LocalAiModelFactory.PROVIDER, DEMO,
|
||||||
|
AzureModelFactory.PROVIDER, DEMO,
|
||||||
|
DashscopeModelFactory.PROVIDER, DEMO
|
||||||
|
));
|
||||||
|
}
|
||||||
|
|
||||||
|
private static List<Parameter.Dependency> getModelNameDependency() {
|
||||||
|
return getDependency(CHAT_MODEL_PROVIDER.getName(),
|
||||||
|
getCandidateValues(),
|
||||||
|
ImmutableMap.of(
|
||||||
|
OpenAiModelFactory.PROVIDER, OpenAiModelFactory.DEFAULT_MODEL_NAME,
|
||||||
|
OllamaModelFactory.PROVIDER, OllamaModelFactory.DEFAULT_MODEL_NAME,
|
||||||
|
QianfanModelFactory.PROVIDER, QianfanModelFactory.DEFAULT_MODEL_NAME,
|
||||||
|
ZhipuModelFactory.PROVIDER, ZhipuModelFactory.DEFAULT_MODEL_NAME,
|
||||||
|
LocalAiModelFactory.PROVIDER, LocalAiModelFactory.DEFAULT_MODEL_NAME,
|
||||||
|
AzureModelFactory.PROVIDER, AzureModelFactory.DEFAULT_MODEL_NAME,
|
||||||
|
DashscopeModelFactory.PROVIDER, DashscopeModelFactory.DEFAULT_MODEL_NAME
|
||||||
|
)
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
private static List<Parameter.Dependency> getEndpointDependency() {
|
||||||
|
return getDependency(CHAT_MODEL_PROVIDER.getName(),
|
||||||
|
Lists.newArrayList(QianfanModelFactory.PROVIDER),
|
||||||
|
ImmutableMap.of(QianfanModelFactory.PROVIDER, QianfanModelFactory.DEFAULT_ENDPOINT)
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
private static List<Parameter.Dependency> getEnableSearchDependency() {
|
||||||
|
return getDependency(CHAT_MODEL_PROVIDER.getName(),
|
||||||
|
Lists.newArrayList(DashscopeModelFactory.PROVIDER),
|
||||||
|
ImmutableMap.of(DashscopeModelFactory.PROVIDER, "false")
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
private static List<Parameter.Dependency> getSecretKeyDependency() {
|
||||||
|
return getDependency(CHAT_MODEL_PROVIDER.getName(),
|
||||||
|
Lists.newArrayList(QianfanModelFactory.PROVIDER),
|
||||||
|
ImmutableMap.of(QianfanModelFactory.PROVIDER, DEMO)
|
||||||
|
);
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -16,7 +16,10 @@ public class DataBaseConfig {
|
|||||||
@Primary
|
@Primary
|
||||||
@ConfigurationProperties("spring.datasource")
|
@ConfigurationProperties("spring.datasource")
|
||||||
public DataSource dataSource() {
|
public DataSource dataSource() {
|
||||||
return new DruidDataSource();
|
DruidDataSource druidDataSource = new DruidDataSource();
|
||||||
|
druidDataSource.setTestWhileIdle(true);
|
||||||
|
druidDataSource.setValidationQuery("select 1");
|
||||||
|
return druidDataSource;
|
||||||
}
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -0,0 +1,168 @@
|
|||||||
|
package com.tencent.supersonic.common.config;
|
||||||
|
|
||||||
|
import com.google.common.collect.ImmutableMap;
|
||||||
|
import com.google.common.collect.Lists;
|
||||||
|
import com.tencent.supersonic.common.pojo.EmbeddingModelConfig;
|
||||||
|
import com.tencent.supersonic.common.pojo.Parameter;
|
||||||
|
import dev.langchain4j.provider.AzureModelFactory;
|
||||||
|
import dev.langchain4j.provider.DashscopeModelFactory;
|
||||||
|
import dev.langchain4j.provider.EmbeddingModelConstant;
|
||||||
|
import dev.langchain4j.provider.InMemoryModelFactory;
|
||||||
|
import dev.langchain4j.provider.OllamaModelFactory;
|
||||||
|
import dev.langchain4j.provider.OpenAiModelFactory;
|
||||||
|
import dev.langchain4j.provider.QianfanModelFactory;
|
||||||
|
import dev.langchain4j.provider.ZhipuModelFactory;
|
||||||
|
import lombok.extern.slf4j.Slf4j;
|
||||||
|
import org.springframework.stereotype.Service;
|
||||||
|
|
||||||
|
import java.util.ArrayList;
|
||||||
|
import java.util.List;
|
||||||
|
|
||||||
|
@Service("EmbeddingModelParameterConfig")
|
||||||
|
@Slf4j
|
||||||
|
public class EmbeddingModelParameterConfig extends ParameterConfig {
|
||||||
|
public static final Parameter EMBEDDING_MODEL_PROVIDER =
|
||||||
|
new Parameter("s2.embedding.model.provider", InMemoryModelFactory.PROVIDER,
|
||||||
|
"接口协议", "", "list",
|
||||||
|
"向量模型配置", getCandidateValues());
|
||||||
|
public static final Parameter EMBEDDING_MODEL_BASE_URL =
|
||||||
|
new Parameter("s2.embedding.model.base.url", "",
|
||||||
|
"BaseUrl", "", "string",
|
||||||
|
"向量模型配置", null, getBaseUrlDependency()
|
||||||
|
);
|
||||||
|
|
||||||
|
public static final Parameter EMBEDDING_MODEL_API_KEY =
|
||||||
|
new Parameter("s2.embedding.model.api.key", "",
|
||||||
|
"ApiKey", "", "password",
|
||||||
|
"向量模型配置", null, getApiKeyDependency());
|
||||||
|
|
||||||
|
public static final Parameter EMBEDDING_MODEL_SECRET_KEY =
|
||||||
|
new Parameter("s2.embedding.model.secretKey", "demo",
|
||||||
|
"SecretKey", "", "password",
|
||||||
|
"向量模型配置", null, getSecretKeyDependency());
|
||||||
|
|
||||||
|
public static final Parameter EMBEDDING_MODEL_NAME =
|
||||||
|
new Parameter("s2.embedding.model.name", EmbeddingModelConstant.BGE_SMALL_ZH,
|
||||||
|
"ModelName", "", "string",
|
||||||
|
"向量模型配置", null, getModelNameDependency());
|
||||||
|
|
||||||
|
public static final Parameter EMBEDDING_MODEL_PATH =
|
||||||
|
new Parameter("s2.embedding.model.path", "",
|
||||||
|
"模型路径", "", "string",
|
||||||
|
"向量模型配置", null, getModelPathDependency());
|
||||||
|
public static final Parameter EMBEDDING_MODEL_VOCABULARY_PATH =
|
||||||
|
new Parameter("s2.embedding.model.vocabulary.path", "",
|
||||||
|
"词汇表路径", "", "string",
|
||||||
|
"向量模型配置", null, getModelPathDependency());
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public List<Parameter> getSysParameters() {
|
||||||
|
return Lists.newArrayList(
|
||||||
|
EMBEDDING_MODEL_PROVIDER, EMBEDDING_MODEL_BASE_URL, EMBEDDING_MODEL_API_KEY,
|
||||||
|
EMBEDDING_MODEL_SECRET_KEY, EMBEDDING_MODEL_NAME, EMBEDDING_MODEL_PATH,
|
||||||
|
EMBEDDING_MODEL_VOCABULARY_PATH
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
public EmbeddingModelConfig convert() {
|
||||||
|
String provider = getParameterValue(EMBEDDING_MODEL_PROVIDER);
|
||||||
|
String baseUrl = getParameterValue(EMBEDDING_MODEL_BASE_URL);
|
||||||
|
String apiKey = getParameterValue(EMBEDDING_MODEL_API_KEY);
|
||||||
|
String modelName = getParameterValue(EMBEDDING_MODEL_NAME);
|
||||||
|
String modelPath = getParameterValue(EMBEDDING_MODEL_PATH);
|
||||||
|
String vocabularyPath = getParameterValue(EMBEDDING_MODEL_VOCABULARY_PATH);
|
||||||
|
String secretKey = getParameterValue(EMBEDDING_MODEL_SECRET_KEY);
|
||||||
|
return EmbeddingModelConfig.builder()
|
||||||
|
.provider(provider)
|
||||||
|
.baseUrl(baseUrl)
|
||||||
|
.apiKey(apiKey)
|
||||||
|
.secretKey(secretKey)
|
||||||
|
.modelName(modelName)
|
||||||
|
.modelPath(modelPath)
|
||||||
|
.vocabularyPath(vocabularyPath)
|
||||||
|
.build();
|
||||||
|
}
|
||||||
|
|
||||||
|
private static ArrayList<String> getCandidateValues() {
|
||||||
|
return Lists.newArrayList(
|
||||||
|
InMemoryModelFactory.PROVIDER,
|
||||||
|
OpenAiModelFactory.PROVIDER,
|
||||||
|
OllamaModelFactory.PROVIDER,
|
||||||
|
AzureModelFactory.PROVIDER,
|
||||||
|
DashscopeModelFactory.PROVIDER,
|
||||||
|
QianfanModelFactory.PROVIDER,
|
||||||
|
ZhipuModelFactory.PROVIDER
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
private static List<Parameter.Dependency> getBaseUrlDependency() {
|
||||||
|
return getDependency(EMBEDDING_MODEL_PROVIDER.getName(),
|
||||||
|
Lists.newArrayList(OpenAiModelFactory.PROVIDER,
|
||||||
|
OllamaModelFactory.PROVIDER,
|
||||||
|
AzureModelFactory.PROVIDER,
|
||||||
|
DashscopeModelFactory.PROVIDER,
|
||||||
|
QianfanModelFactory.PROVIDER,
|
||||||
|
ZhipuModelFactory.PROVIDER),
|
||||||
|
ImmutableMap.of(
|
||||||
|
OpenAiModelFactory.PROVIDER, OpenAiModelFactory.DEFAULT_BASE_URL,
|
||||||
|
OllamaModelFactory.PROVIDER, OllamaModelFactory.DEFAULT_BASE_URL,
|
||||||
|
AzureModelFactory.PROVIDER, AzureModelFactory.DEFAULT_BASE_URL,
|
||||||
|
DashscopeModelFactory.PROVIDER, DashscopeModelFactory.DEFAULT_BASE_URL,
|
||||||
|
QianfanModelFactory.PROVIDER, QianfanModelFactory.DEFAULT_BASE_URL,
|
||||||
|
ZhipuModelFactory.PROVIDER, ZhipuModelFactory.DEFAULT_BASE_URL
|
||||||
|
)
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
private static List<Parameter.Dependency> getApiKeyDependency() {
|
||||||
|
return getDependency(EMBEDDING_MODEL_PROVIDER.getName(),
|
||||||
|
Lists.newArrayList(OpenAiModelFactory.PROVIDER,
|
||||||
|
AzureModelFactory.PROVIDER,
|
||||||
|
DashscopeModelFactory.PROVIDER,
|
||||||
|
QianfanModelFactory.PROVIDER,
|
||||||
|
ZhipuModelFactory.PROVIDER),
|
||||||
|
ImmutableMap.of(OpenAiModelFactory.PROVIDER, DEMO,
|
||||||
|
AzureModelFactory.PROVIDER, DEMO,
|
||||||
|
DashscopeModelFactory.PROVIDER, DEMO,
|
||||||
|
QianfanModelFactory.PROVIDER, DEMO,
|
||||||
|
ZhipuModelFactory.PROVIDER, DEMO)
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
private static List<Parameter.Dependency> getModelNameDependency() {
|
||||||
|
return getDependency(EMBEDDING_MODEL_PROVIDER.getName(),
|
||||||
|
Lists.newArrayList(
|
||||||
|
InMemoryModelFactory.PROVIDER,
|
||||||
|
OpenAiModelFactory.PROVIDER,
|
||||||
|
OllamaModelFactory.PROVIDER,
|
||||||
|
AzureModelFactory.PROVIDER,
|
||||||
|
DashscopeModelFactory.PROVIDER,
|
||||||
|
QianfanModelFactory.PROVIDER,
|
||||||
|
ZhipuModelFactory.PROVIDER
|
||||||
|
),
|
||||||
|
ImmutableMap.of(
|
||||||
|
InMemoryModelFactory.PROVIDER, EmbeddingModelConstant.BGE_SMALL_ZH,
|
||||||
|
OpenAiModelFactory.PROVIDER, OpenAiModelFactory.DEFAULT_EMBEDDING_MODEL_NAME,
|
||||||
|
OllamaModelFactory.PROVIDER, OllamaModelFactory.DEFAULT_EMBEDDING_MODEL_NAME,
|
||||||
|
AzureModelFactory.PROVIDER, AzureModelFactory.DEFAULT_EMBEDDING_MODEL_NAME,
|
||||||
|
DashscopeModelFactory.PROVIDER, DashscopeModelFactory.DEFAULT_EMBEDDING_MODEL_NAME,
|
||||||
|
QianfanModelFactory.PROVIDER, QianfanModelFactory.DEFAULT_EMBEDDING_MODEL_NAME,
|
||||||
|
ZhipuModelFactory.PROVIDER, ZhipuModelFactory.DEFAULT_EMBEDDING_MODEL_NAME
|
||||||
|
)
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
private static List<Parameter.Dependency> getModelPathDependency() {
|
||||||
|
return getDependency(EMBEDDING_MODEL_PROVIDER.getName(),
|
||||||
|
Lists.newArrayList(InMemoryModelFactory.PROVIDER),
|
||||||
|
ImmutableMap.of(InMemoryModelFactory.PROVIDER, "")
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
private static List<Parameter.Dependency> getSecretKeyDependency() {
|
||||||
|
return getDependency(EMBEDDING_MODEL_PROVIDER.getName(),
|
||||||
|
Lists.newArrayList(QianfanModelFactory.PROVIDER),
|
||||||
|
ImmutableMap.of(QianfanModelFactory.PROVIDER, DEMO)
|
||||||
|
);
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -0,0 +1,122 @@
|
|||||||
|
package com.tencent.supersonic.common.config;
|
||||||
|
|
||||||
|
import com.google.common.collect.ImmutableMap;
|
||||||
|
import com.google.common.collect.Lists;
|
||||||
|
import com.tencent.supersonic.common.pojo.EmbeddingStoreConfig;
|
||||||
|
import com.tencent.supersonic.common.pojo.Parameter;
|
||||||
|
import dev.langchain4j.store.embedding.EmbeddingStoreType;
|
||||||
|
import lombok.extern.slf4j.Slf4j;
|
||||||
|
import org.apache.commons.lang3.StringUtils;
|
||||||
|
import org.springframework.stereotype.Service;
|
||||||
|
|
||||||
|
import java.util.ArrayList;
|
||||||
|
import java.util.List;
|
||||||
|
|
||||||
|
@Service("EmbeddingStoreParameterConfig")
|
||||||
|
@Slf4j
|
||||||
|
public class EmbeddingStoreParameterConfig extends ParameterConfig {
|
||||||
|
public static final Parameter EMBEDDING_STORE_PROVIDER =
|
||||||
|
new Parameter("s2.embedding.store.provider", EmbeddingStoreType.IN_MEMORY.name(),
|
||||||
|
"向量库类型", "目前支持三种类型:IN_MEMORY、MILVUS、CHROMA", "list",
|
||||||
|
"向量库配置", getCandidateValues());
|
||||||
|
|
||||||
|
public static final Parameter EMBEDDING_STORE_BASE_URL =
|
||||||
|
new Parameter("s2.embedding.store.base.url", "",
|
||||||
|
"BaseUrl", "", "string",
|
||||||
|
"向量库配置", null, getBaseUrlDependency());
|
||||||
|
|
||||||
|
public static final Parameter EMBEDDING_STORE_API_KEY =
|
||||||
|
new Parameter("s2.embedding.store.api.key", "",
|
||||||
|
"ApiKey", "", "password",
|
||||||
|
"向量库配置", null, getApiKeyDependency());
|
||||||
|
|
||||||
|
public static final Parameter EMBEDDING_STORE_PERSIST_PATH =
|
||||||
|
new Parameter("s2.embedding.store.persist.path", "",
|
||||||
|
"持久化路径", "默认不持久化,如需持久化请填写持久化路径。"
|
||||||
|
+ "注意:如果变更了向量模型需删除该路径下已保存的文件或修改持久化路径", "string",
|
||||||
|
"向量库配置", null, getPathDependency());
|
||||||
|
|
||||||
|
public static final Parameter EMBEDDING_STORE_TIMEOUT =
|
||||||
|
new Parameter("s2.embedding.store.timeout", "60",
|
||||||
|
"超时时间(秒)", "",
|
||||||
|
"number", "向量库配置");
|
||||||
|
|
||||||
|
public static final Parameter EMBEDDING_STORE_DIMENSION =
|
||||||
|
new Parameter("s2.embedding.store.dimension", "",
|
||||||
|
"纬度", "", "number",
|
||||||
|
"向量库配置", null, getDimensionDependency());
|
||||||
|
public static final Parameter EMBEDDING_STORE_DATABASE_NAME =
|
||||||
|
new Parameter("s2.embedding.store.databaseName", "",
|
||||||
|
"DatabaseName", "", "string",
|
||||||
|
"向量库配置", null, getDatabaseNameDependency());
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public List<Parameter> getSysParameters() {
|
||||||
|
return Lists.newArrayList(
|
||||||
|
EMBEDDING_STORE_PROVIDER, EMBEDDING_STORE_BASE_URL, EMBEDDING_STORE_API_KEY,
|
||||||
|
EMBEDDING_STORE_DATABASE_NAME, EMBEDDING_STORE_PERSIST_PATH,
|
||||||
|
EMBEDDING_STORE_TIMEOUT, EMBEDDING_STORE_DIMENSION
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
public EmbeddingStoreConfig convert() {
|
||||||
|
String provider = getParameterValue(EMBEDDING_STORE_PROVIDER);
|
||||||
|
String baseUrl = getParameterValue(EMBEDDING_STORE_BASE_URL);
|
||||||
|
String apiKey = getParameterValue(EMBEDDING_STORE_API_KEY);
|
||||||
|
String persistPath = getParameterValue(EMBEDDING_STORE_PERSIST_PATH);
|
||||||
|
String timeOut = getParameterValue(EMBEDDING_STORE_TIMEOUT);
|
||||||
|
String databaseName = getParameterValue(EMBEDDING_STORE_DATABASE_NAME);
|
||||||
|
Integer dimension = null;
|
||||||
|
if (StringUtils.isNumeric(getParameterValue(EMBEDDING_STORE_DIMENSION))) {
|
||||||
|
dimension = Integer.valueOf(getParameterValue(EMBEDDING_STORE_DIMENSION));
|
||||||
|
}
|
||||||
|
return EmbeddingStoreConfig.builder().provider(provider).baseUrl(baseUrl)
|
||||||
|
.apiKey(apiKey).persistPath(persistPath).databaseName(databaseName)
|
||||||
|
.timeOut(Long.valueOf(timeOut)).dimension(dimension).build();
|
||||||
|
}
|
||||||
|
|
||||||
|
private static ArrayList<String> getCandidateValues() {
|
||||||
|
return Lists.newArrayList(
|
||||||
|
EmbeddingStoreType.IN_MEMORY.name(),
|
||||||
|
EmbeddingStoreType.MILVUS.name(),
|
||||||
|
EmbeddingStoreType.CHROMA.name());
|
||||||
|
}
|
||||||
|
|
||||||
|
private static List<Parameter.Dependency> getBaseUrlDependency() {
|
||||||
|
return getDependency(EMBEDDING_STORE_PROVIDER.getName(),
|
||||||
|
Lists.newArrayList(
|
||||||
|
EmbeddingStoreType.MILVUS.name(), EmbeddingStoreType.CHROMA.name()),
|
||||||
|
ImmutableMap.of(
|
||||||
|
EmbeddingStoreType.MILVUS.name(), "http://localhost:19530",
|
||||||
|
EmbeddingStoreType.CHROMA.name(), "http://localhost:8000"
|
||||||
|
)
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
private static List<Parameter.Dependency> getApiKeyDependency() {
|
||||||
|
return getDependency(EMBEDDING_STORE_PROVIDER.getName(),
|
||||||
|
Lists.newArrayList(EmbeddingStoreType.MILVUS.name()),
|
||||||
|
ImmutableMap.of(EmbeddingStoreType.MILVUS.name(), DEMO)
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
private static List<Parameter.Dependency> getPathDependency() {
|
||||||
|
return getDependency(EMBEDDING_STORE_PROVIDER.getName(),
|
||||||
|
Lists.newArrayList(EmbeddingStoreType.IN_MEMORY.name()),
|
||||||
|
ImmutableMap.of(EmbeddingStoreType.IN_MEMORY.name(), ""));
|
||||||
|
}
|
||||||
|
|
||||||
|
private static List<Parameter.Dependency> getDimensionDependency() {
|
||||||
|
return getDependency(EMBEDDING_STORE_PROVIDER.getName(),
|
||||||
|
Lists.newArrayList(EmbeddingStoreType.MILVUS.name()),
|
||||||
|
ImmutableMap.of(EmbeddingStoreType.MILVUS.name(), "384")
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
private static List<Parameter.Dependency> getDatabaseNameDependency() {
|
||||||
|
return getDependency(EMBEDDING_STORE_PROVIDER.getName(),
|
||||||
|
Lists.newArrayList(EmbeddingStoreType.MILVUS.name()),
|
||||||
|
ImmutableMap.of(EmbeddingStoreType.MILVUS.name(), "")
|
||||||
|
);
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -1,45 +0,0 @@
|
|||||||
package com.tencent.supersonic.common.config;
|
|
||||||
|
|
||||||
import com.tencent.supersonic.common.util.AESEncryptionUtil;
|
|
||||||
import lombok.AllArgsConstructor;
|
|
||||||
import lombok.Data;
|
|
||||||
import lombok.NoArgsConstructor;
|
|
||||||
|
|
||||||
@Data
|
|
||||||
@AllArgsConstructor
|
|
||||||
@NoArgsConstructor
|
|
||||||
public class LLMConfig {
|
|
||||||
|
|
||||||
private String provider;
|
|
||||||
|
|
||||||
private String baseUrl;
|
|
||||||
|
|
||||||
private String apiKey;
|
|
||||||
|
|
||||||
private String modelName;
|
|
||||||
|
|
||||||
private Double temperature = 0.0d;
|
|
||||||
|
|
||||||
private Long timeOut = 60L;
|
|
||||||
|
|
||||||
public LLMConfig(String provider, String baseUrl, String apiKey, String modelName) {
|
|
||||||
this.provider = provider;
|
|
||||||
this.baseUrl = baseUrl;
|
|
||||||
this.apiKey = apiKey;
|
|
||||||
this.modelName = modelName;
|
|
||||||
}
|
|
||||||
|
|
||||||
public LLMConfig(String provider, String baseUrl, String apiKey, String modelName,
|
|
||||||
double temperature) {
|
|
||||||
this.provider = provider;
|
|
||||||
this.baseUrl = baseUrl;
|
|
||||||
this.apiKey = apiKey;
|
|
||||||
this.modelName = modelName;
|
|
||||||
this.temperature = temperature;
|
|
||||||
}
|
|
||||||
|
|
||||||
public String keyDecrypt() {
|
|
||||||
return AESEncryptionUtil.aesDecryptECB(apiKey);
|
|
||||||
}
|
|
||||||
|
|
||||||
}
|
|
||||||
@@ -7,11 +7,14 @@ import org.springframework.beans.factory.annotation.Autowired;
|
|||||||
import org.springframework.core.env.Environment;
|
import org.springframework.core.env.Environment;
|
||||||
import org.springframework.stereotype.Service;
|
import org.springframework.stereotype.Service;
|
||||||
|
|
||||||
|
import java.util.ArrayList;
|
||||||
|
import java.util.Collections;
|
||||||
import java.util.List;
|
import java.util.List;
|
||||||
|
import java.util.Map;
|
||||||
|
|
||||||
@Service
|
@Service
|
||||||
public abstract class ParameterConfig {
|
public abstract class ParameterConfig {
|
||||||
|
public static final String DEMO = "demo";
|
||||||
@Autowired
|
@Autowired
|
||||||
private SystemConfigService sysConfigService;
|
private SystemConfigService sysConfigService;
|
||||||
|
|
||||||
@@ -21,13 +24,16 @@ public abstract class ParameterConfig {
|
|||||||
/**
|
/**
|
||||||
* @return system parameters to be set with user interface
|
* @return system parameters to be set with user interface
|
||||||
*/
|
*/
|
||||||
protected abstract List<Parameter> getSysParameters();
|
protected List<Parameter> getSysParameters() {
|
||||||
|
return Collections.EMPTY_LIST;
|
||||||
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Parameter value will be derived in the following order:
|
* Parameter value will be derived in the following order:
|
||||||
* 1. `system config` set with user interface
|
* 1. `system config` set with user interface
|
||||||
* 2. `system property` set with application.yaml file
|
* 2. `system property` set with application.yaml file
|
||||||
* 3. `default value` set with parameter declaration
|
* 3. `default value` set with parameter declaration
|
||||||
|
*
|
||||||
* @param parameter instance
|
* @param parameter instance
|
||||||
* @return parameter value
|
* @return parameter value
|
||||||
*/
|
*/
|
||||||
@@ -44,4 +50,22 @@ public abstract class ParameterConfig {
|
|||||||
|
|
||||||
return value;
|
return value;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
protected static List<Parameter.Dependency> getDependency(
|
||||||
|
String dependencyParameterName,
|
||||||
|
List<String> includesValue,
|
||||||
|
Map<String, String> setDefaultValue) {
|
||||||
|
|
||||||
|
Parameter.Dependency.Show show = new Parameter.Dependency.Show();
|
||||||
|
show.setIncludesValue(includesValue);
|
||||||
|
|
||||||
|
Parameter.Dependency dependency = new Parameter.Dependency();
|
||||||
|
dependency.setName(dependencyParameterName);
|
||||||
|
dependency.setShow(show);
|
||||||
|
dependency.setSetDefaultValue(setDefaultValue);
|
||||||
|
List<Parameter.Dependency> dependencies = new ArrayList<>();
|
||||||
|
dependencies.add(dependency);
|
||||||
|
return dependencies;
|
||||||
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -0,0 +1,14 @@
|
|||||||
|
package com.tencent.supersonic.common.config;
|
||||||
|
|
||||||
|
import lombok.AllArgsConstructor;
|
||||||
|
import lombok.Data;
|
||||||
|
import lombok.NoArgsConstructor;
|
||||||
|
|
||||||
|
@Data
|
||||||
|
@AllArgsConstructor
|
||||||
|
@NoArgsConstructor
|
||||||
|
public class PromptConfig {
|
||||||
|
|
||||||
|
private String promptTemplate;
|
||||||
|
|
||||||
|
}
|
||||||
@@ -29,8 +29,7 @@ public enum AggregateEnum {
|
|||||||
}
|
}
|
||||||
|
|
||||||
public static Map<String, String> getAggregateEnum() {
|
public static Map<String, String> getAggregateEnum() {
|
||||||
Map<String, String> aggregateMap = Arrays.stream(AggregateEnum.values())
|
return Arrays.stream(AggregateEnum.values())
|
||||||
.collect(Collectors.toMap(AggregateEnum::getAggregateCh, AggregateEnum::getAggregateEN));
|
.collect(Collectors.toMap(AggregateEnum::getAggregateCh, AggregateEnum::getAggregateEN));
|
||||||
return aggregateMap;
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,24 +1,36 @@
|
|||||||
package com.tencent.supersonic.common.jsqlparser;
|
package com.tencent.supersonic.common.jsqlparser;
|
||||||
|
|
||||||
import java.util.Map;
|
|
||||||
import lombok.extern.slf4j.Slf4j;
|
import lombok.extern.slf4j.Slf4j;
|
||||||
import net.sf.jsqlparser.expression.ExpressionVisitorAdapter;
|
import net.sf.jsqlparser.expression.ExpressionVisitorAdapter;
|
||||||
|
import net.sf.jsqlparser.expression.Function;
|
||||||
import net.sf.jsqlparser.schema.Column;
|
import net.sf.jsqlparser.schema.Column;
|
||||||
|
|
||||||
|
import java.util.Map;
|
||||||
|
|
||||||
@Slf4j
|
@Slf4j
|
||||||
public class FieldReplaceVisitor extends ExpressionVisitorAdapter {
|
public class FieldReplaceVisitor extends ExpressionVisitorAdapter {
|
||||||
|
|
||||||
ParseVisitorHelper parseVisitorHelper = new ParseVisitorHelper();
|
ParseVisitorHelper parseVisitorHelper = new ParseVisitorHelper();
|
||||||
private Map<String, String> fieldNameMap;
|
private Map<String, String> fieldNameMap;
|
||||||
private boolean exactReplace;
|
private ThreadLocal<Boolean> exactReplace = ThreadLocal.withInitial(() -> false);
|
||||||
|
|
||||||
public FieldReplaceVisitor(Map<String, String> fieldNameMap, boolean exactReplace) {
|
public FieldReplaceVisitor(Map<String, String> fieldNameMap, boolean exactReplace) {
|
||||||
this.fieldNameMap = fieldNameMap;
|
this.fieldNameMap = fieldNameMap;
|
||||||
this.exactReplace = exactReplace;
|
this.exactReplace.set(exactReplace);
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public void visit(Column column) {
|
public void visit(Column column) {
|
||||||
parseVisitorHelper.replaceColumn(column, fieldNameMap, exactReplace);
|
parseVisitorHelper.replaceColumn(column, fieldNameMap, exactReplace.get());
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public void visit(Function function) {
|
||||||
|
boolean originalExactReplace = exactReplace.get();
|
||||||
|
exactReplace.set(true);
|
||||||
|
try {
|
||||||
|
super.visit(function);
|
||||||
|
} finally {
|
||||||
|
exactReplace.set(originalExactReplace);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -1,10 +1,5 @@
|
|||||||
package com.tencent.supersonic.common.jsqlparser;
|
package com.tencent.supersonic.common.jsqlparser;
|
||||||
|
|
||||||
import com.tencent.supersonic.common.util.JsonUtil;
|
|
||||||
import java.util.ArrayList;
|
|
||||||
import java.util.List;
|
|
||||||
import java.util.Map;
|
|
||||||
import java.util.Objects;
|
|
||||||
import lombok.extern.slf4j.Slf4j;
|
import lombok.extern.slf4j.Slf4j;
|
||||||
import net.sf.jsqlparser.expression.DoubleValue;
|
import net.sf.jsqlparser.expression.DoubleValue;
|
||||||
import net.sf.jsqlparser.expression.Expression;
|
import net.sf.jsqlparser.expression.Expression;
|
||||||
@@ -24,14 +19,19 @@ import net.sf.jsqlparser.schema.Column;
|
|||||||
import org.apache.commons.lang3.StringUtils;
|
import org.apache.commons.lang3.StringUtils;
|
||||||
import org.springframework.util.CollectionUtils;
|
import org.springframework.util.CollectionUtils;
|
||||||
|
|
||||||
|
import java.util.ArrayList;
|
||||||
|
import java.util.List;
|
||||||
|
import java.util.Map;
|
||||||
|
import java.util.Objects;
|
||||||
|
|
||||||
@Slf4j
|
@Slf4j
|
||||||
public class FieldlValueReplaceVisitor extends ExpressionVisitorAdapter {
|
public class FieldValueReplaceVisitor extends ExpressionVisitorAdapter {
|
||||||
|
|
||||||
ParseVisitorHelper parseVisitorHelper = new ParseVisitorHelper();
|
ParseVisitorHelper parseVisitorHelper = new ParseVisitorHelper();
|
||||||
private boolean exactReplace;
|
private boolean exactReplace;
|
||||||
private Map<String, Map<String, String>> filedNameToValueMap;
|
private Map<String, Map<String, String>> filedNameToValueMap;
|
||||||
|
|
||||||
public FieldlValueReplaceVisitor(boolean exactReplace, Map<String, Map<String, String>> filedNameToValueMap) {
|
public FieldValueReplaceVisitor(boolean exactReplace, Map<String, Map<String, String>> filedNameToValueMap) {
|
||||||
this.exactReplace = exactReplace;
|
this.exactReplace = exactReplace;
|
||||||
this.filedNameToValueMap = filedNameToValueMap;
|
this.filedNameToValueMap = filedNameToValueMap;
|
||||||
}
|
}
|
||||||
@@ -71,17 +71,13 @@ public class FieldlValueReplaceVisitor extends ExpressionVisitorAdapter {
|
|||||||
values.add(((StringValue) o).getValue());
|
values.add(((StringValue) o).getValue());
|
||||||
}
|
}
|
||||||
});
|
});
|
||||||
if (valueMap == null) {
|
if (valueMap == null || CollectionUtils.isEmpty(values)) {
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
String value = valueMap.get(JsonUtil.toString(values));
|
|
||||||
if (StringUtils.isBlank(value)) {
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
List<String> valueList = JsonUtil.toList(value, String.class);
|
|
||||||
List<Expression> newExpressions = new ArrayList<>();
|
List<Expression> newExpressions = new ArrayList<>();
|
||||||
valueList.stream().forEach(o -> {
|
values.stream().forEach(o -> {
|
||||||
StringValue stringValue = new StringValue(o);
|
String replaceValue = valueMap.getOrDefault(o, o);
|
||||||
|
StringValue stringValue = new StringValue(replaceValue);
|
||||||
newExpressions.add(stringValue);
|
newExpressions.add(stringValue);
|
||||||
});
|
});
|
||||||
rightItemsList.setExpressions(newExpressions);
|
rightItemsList.setExpressions(newExpressions);
|
||||||
@@ -1,9 +1,5 @@
|
|||||||
package com.tencent.supersonic.common.jsqlparser;
|
package com.tencent.supersonic.common.jsqlparser;
|
||||||
|
|
||||||
import java.util.ArrayList;
|
|
||||||
import java.util.List;
|
|
||||||
import java.util.Objects;
|
|
||||||
import java.util.Set;
|
|
||||||
import lombok.extern.slf4j.Slf4j;
|
import lombok.extern.slf4j.Slf4j;
|
||||||
import net.sf.jsqlparser.JSQLParserException;
|
import net.sf.jsqlparser.JSQLParserException;
|
||||||
import net.sf.jsqlparser.expression.Expression;
|
import net.sf.jsqlparser.expression.Expression;
|
||||||
@@ -20,6 +16,11 @@ import net.sf.jsqlparser.parser.CCJSqlParserUtil;
|
|||||||
import net.sf.jsqlparser.schema.Column;
|
import net.sf.jsqlparser.schema.Column;
|
||||||
import org.apache.commons.collections.CollectionUtils;
|
import org.apache.commons.collections.CollectionUtils;
|
||||||
|
|
||||||
|
import java.util.ArrayList;
|
||||||
|
import java.util.List;
|
||||||
|
import java.util.Objects;
|
||||||
|
import java.util.Set;
|
||||||
|
|
||||||
@Slf4j
|
@Slf4j
|
||||||
public class FiledFilterReplaceVisitor extends ExpressionVisitorAdapter {
|
public class FiledFilterReplaceVisitor extends ExpressionVisitorAdapter {
|
||||||
|
|
||||||
@@ -76,37 +77,39 @@ public class FiledFilterReplaceVisitor extends ExpressionVisitorAdapter {
|
|||||||
|
|
||||||
public List<Expression> parserFilter(ComparisonOperator comparisonOperator, String condExpr) {
|
public List<Expression> parserFilter(ComparisonOperator comparisonOperator, String condExpr) {
|
||||||
List<Expression> result = new ArrayList<>();
|
List<Expression> result = new ArrayList<>();
|
||||||
String toString = comparisonOperator.toString();
|
String comparisonOperatorStr = comparisonOperator.toString();
|
||||||
Expression leftExpression = comparisonOperator.getLeftExpression();
|
Expression leftExpression = comparisonOperator.getLeftExpression();
|
||||||
|
|
||||||
if (!(leftExpression instanceof Function)) {
|
if (!(leftExpression instanceof Function)) {
|
||||||
return result;
|
return result;
|
||||||
}
|
}
|
||||||
Function leftExpressionFunction = (Function) leftExpression;
|
|
||||||
if (leftExpressionFunction.toString().contains(JsqlConstants.DATE_FUNCTION)) {
|
Function leftFunction = (Function) leftExpression;
|
||||||
|
if (leftFunction.toString().contains(JsqlConstants.DATE_FUNCTION)) {
|
||||||
return result;
|
return result;
|
||||||
}
|
}
|
||||||
|
|
||||||
//List<Expression> leftExpressions = leftExpressionFunction.getParameters().getExpressions();
|
ExpressionList<?> leftFunctionParams = leftFunction.getParameters();
|
||||||
ExpressionList<?> leftExpressions = leftExpressionFunction.getParameters();
|
if (CollectionUtils.isEmpty(leftFunctionParams)) {
|
||||||
if (CollectionUtils.isEmpty(leftExpressions)) {
|
|
||||||
return result;
|
return result;
|
||||||
}
|
}
|
||||||
Column field = (Column) leftExpressions.get(0);
|
|
||||||
|
Column field = (Column) leftFunctionParams.get(0);
|
||||||
String columnName = field.getColumnName();
|
String columnName = field.getColumnName();
|
||||||
if (!fieldNames.contains(columnName)) {
|
if (!fieldNames.contains(columnName)) {
|
||||||
return null;
|
return null;
|
||||||
}
|
}
|
||||||
|
|
||||||
try {
|
try {
|
||||||
ComparisonOperator expression = (ComparisonOperator) CCJSqlParserUtil.parseCondExpression(condExpr);
|
ComparisonOperator parsedExpression = (ComparisonOperator) CCJSqlParserUtil.parseCondExpression(condExpr);
|
||||||
comparisonOperator.setLeftExpression(expression.getLeftExpression());
|
comparisonOperator.setLeftExpression(parsedExpression.getLeftExpression());
|
||||||
comparisonOperator.setRightExpression(expression.getRightExpression());
|
comparisonOperator.setRightExpression(parsedExpression.getRightExpression());
|
||||||
comparisonOperator.setASTNode(expression.getASTNode());
|
comparisonOperator.setASTNode(parsedExpression.getASTNode());
|
||||||
result.add(CCJSqlParserUtil.parseCondExpression(toString));
|
result.add(CCJSqlParserUtil.parseCondExpression(comparisonOperatorStr));
|
||||||
return result;
|
return result;
|
||||||
} catch (JSQLParserException e) {
|
} catch (JSQLParserException e) {
|
||||||
log.error("JSQLParserException", e);
|
log.error("JSQLParserException", e);
|
||||||
}
|
}
|
||||||
return null;
|
return null;
|
||||||
}
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,8 +1,5 @@
|
|||||||
package com.tencent.supersonic.common.jsqlparser;
|
package com.tencent.supersonic.common.jsqlparser;
|
||||||
|
|
||||||
import java.util.Map;
|
|
||||||
import java.util.Objects;
|
|
||||||
import java.util.Set;
|
|
||||||
import net.sf.jsqlparser.expression.BinaryExpression;
|
import net.sf.jsqlparser.expression.BinaryExpression;
|
||||||
import net.sf.jsqlparser.expression.Expression;
|
import net.sf.jsqlparser.expression.Expression;
|
||||||
import net.sf.jsqlparser.expression.ExpressionVisitorAdapter;
|
import net.sf.jsqlparser.expression.ExpressionVisitorAdapter;
|
||||||
@@ -12,9 +9,11 @@ import net.sf.jsqlparser.expression.operators.relational.LikeExpression;
|
|||||||
import net.sf.jsqlparser.schema.Column;
|
import net.sf.jsqlparser.schema.Column;
|
||||||
import org.springframework.util.CollectionUtils;
|
import org.springframework.util.CollectionUtils;
|
||||||
|
|
||||||
public class FiledNameReplaceVisitor extends ExpressionVisitorAdapter {
|
import java.util.Map;
|
||||||
|
import java.util.Objects;
|
||||||
|
import java.util.Set;
|
||||||
|
|
||||||
public static final String PREFIX = "%";
|
public class FiledNameReplaceVisitor extends ExpressionVisitorAdapter {
|
||||||
private Map<String, Set<String>> fieldValueToFieldNames;
|
private Map<String, Set<String>> fieldValueToFieldNames;
|
||||||
|
|
||||||
public FiledNameReplaceVisitor(Map<String, Set<String>> fieldValueToFieldNames) {
|
public FiledNameReplaceVisitor(Map<String, Set<String>> fieldValueToFieldNames) {
|
||||||
@@ -34,40 +33,20 @@ public class FiledNameReplaceVisitor extends ExpressionVisitorAdapter {
|
|||||||
private void replaceFieldNameByFieldValue(BinaryExpression expr) {
|
private void replaceFieldNameByFieldValue(BinaryExpression expr) {
|
||||||
Expression leftExpression = expr.getLeftExpression();
|
Expression leftExpression = expr.getLeftExpression();
|
||||||
Expression rightExpression = expr.getRightExpression();
|
Expression rightExpression = expr.getRightExpression();
|
||||||
if (!(rightExpression instanceof StringValue)) {
|
|
||||||
|
if (!(rightExpression instanceof StringValue) || !(leftExpression instanceof Column)
|
||||||
|
|| CollectionUtils.isEmpty(fieldValueToFieldNames)
|
||||||
|
|| Objects.isNull(rightExpression) || Objects.isNull(leftExpression)) {
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
if (!(leftExpression instanceof Column)) {
|
|
||||||
return;
|
Column leftColumn = (Column) leftExpression;
|
||||||
}
|
|
||||||
if (CollectionUtils.isEmpty(fieldValueToFieldNames)) {
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
if (Objects.isNull(rightExpression) || Objects.isNull(leftExpression)) {
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
Column leftColumnName = (Column) leftExpression;
|
|
||||||
StringValue rightStringValue = (StringValue) rightExpression;
|
StringValue rightStringValue = (StringValue) rightExpression;
|
||||||
|
|
||||||
if (expr instanceof LikeExpression) {
|
|
||||||
String value = getValue(rightStringValue.getValue());
|
|
||||||
rightStringValue.setValue(value);
|
|
||||||
}
|
|
||||||
|
|
||||||
Set<String> fieldNames = fieldValueToFieldNames.get(rightStringValue.getValue());
|
Set<String> fieldNames = fieldValueToFieldNames.get(rightStringValue.getValue());
|
||||||
if (!CollectionUtils.isEmpty(fieldNames) && !fieldNames.contains(leftColumnName.getColumnName())) {
|
if (!CollectionUtils.isEmpty(fieldNames) && !fieldNames.contains(leftColumn.getColumnName())) {
|
||||||
leftColumnName.setColumnName(fieldNames.stream().findFirst().get());
|
leftColumn.setColumnName(fieldNames.stream().findFirst().get());
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
private String getValue(String value) {
|
|
||||||
if (value.startsWith(PREFIX)) {
|
|
||||||
value = value.substring(1);
|
|
||||||
}
|
|
||||||
if (value.endsWith(PREFIX)) {
|
|
||||||
value = value.substring(0, value.length() - 1);
|
|
||||||
}
|
|
||||||
return value;
|
|
||||||
}
|
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,94 +0,0 @@
|
|||||||
package com.tencent.supersonic.common.jsqlparser;
|
|
||||||
|
|
||||||
import java.util.List;
|
|
||||||
import net.sf.jsqlparser.expression.Expression;
|
|
||||||
import net.sf.jsqlparser.expression.ExpressionVisitorAdapter;
|
|
||||||
import net.sf.jsqlparser.expression.LongValue;
|
|
||||||
import net.sf.jsqlparser.expression.operators.relational.EqualsTo;
|
|
||||||
import net.sf.jsqlparser.expression.operators.relational.GreaterThan;
|
|
||||||
import net.sf.jsqlparser.expression.operators.relational.GreaterThanEquals;
|
|
||||||
import net.sf.jsqlparser.expression.operators.relational.InExpression;
|
|
||||||
import net.sf.jsqlparser.expression.operators.relational.MinorThan;
|
|
||||||
import net.sf.jsqlparser.expression.operators.relational.MinorThanEquals;
|
|
||||||
import net.sf.jsqlparser.schema.Column;
|
|
||||||
import org.apache.commons.lang3.StringUtils;
|
|
||||||
|
|
||||||
public class FilterRemoveVisitor extends ExpressionVisitorAdapter {
|
|
||||||
|
|
||||||
private List<String> filedNames;
|
|
||||||
|
|
||||||
public FilterRemoveVisitor(List<String> filedNames) {
|
|
||||||
this.filedNames = filedNames;
|
|
||||||
}
|
|
||||||
|
|
||||||
private boolean isRemove(Expression leftExpression) {
|
|
||||||
if (!(leftExpression instanceof Column)) {
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
Column leftColumnName = (Column) leftExpression;
|
|
||||||
String columnName = leftColumnName.getColumnName();
|
|
||||||
if (StringUtils.isEmpty(columnName)) {
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
if (!filedNames.contains(columnName)) {
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
return true;
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public void visit(EqualsTo expr) {
|
|
||||||
if (!isRemove(expr.getLeftExpression())) {
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
expr.setRightExpression(new LongValue(1L));
|
|
||||||
expr.setLeftExpression(new LongValue(1L));
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public void visit(MinorThan expr) {
|
|
||||||
if (!isRemove(expr.getLeftExpression())) {
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
expr.setRightExpression(new LongValue(1L));
|
|
||||||
expr.setLeftExpression(new LongValue(0L));
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public void visit(MinorThanEquals expr) {
|
|
||||||
if (!isRemove(expr.getLeftExpression())) {
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
expr.setRightExpression(new LongValue(1L));
|
|
||||||
expr.setLeftExpression(new LongValue(1L));
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public void visit(GreaterThan expr) {
|
|
||||||
if (!isRemove(expr.getLeftExpression())) {
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
expr.setRightExpression(new LongValue(0L));
|
|
||||||
expr.setLeftExpression(new LongValue(1L));
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public void visit(GreaterThanEquals expr) {
|
|
||||||
if (!isRemove(expr.getLeftExpression())) {
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
expr.setRightExpression(new LongValue(1L));
|
|
||||||
expr.setLeftExpression(new LongValue(1L));
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public void visit(InExpression expr) {
|
|
||||||
if (!isRemove(expr.getLeftExpression())) {
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
expr.setNot(false);
|
|
||||||
expr.setRightExpression(new LongValue(1L));
|
|
||||||
expr.setLeftExpression(new LongValue(1L));
|
|
||||||
}
|
|
||||||
|
|
||||||
}
|
|
||||||
@@ -1,24 +1,21 @@
|
|||||||
package com.tencent.supersonic.common.jsqlparser;
|
package com.tencent.supersonic.common.jsqlparser;
|
||||||
|
|
||||||
import java.util.Map;
|
|
||||||
import java.util.Objects;
|
|
||||||
import java.util.function.UnaryOperator;
|
|
||||||
import lombok.extern.slf4j.Slf4j;
|
import lombok.extern.slf4j.Slf4j;
|
||||||
import net.sf.jsqlparser.expression.ExpressionVisitorAdapter;
|
import net.sf.jsqlparser.expression.ExpressionVisitorAdapter;
|
||||||
import net.sf.jsqlparser.expression.Function;
|
import net.sf.jsqlparser.expression.Function;
|
||||||
import net.sf.jsqlparser.expression.operators.relational.ExpressionList;
|
import net.sf.jsqlparser.expression.operators.relational.ExpressionList;
|
||||||
import org.apache.commons.lang3.StringUtils;
|
import org.apache.commons.lang3.StringUtils;
|
||||||
|
|
||||||
|
import java.util.Map;
|
||||||
|
import java.util.Objects;
|
||||||
|
import java.util.function.UnaryOperator;
|
||||||
|
|
||||||
@Slf4j
|
@Slf4j
|
||||||
public class FunctionNameReplaceVisitor extends ExpressionVisitorAdapter {
|
public class FunctionNameReplaceVisitor extends ExpressionVisitorAdapter {
|
||||||
|
|
||||||
private Map<String, String> functionMap;
|
private Map<String, String> functionMap;
|
||||||
private Map<String, UnaryOperator> functionCallMap;
|
private Map<String, UnaryOperator> functionCallMap;
|
||||||
|
|
||||||
public FunctionNameReplaceVisitor(Map<String, String> functionMap) {
|
|
||||||
this.functionMap = functionMap;
|
|
||||||
}
|
|
||||||
|
|
||||||
public FunctionNameReplaceVisitor(Map<String, String> functionMap, Map<String, UnaryOperator> functionCallMap) {
|
public FunctionNameReplaceVisitor(Map<String, String> functionMap, Map<String, UnaryOperator> functionCallMap) {
|
||||||
this.functionMap = functionMap;
|
this.functionMap = functionMap;
|
||||||
this.functionCallMap = functionCallMap;
|
this.functionCallMap = functionCallMap;
|
||||||
|
|||||||
@@ -1,123 +0,0 @@
|
|||||||
package com.tencent.supersonic.common.jsqlparser;
|
|
||||||
|
|
||||||
import com.tencent.supersonic.common.util.StringUtil;
|
|
||||||
import java.util.ArrayList;
|
|
||||||
import java.util.List;
|
|
||||||
import java.util.Objects;
|
|
||||||
import lombok.extern.slf4j.Slf4j;
|
|
||||||
import net.sf.jsqlparser.JSQLParserException;
|
|
||||||
import net.sf.jsqlparser.expression.Expression;
|
|
||||||
import net.sf.jsqlparser.expression.ExpressionVisitorAdapter;
|
|
||||||
import net.sf.jsqlparser.expression.Function;
|
|
||||||
import net.sf.jsqlparser.expression.LongValue;
|
|
||||||
import net.sf.jsqlparser.expression.operators.relational.ComparisonOperator;
|
|
||||||
import net.sf.jsqlparser.expression.operators.relational.EqualsTo;
|
|
||||||
import net.sf.jsqlparser.expression.operators.relational.ExpressionList;
|
|
||||||
import net.sf.jsqlparser.expression.operators.relational.GreaterThan;
|
|
||||||
import net.sf.jsqlparser.expression.operators.relational.GreaterThanEquals;
|
|
||||||
import net.sf.jsqlparser.expression.operators.relational.MinorThan;
|
|
||||||
import net.sf.jsqlparser.expression.operators.relational.MinorThanEquals;
|
|
||||||
import net.sf.jsqlparser.parser.CCJSqlParserUtil;
|
|
||||||
import net.sf.jsqlparser.schema.Column;
|
|
||||||
import org.apache.commons.collections.CollectionUtils;
|
|
||||||
|
|
||||||
@Slf4j
|
|
||||||
public class FunctionReplaceVisitor extends ExpressionVisitorAdapter {
|
|
||||||
|
|
||||||
private List<Expression> waitingForAdds = new ArrayList<>();
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public void visit(MinorThan expr) {
|
|
||||||
List<Expression> expressions = reparseDate(expr, ">");
|
|
||||||
if (Objects.nonNull(expressions)) {
|
|
||||||
waitingForAdds.addAll(expressions);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public void visit(EqualsTo expr) {
|
|
||||||
List<Expression> expressions = reparseDate(expr, ">=");
|
|
||||||
if (Objects.nonNull(expressions)) {
|
|
||||||
waitingForAdds.addAll(expressions);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public void visit(MinorThanEquals expr) {
|
|
||||||
List<Expression> expressions = reparseDate(expr, ">=");
|
|
||||||
if (Objects.nonNull(expressions)) {
|
|
||||||
waitingForAdds.addAll(expressions);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public void visit(GreaterThan expr) {
|
|
||||||
List<Expression> expressions = reparseDate(expr, "<");
|
|
||||||
if (Objects.nonNull(expressions)) {
|
|
||||||
waitingForAdds.addAll(expressions);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public void visit(GreaterThanEquals expr) {
|
|
||||||
List<Expression> expressions = reparseDate(expr, "<=");
|
|
||||||
if (Objects.nonNull(expressions)) {
|
|
||||||
waitingForAdds.addAll(expressions);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
public List<Expression> getWaitingForAdds() {
|
|
||||||
return waitingForAdds;
|
|
||||||
}
|
|
||||||
|
|
||||||
public List<Expression> reparseDate(ComparisonOperator comparisonOperator, String startDateOperator) {
|
|
||||||
List<Expression> result = new ArrayList<>();
|
|
||||||
Expression leftExpression = comparisonOperator.getLeftExpression();
|
|
||||||
if (!(leftExpression instanceof Function)) {
|
|
||||||
return result;
|
|
||||||
}
|
|
||||||
Function leftExpressionFunction = (Function) leftExpression;
|
|
||||||
if (!leftExpressionFunction.toString().contains(JsqlConstants.DATE_FUNCTION)) {
|
|
||||||
return result;
|
|
||||||
}
|
|
||||||
//List<Expression> leftExpressions = leftExpressionFunction.getParameters().getExpressions();
|
|
||||||
ExpressionList<?> leftExpressions = leftExpressionFunction.getParameters();
|
|
||||||
if (CollectionUtils.isEmpty(leftExpressions) || leftExpressions.size() < 3) {
|
|
||||||
return result;
|
|
||||||
}
|
|
||||||
Column field = (Column) leftExpressions.get(1);
|
|
||||||
String columnName = field.getColumnName();
|
|
||||||
try {
|
|
||||||
String startDateValue = DateFunctionHelper.getStartDateStr(comparisonOperator, leftExpressions);
|
|
||||||
String endDateValue = DateFunctionHelper.getEndDateValue(leftExpressions);
|
|
||||||
String endDateOperator = comparisonOperator.getStringExpression();
|
|
||||||
String condExpr =
|
|
||||||
columnName + StringUtil.getSpaceWrap(DateFunctionHelper.getEndDateOperator(comparisonOperator))
|
|
||||||
+ StringUtil.getCommaWrap(endDateValue);
|
|
||||||
ComparisonOperator expression = (ComparisonOperator) CCJSqlParserUtil.parseCondExpression(condExpr);
|
|
||||||
|
|
||||||
String startDataCondExpr =
|
|
||||||
columnName + StringUtil.getSpaceWrap(startDateOperator) + StringUtil.getCommaWrap(startDateValue);
|
|
||||||
if (JsqlConstants.EQUAL.equalsIgnoreCase(endDateOperator)) {
|
|
||||||
result.add(CCJSqlParserUtil.parseCondExpression(condExpr));
|
|
||||||
expression = (ComparisonOperator) CCJSqlParserUtil.parseCondExpression(JsqlConstants.EQUAL_CONSTANT);
|
|
||||||
}
|
|
||||||
if (startDateOperator.equals("<=") || startDateOperator.equals("<")) {
|
|
||||||
comparisonOperator.setLeftExpression(new Column("1"));
|
|
||||||
comparisonOperator.setRightExpression(new LongValue(1));
|
|
||||||
comparisonOperator.setASTNode(null);
|
|
||||||
} else {
|
|
||||||
comparisonOperator.setLeftExpression(expression.getLeftExpression());
|
|
||||||
comparisonOperator.setRightExpression(expression.getRightExpression());
|
|
||||||
comparisonOperator.setASTNode(expression.getASTNode());
|
|
||||||
}
|
|
||||||
result.add(CCJSqlParserUtil.parseCondExpression(startDataCondExpr));
|
|
||||||
return result;
|
|
||||||
} catch (JSQLParserException e) {
|
|
||||||
log.error("JSQLParserException", e);
|
|
||||||
}
|
|
||||||
|
|
||||||
return null;
|
|
||||||
}
|
|
||||||
|
|
||||||
}
|
|
||||||
@@ -1,9 +1,5 @@
|
|||||||
package com.tencent.supersonic.common.jsqlparser;
|
package com.tencent.supersonic.common.jsqlparser;
|
||||||
|
|
||||||
import java.util.List;
|
|
||||||
import java.util.Map;
|
|
||||||
import java.util.Objects;
|
|
||||||
import java.util.function.UnaryOperator;
|
|
||||||
import lombok.extern.slf4j.Slf4j;
|
import lombok.extern.slf4j.Slf4j;
|
||||||
import net.sf.jsqlparser.expression.Expression;
|
import net.sf.jsqlparser.expression.Expression;
|
||||||
import net.sf.jsqlparser.expression.Function;
|
import net.sf.jsqlparser.expression.Function;
|
||||||
@@ -12,16 +8,17 @@ import net.sf.jsqlparser.statement.select.GroupByElement;
|
|||||||
import net.sf.jsqlparser.statement.select.GroupByVisitor;
|
import net.sf.jsqlparser.statement.select.GroupByVisitor;
|
||||||
import org.apache.commons.lang3.StringUtils;
|
import org.apache.commons.lang3.StringUtils;
|
||||||
|
|
||||||
|
import java.util.List;
|
||||||
|
import java.util.Map;
|
||||||
|
import java.util.Objects;
|
||||||
|
import java.util.function.UnaryOperator;
|
||||||
|
|
||||||
@Slf4j
|
@Slf4j
|
||||||
public class GroupByFunctionReplaceVisitor implements GroupByVisitor {
|
public class GroupByFunctionReplaceVisitor implements GroupByVisitor {
|
||||||
|
|
||||||
private Map<String, String> functionMap;
|
private Map<String, String> functionMap;
|
||||||
private Map<String, UnaryOperator> functionCallMap;
|
private Map<String, UnaryOperator> functionCallMap;
|
||||||
|
|
||||||
public GroupByFunctionReplaceVisitor(Map<String, String> functionMap) {
|
|
||||||
this.functionMap = functionMap;
|
|
||||||
}
|
|
||||||
|
|
||||||
public GroupByFunctionReplaceVisitor(Map<String, String> functionMap, Map<String, UnaryOperator> functionCallMap) {
|
public GroupByFunctionReplaceVisitor(Map<String, String> functionMap, Map<String, UnaryOperator> functionCallMap) {
|
||||||
this.functionMap = functionMap;
|
this.functionMap = functionMap;
|
||||||
this.functionCallMap = functionCallMap;
|
this.functionCallMap = functionCallMap;
|
||||||
@@ -31,14 +28,16 @@ public class GroupByFunctionReplaceVisitor implements GroupByVisitor {
|
|||||||
groupByElement.getGroupByExpressionList();
|
groupByElement.getGroupByExpressionList();
|
||||||
ExpressionList groupByExpressionList = groupByElement.getGroupByExpressionList();
|
ExpressionList groupByExpressionList = groupByElement.getGroupByExpressionList();
|
||||||
List<Expression> groupByExpressions = groupByExpressionList.getExpressions();
|
List<Expression> groupByExpressions = groupByExpressionList.getExpressions();
|
||||||
|
for (Expression expression : groupByExpressions) {
|
||||||
for (int i = 0; i < groupByExpressions.size(); i++) {
|
if (!(expression instanceof Function)) {
|
||||||
Expression expression = groupByExpressions.get(i);
|
continue;
|
||||||
if (expression instanceof Function) {
|
}
|
||||||
Function function = (Function) expression;
|
Function function = (Function) expression;
|
||||||
String functionName = function.getName().toLowerCase();
|
String functionName = function.getName().toLowerCase();
|
||||||
String replaceName = functionMap.get(functionName);
|
String replaceName = functionMap.get(functionName);
|
||||||
if (StringUtils.isNotBlank(replaceName)) {
|
if (StringUtils.isBlank(replaceName)) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
function.setName(replaceName);
|
function.setName(replaceName);
|
||||||
if (Objects.nonNull(functionCallMap) && functionCallMap.containsKey(functionName)) {
|
if (Objects.nonNull(functionCallMap) && functionCallMap.containsKey(functionName)) {
|
||||||
Object ret = functionCallMap.get(functionName).apply(function.getParameters());
|
Object ret = functionCallMap.get(functionName).apply(function.getParameters());
|
||||||
@@ -49,6 +48,4 @@ public class GroupByFunctionReplaceVisitor implements GroupByVisitor {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
@@ -1,8 +1,5 @@
|
|||||||
package com.tencent.supersonic.common.jsqlparser;
|
package com.tencent.supersonic.common.jsqlparser;
|
||||||
|
|
||||||
import java.util.List;
|
|
||||||
import java.util.Map;
|
|
||||||
import java.util.Objects;
|
|
||||||
import lombok.extern.slf4j.Slf4j;
|
import lombok.extern.slf4j.Slf4j;
|
||||||
import net.sf.jsqlparser.JSQLParserException;
|
import net.sf.jsqlparser.JSQLParserException;
|
||||||
import net.sf.jsqlparser.expression.Expression;
|
import net.sf.jsqlparser.expression.Expression;
|
||||||
@@ -14,6 +11,10 @@ import net.sf.jsqlparser.statement.select.GroupByElement;
|
|||||||
import net.sf.jsqlparser.statement.select.GroupByVisitor;
|
import net.sf.jsqlparser.statement.select.GroupByVisitor;
|
||||||
import org.apache.commons.lang3.StringUtils;
|
import org.apache.commons.lang3.StringUtils;
|
||||||
|
|
||||||
|
import java.util.List;
|
||||||
|
import java.util.Map;
|
||||||
|
import java.util.Objects;
|
||||||
|
|
||||||
@Slf4j
|
@Slf4j
|
||||||
public class GroupByReplaceVisitor implements GroupByVisitor {
|
public class GroupByReplaceVisitor implements GroupByVisitor {
|
||||||
|
|
||||||
@@ -27,38 +28,51 @@ public class GroupByReplaceVisitor implements GroupByVisitor {
|
|||||||
}
|
}
|
||||||
|
|
||||||
public void visit(GroupByElement groupByElement) {
|
public void visit(GroupByElement groupByElement) {
|
||||||
groupByElement.getGroupByExpressionList();
|
|
||||||
ExpressionList groupByExpressionList = groupByElement.getGroupByExpressionList();
|
ExpressionList groupByExpressionList = groupByElement.getGroupByExpressionList();
|
||||||
List<Expression> groupByExpressions = groupByExpressionList.getExpressions();
|
List<Expression> groupByExpressions = groupByExpressionList.getExpressions();
|
||||||
|
|
||||||
for (int i = 0; i < groupByExpressions.size(); i++) {
|
for (int i = 0; i < groupByExpressions.size(); i++) {
|
||||||
Expression expression = groupByExpressions.get(i);
|
Expression expression = groupByExpressions.get(i);
|
||||||
String columnName = expression.toString();
|
String columnName = getColumnName(expression);
|
||||||
if (expression instanceof Function && Objects.nonNull(
|
|
||||||
((Function) expression).getParameters().getExpressions().get(0))) {
|
String replaceColumn = parseVisitorHelper.getReplaceValue(columnName, fieldNameMap, exactReplace);
|
||||||
columnName = ((Function) expression).getParameters().getExpressions().get(0).toString();
|
|
||||||
}
|
|
||||||
String replaceColumn = parseVisitorHelper.getReplaceValue(columnName, fieldNameMap,
|
|
||||||
exactReplace);
|
|
||||||
if (StringUtils.isNotEmpty(replaceColumn)) {
|
if (StringUtils.isNotEmpty(replaceColumn)) {
|
||||||
if (expression instanceof Column) {
|
replaceExpression(groupByExpressions, i, expression, replaceColumn);
|
||||||
groupByExpressions.set(i, new Column(replaceColumn));
|
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
private String getColumnName(Expression expression) {
|
||||||
if (expression instanceof Function) {
|
if (expression instanceof Function) {
|
||||||
|
Function function = (Function) expression;
|
||||||
|
if (Objects.nonNull(function.getParameters().getExpressions().get(0))) {
|
||||||
|
return function.getParameters().getExpressions().get(0).toString();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return expression.toString();
|
||||||
|
}
|
||||||
|
|
||||||
|
private void replaceExpression(List<Expression> groupByExpressions,
|
||||||
|
int index,
|
||||||
|
Expression expression,
|
||||||
|
String replaceColumn) {
|
||||||
|
if (expression instanceof Column) {
|
||||||
|
groupByExpressions.set(index, new Column(replaceColumn));
|
||||||
|
} else if (expression instanceof Function) {
|
||||||
try {
|
try {
|
||||||
Expression element = CCJSqlParserUtil.parseExpression(replaceColumn);
|
Expression newExpression = CCJSqlParserUtil.parseExpression(replaceColumn);
|
||||||
ExpressionList<Expression> expressionList = new ExpressionList<Expression>();
|
ExpressionList<Expression> newExpressionList = new ExpressionList<>();
|
||||||
expressionList.add(element);
|
newExpressionList.add(newExpression);
|
||||||
if (((Function) expression).getParameters().size() > 1) {
|
|
||||||
((Function) expression).getParameters().stream().skip(1).forEach(e -> {
|
Function function = (Function) expression;
|
||||||
expressionList.add((Function) e);
|
if (function.getParameters().size() > 1) {
|
||||||
});
|
function.getParameters().stream().skip(1).forEach(
|
||||||
|
e -> newExpressionList.add((Function) e)
|
||||||
|
);
|
||||||
}
|
}
|
||||||
((Function) expression).setParameters(expressionList);
|
function.setParameters(newExpressionList);
|
||||||
} catch (JSQLParserException e) {
|
} catch (JSQLParserException e) {
|
||||||
log.error("e", e);
|
log.error("Error parsing expression: {}", replaceColumn, e);
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -71,7 +71,6 @@ public class QueryExpressionReplaceVisitor extends ExpressionVisitorAdapter {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
//selectExpressionItem.getExpression().accept(this);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
public static Expression replace(Expression expression, Map<String, String> fieldExprMap) {
|
public static Expression replace(Expression expression, Map<String, String> fieldExprMap) {
|
||||||
|
|||||||
@@ -1,10 +1,5 @@
|
|||||||
package com.tencent.supersonic.common.jsqlparser;
|
package com.tencent.supersonic.common.jsqlparser;
|
||||||
|
|
||||||
import java.util.ArrayList;
|
|
||||||
import java.util.HashSet;
|
|
||||||
import java.util.List;
|
|
||||||
import java.util.Objects;
|
|
||||||
import java.util.Set;
|
|
||||||
import lombok.extern.slf4j.Slf4j;
|
import lombok.extern.slf4j.Slf4j;
|
||||||
import net.sf.jsqlparser.JSQLParserException;
|
import net.sf.jsqlparser.JSQLParserException;
|
||||||
import net.sf.jsqlparser.expression.BinaryExpression;
|
import net.sf.jsqlparser.expression.BinaryExpression;
|
||||||
@@ -33,6 +28,12 @@ import net.sf.jsqlparser.statement.select.SelectItem;
|
|||||||
import net.sf.jsqlparser.statement.select.SelectVisitorAdapter;
|
import net.sf.jsqlparser.statement.select.SelectVisitorAdapter;
|
||||||
import org.springframework.util.CollectionUtils;
|
import org.springframework.util.CollectionUtils;
|
||||||
|
|
||||||
|
import java.util.ArrayList;
|
||||||
|
import java.util.HashSet;
|
||||||
|
import java.util.List;
|
||||||
|
import java.util.Objects;
|
||||||
|
import java.util.Set;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Sql Parser remove Helper
|
* Sql Parser remove Helper
|
||||||
*/
|
*/
|
||||||
@@ -228,7 +229,6 @@ public class SqlRemoveHelper {
|
|||||||
if (selectStatement == null) {
|
if (selectStatement == null) {
|
||||||
return sql;
|
return sql;
|
||||||
}
|
}
|
||||||
//SelectBody selectBody = selectStatement.getSelectBody();
|
|
||||||
if (!(selectStatement instanceof PlainSelect)) {
|
if (!(selectStatement instanceof PlainSelect)) {
|
||||||
return sql;
|
return sql;
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -2,15 +2,6 @@ package com.tencent.supersonic.common.jsqlparser;
|
|||||||
|
|
||||||
import com.tencent.supersonic.common.pojo.enums.AggOperatorEnum;
|
import com.tencent.supersonic.common.pojo.enums.AggOperatorEnum;
|
||||||
import com.tencent.supersonic.common.util.StringUtil;
|
import com.tencent.supersonic.common.util.StringUtil;
|
||||||
|
|
||||||
import java.util.ArrayList;
|
|
||||||
import java.util.HashMap;
|
|
||||||
import java.util.List;
|
|
||||||
import java.util.Map;
|
|
||||||
import java.util.Objects;
|
|
||||||
import java.util.Set;
|
|
||||||
import java.util.function.UnaryOperator;
|
|
||||||
|
|
||||||
import lombok.extern.slf4j.Slf4j;
|
import lombok.extern.slf4j.Slf4j;
|
||||||
import net.sf.jsqlparser.JSQLParserException;
|
import net.sf.jsqlparser.JSQLParserException;
|
||||||
import net.sf.jsqlparser.expression.Alias;
|
import net.sf.jsqlparser.expression.Alias;
|
||||||
@@ -30,6 +21,7 @@ import net.sf.jsqlparser.expression.operators.relational.NotEqualsTo;
|
|||||||
import net.sf.jsqlparser.parser.CCJSqlParserUtil;
|
import net.sf.jsqlparser.parser.CCJSqlParserUtil;
|
||||||
import net.sf.jsqlparser.schema.Column;
|
import net.sf.jsqlparser.schema.Column;
|
||||||
import net.sf.jsqlparser.schema.Table;
|
import net.sf.jsqlparser.schema.Table;
|
||||||
|
import net.sf.jsqlparser.statement.select.FromItem;
|
||||||
import net.sf.jsqlparser.statement.select.GroupByElement;
|
import net.sf.jsqlparser.statement.select.GroupByElement;
|
||||||
import net.sf.jsqlparser.statement.select.Join;
|
import net.sf.jsqlparser.statement.select.Join;
|
||||||
import net.sf.jsqlparser.statement.select.OrderByElement;
|
import net.sf.jsqlparser.statement.select.OrderByElement;
|
||||||
@@ -40,11 +32,18 @@ import net.sf.jsqlparser.statement.select.Select;
|
|||||||
import net.sf.jsqlparser.statement.select.SelectItem;
|
import net.sf.jsqlparser.statement.select.SelectItem;
|
||||||
import net.sf.jsqlparser.statement.select.SelectVisitorAdapter;
|
import net.sf.jsqlparser.statement.select.SelectVisitorAdapter;
|
||||||
import net.sf.jsqlparser.statement.select.SetOperationList;
|
import net.sf.jsqlparser.statement.select.SetOperationList;
|
||||||
import net.sf.jsqlparser.statement.select.FromItem;
|
|
||||||
import org.apache.commons.lang3.StringUtils;
|
import org.apache.commons.lang3.StringUtils;
|
||||||
import org.apache.commons.lang3.tuple.Pair;
|
import org.apache.commons.lang3.tuple.Pair;
|
||||||
import org.springframework.util.CollectionUtils;
|
import org.springframework.util.CollectionUtils;
|
||||||
|
|
||||||
|
import java.util.ArrayList;
|
||||||
|
import java.util.HashMap;
|
||||||
|
import java.util.List;
|
||||||
|
import java.util.Map;
|
||||||
|
import java.util.Objects;
|
||||||
|
import java.util.Set;
|
||||||
|
import java.util.function.UnaryOperator;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Sql Parser replace Helper
|
* Sql Parser replace Helper
|
||||||
*/
|
*/
|
||||||
@@ -127,12 +126,10 @@ public class SqlReplaceHelper {
|
|||||||
if (!(selectStatement instanceof PlainSelect)) {
|
if (!(selectStatement instanceof PlainSelect)) {
|
||||||
return sql;
|
return sql;
|
||||||
}
|
}
|
||||||
//List<PlainSelect> plainSelectList = new ArrayList<>();
|
|
||||||
//plainSelectList.add((PlainSelect) selectStatement);
|
|
||||||
List<PlainSelect> plainSelects = SqlSelectHelper.getPlainSelect(selectStatement);
|
List<PlainSelect> plainSelects = SqlSelectHelper.getPlainSelect(selectStatement);
|
||||||
for (PlainSelect plainSelect : plainSelects) {
|
for (PlainSelect plainSelect : plainSelects) {
|
||||||
Expression where = plainSelect.getWhere();
|
Expression where = plainSelect.getWhere();
|
||||||
FieldlValueReplaceVisitor visitor = new FieldlValueReplaceVisitor(exactReplace, filedNameToValueMap);
|
FieldValueReplaceVisitor visitor = new FieldValueReplaceVisitor(exactReplace, filedNameToValueMap);
|
||||||
if (Objects.nonNull(where)) {
|
if (Objects.nonNull(where)) {
|
||||||
where.accept(visitor);
|
where.accept(visitor);
|
||||||
}
|
}
|
||||||
@@ -187,18 +184,14 @@ public class SqlReplaceHelper {
|
|||||||
public static String replaceFields(String sql, Map<String, String> fieldNameMap, boolean exactReplace) {
|
public static String replaceFields(String sql, Map<String, String> fieldNameMap, boolean exactReplace) {
|
||||||
Select selectStatement = SqlSelectHelper.getSelect(sql);
|
Select selectStatement = SqlSelectHelper.getSelect(sql);
|
||||||
List<PlainSelect> plainSelectList = SqlSelectHelper.getWithItem(selectStatement);
|
List<PlainSelect> plainSelectList = SqlSelectHelper.getWithItem(selectStatement);
|
||||||
//plainSelectList.add(selectStatement.getPlainSelect());
|
|
||||||
if (selectStatement instanceof PlainSelect) {
|
if (selectStatement instanceof PlainSelect) {
|
||||||
PlainSelect plainSelect = (PlainSelect) selectStatement;
|
PlainSelect plainSelect = (PlainSelect) selectStatement;
|
||||||
plainSelectList.add(plainSelect);
|
plainSelectList.add(plainSelect);
|
||||||
getFromSelect(plainSelect.getFromItem(), plainSelectList);
|
getFromSelect(plainSelect.getFromItem(), plainSelectList);
|
||||||
//plainSelectList.add((PlainSelect) selectStatement);
|
|
||||||
} else if (selectStatement instanceof SetOperationList) {
|
} else if (selectStatement instanceof SetOperationList) {
|
||||||
SetOperationList setOperationList = (SetOperationList) selectStatement;
|
SetOperationList setOperationList = (SetOperationList) selectStatement;
|
||||||
if (!CollectionUtils.isEmpty(setOperationList.getSelects())) {
|
if (!CollectionUtils.isEmpty(setOperationList.getSelects())) {
|
||||||
setOperationList.getSelects().forEach(subSelectBody -> {
|
setOperationList.getSelects().forEach(subSelectBody -> {
|
||||||
//PlainSelect subPlainSelect = (PlainSelect) subSelectBody;
|
|
||||||
//plainSelectList.add(subPlainSelect);
|
|
||||||
PlainSelect subPlainSelect = (PlainSelect) subSelectBody;
|
PlainSelect subPlainSelect = (PlainSelect) subSelectBody;
|
||||||
plainSelectList.add(subPlainSelect);
|
plainSelectList.add(subPlainSelect);
|
||||||
getFromSelect(subPlainSelect.getFromItem(), plainSelectList);
|
getFromSelect(subPlainSelect.getFromItem(), plainSelectList);
|
||||||
@@ -546,7 +539,7 @@ public class SqlReplaceHelper {
|
|||||||
}
|
}
|
||||||
PlainSelect plainSelect = (PlainSelect) selectStatement;
|
PlainSelect plainSelect = (PlainSelect) selectStatement;
|
||||||
Expression having = plainSelect.getHaving();
|
Expression having = plainSelect.getHaving();
|
||||||
FieldlValueReplaceVisitor visitor = new FieldlValueReplaceVisitor(false, filedNameToValueMap);
|
FieldValueReplaceVisitor visitor = new FieldValueReplaceVisitor(false, filedNameToValueMap);
|
||||||
if (Objects.nonNull(having)) {
|
if (Objects.nonNull(having)) {
|
||||||
having.accept(visitor);
|
having.accept(visitor);
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,14 +1,6 @@
|
|||||||
package com.tencent.supersonic.common.jsqlparser;
|
package com.tencent.supersonic.common.jsqlparser;
|
||||||
|
|
||||||
import com.tencent.supersonic.common.util.StringUtil;
|
import com.tencent.supersonic.common.util.StringUtil;
|
||||||
import java.util.ArrayList;
|
|
||||||
import java.util.HashMap;
|
|
||||||
import java.util.HashSet;
|
|
||||||
import java.util.List;
|
|
||||||
import java.util.Map;
|
|
||||||
import java.util.Objects;
|
|
||||||
import java.util.Set;
|
|
||||||
import java.util.stream.Collectors;
|
|
||||||
import lombok.extern.slf4j.Slf4j;
|
import lombok.extern.slf4j.Slf4j;
|
||||||
import net.sf.jsqlparser.JSQLParserException;
|
import net.sf.jsqlparser.JSQLParserException;
|
||||||
import net.sf.jsqlparser.expression.Alias;
|
import net.sf.jsqlparser.expression.Alias;
|
||||||
@@ -50,6 +42,15 @@ import net.sf.jsqlparser.statement.select.WithItem;
|
|||||||
import org.apache.commons.lang3.StringUtils;
|
import org.apache.commons.lang3.StringUtils;
|
||||||
import org.springframework.util.CollectionUtils;
|
import org.springframework.util.CollectionUtils;
|
||||||
|
|
||||||
|
import java.util.ArrayList;
|
||||||
|
import java.util.HashMap;
|
||||||
|
import java.util.HashSet;
|
||||||
|
import java.util.List;
|
||||||
|
import java.util.Map;
|
||||||
|
import java.util.Objects;
|
||||||
|
import java.util.Set;
|
||||||
|
import java.util.stream.Collectors;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Sql Parser Select Helper
|
* Sql Parser Select Helper
|
||||||
*/
|
*/
|
||||||
@@ -97,6 +98,22 @@ public class SqlSelectHelper {
|
|||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
|
public static List<String> gePureSelectFields(String sql) {
|
||||||
|
List<PlainSelect> plainSelectList = getPlainSelect(sql);
|
||||||
|
Set<String> result = new HashSet<>();
|
||||||
|
plainSelectList.stream().forEach(plainSelect -> {
|
||||||
|
List<SelectItem<?>> selectItems = plainSelect.getSelectItems();
|
||||||
|
for (SelectItem selectItem : selectItems) {
|
||||||
|
if (!(selectItem.getExpression() instanceof Column)) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
Column column = (Column) selectItem.getExpression();
|
||||||
|
result.add(column.getColumnName());
|
||||||
|
}
|
||||||
|
});
|
||||||
|
return new ArrayList<>(result);
|
||||||
|
}
|
||||||
|
|
||||||
public static List<String> getSelectFields(String sql) {
|
public static List<String> getSelectFields(String sql) {
|
||||||
List<PlainSelect> plainSelectList = getPlainSelect(sql);
|
List<PlainSelect> plainSelectList = getPlainSelect(sql);
|
||||||
if (CollectionUtils.isEmpty(plainSelectList)) {
|
if (CollectionUtils.isEmpty(plainSelectList)) {
|
||||||
@@ -244,7 +261,7 @@ public class SqlSelectHelper {
|
|||||||
return plainSelects;
|
return plainSelects;
|
||||||
}
|
}
|
||||||
|
|
||||||
public static List<String> getAllFields(String sql) {
|
public static List<String> getAllSelectFields(String sql) {
|
||||||
List<PlainSelect> plainSelects = getPlainSelects(getPlainSelect(sql));
|
List<PlainSelect> plainSelects = getPlainSelects(getPlainSelect(sql));
|
||||||
Set<String> results = new HashSet<>();
|
Set<String> results = new HashSet<>();
|
||||||
for (PlainSelect plainSelect : plainSelects) {
|
for (PlainSelect plainSelect : plainSelects) {
|
||||||
@@ -632,22 +649,6 @@ public class SqlSelectHelper {
|
|||||||
return withNameList;
|
return withNameList;
|
||||||
}
|
}
|
||||||
|
|
||||||
public static Map<String, WithItem> getWith(String sql) {
|
|
||||||
Select selectStatement = getSelect(sql);
|
|
||||||
if (selectStatement == null) {
|
|
||||||
return new HashMap<>();
|
|
||||||
}
|
|
||||||
Map<String, WithItem> withMap = new HashMap<>();
|
|
||||||
List<WithItem> withItemList = selectStatement.getWithItemsList();
|
|
||||||
if (!CollectionUtils.isEmpty(withItemList)) {
|
|
||||||
for (int i = 0; i < withItemList.size(); i++) {
|
|
||||||
WithItem withItem = withItemList.get(i);
|
|
||||||
withMap.put(withItem.getAlias().getName(), withItem);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return withMap;
|
|
||||||
}
|
|
||||||
|
|
||||||
public static Table getTable(String sql) {
|
public static Table getTable(String sql) {
|
||||||
Select selectStatement = getSelect(sql);
|
Select selectStatement = getSelect(sql);
|
||||||
if (selectStatement == null) {
|
if (selectStatement == null) {
|
||||||
@@ -776,24 +777,25 @@ public class SqlSelectHelper {
|
|||||||
|
|
||||||
private static void getFieldsWithSubQuery(PlainSelect plainSelect, Map<String, Set<String>> fields) {
|
private static void getFieldsWithSubQuery(PlainSelect plainSelect, Map<String, Set<String>> fields) {
|
||||||
if (plainSelect.getFromItem() instanceof Table) {
|
if (plainSelect.getFromItem() instanceof Table) {
|
||||||
boolean isWith = false;
|
List<String> withAlias = new ArrayList<>();
|
||||||
if (!CollectionUtils.isEmpty(plainSelect.getWithItemsList())) {
|
if (!CollectionUtils.isEmpty(plainSelect.getWithItemsList())) {
|
||||||
for (WithItem withItem : plainSelect.getWithItemsList()) {
|
for (WithItem withItem : plainSelect.getWithItemsList()) {
|
||||||
if (Objects.nonNull(withItem.getSelect())) {
|
if (Objects.nonNull(withItem.getSelect())) {
|
||||||
getFieldsWithSubQuery(withItem.getSelect().getPlainSelect(), fields);
|
getFieldsWithSubQuery(withItem.getSelect().getPlainSelect(), fields);
|
||||||
isWith = true;
|
withAlias.add(withItem.getAlias().getName());
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
if (!isWith) {
|
|
||||||
Table table = (Table) plainSelect.getFromItem();
|
Table table = (Table) plainSelect.getFromItem();
|
||||||
|
String tableName = table.getFullyQualifiedName();
|
||||||
|
if (!withAlias.contains(tableName)) {
|
||||||
if (!fields.containsKey(table.getFullyQualifiedName())) {
|
if (!fields.containsKey(table.getFullyQualifiedName())) {
|
||||||
fields.put(table.getFullyQualifiedName(), new HashSet<>());
|
fields.put(tableName, new HashSet<>());
|
||||||
}
|
}
|
||||||
List<String> sqlFields = getFieldsByPlainSelect(plainSelect).stream().map(f -> f.replaceAll("`", ""))
|
List<String> sqlFields = getFieldsByPlainSelect(plainSelect).stream().map(f -> f.replaceAll("`", ""))
|
||||||
.collect(
|
.collect(
|
||||||
Collectors.toList());
|
Collectors.toList());
|
||||||
fields.get(table.getFullyQualifiedName()).addAll(sqlFields);
|
fields.get(tableName).addAll(sqlFields);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
if (plainSelect.getFromItem() instanceof ParenthesedSelect) {
|
if (plainSelect.getFromItem() instanceof ParenthesedSelect) {
|
||||||
|
|||||||
@@ -2,7 +2,6 @@ package com.tencent.supersonic.common.jsqlparser;
|
|||||||
|
|
||||||
import java.util.List;
|
import java.util.List;
|
||||||
import lombok.extern.slf4j.Slf4j;
|
import lombok.extern.slf4j.Slf4j;
|
||||||
import net.sf.jsqlparser.JSQLParserException;
|
|
||||||
import net.sf.jsqlparser.parser.CCJSqlParserUtil;
|
import net.sf.jsqlparser.parser.CCJSqlParserUtil;
|
||||||
import org.apache.commons.collections.CollectionUtils;
|
import org.apache.commons.collections.CollectionUtils;
|
||||||
|
|
||||||
@@ -29,8 +28,8 @@ public class SqlValidHelper {
|
|||||||
}
|
}
|
||||||
|
|
||||||
//2. all fields
|
//2. all fields
|
||||||
List<String> thisAllFields = SqlSelectHelper.getAllFields(thisSql);
|
List<String> thisAllFields = SqlSelectHelper.getAllSelectFields(thisSql);
|
||||||
List<String> otherAllFields = SqlSelectHelper.getAllFields(otherSql);
|
List<String> otherAllFields = SqlSelectHelper.getAllSelectFields(otherSql);
|
||||||
|
|
||||||
if (!CollectionUtils.isEqualCollection(thisAllFields, otherAllFields)) {
|
if (!CollectionUtils.isEqualCollection(thisAllFields, otherAllFields)) {
|
||||||
return false;
|
return false;
|
||||||
@@ -69,7 +68,7 @@ public class SqlValidHelper {
|
|||||||
try {
|
try {
|
||||||
CCJSqlParserUtil.parse(sql);
|
CCJSqlParserUtil.parse(sql);
|
||||||
return true;
|
return true;
|
||||||
} catch (JSQLParserException e) {
|
} catch (Exception e) {
|
||||||
log.error("isValidSQL parse:{}", e);
|
log.error("isValidSQL parse:{}", e);
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
|||||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user