mirror of
https://github.com/tencentmusic/supersonic.git
synced 2025-12-10 11:07:06 +00:00
Compare commits
176 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
0f766f84b9 | ||
|
|
5229fdc8b5 | ||
|
|
bef652892b | ||
|
|
d2306464a6 | ||
|
|
371e2f1e05 | ||
|
|
59c50176c5 | ||
|
|
be9a8bbc27 | ||
|
|
afa82bf98d | ||
|
|
7e013ca36a | ||
|
|
ca098b576c | ||
|
|
af270580bf | ||
|
|
4632549603 | ||
|
|
bd2eaef3f6 | ||
|
|
ba55ecb31e | ||
|
|
10a5e485cb | ||
|
|
2801b27ade | ||
|
|
07e0ba24bc | ||
|
|
115cf19078 | ||
|
|
898c7100ba | ||
|
|
7150f19def | ||
|
|
6aff51d394 | ||
|
|
8d29e89317 | ||
|
|
93fedd787f | ||
|
|
ebea58098c | ||
|
|
95be7f3ce1 | ||
|
|
d32d791238 | ||
|
|
4d65b8b93a | ||
|
|
4d02bb7068 | ||
|
|
c82c2d0a95 | ||
|
|
0c70df12ca | ||
|
|
1ff4a71a41 | ||
|
|
8b01dac8d4 | ||
|
|
2f2f493d17 | ||
|
|
b13b38c645 | ||
|
|
68952fdb55 | ||
|
|
ba9e6afa51 | ||
|
|
da0ac7b26c | ||
|
|
db698ecb75 | ||
|
|
cc3fa0078a | ||
|
|
e586d887ed | ||
|
|
ecc651e12d | ||
|
|
24c63c93bb | ||
|
|
e88e381302 | ||
|
|
f06cd0b296 | ||
|
|
1608317ab3 | ||
|
|
cdb67650c5 | ||
|
|
3ca51145e5 | ||
|
|
794a448619 | ||
|
|
9dbc8657e2 | ||
|
|
b8aeff9a6a | ||
|
|
82f86a8635 | ||
|
|
3d1ca6ac1d | ||
|
|
208686de46 | ||
|
|
c8fe6d2d04 | ||
|
|
45fb83356f | ||
|
|
5ad8ac69ab | ||
|
|
3621766a0d | ||
|
|
89b028b594 | ||
|
|
0a4272c25e | ||
|
|
e2e45a40ab | ||
|
|
97bf8049d7 | ||
|
|
a9232fa1c7 | ||
|
|
ac6b28ebb7 | ||
|
|
53a9f7c451 | ||
|
|
e26263d229 | ||
|
|
cabcbf16ff | ||
|
|
5a18ad5229 | ||
|
|
ac96f72b07 | ||
|
|
ddd44f343a | ||
|
|
55abc883ab | ||
|
|
27a70de1be | ||
|
|
4a5bb9e457 | ||
|
|
12a504585f | ||
|
|
b5fa54a754 | ||
|
|
52a584b9a4 | ||
|
|
b45f3d0663 | ||
|
|
e571ca1f55 | ||
|
|
9a1fac5d4c | ||
|
|
23af977972 | ||
|
|
2472ce2461 | ||
|
|
9e4513f7ca | ||
|
|
9a14728152 | ||
|
|
26f682cc45 | ||
|
|
ccd79e4830 | ||
|
|
e5504473a4 | ||
|
|
ebbb519c07 | ||
|
|
8f620480c6 | ||
|
|
85088abd7b | ||
|
|
f38a84bc8c | ||
|
|
8307c813d2 | ||
|
|
cd8f38c334 | ||
|
|
ae34c15c95 | ||
|
|
c8df102402 | ||
|
|
335902bd1f | ||
|
|
0f5b49f7c5 | ||
|
|
f0bdb14818 | ||
|
|
c39460ee02 | ||
|
|
865788b71b | ||
|
|
8f55fa0c1e | ||
|
|
f03e5a0d38 | ||
|
|
73d9cbdbc1 | ||
|
|
d64ed02df9 | ||
|
|
3797cc2ce8 | ||
|
|
7d64aa893c | ||
|
|
b5768b27aa | ||
|
|
b319f4682a | ||
|
|
ef44954e1e | ||
|
|
17f965eabd | ||
|
|
2425067091 | ||
|
|
2eac301076 | ||
|
|
f30c74c18f | ||
|
|
2cec2e61bc | ||
|
|
1b37925bf3 | ||
|
|
ff38a6d250 | ||
|
|
782b768950 | ||
|
|
1c0b8f8161 | ||
|
|
35892f2344 | ||
|
|
13ae312e51 | ||
|
|
4d1360b924 | ||
|
|
529251097b | ||
|
|
d5c78d87e7 | ||
|
|
4eb6193699 | ||
|
|
407c8d4702 | ||
|
|
baff30550e | ||
|
|
e365a36749 | ||
|
|
5bf4a4160d | ||
|
|
37da1ac2ae | ||
|
|
41ad1ada6c | ||
|
|
f9d6ea11c5 | ||
|
|
d6c5702b5a | ||
|
|
e0647dd990 | ||
|
|
ee258d4c8f | ||
|
|
d7c1b2cbaa | ||
|
|
e041fcb37e | ||
|
|
9bb95ca4be | ||
|
|
76fb3cb4a2 | ||
|
|
78a91ad8c2 | ||
|
|
03f5678732 | ||
|
|
62e70f5cb7 | ||
|
|
8ef5ce8b76 | ||
|
|
d3f3fc5de3 | ||
|
|
c3b3b7e769 | ||
|
|
ea4aa3eacf | ||
|
|
f0b4eb46cf | ||
|
|
7a376bd9a3 | ||
|
|
c9c049a20f | ||
|
|
efd617b2e5 | ||
|
|
9911e6772c | ||
|
|
08ae27ab43 | ||
|
|
64786cb0ef | ||
|
|
4d7bfe07aa | ||
|
|
3f460429e6 | ||
|
|
d849acf971 | ||
|
|
e0e77a3b64 | ||
|
|
609146d12c | ||
|
|
6db6aaf98d | ||
|
|
d39db734c4 | ||
|
|
a1ab7ac1c1 | ||
|
|
16c3ff0c30 | ||
|
|
7c86e2b3db | ||
|
|
2ddf0ad41a | ||
|
|
66e5ee06e6 | ||
|
|
72465cd88c | ||
|
|
097f2f4fe7 | ||
|
|
71954e42a8 | ||
|
|
14b9086d83 | ||
|
|
93ea7a618c | ||
|
|
bb4cd880b0 | ||
|
|
fa5abc58a5 | ||
|
|
d200086779 | ||
|
|
5d5ca438a6 | ||
|
|
c5478ad8a2 | ||
|
|
20aa0bc0a9 | ||
|
|
9cd352f146 | ||
|
|
78e023e955 | ||
|
|
ccf41fa6db |
30
.github/workflows/centos-ci.yml
vendored
30
.github/workflows/centos-ci.yml
vendored
@@ -1,5 +1,4 @@
|
||||
name: supersonic RHEL/CentOS CI
|
||||
|
||||
name: supersonic CentOS CI
|
||||
on:
|
||||
push:
|
||||
branches:
|
||||
@@ -14,31 +13,52 @@ jobs:
|
||||
container:
|
||||
image: quay.io/centos/centos:stream8 # 使用 CentOS Stream 8 容器
|
||||
|
||||
strategy:
|
||||
matrix:
|
||||
java-version: [8, 11, 21] # 定义要测试的JDK版本
|
||||
|
||||
steps:
|
||||
- uses: actions/checkout@v2
|
||||
|
||||
- name: Reset DNF repositories
|
||||
run: |
|
||||
cd /etc/yum.repos.d/
|
||||
sed -i 's/mirrorlist/#mirrorlist/g' /etc/yum.repos.d/CentOS-*
|
||||
sed -i 's|#baseurl=http://mirror.centos.org|baseurl=http://vault.centos.org|g' /etc/yum.repos.d/CentOS-*
|
||||
|
||||
- name: Update DNF package index
|
||||
run: dnf makecache
|
||||
|
||||
- name: Install Java and Maven with retry
|
||||
run: |
|
||||
for i in {1..5}; do
|
||||
dnf install -y java-1.8.0-openjdk-devel maven && break || sleep 15
|
||||
done
|
||||
if [ ${{ matrix.java-version }} -eq 8 ]; then
|
||||
for i in {1..5}; do
|
||||
dnf install -y java-1.8.0-openjdk-devel maven && break || sleep 15
|
||||
done
|
||||
elif [ ${{ matrix.java-version }} -eq 11 ]; then
|
||||
for i in {1..5}; do
|
||||
dnf install -y java-11-openjdk-devel maven && break || sleep 15
|
||||
done
|
||||
elif [ ${{ matrix.java-version }} -eq 21 ]; then
|
||||
for i in {1..5}; do
|
||||
dnf install -y java-21-openjdk-devel maven && break || sleep 15
|
||||
done
|
||||
fi
|
||||
|
||||
- name: Verify Java and Maven installation
|
||||
run: |
|
||||
java -version
|
||||
mvn -version
|
||||
|
||||
- name: Cache Maven packages
|
||||
uses: actions/cache@v2
|
||||
with:
|
||||
path: ~/.m2
|
||||
key: ${{ runner.os }}-m2-${{ hashFiles('**/pom.xml') }}
|
||||
restore-keys: ${{ runner.os }}-m2
|
||||
|
||||
- name: Build with Maven
|
||||
run: mvn -B package --file pom.xml
|
||||
|
||||
- name: Test with Maven
|
||||
run: mvn test
|
||||
8
.github/workflows/mac-ci.yml
vendored
8
.github/workflows/mac-ci.yml
vendored
@@ -12,13 +12,17 @@ jobs:
|
||||
build:
|
||||
runs-on: macos-latest # Specify a macOS runner
|
||||
|
||||
strategy:
|
||||
matrix:
|
||||
java-version: [8, 11, 21] # Define the JDK versions to test
|
||||
|
||||
steps:
|
||||
- uses: actions/checkout@v2
|
||||
|
||||
- name: Set up JDK 8
|
||||
- name: Set up JDK ${{ matrix.java-version }}
|
||||
uses: actions/setup-java@v2
|
||||
with:
|
||||
java-version: '8'
|
||||
java-version: ${{ matrix.java-version }}
|
||||
distribution: 'adopt'
|
||||
|
||||
- name: Cache Maven packages
|
||||
|
||||
42
.github/workflows/ubuntu-ci.yml
vendored
42
.github/workflows/ubuntu-ci.yml
vendored
@@ -7,25 +7,33 @@ on:
|
||||
pull_request:
|
||||
branches:
|
||||
- master
|
||||
|
||||
jobs:
|
||||
build:
|
||||
|
||||
runs-on: ubuntu-latest
|
||||
|
||||
strategy:
|
||||
matrix:
|
||||
java-version: [8, 11, 21] # 定义要测试的JDK版本
|
||||
|
||||
steps:
|
||||
- uses: actions/checkout@v2
|
||||
- name: Set up JDK 8
|
||||
uses: actions/setup-java@v2
|
||||
with:
|
||||
java-version: '8'
|
||||
distribution: 'adopt'
|
||||
- name: Cache Maven packages
|
||||
uses: actions/cache@v2
|
||||
with:
|
||||
path: ~/.m2
|
||||
key: ${{ runner.os }}-m2-${{ hashFiles('**/pom.xml') }}
|
||||
restore-keys: ${{ runner.os }}-m2
|
||||
- name: Build with Maven
|
||||
run: mvn -B package --file pom.xml
|
||||
- name: Test with Maven
|
||||
run: mvn test
|
||||
- uses: actions/checkout@v2
|
||||
|
||||
- name: Set up JDK ${{ matrix.java-version }}
|
||||
uses: actions/setup-java@v2
|
||||
with:
|
||||
java-version: ${{ matrix.java-version }}
|
||||
distribution: 'adopt'
|
||||
|
||||
- name: Cache Maven packages
|
||||
uses: actions/cache@v2
|
||||
with:
|
||||
path: ~/.m2
|
||||
key: ${{ runner.os }}-m2-${{ hashFiles('**/pom.xml') }}
|
||||
restore-keys: ${{ runner.os }}-m2
|
||||
|
||||
- name: Build with Maven
|
||||
run: mvn -B package --file pom.xml
|
||||
|
||||
- name: Test with Maven
|
||||
run: mvn test
|
||||
10
.github/workflows/windows-ci.yml
vendored
10
.github/workflows/windows-ci.yml
vendored
@@ -12,14 +12,18 @@ jobs:
|
||||
build:
|
||||
runs-on: windows-latest # Specify a Windows runner
|
||||
|
||||
strategy:
|
||||
matrix:
|
||||
java-version: [8, 11, 21] # Add JDK 21 to the matrix
|
||||
|
||||
steps:
|
||||
- uses: actions/checkout@v2
|
||||
|
||||
- name: Set up JDK 8
|
||||
- name: Set up JDK ${{ matrix.java-version }}
|
||||
uses: actions/setup-java@v2
|
||||
with:
|
||||
java-version: '8'
|
||||
distribution: 'adopt'
|
||||
java-version: ${{ matrix.java-version }}
|
||||
distribution: 'adopt' # You might need to change this if 'adopt' doesn't support JDK 21
|
||||
|
||||
- name: Cache Maven packages
|
||||
uses: actions/cache@v2
|
||||
|
||||
12
README.md
12
README.md
@@ -60,6 +60,12 @@ The high-level architecture and main process flow is as follows:
|
||||
### Online playground
|
||||
Visit http://117.72.46.148:9080 to register and experience as a new user. Please do not modify system configurations. We will restart to reset configurations regularly every weekend.
|
||||
|
||||
### Docker Deployment
|
||||
- Install Docker and docker-compose.
|
||||
- Download the docker-compose.yml file; Execute: wget https://raw.githubusercontent.com/tencentmusic/supersonic/master/docker/docker-compose.yml.
|
||||
- Execute "docker-compose up -d".
|
||||
- Open a browser and visit http://localhost:9080 to start exploring.
|
||||
|
||||
### Local build
|
||||
SuperSonic comes with sample semantic models as well as chat conversations that can be used as a starting point. Please follow the steps:
|
||||
|
||||
@@ -75,8 +81,4 @@ Please refer to project [Docs](https://supersonicbi.github.io/docs/%E7%B3%BB%E7%
|
||||
|
||||
Please follow SuperSonic wechat official account:
|
||||
|
||||
<img src="https://github.com/supersonicbi/supersonic-website/blob/main/static/img/supersonic_wechat_oa.png" height="50%" width="50%" />
|
||||
|
||||
Welcome to join the WeChat community:
|
||||
|
||||
<img src="https://github.com/supersonicbi/supersonic-website/blob/main/static/img/supersonic_wechat.png" height="50%" width="50%" />
|
||||
<img src="https://github.com/supersonicbi/supersonic-website/blob/main/static/img/supersonic_wechat_oa.png" height="50%" width="50%" />
|
||||
|
||||
12
README_CN.md
12
README_CN.md
@@ -59,6 +59,12 @@ SuperSonic的整体架构和主流程如下图所示:
|
||||
### 线上环境体验
|
||||
访问http://117.72.46.148:9080 注册新用户体验. 请勿修改系统配置。我们每周末定期重启重置配置。
|
||||
|
||||
### Docker部署
|
||||
- 安装好Docker以及docker-compose
|
||||
- 下载docker-compose.yml;执行命令:wget https://raw.githubusercontent.com/tencentmusic/supersonic/master/docker/docker-compose.yml
|
||||
- 执行:"docker-compose up -d"
|
||||
- 在浏览器访问http://localhost:9080 开启探索
|
||||
|
||||
### 本地构建
|
||||
|
||||
SuperSonic自带样例的语义模型和问答对话,只需以下三步即可快速体验:
|
||||
@@ -75,8 +81,4 @@ SuperSonic自带样例的语义模型和问答对话,只需以下三步即可
|
||||
|
||||
欢迎关注微信公众号:
|
||||
|
||||
<img src="https://github.com/supersonicbi/supersonic-website/blob/main/static/img/supersonic_wechat_oa.png" height="50%" width="50%" />
|
||||
|
||||
欢迎加入微信社群:
|
||||
|
||||
<img src="https://github.com/supersonicbi/supersonic-website/blob/main/static/img/supersonic_wechat.png" height="50%" width="50%" />
|
||||
<img src="https://github.com/supersonicbi/supersonic-website/blob/main/static/img/supersonic_wechat_oa.png" height="50%" width="50%" />
|
||||
|
||||
10
README_JP.md
10
README_JP.md
@@ -56,6 +56,12 @@ ChatGPTのような大規模言語モデル(LLM)の出現は、情報検索
|
||||
### オンラインプレイグラウンド
|
||||
http://117.72.46.148:9080 にアクセスして、新規ユーザーとして登録して体験してください。システム設定を変更しないでください。毎週末に定期的に再起動して設定をリセットします。
|
||||
|
||||
### Dockerのデプロイメント
|
||||
- Dockerおよびdocker-composeをインストールします。
|
||||
- docker-compose.ymlファイルをダウンロードします。コマンドを実行します:wget https://raw.githubusercontent.com/tencentmusic/supersonic/master/docker/docker-compose.yml。
|
||||
- docker-compose up -dを実行します。
|
||||
- ブラウザを開いてhttp://localhost:9080にアクセスし、探索を開始します。
|
||||
|
||||
### ローカルビルド
|
||||
SuperSonicには、サンプルのセマンティックモデルとチャット会話が付属しており、以下の手順で簡単に体験できます:
|
||||
|
||||
@@ -72,7 +78,3 @@ SuperSonicには、サンプルのセマンティックモデルとチャット
|
||||
SuperSonicの公式WeChatアカウントをフォローしてください:
|
||||
|
||||

|
||||
|
||||
WeChatコミュニティに参加することを歓迎します:
|
||||
|
||||

|
||||
|
||||
@@ -54,7 +54,7 @@ if "%command%"=="restart" (
|
||||
set "webDir=%baseDir%\webapp"
|
||||
set "logDir=%baseDir%\logs"
|
||||
set "classpath=%baseDir%;%webDir%;%libDir%\*;%confDir%"
|
||||
set "java-command=-Dfile.encoding=UTF-8 -Duser.language=Zh -Duser.region=CN -Duser.timezone=GMT+08 -Dspring.profiles.active=%profile% -Xms1024m -Xmx2048m -cp %CLASSPATH% %MAIN_CLASS%"
|
||||
set "java-command=-Dfile.encoding=UTF-8 -Duser.language=Zh -Duser.region=CN -Duser.timezone=GMT+08 -Dspring.profiles.active=%profile% -Xms1024m -Xmx1024m -cp %CLASSPATH% %MAIN_CLASS%"
|
||||
if not exist %logDir% mkdir %logDir%
|
||||
start /B java %java-command% >nul 2>&1
|
||||
timeout /t 10 >nul
|
||||
|
||||
@@ -59,7 +59,7 @@ function runJavaService {
|
||||
JAVA_HOME=$(ls /usr/jdk64/jdk* -d 2>/dev/null | xargs | awk '{print "'$local_app_name'"}')
|
||||
fi
|
||||
export PATH=$JAVA_HOME/bin:$PATH
|
||||
command="-Dfile.encoding=UTF-8 -Duser.language=Zh -Duser.region=CN -Duser.timezone=GMT+08 -Dapp_name=${local_app_name} -Xms1024m -Xmx2048m $main_class"
|
||||
command="-Dfile.encoding=UTF-8 -Duser.language=Zh -Duser.region=CN -Duser.timezone=GMT+08 -Dapp_name=${local_app_name} -Xms1024m -Xmx1024m $main_class"
|
||||
|
||||
mkdir -p $javaRunDir/logs
|
||||
java -Dspring.profiles.active="$profile" $command >/dev/null 2>$javaRunDir/logs/error.log &
|
||||
|
||||
@@ -1,11 +0,0 @@
|
||||
package com.tencent.supersonic.auth.api.authorization.pojo;
|
||||
|
||||
import java.util.ArrayList;
|
||||
import java.util.List;
|
||||
import lombok.Data;
|
||||
|
||||
@Data
|
||||
public class AuthResGrp {
|
||||
|
||||
private List<AuthRes> group = new ArrayList<>();
|
||||
}
|
||||
@@ -1,7 +1,6 @@
|
||||
package com.tencent.supersonic.auth.api.authorization.request;
|
||||
|
||||
import com.google.common.collect.Lists;
|
||||
import com.tencent.supersonic.auth.api.authorization.pojo.AuthRes;
|
||||
import lombok.Data;
|
||||
import lombok.ToString;
|
||||
import org.springframework.util.CollectionUtils;
|
||||
@@ -15,8 +14,6 @@ public class QueryAuthResReq {
|
||||
|
||||
private List<String> departmentIds = new ArrayList<>();
|
||||
|
||||
private List<AuthRes> resources;
|
||||
|
||||
private Long modelId;
|
||||
|
||||
private List<Long> modelIds;
|
||||
|
||||
@@ -1,16 +1,16 @@
|
||||
package com.tencent.supersonic.auth.api.authorization.response;
|
||||
|
||||
import com.tencent.supersonic.auth.api.authorization.pojo.AuthResGrp;
|
||||
import com.tencent.supersonic.auth.api.authorization.pojo.AuthRes;
|
||||
import com.tencent.supersonic.auth.api.authorization.pojo.DimensionFilter;
|
||||
import lombok.Data;
|
||||
|
||||
import java.util.ArrayList;
|
||||
import java.util.List;
|
||||
|
||||
import lombok.Data;
|
||||
|
||||
@Data
|
||||
public class AuthorizedResourceResp {
|
||||
|
||||
private List<AuthResGrp> resources = new ArrayList<>();
|
||||
private List<AuthRes> authResList = new ArrayList<>();
|
||||
|
||||
private List<DimensionFilter> filters = new ArrayList<>();
|
||||
}
|
||||
|
||||
@@ -1,19 +1,16 @@
|
||||
package com.tencent.supersonic.auth.authorization.service;
|
||||
|
||||
import com.google.common.collect.Lists;
|
||||
import com.google.gson.Gson;
|
||||
import com.tencent.supersonic.auth.api.authentication.pojo.User;
|
||||
import com.tencent.supersonic.auth.api.authentication.service.UserService;
|
||||
import com.tencent.supersonic.auth.api.authorization.pojo.AuthGroup;
|
||||
import com.tencent.supersonic.auth.api.authorization.pojo.AuthRes;
|
||||
import com.tencent.supersonic.auth.api.authorization.pojo.AuthResGrp;
|
||||
import com.tencent.supersonic.auth.api.authorization.pojo.AuthRule;
|
||||
import com.tencent.supersonic.auth.api.authorization.pojo.DimensionFilter;
|
||||
import com.tencent.supersonic.auth.api.authorization.request.QueryAuthResReq;
|
||||
import com.tencent.supersonic.auth.api.authorization.response.AuthorizedResourceResp;
|
||||
import com.tencent.supersonic.auth.api.authorization.service.AuthService;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.apache.commons.lang3.StringUtils;
|
||||
import org.springframework.jdbc.core.JdbcTemplate;
|
||||
import org.springframework.stereotype.Service;
|
||||
import org.springframework.util.CollectionUtils;
|
||||
@@ -79,53 +76,35 @@ public class AuthServiceImpl implements AuthService {
|
||||
|
||||
@Override
|
||||
public AuthorizedResourceResp queryAuthorizedResources(QueryAuthResReq req, User user) {
|
||||
if (CollectionUtils.isEmpty(req.getModelIds())) {
|
||||
return new AuthorizedResourceResp();
|
||||
}
|
||||
Set<String> userOrgIds = userService.getUserAllOrgId(user.getName());
|
||||
List<AuthGroup> groups = getAuthGroups(req.getModelIds(), user.getName(), new ArrayList<>(userOrgIds));
|
||||
AuthorizedResourceResp resource = new AuthorizedResourceResp();
|
||||
Map<Long, List<AuthGroup>> authGroupsByModelId = groups.stream()
|
||||
.collect(Collectors.groupingBy(AuthGroup::getModelId));
|
||||
Map<Long, List<AuthRes>> reqAuthRes = req.getResources().stream()
|
||||
.collect(Collectors.groupingBy(AuthRes::getModelId));
|
||||
|
||||
for (Long modelId : reqAuthRes.keySet()) {
|
||||
List<AuthRes> reqResourcesList = reqAuthRes.get(modelId);
|
||||
AuthResGrp rg = new AuthResGrp();
|
||||
for (Long modelId : req.getModelIds()) {
|
||||
if (authGroupsByModelId.containsKey(modelId)) {
|
||||
List<AuthGroup> authGroups = authGroupsByModelId.get(modelId);
|
||||
for (AuthRes reqRes : reqResourcesList) {
|
||||
for (AuthGroup authRuleGroup : authGroups) {
|
||||
List<AuthRule> authRules = authRuleGroup.getAuthRules();
|
||||
List<String> allAuthItems = new ArrayList<>();
|
||||
authRules.forEach(authRule -> allAuthItems.addAll(authRule.resourceNames()));
|
||||
|
||||
if (allAuthItems.contains(reqRes.getName())) {
|
||||
rg.getGroup().add(reqRes);
|
||||
for (AuthGroup authRuleGroup : authGroups) {
|
||||
List<AuthRule> authRules = authRuleGroup.getAuthRules();
|
||||
for (AuthRule authRule : authRules) {
|
||||
for (String resBizName : authRule.resourceNames()) {
|
||||
resource.getAuthResList().add(new AuthRes(modelId, resBizName));
|
||||
}
|
||||
|
||||
}
|
||||
}
|
||||
}
|
||||
if (!CollectionUtils.isEmpty(rg.getGroup())) {
|
||||
resource.getResources().add(rg);
|
||||
}
|
||||
}
|
||||
|
||||
if (!CollectionUtils.isEmpty(req.getModelIds())) {
|
||||
List<AuthGroup> authGroups = Lists.newArrayList();
|
||||
for (Long modelId : authGroupsByModelId.keySet()) {
|
||||
authGroups.addAll(authGroupsByModelId.getOrDefault(modelId, Lists.newArrayList()));
|
||||
}
|
||||
if (!CollectionUtils.isEmpty(authGroups)) {
|
||||
for (AuthGroup group : authGroups) {
|
||||
if (group.getDimensionFilters() != null
|
||||
&& group.getDimensionFilters().stream().anyMatch(expr ->
|
||||
!StringUtils.isEmpty(expr))) {
|
||||
DimensionFilter df = new DimensionFilter();
|
||||
df.setDescription(group.getDimensionFilterDescription());
|
||||
df.setExpressions(group.getDimensionFilters());
|
||||
resource.getFilters().add(df);
|
||||
}
|
||||
}
|
||||
Set<Map.Entry<Long, List<AuthGroup>>> entries = authGroupsByModelId.entrySet();
|
||||
for (Map.Entry<Long, List<AuthGroup>> entry : entries) {
|
||||
List<AuthGroup> authGroups = entry.getValue();
|
||||
for (AuthGroup authGroup : authGroups) {
|
||||
DimensionFilter df = new DimensionFilter();
|
||||
df.setDescription(authGroup.getDimensionFilterDescription());
|
||||
df.setExpressions(authGroup.getDimensionFilters());
|
||||
resource.getFilters().add(df);
|
||||
}
|
||||
}
|
||||
return resource;
|
||||
@@ -134,11 +113,11 @@ public class AuthServiceImpl implements AuthService {
|
||||
private List<AuthGroup> getAuthGroups(List<Long> modelIds, String userName, List<String> departmentIds) {
|
||||
List<AuthGroup> groups = load().stream()
|
||||
.filter(group -> {
|
||||
if (CollectionUtils.isEmpty(modelIds) || !modelIds.contains(group.getModelId())) {
|
||||
if (!modelIds.contains(group.getModelId())) {
|
||||
return false;
|
||||
}
|
||||
if (!CollectionUtils.isEmpty(group.getAuthorizedUsers()) && group.getAuthorizedUsers()
|
||||
.contains(userName)) {
|
||||
if (!CollectionUtils.isEmpty(group.getAuthorizedUsers())
|
||||
&& group.getAuthorizedUsers().contains(userName)) {
|
||||
return true;
|
||||
}
|
||||
for (String departmentId : departmentIds) {
|
||||
|
||||
@@ -2,7 +2,6 @@ package com.tencent.supersonic.chat.api.pojo.response;
|
||||
|
||||
import com.tencent.supersonic.headless.api.pojo.SemanticParseInfo;
|
||||
import com.tencent.supersonic.headless.api.pojo.response.ParseTimeCostResp;
|
||||
import com.tencent.supersonic.headless.api.pojo.response.QueryResult;
|
||||
import lombok.Data;
|
||||
import java.util.Date;
|
||||
import java.util.List;
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
package com.tencent.supersonic.headless.api.pojo.response;
|
||||
package com.tencent.supersonic.chat.api.pojo.response;
|
||||
|
||||
import com.tencent.supersonic.common.pojo.QueryAuthorization;
|
||||
import com.tencent.supersonic.common.pojo.QueryColumn;
|
||||
@@ -6,6 +6,7 @@ import com.tencent.supersonic.headless.api.pojo.AggregateInfo;
|
||||
import com.tencent.supersonic.headless.api.pojo.EntityInfo;
|
||||
import com.tencent.supersonic.headless.api.pojo.SchemaElement;
|
||||
import com.tencent.supersonic.headless.api.pojo.SemanticParseInfo;
|
||||
import com.tencent.supersonic.headless.api.pojo.response.QueryState;
|
||||
import lombok.Data;
|
||||
|
||||
import java.util.List;
|
||||
@@ -4,8 +4,9 @@ package com.tencent.supersonic.chat.server.agent;
|
||||
import com.alibaba.fastjson.JSONObject;
|
||||
import com.google.common.collect.Lists;
|
||||
import com.google.common.collect.Sets;
|
||||
import com.tencent.supersonic.common.config.LLMConfig;
|
||||
import com.tencent.supersonic.common.config.PromptConfig;
|
||||
import com.tencent.supersonic.common.config.VisualConfig;
|
||||
import com.tencent.supersonic.common.pojo.ChatModelConfig;
|
||||
import com.tencent.supersonic.common.pojo.RecordInfo;
|
||||
import lombok.Data;
|
||||
import org.springframework.util.CollectionUtils;
|
||||
@@ -23,6 +24,7 @@ public class Agent extends RecordInfo {
|
||||
|
||||
private Integer id;
|
||||
private Integer enableSearch;
|
||||
private Integer enableMemoryReview;
|
||||
private String name;
|
||||
private String description;
|
||||
|
||||
@@ -32,7 +34,8 @@ public class Agent extends RecordInfo {
|
||||
private Integer status;
|
||||
private List<String> examples;
|
||||
private String agentConfig;
|
||||
private LLMConfig llmConfig;
|
||||
private ChatModelConfig modelConfig;
|
||||
private PromptConfig promptConfig;
|
||||
private MultiTurnConfig multiTurnConfig;
|
||||
private VisualConfig visualConfig;
|
||||
|
||||
@@ -58,6 +61,10 @@ public class Agent extends RecordInfo {
|
||||
return enableSearch != null && enableSearch == 1;
|
||||
}
|
||||
|
||||
public boolean enableMemoryReview() {
|
||||
return enableMemoryReview != null && enableMemoryReview == 1;
|
||||
}
|
||||
|
||||
public static boolean containsAllModel(Set<Long> detectViewIds) {
|
||||
return !CollectionUtils.isEmpty(detectViewIds) && detectViewIds.contains(-1L);
|
||||
}
|
||||
|
||||
@@ -1,10 +0,0 @@
|
||||
package com.tencent.supersonic.chat.server.executor;
|
||||
|
||||
import com.tencent.supersonic.chat.server.pojo.ChatExecuteContext;
|
||||
import com.tencent.supersonic.headless.api.pojo.response.QueryResult;
|
||||
|
||||
public interface ChatExecutor {
|
||||
|
||||
QueryResult execute(ChatExecuteContext chatExecuteContext);
|
||||
|
||||
}
|
||||
@@ -0,0 +1,10 @@
|
||||
package com.tencent.supersonic.chat.server.executor;
|
||||
|
||||
import com.tencent.supersonic.chat.server.pojo.ExecuteContext;
|
||||
import com.tencent.supersonic.chat.api.pojo.response.QueryResult;
|
||||
|
||||
public interface ChatQueryExecutor {
|
||||
|
||||
QueryResult execute(ExecuteContext executeContext);
|
||||
|
||||
}
|
||||
@@ -1,29 +1,30 @@
|
||||
package com.tencent.supersonic.chat.server.executor;
|
||||
|
||||
import com.tencent.supersonic.chat.api.pojo.response.QueryResp;
|
||||
import com.tencent.supersonic.chat.server.agent.Agent;
|
||||
import com.tencent.supersonic.chat.server.agent.MultiTurnConfig;
|
||||
import com.tencent.supersonic.chat.server.parser.ParserConfig;
|
||||
import com.tencent.supersonic.chat.server.persistence.repository.ChatQueryRepository;
|
||||
import com.tencent.supersonic.chat.server.pojo.ChatExecuteContext;
|
||||
import com.tencent.supersonic.chat.server.pojo.ExecuteContext;
|
||||
import com.tencent.supersonic.chat.server.service.AgentService;
|
||||
import com.tencent.supersonic.chat.server.service.ChatManageService;
|
||||
import com.tencent.supersonic.common.util.ContextUtils;
|
||||
import com.tencent.supersonic.headless.api.pojo.response.ParseResp;
|
||||
import com.tencent.supersonic.headless.api.pojo.response.QueryResult;
|
||||
import com.tencent.supersonic.chat.api.pojo.response.QueryResult;
|
||||
import com.tencent.supersonic.headless.api.pojo.response.QueryState;
|
||||
import dev.langchain4j.data.message.AiMessage;
|
||||
import dev.langchain4j.model.chat.ChatLanguageModel;
|
||||
import dev.langchain4j.model.input.Prompt;
|
||||
import dev.langchain4j.model.input.PromptTemplate;
|
||||
import dev.langchain4j.model.output.Response;
|
||||
import dev.langchain4j.model.provider.ChatLanguageModelProvider;
|
||||
import dev.langchain4j.provider.ModelProvider;
|
||||
|
||||
import java.util.Collections;
|
||||
import java.util.List;
|
||||
import java.util.Objects;
|
||||
import java.util.stream.Collectors;
|
||||
|
||||
import static com.tencent.supersonic.chat.server.parser.ParserConfig.PARSER_MULTI_TURN_ENABLE;
|
||||
|
||||
public class PlainTextExecutor implements ChatExecutor {
|
||||
public class PlainTextExecutor implements ChatQueryExecutor {
|
||||
|
||||
private static final String INSTRUCTION = ""
|
||||
+ "#Role: You are a nice person to talk to.\n"
|
||||
@@ -34,34 +35,34 @@ public class PlainTextExecutor implements ChatExecutor {
|
||||
+ "#Your response: ";
|
||||
|
||||
@Override
|
||||
public QueryResult execute(ChatExecuteContext chatExecuteContext) {
|
||||
if (!"PLAIN_TEXT".equals(chatExecuteContext.getParseInfo().getQueryMode())) {
|
||||
public QueryResult execute(ExecuteContext executeContext) {
|
||||
if (!"PLAIN_TEXT".equals(executeContext.getParseInfo().getQueryMode())) {
|
||||
return null;
|
||||
}
|
||||
|
||||
String promptStr = String.format(INSTRUCTION, getHistoryInputs(chatExecuteContext),
|
||||
chatExecuteContext.getQueryText());
|
||||
String promptStr = String.format(INSTRUCTION, getHistoryInputs(executeContext),
|
||||
executeContext.getQueryText());
|
||||
Prompt prompt = PromptTemplate.from(promptStr).apply(Collections.EMPTY_MAP);
|
||||
|
||||
AgentService agentService = ContextUtils.getBean(AgentService.class);
|
||||
Agent chatAgent = agentService.getAgent(chatExecuteContext.getAgentId());
|
||||
Agent chatAgent = agentService.getAgent(executeContext.getAgent().getId());
|
||||
|
||||
ChatLanguageModel chatLanguageModel = ChatLanguageModelProvider.provide(chatAgent.getLlmConfig());
|
||||
ChatLanguageModel chatLanguageModel = ModelProvider.getChatModel(chatAgent.getModelConfig());
|
||||
Response<AiMessage> response = chatLanguageModel.generate(prompt.toUserMessage());
|
||||
|
||||
QueryResult result = new QueryResult();
|
||||
result.setQueryState(QueryState.SUCCESS);
|
||||
result.setQueryMode(chatExecuteContext.getParseInfo().getQueryMode());
|
||||
result.setQueryMode(executeContext.getParseInfo().getQueryMode());
|
||||
result.setTextResult(response.content().text());
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
private String getHistoryInputs(ChatExecuteContext chatExecuteContext) {
|
||||
private String getHistoryInputs(ExecuteContext executeContext) {
|
||||
StringBuilder historyInput = new StringBuilder();
|
||||
|
||||
AgentService agentService = ContextUtils.getBean(AgentService.class);
|
||||
Agent chatAgent = agentService.getAgent(chatExecuteContext.getAgentId());
|
||||
Agent chatAgent = agentService.getAgent(executeContext.getAgent().getId());
|
||||
|
||||
ParserConfig parserConfig = ContextUtils.getBean(ParserConfig.class);
|
||||
MultiTurnConfig agentMultiTurnConfig = chatAgent.getMultiTurnConfig();
|
||||
@@ -70,8 +71,8 @@ public class PlainTextExecutor implements ChatExecutor {
|
||||
? agentMultiTurnConfig.isEnableMultiTurn() : globalMultiTurnConfig;
|
||||
|
||||
if (Boolean.TRUE.equals(multiTurnConfig)) {
|
||||
List<ParseResp> parseResps = getHistoryParseResult(chatExecuteContext.getChatId(), 5);
|
||||
parseResps.stream().forEach(p -> {
|
||||
List<QueryResp> queryResps = getHistoryQueries(executeContext.getChatId(), 5);
|
||||
queryResps.stream().forEach(p -> {
|
||||
historyInput.append(p.getQueryText());
|
||||
historyInput.append(";");
|
||||
});
|
||||
@@ -80,12 +81,15 @@ public class PlainTextExecutor implements ChatExecutor {
|
||||
return historyInput.toString();
|
||||
}
|
||||
|
||||
private List<ParseResp> getHistoryParseResult(int chatId, int multiNum) {
|
||||
ChatQueryRepository chatQueryRepository = ContextUtils.getBean(ChatQueryRepository.class);
|
||||
List<ParseResp> contextualParseInfoList = chatQueryRepository.getContextualParseInfo(chatId)
|
||||
.stream().filter(p -> p.getState() != ParseResp.ParseState.FAILED).collect(Collectors.toList());
|
||||
private List<QueryResp> getHistoryQueries(int chatId, int multiNum) {
|
||||
ChatManageService chatManageService = ContextUtils.getBean(ChatManageService.class);
|
||||
List<QueryResp> contextualParseInfoList = chatManageService.getChatQueries(chatId)
|
||||
.stream()
|
||||
.filter(q -> Objects.nonNull(q.getQueryResult())
|
||||
&& q.getQueryResult().getQueryState() == QueryState.SUCCESS)
|
||||
.collect(Collectors.toList());
|
||||
|
||||
List<ParseResp> contextualList = contextualParseInfoList.subList(0,
|
||||
List<QueryResp> contextualList = contextualParseInfoList.subList(0,
|
||||
Math.min(multiNum, contextualParseInfoList.size()));
|
||||
Collections.reverse(contextualList);
|
||||
|
||||
|
||||
@@ -2,15 +2,15 @@ package com.tencent.supersonic.chat.server.executor;
|
||||
|
||||
import com.tencent.supersonic.chat.server.plugin.PluginQueryManager;
|
||||
import com.tencent.supersonic.chat.server.plugin.build.PluginSemanticQuery;
|
||||
import com.tencent.supersonic.chat.server.pojo.ChatExecuteContext;
|
||||
import com.tencent.supersonic.chat.server.pojo.ExecuteContext;
|
||||
import com.tencent.supersonic.headless.api.pojo.SemanticParseInfo;
|
||||
import com.tencent.supersonic.headless.api.pojo.response.QueryResult;
|
||||
import com.tencent.supersonic.chat.api.pojo.response.QueryResult;
|
||||
|
||||
public class PluginExecutor implements ChatExecutor {
|
||||
public class PluginExecutor implements ChatQueryExecutor {
|
||||
|
||||
@Override
|
||||
public QueryResult execute(ChatExecuteContext chatExecuteContext) {
|
||||
SemanticParseInfo parseInfo = chatExecuteContext.getParseInfo();
|
||||
public QueryResult execute(ExecuteContext executeContext) {
|
||||
SemanticParseInfo parseInfo = executeContext.getParseInfo();
|
||||
if (!PluginQueryManager.isPluginQuery(parseInfo.getQueryMode())) {
|
||||
return null;
|
||||
}
|
||||
|
||||
@@ -2,28 +2,34 @@ package com.tencent.supersonic.chat.server.executor;
|
||||
|
||||
import com.tencent.supersonic.chat.api.pojo.enums.MemoryStatus;
|
||||
import com.tencent.supersonic.chat.server.persistence.dataobject.ChatMemoryDO;
|
||||
import com.tencent.supersonic.chat.server.pojo.ChatExecuteContext;
|
||||
import com.tencent.supersonic.chat.server.pojo.ExecuteContext;
|
||||
import com.tencent.supersonic.chat.server.service.MemoryService;
|
||||
import com.tencent.supersonic.chat.server.util.ResultFormatter;
|
||||
import com.tencent.supersonic.common.pojo.Text2SQLExemplar;
|
||||
import com.tencent.supersonic.common.util.ContextUtils;
|
||||
import com.tencent.supersonic.common.util.JsonUtil;
|
||||
import com.tencent.supersonic.headless.api.pojo.SemanticParseInfo;
|
||||
import com.tencent.supersonic.headless.api.pojo.request.ExecuteQueryReq;
|
||||
import com.tencent.supersonic.headless.api.pojo.response.QueryResult;
|
||||
import com.tencent.supersonic.headless.api.pojo.request.QuerySqlReq;
|
||||
import com.tencent.supersonic.chat.api.pojo.response.QueryResult;
|
||||
import com.tencent.supersonic.headless.api.pojo.response.QueryState;
|
||||
import com.tencent.supersonic.headless.api.pojo.response.SemanticQueryResp;
|
||||
import com.tencent.supersonic.chat.server.pojo.ChatContext;
|
||||
import com.tencent.supersonic.headless.chat.query.llm.s2sql.LLMSqlQuery;
|
||||
import com.tencent.supersonic.headless.server.facade.service.ChatQueryService;
|
||||
import com.tencent.supersonic.headless.server.facade.service.SemanticLayerService;
|
||||
import com.tencent.supersonic.chat.server.service.ChatContextService;
|
||||
import lombok.SneakyThrows;
|
||||
import org.apache.commons.lang3.StringUtils;
|
||||
import java.util.Date;
|
||||
|
||||
public class SqlExecutor implements ChatExecutor {
|
||||
import java.util.Date;
|
||||
import java.util.Objects;
|
||||
|
||||
public class SqlExecutor implements ChatQueryExecutor {
|
||||
|
||||
@SneakyThrows
|
||||
@Override
|
||||
public QueryResult execute(ChatExecuteContext chatExecuteContext) {
|
||||
ExecuteQueryReq executeQueryReq = buildExecuteReq(chatExecuteContext);
|
||||
ChatQueryService chatQueryService = ContextUtils.getBean(ChatQueryService.class);
|
||||
QueryResult queryResult = chatQueryService.performExecution(executeQueryReq);
|
||||
public QueryResult execute(ExecuteContext executeContext) {
|
||||
QueryResult queryResult = doExecute(executeContext);
|
||||
|
||||
if (queryResult != null) {
|
||||
String textResult = ResultFormatter.transform2TextNew(queryResult.getQueryColumns(),
|
||||
queryResult.getQueryResults());
|
||||
@@ -31,14 +37,20 @@ public class SqlExecutor implements ChatExecutor {
|
||||
|
||||
if (queryResult.getQueryState().equals(QueryState.SUCCESS)
|
||||
&& queryResult.getQueryMode().equals(LLMSqlQuery.QUERY_MODE)) {
|
||||
Text2SQLExemplar exemplar = JsonUtil.toObject(JsonUtil.toString(
|
||||
executeContext.getParseInfo().getProperties()
|
||||
.get(Text2SQLExemplar.PROPERTY_KEY)), Text2SQLExemplar.class);
|
||||
|
||||
MemoryService memoryService = ContextUtils.getBean(MemoryService.class);
|
||||
memoryService.createMemory(ChatMemoryDO.builder()
|
||||
.agentId(chatExecuteContext.getAgentId())
|
||||
.agentId(executeContext.getAgent().getId())
|
||||
.status(MemoryStatus.PENDING)
|
||||
.question(chatExecuteContext.getQueryText())
|
||||
.s2sql(chatExecuteContext.getParseInfo().getSqlInfo().getS2SQL())
|
||||
.dbSchema(buildSchemaStr(chatExecuteContext.getParseInfo()))
|
||||
.createdBy(chatExecuteContext.getUser().getName())
|
||||
.question(exemplar.getQuestion())
|
||||
.sideInfo(exemplar.getSideInfo())
|
||||
.dbSchema(exemplar.getDbSchema())
|
||||
.s2sql(exemplar.getSql())
|
||||
.createdBy(executeContext.getUser().getName())
|
||||
.updatedBy(executeContext.getUser().getName())
|
||||
.createdAt(new Date())
|
||||
.build());
|
||||
}
|
||||
@@ -47,48 +59,43 @@ public class SqlExecutor implements ChatExecutor {
|
||||
return queryResult;
|
||||
}
|
||||
|
||||
private ExecuteQueryReq buildExecuteReq(ChatExecuteContext chatExecuteContext) {
|
||||
SemanticParseInfo parseInfo = chatExecuteContext.getParseInfo();
|
||||
return ExecuteQueryReq.builder()
|
||||
.queryId(chatExecuteContext.getQueryId())
|
||||
.chatId(chatExecuteContext.getChatId())
|
||||
.queryText(chatExecuteContext.getQueryText())
|
||||
.parseInfo(parseInfo)
|
||||
.saveAnswer(chatExecuteContext.isSaveAnswer())
|
||||
.user(chatExecuteContext.getUser())
|
||||
@SneakyThrows
|
||||
private QueryResult doExecute(ExecuteContext executeContext) {
|
||||
SemanticLayerService semanticLayer = ContextUtils.getBean(SemanticLayerService.class);
|
||||
ChatContextService chatContextService = ContextUtils.getBean(ChatContextService.class);
|
||||
|
||||
ChatContext chatCtx = chatContextService.getOrCreateContext(executeContext.getChatId());
|
||||
SemanticParseInfo parseInfo = executeContext.getParseInfo();
|
||||
if (Objects.isNull(parseInfo.getSqlInfo())
|
||||
|| StringUtils.isBlank(parseInfo.getSqlInfo().getCorrectedS2SQL())) {
|
||||
return null;
|
||||
}
|
||||
|
||||
QuerySqlReq sqlReq = QuerySqlReq.builder()
|
||||
.sql(parseInfo.getSqlInfo().getCorrectedS2SQL())
|
||||
.build();
|
||||
}
|
||||
sqlReq.setSqlInfo(parseInfo.getSqlInfo());
|
||||
sqlReq.setDataSetId(parseInfo.getDataSetId());
|
||||
|
||||
public String buildSchemaStr(SemanticParseInfo parseInfo) {
|
||||
String tableStr = parseInfo.getDataSet().getName();
|
||||
StringBuilder metricStr = new StringBuilder();
|
||||
StringBuilder dimensionStr = new StringBuilder();
|
||||
long startTime = System.currentTimeMillis();
|
||||
SemanticQueryResp queryResp = semanticLayer.queryByReq(sqlReq, executeContext.getUser());
|
||||
QueryResult queryResult = new QueryResult();
|
||||
queryResult.setChatContext(parseInfo);
|
||||
queryResult.setQueryMode(parseInfo.getQueryMode());
|
||||
queryResult.setQueryTimeCost(System.currentTimeMillis() - startTime);
|
||||
if (queryResp != null) {
|
||||
queryResult.setQueryAuthorization(queryResp.getQueryAuthorization());
|
||||
queryResult.setQuerySql(queryResp.getSql());
|
||||
queryResult.setQueryResults(queryResp.getResultList());
|
||||
queryResult.setQueryColumns(queryResp.getColumns());
|
||||
queryResult.setQueryState(QueryState.SUCCESS);
|
||||
|
||||
parseInfo.getMetrics().stream().forEach(
|
||||
metric -> {
|
||||
metricStr.append(metric.getName());
|
||||
if (StringUtils.isNotEmpty(metric.getDescription())) {
|
||||
metricStr.append(" COMMENT '" + metric.getDescription() + "'");
|
||||
}
|
||||
if (StringUtils.isNotEmpty(metric.getDefaultAgg())) {
|
||||
metricStr.append(" AGGREGATE '" + metric.getDefaultAgg().toUpperCase() + "'");
|
||||
}
|
||||
metricStr.append(",");
|
||||
}
|
||||
);
|
||||
|
||||
parseInfo.getDimensions().stream().forEach(
|
||||
dimension -> {
|
||||
dimensionStr.append(dimension.getName());
|
||||
if (StringUtils.isNotEmpty(dimension.getDescription())) {
|
||||
dimensionStr.append(" COMMENT '" + dimension.getDescription() + "'");
|
||||
}
|
||||
dimensionStr.append(",");
|
||||
}
|
||||
);
|
||||
|
||||
String template = "Table: %s, Metrics: [%s], Dimensions: [%s]";
|
||||
return String.format(template, tableStr, metricStr, dimensionStr);
|
||||
chatCtx.setParseInfo(parseInfo);
|
||||
chatContextService.updateContext(chatCtx);
|
||||
} else {
|
||||
queryResult.setQueryState(QueryState.INVALID);
|
||||
}
|
||||
return queryResult;
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
@@ -7,7 +7,7 @@ import com.tencent.supersonic.chat.server.service.MemoryService;
|
||||
import dev.langchain4j.model.chat.ChatLanguageModel;
|
||||
import dev.langchain4j.model.input.Prompt;
|
||||
import dev.langchain4j.model.input.PromptTemplate;
|
||||
import dev.langchain4j.model.provider.ChatLanguageModelProvider;
|
||||
import dev.langchain4j.provider.ModelProvider;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.slf4j.Logger;
|
||||
import org.slf4j.LoggerFactory;
|
||||
@@ -32,9 +32,11 @@ public class MemoryReviewTask {
|
||||
+ "please take a review and give your opinion.\n"
|
||||
+ "#Rules: "
|
||||
+ "1.ALWAYS follow the output format: `opinion=(POSITIVE|NEGATIVE),comment=(your comment)`."
|
||||
+ "2.DO NOT check the usage of `数据日期` field and `datediff()` function.\n"
|
||||
+ "2.ALWAYS recognize `数据日期` as the date field."
|
||||
+ "3.IGNORE `数据日期` if not expressed in the `Question`."
|
||||
+ "#Question: %s\n"
|
||||
+ "#Schema: %s\n"
|
||||
+ "#SideInfo: %s\n"
|
||||
+ "#SQL: %s\n"
|
||||
+ "#Response: ";
|
||||
|
||||
@@ -51,28 +53,33 @@ public class MemoryReviewTask {
|
||||
memoryService.getMemoriesForLlmReview().stream()
|
||||
.forEach(m -> {
|
||||
Agent chatAgent = agentService.getAgent(m.getAgentId());
|
||||
if (Objects.nonNull(chatAgent)) {
|
||||
String promptStr = String.format(INSTRUCTION, m.getQuestion(), m.getDbSchema(), m.getS2sql());
|
||||
if (Objects.nonNull(chatAgent) && chatAgent.enableMemoryReview()) {
|
||||
String promptStr = String.format(INSTRUCTION, m.getQuestion(), m.getDbSchema(),
|
||||
m.getSideInfo(), m.getS2sql());
|
||||
Prompt prompt = PromptTemplate.from(promptStr).apply(Collections.EMPTY_MAP);
|
||||
|
||||
keyPipelineLog.info("MemoryReviewTask reqPrompt:{}", promptStr);
|
||||
ChatLanguageModel chatLanguageModel = ChatLanguageModelProvider.provide(
|
||||
chatAgent.getLlmConfig());
|
||||
keyPipelineLog.info("MemoryReviewTask reqPrompt:\n{}", promptStr);
|
||||
ChatLanguageModel chatLanguageModel = ModelProvider.getChatModel(
|
||||
chatAgent.getModelConfig());
|
||||
if (Objects.nonNull(chatLanguageModel)) {
|
||||
String response = chatLanguageModel.generate(prompt.toUserMessage()).content().text();
|
||||
keyPipelineLog.info("MemoryReviewTask modelResp:{}", response);
|
||||
keyPipelineLog.info("MemoryReviewTask modelResp:\n{}", response);
|
||||
|
||||
Matcher matcher = OUTPUT_PATTERN.matcher(response);
|
||||
if (matcher.find()) {
|
||||
m.setLlmReviewRet(MemoryReviewResult.valueOf(matcher.group(1)));
|
||||
m.setLlmReviewCmt(matcher.group(2));
|
||||
// directly enable memory if the LLM determines it positive
|
||||
if (MemoryReviewResult.POSITIVE.equals(m.getLlmReviewRet())) {
|
||||
memoryService.enableMemory(m);
|
||||
}
|
||||
memoryService.updateMemory(m);
|
||||
}
|
||||
} else {
|
||||
log.debug("ChatLanguageModel not found for agent:{}", chatAgent.getId());
|
||||
}
|
||||
} else {
|
||||
log.debug("Agent not found for memory:{}", m.getAgentId());
|
||||
log.debug("Agent id {} not found or memory review disabled", m.getAgentId());
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
@@ -1,10 +0,0 @@
|
||||
package com.tencent.supersonic.chat.server.parser;
|
||||
|
||||
import com.tencent.supersonic.chat.server.pojo.ChatParseContext;
|
||||
import com.tencent.supersonic.headless.api.pojo.response.ParseResp;
|
||||
|
||||
public interface ChatParser {
|
||||
|
||||
void parse(ChatParseContext chatParseContext, ParseResp parseResp);
|
||||
|
||||
}
|
||||
@@ -0,0 +1,10 @@
|
||||
package com.tencent.supersonic.chat.server.parser;
|
||||
|
||||
import com.tencent.supersonic.chat.server.pojo.ParseContext;
|
||||
import com.tencent.supersonic.headless.api.pojo.response.ParseResp;
|
||||
|
||||
public interface ChatQueryParser {
|
||||
|
||||
void parse(ParseContext parseContext, ParseResp parseResp);
|
||||
|
||||
}
|
||||
@@ -1,7 +1,7 @@
|
||||
package com.tencent.supersonic.chat.server.parser;
|
||||
|
||||
import com.tencent.supersonic.chat.server.plugin.recognize.PluginRecognizer;
|
||||
import com.tencent.supersonic.chat.server.pojo.ChatParseContext;
|
||||
import com.tencent.supersonic.chat.server.pojo.ParseContext;
|
||||
import com.tencent.supersonic.chat.server.util.ComponentFactory;
|
||||
import com.tencent.supersonic.common.util.JsonUtil;
|
||||
import com.tencent.supersonic.headless.api.pojo.response.ParseResp;
|
||||
@@ -9,18 +9,18 @@ import lombok.extern.slf4j.Slf4j;
|
||||
import java.util.List;
|
||||
|
||||
@Slf4j
|
||||
public class NL2PluginParser implements ChatParser {
|
||||
public class NL2PluginParser implements ChatQueryParser {
|
||||
|
||||
private final List<PluginRecognizer> pluginRecognizers = ComponentFactory.getPluginRecognizers();
|
||||
|
||||
@Override
|
||||
public void parse(ChatParseContext chatParseContext, ParseResp parseResp) {
|
||||
if (!chatParseContext.getAgent().containsPluginTool()) {
|
||||
public void parse(ParseContext parseContext, ParseResp parseResp) {
|
||||
if (!parseContext.getAgent().containsPluginTool()) {
|
||||
return;
|
||||
}
|
||||
|
||||
pluginRecognizers.forEach(pluginRecognizer -> {
|
||||
pluginRecognizer.recognize(chatParseContext, parseResp);
|
||||
pluginRecognizer.recognize(parseContext, parseResp);
|
||||
log.info("{} recallResult:{}", pluginRecognizer.getClass().getSimpleName(),
|
||||
JsonUtil.toString(parseResp));
|
||||
});
|
||||
|
||||
@@ -1,15 +1,14 @@
|
||||
package com.tencent.supersonic.chat.server.parser;
|
||||
|
||||
import static com.tencent.supersonic.chat.server.parser.ParserConfig.PARSER_MULTI_TURN_ENABLE;
|
||||
|
||||
import com.tencent.supersonic.chat.api.pojo.response.QueryResp;
|
||||
import com.tencent.supersonic.chat.server.agent.MultiTurnConfig;
|
||||
import com.tencent.supersonic.chat.server.persistence.repository.ChatQueryRepository;
|
||||
import com.tencent.supersonic.chat.server.plugin.PluginQueryManager;
|
||||
import com.tencent.supersonic.chat.server.pojo.ChatParseContext;
|
||||
import com.tencent.supersonic.chat.server.pojo.ParseContext;
|
||||
import com.tencent.supersonic.chat.server.service.ChatContextService;
|
||||
import com.tencent.supersonic.chat.server.service.ChatManageService;
|
||||
import com.tencent.supersonic.chat.server.util.QueryReqConverter;
|
||||
import com.tencent.supersonic.common.config.EmbeddingConfig;
|
||||
import com.tencent.supersonic.common.config.LLMConfig;
|
||||
import com.tencent.supersonic.common.pojo.SqlExemplar;
|
||||
import com.tencent.supersonic.common.pojo.Text2SQLExemplar;
|
||||
import com.tencent.supersonic.common.service.impl.ExemplarServiceImpl;
|
||||
import com.tencent.supersonic.common.util.ContextUtils;
|
||||
import com.tencent.supersonic.headless.api.pojo.SchemaElement;
|
||||
@@ -17,65 +16,93 @@ import com.tencent.supersonic.headless.api.pojo.SchemaElementMatch;
|
||||
import com.tencent.supersonic.headless.api.pojo.SchemaElementType;
|
||||
import com.tencent.supersonic.headless.api.pojo.SemanticParseInfo;
|
||||
import com.tencent.supersonic.headless.api.pojo.request.QueryFilter;
|
||||
import com.tencent.supersonic.headless.api.pojo.request.QueryReq;
|
||||
import com.tencent.supersonic.headless.api.pojo.request.QueryNLReq;
|
||||
import com.tencent.supersonic.headless.api.pojo.response.MapResp;
|
||||
import com.tencent.supersonic.headless.api.pojo.response.ParseResp;
|
||||
import com.tencent.supersonic.headless.server.facade.service.ChatQueryService;
|
||||
import com.tencent.supersonic.chat.server.pojo.ChatContext;
|
||||
import com.tencent.supersonic.headless.api.pojo.response.QueryState;
|
||||
import com.tencent.supersonic.headless.server.facade.service.ChatLayerService;
|
||||
import dev.langchain4j.data.message.AiMessage;
|
||||
import dev.langchain4j.model.chat.ChatLanguageModel;
|
||||
import dev.langchain4j.model.input.Prompt;
|
||||
import dev.langchain4j.model.input.PromptTemplate;
|
||||
import dev.langchain4j.model.output.Response;
|
||||
import dev.langchain4j.model.provider.ChatLanguageModelProvider;
|
||||
import java.util.ArrayList;
|
||||
import java.util.Collections;
|
||||
import java.util.List;
|
||||
import java.util.Objects;
|
||||
import java.util.Optional;
|
||||
import java.util.Set;
|
||||
import java.util.stream.Collectors;
|
||||
import lombok.Builder;
|
||||
import lombok.Data;
|
||||
import dev.langchain4j.provider.ModelProvider;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.slf4j.Logger;
|
||||
import org.slf4j.LoggerFactory;
|
||||
import org.springframework.util.CollectionUtils;
|
||||
|
||||
import java.util.ArrayList;
|
||||
import java.util.Collections;
|
||||
import java.util.HashMap;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
import java.util.Objects;
|
||||
import java.util.Optional;
|
||||
import java.util.Set;
|
||||
import java.util.stream.Collectors;
|
||||
|
||||
import static com.tencent.supersonic.chat.server.parser.ParserConfig.PARSER_MULTI_TURN_ENABLE;
|
||||
|
||||
@Slf4j
|
||||
public class NL2SQLParser implements ChatParser {
|
||||
public class NL2SQLParser implements ChatQueryParser {
|
||||
|
||||
private static final Logger keyPipelineLog = LoggerFactory.getLogger("keyPipeline");
|
||||
|
||||
private static final String REWRITE_INSTRUCTION = ""
|
||||
+ "#Role: You are a data product manager experienced in data requirements.\n"
|
||||
private static final String REWRITE_USER_QUESTION_INSTRUCTION = ""
|
||||
+ "#Role: You are a data product manager experienced in data requirements."
|
||||
+ "#Task: Your will be provided with current and history questions asked by a user,"
|
||||
+ "along with their mapped schema elements(metric, dimension and value),"
|
||||
+ "please try understanding the semantics and rewrite a question.\n"
|
||||
+ "please try understanding the semantics and rewrite a question."
|
||||
+ "#Rules: "
|
||||
+ "1.ALWAYS keep relevant entities, metrics, dimensions, values and date ranges. "
|
||||
+ "2.ONLY respond with the rewritten question.\n"
|
||||
+ "#Current Question: %s\n"
|
||||
+ "#Current Mapped Schema: %s\n"
|
||||
+ "#History Question: %s\n"
|
||||
+ "#History Mapped Schema: %s\n"
|
||||
+ "#History SQL: %s\n"
|
||||
+ "1.ALWAYS keep relevant entities, metrics, dimensions, values and date ranges."
|
||||
+ "2.ONLY respond with the rewritten question."
|
||||
+ "#Current Question: {{current_question}}"
|
||||
+ "#Current Mapped Schema: {{current_schema}}"
|
||||
+ "#History Question: {{history_question}}"
|
||||
+ "#History Mapped Schema: {{history_schema}}"
|
||||
+ "#History SQL: {{history_sql}}"
|
||||
+ "#Rewritten Question: ";
|
||||
|
||||
private static final String REWRITE_ERROR_MESSAGE_INSTRUCTION = ""
|
||||
+ "#Role: You are a data business partner who closely interacts with business people.\n"
|
||||
+ "#Task: Your will be provided with user input, system output and some examples, "
|
||||
+ "please respond shortly to teach user how to ask the right question, "
|
||||
+ "by using `Examples` as references."
|
||||
+ "#Rules: ALWAYS respond with the same language as the `Input`.\n"
|
||||
+ "#Input: {{user_question}}\n"
|
||||
+ "#Output: {{system_message}}\n"
|
||||
+ "#Examples: {{examples}}\n"
|
||||
+ "#Response: ";
|
||||
|
||||
@Override
|
||||
public void parse(ChatParseContext chatParseContext, ParseResp parseResp) {
|
||||
if (!chatParseContext.enableNL2SQL() || checkSkip(parseResp)) {
|
||||
public void parse(ParseContext parseContext, ParseResp parseResp) {
|
||||
if (!parseContext.enableNL2SQL() || checkSkip(parseResp)) {
|
||||
return;
|
||||
}
|
||||
processMultiTurn(chatParseContext);
|
||||
ChatContextService chatContextService = ContextUtils.getBean(ChatContextService.class);
|
||||
ChatContext chatCtx = chatContextService.getOrCreateContext(parseContext.getChatId());
|
||||
|
||||
QueryReq queryReq = QueryReqConverter.buildText2SqlQueryReq(chatParseContext);
|
||||
addExemplars(chatParseContext.getAgent().getId(), queryReq);
|
||||
ChatLanguageModel chatLanguageModel = ModelProvider.getChatModel(
|
||||
parseContext.getAgent().getModelConfig());
|
||||
|
||||
ChatQueryService chatQueryService = ContextUtils.getBean(ChatQueryService.class);
|
||||
ParseResp text2SqlParseResp = chatQueryService.performParsing(queryReq);
|
||||
if (!ParseResp.ParseState.FAILED.equals(text2SqlParseResp.getState())) {
|
||||
processMultiTurn(chatLanguageModel, parseContext);
|
||||
QueryNLReq queryNLReq = QueryReqConverter.buildText2SqlQueryReq(parseContext, chatCtx);
|
||||
addDynamicExemplars(parseContext.getAgent().getId(), queryNLReq);
|
||||
|
||||
ChatLayerService chatLayerService = ContextUtils.getBean(ChatLayerService.class);
|
||||
ParseResp text2SqlParseResp = chatLayerService.performParsing(queryNLReq);
|
||||
if (ParseResp.ParseState.COMPLETED.equals(text2SqlParseResp.getState())) {
|
||||
parseResp.getSelectedParses().addAll(text2SqlParseResp.getSelectedParses());
|
||||
} else {
|
||||
parseResp.setErrorMsg(rewriteErrorMessage(chatLanguageModel,
|
||||
parseContext.getQueryText(),
|
||||
text2SqlParseResp.getErrorMsg(),
|
||||
queryNLReq.getDynamicExemplars(),
|
||||
parseContext.getAgent().getExamples()));
|
||||
}
|
||||
parseResp.setState(text2SqlParseResp.getState());
|
||||
parseResp.getParseTimeCost().setSqlTime(text2SqlParseResp.getParseTimeCost().getSqlTime());
|
||||
formatParseResult(parseResp);
|
||||
}
|
||||
@@ -135,9 +162,9 @@ public class NL2SQLParser implements ChatParser {
|
||||
parseInfo.setTextInfo(textBuilder.toString());
|
||||
}
|
||||
|
||||
private void processMultiTurn(ChatParseContext chatParseContext) {
|
||||
private void processMultiTurn(ChatLanguageModel chatLanguageModel, ParseContext parseContext) {
|
||||
ParserConfig parserConfig = ContextUtils.getBean(ParserConfig.class);
|
||||
MultiTurnConfig agentMultiTurnConfig = chatParseContext.getAgent().getMultiTurnConfig();
|
||||
MultiTurnConfig agentMultiTurnConfig = parseContext.getAgent().getMultiTurnConfig();
|
||||
Boolean globalMultiTurnConfig = Boolean.valueOf(parserConfig.getParameterValue(PARSER_MULTI_TURN_ENABLE));
|
||||
|
||||
Boolean multiTurnConfig = agentMultiTurnConfig != null
|
||||
@@ -147,45 +174,63 @@ public class NL2SQLParser implements ChatParser {
|
||||
}
|
||||
|
||||
// derive mapping result of current question and parsing result of last question.
|
||||
ChatQueryService chatQueryService = ContextUtils.getBean(ChatQueryService.class);
|
||||
QueryReq queryReq = QueryReqConverter.buildText2SqlQueryReq(chatParseContext);
|
||||
MapResp currentMapResult = chatQueryService.performMapping(queryReq);
|
||||
ChatLayerService chatLayerService = ContextUtils.getBean(ChatLayerService.class);
|
||||
QueryNLReq queryNLReq = QueryReqConverter.buildText2SqlQueryReq(parseContext);
|
||||
MapResp currentMapResult = chatLayerService.performMapping(queryNLReq);
|
||||
|
||||
List<ParseResp> historyParseResults = getHistoryParseResult(chatParseContext.getChatId(), 1);
|
||||
if (historyParseResults.size() == 0) {
|
||||
List<QueryResp> historyQueries = getHistoryQueries(parseContext.getChatId(), 1);
|
||||
if (historyQueries.size() == 0) {
|
||||
return;
|
||||
}
|
||||
ParseResp lastParseResult = historyParseResults.get(0);
|
||||
Long dataId = lastParseResult.getSelectedParses().get(0).getDataSetId();
|
||||
QueryResp lastQuery = historyQueries.get(0);
|
||||
SemanticParseInfo lastParseInfo = lastQuery.getParseInfos().get(0);
|
||||
Long dataId = lastParseInfo.getDataSetId();
|
||||
|
||||
String curtMapStr = generateSchemaPrompt(currentMapResult.getMapInfo().getMatchedElements(dataId));
|
||||
String histMapStr = generateSchemaPrompt(lastParseResult.getSelectedParses().get(0).getElementMatches());
|
||||
String histSQL = lastParseResult.getSelectedParses().get(0).getSqlInfo().getCorrectS2SQL();
|
||||
String rewrittenQuery = rewriteQuery(RewriteContext.builder()
|
||||
.curtQuestion(currentMapResult.getQueryText())
|
||||
.histQuestion(lastParseResult.getQueryText())
|
||||
.curtSchema(curtMapStr)
|
||||
.histSchema(histMapStr)
|
||||
.histSQL(histSQL)
|
||||
.llmConfig(queryReq.getLlmConfig())
|
||||
.build());
|
||||
chatParseContext.setQueryText(rewrittenQuery);
|
||||
String histMapStr = generateSchemaPrompt(lastParseInfo.getElementMatches());
|
||||
String histSQL = lastParseInfo.getSqlInfo().getCorrectedS2SQL();
|
||||
|
||||
Map<String, Object> variables = new HashMap<>();
|
||||
variables.put("current_question", currentMapResult.getQueryText());
|
||||
variables.put("current_schema", curtMapStr);
|
||||
variables.put("history_question", lastQuery.getQueryText());
|
||||
variables.put("history_schema", histMapStr);
|
||||
variables.put("history_sql", histSQL);
|
||||
|
||||
Prompt prompt = PromptTemplate.from(REWRITE_USER_QUESTION_INSTRUCTION).apply(variables);
|
||||
keyPipelineLog.info("NL2SQLParser reqPrompt:{}", prompt.text());
|
||||
|
||||
Response<AiMessage> response = chatLanguageModel.generate(prompt.toUserMessage());
|
||||
String rewrittenQuery = response.content().text();
|
||||
keyPipelineLog.info("NL2SQLParser modelResp:{}", rewrittenQuery);
|
||||
|
||||
parseContext.setQueryText(rewrittenQuery);
|
||||
log.info("Last Query: {} Current Query: {}, Rewritten Query: {}",
|
||||
lastParseResult.getQueryText(), currentMapResult.getQueryText(), rewrittenQuery);
|
||||
lastQuery.getQueryText(), currentMapResult.getQueryText(), rewrittenQuery);
|
||||
}
|
||||
|
||||
private String rewriteQuery(RewriteContext context) {
|
||||
String promptStr = String.format(REWRITE_INSTRUCTION, context.getCurtQuestion(), context.getCurtSchema(),
|
||||
context.getHistQuestion(), context.getHistSchema(), context.getHistSQL());
|
||||
Prompt prompt = PromptTemplate.from(promptStr).apply(Collections.EMPTY_MAP);
|
||||
keyPipelineLog.info("NL2SQLParser reqPrompt:{}", promptStr);
|
||||
private String rewriteErrorMessage(ChatLanguageModel chatLanguageModel, String userQuestion,
|
||||
String errMsg, List<Text2SQLExemplar> similarExemplars,
|
||||
List<String> agentExamples) {
|
||||
Map<String, Object> variables = new HashMap<>();
|
||||
variables.put("user_question", userQuestion);
|
||||
variables.put("system_message", errMsg);
|
||||
|
||||
ChatLanguageModel chatLanguageModel = ChatLanguageModelProvider.provide(context.getLlmConfig());
|
||||
StringBuilder exampleStr = new StringBuilder();
|
||||
similarExemplars.forEach(e ->
|
||||
exampleStr.append(String.format("<Question:{%s},Schema:{%s}> ", e.getQuestion(), e.getDbSchema())));
|
||||
agentExamples.forEach(e ->
|
||||
exampleStr.append(String.format("<Question:{%s}> ", e)));
|
||||
variables.put("examples", exampleStr);
|
||||
|
||||
Prompt prompt = PromptTemplate.from(REWRITE_ERROR_MESSAGE_INSTRUCTION).apply(variables);
|
||||
keyPipelineLog.info("NL2SQLParser reqPrompt:{}", prompt.text());
|
||||
Response<AiMessage> response = chatLanguageModel.generate(prompt.toUserMessage());
|
||||
|
||||
String result = response.content().text();
|
||||
keyPipelineLog.info("NL2SQLParser modelResp:{}", result);
|
||||
return response.content().text();
|
||||
String rewrittenMsg = response.content().text();
|
||||
keyPipelineLog.info("NL2SQLParser modelResp:{}", rewrittenMsg);
|
||||
|
||||
return rewrittenMsg;
|
||||
}
|
||||
|
||||
private String generateSchemaPrompt(List<SchemaElementMatch> elementMatches) {
|
||||
@@ -213,36 +258,27 @@ public class NL2SQLParser implements ChatParser {
|
||||
return prompt.toString();
|
||||
}
|
||||
|
||||
private List<ParseResp> getHistoryParseResult(int chatId, int multiNum) {
|
||||
ChatQueryRepository chatQueryRepository = ContextUtils.getBean(ChatQueryRepository.class);
|
||||
List<ParseResp> contextualParseInfoList = chatQueryRepository.getContextualParseInfo(chatId)
|
||||
.stream().filter(p -> p.getState() != ParseResp.ParseState.FAILED).collect(Collectors.toList());
|
||||
private List<QueryResp> getHistoryQueries(int chatId, int multiNum) {
|
||||
ChatManageService chatManageService = ContextUtils.getBean(ChatManageService.class);
|
||||
List<QueryResp> contextualParseInfoList = chatManageService.getChatQueries(chatId)
|
||||
.stream()
|
||||
.filter(q -> Objects.nonNull(q.getQueryResult())
|
||||
&& q.getQueryResult().getQueryState() == QueryState.SUCCESS)
|
||||
.collect(Collectors.toList());
|
||||
|
||||
List<ParseResp> contextualList = contextualParseInfoList.subList(0,
|
||||
List<QueryResp> contextualList = contextualParseInfoList.subList(0,
|
||||
Math.min(multiNum, contextualParseInfoList.size()));
|
||||
Collections.reverse(contextualList);
|
||||
return contextualList;
|
||||
}
|
||||
|
||||
private void addExemplars(Integer agentId, QueryReq queryReq) {
|
||||
private void addDynamicExemplars(Integer agentId, QueryNLReq queryNLReq) {
|
||||
ExemplarServiceImpl exemplarManager = ContextUtils.getBean(ExemplarServiceImpl.class);
|
||||
EmbeddingConfig embeddingConfig = ContextUtils.getBean(EmbeddingConfig.class);
|
||||
String memoryCollectionName = embeddingConfig.getMemoryCollectionName(agentId);
|
||||
List<SqlExemplar> exemplars = exemplarManager.recallExemplars(memoryCollectionName,
|
||||
queryReq.getQueryText(), 5);
|
||||
queryReq.getExemplars().addAll(exemplars);
|
||||
}
|
||||
|
||||
@Builder
|
||||
@Data
|
||||
public static class RewriteContext {
|
||||
|
||||
private String curtQuestion;
|
||||
private String histQuestion;
|
||||
private String curtSchema;
|
||||
private String histSchema;
|
||||
private String histSQL;
|
||||
private LLMConfig llmConfig;
|
||||
List<Text2SQLExemplar> exemplars = exemplarManager.recallExemplars(memoryCollectionName,
|
||||
queryNLReq.getQueryText(), 5);
|
||||
queryNLReq.getDynamicExemplars().addAll(exemplars);
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
@@ -1,14 +1,11 @@
|
||||
package com.tencent.supersonic.chat.server.parser;
|
||||
|
||||
import com.google.common.collect.Lists;
|
||||
import com.tencent.supersonic.common.config.ParameterConfig;
|
||||
import com.tencent.supersonic.common.pojo.Parameter;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.springframework.stereotype.Service;
|
||||
|
||||
import java.util.List;
|
||||
|
||||
@Service("ChatParserConfig")
|
||||
@Service("ChatQueryParserConfig")
|
||||
@Slf4j
|
||||
public class ParserConfig extends ParameterConfig {
|
||||
|
||||
@@ -17,11 +14,4 @@ public class ParserConfig extends ParameterConfig {
|
||||
"是否开启多轮对话", "开启多轮对话将消耗更多token",
|
||||
"bool", "Parser相关配置");
|
||||
|
||||
@Override
|
||||
public List<Parameter> getSysParameters() {
|
||||
return Lists.newArrayList(
|
||||
PARSER_MULTI_TURN_ENABLE
|
||||
);
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
@@ -1,21 +1,22 @@
|
||||
package com.tencent.supersonic.chat.server.parser;
|
||||
|
||||
import com.tencent.supersonic.chat.server.pojo.ChatParseContext;
|
||||
import com.tencent.supersonic.chat.server.pojo.ParseContext;
|
||||
import com.tencent.supersonic.headless.api.pojo.SemanticParseInfo;
|
||||
import com.tencent.supersonic.headless.api.pojo.response.ParseResp;
|
||||
|
||||
|
||||
public class PlainTextParser implements ChatParser {
|
||||
public class PlainTextParser implements ChatQueryParser {
|
||||
|
||||
@Override
|
||||
public void parse(ChatParseContext chatParseContext, ParseResp parseResp) {
|
||||
if (chatParseContext.getAgent().containsAnyTool()) {
|
||||
public void parse(ParseContext parseContext, ParseResp parseResp) {
|
||||
if (parseContext.getAgent().containsAnyTool()) {
|
||||
return;
|
||||
}
|
||||
|
||||
SemanticParseInfo parseInfo = new SemanticParseInfo();
|
||||
parseInfo.setQueryMode("PLAIN_TEXT");
|
||||
parseResp.getSelectedParses().add(parseInfo);
|
||||
parseResp.setState(ParseResp.ParseState.COMPLETED);
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
@@ -3,7 +3,6 @@ package com.tencent.supersonic.chat.server.persistence.dataobject;
|
||||
import com.baomidou.mybatisplus.annotation.IdType;
|
||||
import com.baomidou.mybatisplus.annotation.TableId;
|
||||
import com.baomidou.mybatisplus.annotation.TableName;
|
||||
import com.tencent.supersonic.common.config.VisualConfig;
|
||||
import lombok.Data;
|
||||
|
||||
import java.util.Date;
|
||||
@@ -12,15 +11,18 @@ import java.util.Date;
|
||||
@TableName("s2_agent")
|
||||
public class AgentDO {
|
||||
/**
|
||||
*
|
||||
*/
|
||||
@TableId(type = IdType.AUTO)
|
||||
private Integer id;
|
||||
|
||||
/**
|
||||
*
|
||||
*/
|
||||
private String name;
|
||||
|
||||
/**
|
||||
*
|
||||
*/
|
||||
private String description;
|
||||
|
||||
@@ -30,37 +32,45 @@ public class AgentDO {
|
||||
private Integer status;
|
||||
|
||||
/**
|
||||
*
|
||||
*/
|
||||
private String examples;
|
||||
|
||||
/**
|
||||
*
|
||||
*/
|
||||
private String config;
|
||||
|
||||
/**
|
||||
*
|
||||
*/
|
||||
private String createdBy;
|
||||
|
||||
/**
|
||||
*
|
||||
*/
|
||||
private Date createdAt;
|
||||
|
||||
/**
|
||||
*
|
||||
*/
|
||||
private String updatedBy;
|
||||
|
||||
/**
|
||||
*
|
||||
*/
|
||||
private Date updatedAt;
|
||||
|
||||
/**
|
||||
*
|
||||
*/
|
||||
private Integer enableSearch;
|
||||
|
||||
private String llmConfig;
|
||||
|
||||
private Integer enableMemoryReview;
|
||||
private String modelConfig;
|
||||
private String multiTurnConfig;
|
||||
|
||||
private String visualConfig;
|
||||
|
||||
private String promptConfig;
|
||||
|
||||
}
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
package com.tencent.supersonic.headless.server.persistence.dataobject;
|
||||
package com.tencent.supersonic.chat.server.persistence.dataobject;
|
||||
|
||||
import lombok.Data;
|
||||
|
||||
@@ -20,11 +20,14 @@ public class ChatMemoryDO {
|
||||
@TableId(type = IdType.AUTO)
|
||||
private Long id;
|
||||
|
||||
@TableField("agent_id")
|
||||
private Integer agentId;
|
||||
|
||||
@TableField("question")
|
||||
private String question;
|
||||
|
||||
@TableField("agent_id")
|
||||
private Integer agentId;
|
||||
@TableField("side_info")
|
||||
private String sideInfo;
|
||||
|
||||
@TableField("db_schema")
|
||||
private String dbSchema;
|
||||
|
||||
@@ -1,142 +1,25 @@
|
||||
package com.tencent.supersonic.chat.server.persistence.dataobject;
|
||||
|
||||
import lombok.Data;
|
||||
|
||||
import java.util.Date;
|
||||
|
||||
|
||||
@Data
|
||||
public class ChatParseDO {
|
||||
|
||||
/**
|
||||
* questionId
|
||||
*/
|
||||
private Long questionId;
|
||||
|
||||
/**
|
||||
* chatId
|
||||
*/
|
||||
private Long chatId;
|
||||
private Integer chatId;
|
||||
|
||||
/**
|
||||
* parseId
|
||||
*/
|
||||
private Integer parseId;
|
||||
|
||||
/**
|
||||
* createTime
|
||||
*/
|
||||
private Date createTime;
|
||||
|
||||
/**
|
||||
* queryText
|
||||
*/
|
||||
private String queryText;
|
||||
|
||||
/**
|
||||
* userName
|
||||
*/
|
||||
private String userName;
|
||||
|
||||
|
||||
/**
|
||||
* parseInfo
|
||||
*/
|
||||
private String parseInfo;
|
||||
|
||||
/**
|
||||
* isCandidate
|
||||
*/
|
||||
private Integer isCandidate;
|
||||
|
||||
/**
|
||||
* return question_id
|
||||
*/
|
||||
public Long getQuestionId() {
|
||||
return questionId;
|
||||
}
|
||||
|
||||
/**
|
||||
* questionId
|
||||
*/
|
||||
public void setQuestionId(Long questionId) {
|
||||
this.questionId = questionId;
|
||||
}
|
||||
|
||||
/**
|
||||
* return create_time
|
||||
*/
|
||||
public Date getCreateTime() {
|
||||
return createTime;
|
||||
}
|
||||
|
||||
/**
|
||||
* createTime
|
||||
*/
|
||||
public void setCreateTime(Date createTime) {
|
||||
this.createTime = createTime;
|
||||
}
|
||||
|
||||
/**
|
||||
* return user_name
|
||||
*/
|
||||
public String getUserName() {
|
||||
return userName;
|
||||
}
|
||||
|
||||
/**
|
||||
* userName
|
||||
*/
|
||||
public void setUserName(String userName) {
|
||||
this.userName = userName == null ? null : userName.trim();
|
||||
}
|
||||
|
||||
/**
|
||||
* return chat_id
|
||||
*/
|
||||
public Long getChatId() {
|
||||
return chatId;
|
||||
}
|
||||
|
||||
/**
|
||||
* chatId
|
||||
*/
|
||||
public void setChatId(Long chatId) {
|
||||
this.chatId = chatId;
|
||||
}
|
||||
|
||||
/**
|
||||
* return query_text
|
||||
*/
|
||||
public String getQueryText() {
|
||||
return queryText;
|
||||
}
|
||||
|
||||
/**
|
||||
* queryText
|
||||
*/
|
||||
public void setQueryText(String queryText) {
|
||||
this.queryText = queryText == null ? null : queryText.trim();
|
||||
}
|
||||
|
||||
public Integer getIsCandidate() {
|
||||
return isCandidate;
|
||||
}
|
||||
|
||||
public Integer getParseId() {
|
||||
return parseId;
|
||||
}
|
||||
|
||||
public String getParseInfo() {
|
||||
return parseInfo;
|
||||
}
|
||||
|
||||
public void setParseId(Integer parseId) {
|
||||
this.parseId = parseId;
|
||||
}
|
||||
|
||||
public void setIsCandidate(Integer isCandidate) {
|
||||
this.isCandidate = isCandidate;
|
||||
}
|
||||
|
||||
public void setParseInfo(String parseInfo) {
|
||||
this.parseInfo = parseInfo;
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
package com.tencent.supersonic.headless.server.persistence.dataobject;
|
||||
package com.tencent.supersonic.chat.server.persistence.dataobject;
|
||||
|
||||
import lombok.AllArgsConstructor;
|
||||
import lombok.Data;
|
||||
@@ -1,6 +1,6 @@
|
||||
package com.tencent.supersonic.headless.server.persistence.mapper;
|
||||
package com.tencent.supersonic.chat.server.persistence.mapper;
|
||||
|
||||
import com.tencent.supersonic.headless.server.persistence.dataobject.ChatContextDO;
|
||||
import com.tencent.supersonic.chat.server.persistence.dataobject.ChatContextDO;
|
||||
import org.apache.ibatis.annotations.Mapper;
|
||||
|
||||
@Mapper
|
||||
@@ -1,6 +1,6 @@
|
||||
package com.tencent.supersonic.headless.server.persistence.mapper;
|
||||
package com.tencent.supersonic.chat.server.persistence.mapper;
|
||||
|
||||
import com.tencent.supersonic.headless.server.persistence.dataobject.StatisticsDO;
|
||||
import com.tencent.supersonic.chat.server.persistence.dataobject.StatisticsDO;
|
||||
import org.apache.ibatis.annotations.Mapper;
|
||||
import org.apache.ibatis.annotations.Param;
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
package com.tencent.supersonic.headless.server.persistence.repository;
|
||||
package com.tencent.supersonic.chat.server.persistence.repository;
|
||||
|
||||
import com.tencent.supersonic.headless.chat.ChatContext;
|
||||
import com.tencent.supersonic.chat.server.pojo.ChatContext;
|
||||
|
||||
public interface ChatContextRepository {
|
||||
|
||||
@@ -18,6 +18,8 @@ public interface ChatQueryRepository {
|
||||
|
||||
QueryResp getChatQuery(Long queryId);
|
||||
|
||||
List<QueryResp> getChatQueries(Integer chatId);
|
||||
|
||||
ChatQueryDO getChatQueryDO(Long queryId);
|
||||
|
||||
List<QueryResp> queryShowCase(PageQueryInfoReq pageQueryInfoCommend, int agentId);
|
||||
@@ -35,6 +37,4 @@ public interface ChatQueryRepository {
|
||||
|
||||
List<ChatParseDO> getParseInfoList(List<Long> questionIds);
|
||||
|
||||
List<ParseResp> getContextualParseInfo(Integer chatId);
|
||||
|
||||
}
|
||||
|
||||
@@ -1,12 +1,12 @@
|
||||
package com.tencent.supersonic.headless.server.persistence.repository.impl;
|
||||
package com.tencent.supersonic.chat.server.persistence.repository.impl;
|
||||
|
||||
import com.google.gson.Gson;
|
||||
import com.tencent.supersonic.chat.server.persistence.repository.ChatContextRepository;
|
||||
import com.tencent.supersonic.chat.server.pojo.ChatContext;
|
||||
import com.tencent.supersonic.common.util.JsonUtil;
|
||||
import com.tencent.supersonic.headless.api.pojo.SemanticParseInfo;
|
||||
import com.tencent.supersonic.headless.chat.ChatContext;
|
||||
import com.tencent.supersonic.headless.server.persistence.dataobject.ChatContextDO;
|
||||
import com.tencent.supersonic.headless.server.persistence.mapper.ChatContextMapper;
|
||||
import com.tencent.supersonic.headless.server.persistence.repository.ChatContextRepository;
|
||||
import com.tencent.supersonic.chat.server.persistence.dataobject.ChatContextDO;
|
||||
import com.tencent.supersonic.chat.server.persistence.mapper.ChatContextMapper;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.springframework.context.annotation.Primary;
|
||||
import org.springframework.stereotype.Repository;
|
||||
@@ -20,7 +20,7 @@ import com.tencent.supersonic.headless.api.pojo.SemanticParseInfo;
|
||||
import com.tencent.supersonic.headless.api.pojo.response.ParseResp;
|
||||
import com.tencent.supersonic.chat.api.pojo.response.QueryResp;
|
||||
import com.tencent.supersonic.headless.api.pojo.response.ParseTimeCostResp;
|
||||
import com.tencent.supersonic.headless.api.pojo.response.QueryResult;
|
||||
import com.tencent.supersonic.chat.api.pojo.response.QueryResult;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.apache.commons.lang3.StringUtils;
|
||||
import org.springframework.beans.BeanUtils;
|
||||
@@ -61,7 +61,8 @@ public class ChatQueryRepositoryImpl implements ChatQueryRepository {
|
||||
if (!CollectionUtils.isEmpty(pageQueryInfoReq.getIds())) {
|
||||
queryWrapper.lambda().in(ChatQueryDO::getQuestionId, pageQueryInfoReq.getIds());
|
||||
}
|
||||
|
||||
queryWrapper.lambda().isNotNull(ChatQueryDO::getQueryResult);
|
||||
queryWrapper.lambda().ne(ChatQueryDO::getQueryResult, "");
|
||||
queryWrapper.lambda().orderByDesc(ChatQueryDO::getQuestionId);
|
||||
|
||||
PageInfo<ChatQueryDO> pageInfo = PageHelper.startPage(pageQueryInfoReq.getCurrent(),
|
||||
@@ -70,8 +71,9 @@ public class ChatQueryRepositoryImpl implements ChatQueryRepository {
|
||||
|
||||
PageInfo<QueryResp> chatQueryVOPageInfo = PageUtils.pageInfo2PageInfoVo(pageInfo);
|
||||
chatQueryVOPageInfo.setList(
|
||||
pageInfo.getList().stream().filter(o -> !StringUtils.isEmpty(o.getQueryResult())).map(this::convertTo)
|
||||
pageInfo.getList().stream()
|
||||
.sorted(Comparator.comparingInt(o -> o.getQuestionId().intValue()))
|
||||
.map(this::convertTo)
|
||||
.collect(Collectors.toList()));
|
||||
return chatQueryVOPageInfo;
|
||||
}
|
||||
@@ -90,6 +92,16 @@ public class ChatQueryRepositoryImpl implements ChatQueryRepository {
|
||||
return chatQueryDOMapper.selectById(queryId);
|
||||
}
|
||||
|
||||
@Override
|
||||
public List<QueryResp> getChatQueries(Integer chatId) {
|
||||
QueryWrapper<ChatQueryDO> queryWrapper = new QueryWrapper<>();
|
||||
queryWrapper.lambda().eq(ChatQueryDO::getChatId, chatId);
|
||||
queryWrapper.lambda().orderByDesc(ChatQueryDO::getQuestionId);
|
||||
return chatQueryDOMapper.selectList(queryWrapper).stream()
|
||||
.map(q -> convertTo(q))
|
||||
.collect(Collectors.toList());
|
||||
}
|
||||
|
||||
@Override
|
||||
public List<QueryResp> queryShowCase(PageQueryInfoReq pageQueryInfoReq, int agentId) {
|
||||
return showCaseCustomMapper.queryShowCase(pageQueryInfoReq.getLimitStart(),
|
||||
@@ -145,7 +157,7 @@ public class ChatQueryRepositoryImpl implements ChatQueryRepository {
|
||||
List<SemanticParseInfo> parses, List<ChatParseDO> chatParseDOList) {
|
||||
for (int i = 0; i < parses.size(); i++) {
|
||||
ChatParseDO chatParseDO = new ChatParseDO();
|
||||
chatParseDO.setChatId(Long.valueOf(chatParseReq.getChatId()));
|
||||
chatParseDO.setChatId(chatParseReq.getChatId());
|
||||
chatParseDO.setQuestionId(queryId);
|
||||
chatParseDO.setQueryText(chatParseReq.getQueryText());
|
||||
chatParseDO.setParseInfo(JsonUtil.toString(parses.get(i)));
|
||||
@@ -179,17 +191,4 @@ public class ChatQueryRepositoryImpl implements ChatQueryRepository {
|
||||
return chatParseMapper.getParseInfoList(questionIds);
|
||||
}
|
||||
|
||||
@Override
|
||||
public List<ParseResp> getContextualParseInfo(Integer chatId) {
|
||||
List<ChatParseDO> chatParseDOList = chatParseMapper.getContextualParseInfo(chatId);
|
||||
List<ParseResp> semanticParseInfoList = chatParseDOList.stream().map(parseInfo -> {
|
||||
ParseResp parseResp = new ParseResp(chatId, parseInfo.getQueryText());
|
||||
List<SemanticParseInfo> selectedParses = new ArrayList<>();
|
||||
selectedParses.add(JSONObject.parseObject(parseInfo.getParseInfo(), SemanticParseInfo.class));
|
||||
parseResp.setSelectedParses(selectedParses);
|
||||
return parseResp;
|
||||
}).collect(Collectors.toList());
|
||||
return semanticParseInfoList;
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
@@ -11,7 +11,7 @@ import com.tencent.supersonic.chat.server.plugin.build.WebBase;
|
||||
import com.tencent.supersonic.chat.server.plugin.event.PluginAddEvent;
|
||||
import com.tencent.supersonic.chat.server.plugin.event.PluginDelEvent;
|
||||
import com.tencent.supersonic.chat.server.plugin.event.PluginUpdateEvent;
|
||||
import com.tencent.supersonic.chat.server.pojo.ChatParseContext;
|
||||
import com.tencent.supersonic.chat.server.pojo.ParseContext;
|
||||
import com.tencent.supersonic.chat.server.service.PluginService;
|
||||
import com.tencent.supersonic.common.config.EmbeddingConfig;
|
||||
import com.tencent.supersonic.common.service.EmbeddingService;
|
||||
@@ -52,9 +52,9 @@ public class PluginManager {
|
||||
@Autowired
|
||||
private EmbeddingService embeddingService;
|
||||
|
||||
public static List<ChatPlugin> getPluginAgentCanSupport(ChatParseContext chatParseContext) {
|
||||
public static List<ChatPlugin> getPluginAgentCanSupport(ParseContext parseContext) {
|
||||
PluginService pluginService = ContextUtils.getBean(PluginService.class);
|
||||
Agent agent = chatParseContext.getAgent();
|
||||
Agent agent = parseContext.getAgent();
|
||||
List<ChatPlugin> plugins = pluginService.getPluginList();
|
||||
if (Objects.isNull(agent)) {
|
||||
return plugins;
|
||||
@@ -191,9 +191,9 @@ public class PluginManager {
|
||||
return String.valueOf(Integer.parseInt(id) / 1000);
|
||||
}
|
||||
|
||||
public static Pair<Boolean, Set<Long>> resolve(ChatPlugin plugin, ChatParseContext chatParseContext) {
|
||||
SchemaMapInfo schemaMapInfo = chatParseContext.getMapInfo();
|
||||
Set<Long> pluginMatchedDataSet = getPluginMatchedDataSet(plugin, chatParseContext);
|
||||
public static Pair<Boolean, Set<Long>> resolve(ChatPlugin plugin, ParseContext parseContext) {
|
||||
SchemaMapInfo schemaMapInfo = parseContext.getMapInfo();
|
||||
Set<Long> pluginMatchedDataSet = getPluginMatchedDataSet(plugin, parseContext);
|
||||
if (CollectionUtils.isEmpty(pluginMatchedDataSet) && !plugin.isContainsAllDataSet()) {
|
||||
return Pair.of(false, Sets.newHashSet());
|
||||
}
|
||||
@@ -259,8 +259,8 @@ public class PluginManager {
|
||||
.collect(Collectors.toList());
|
||||
}
|
||||
|
||||
private static Set<Long> getPluginMatchedDataSet(ChatPlugin plugin, ChatParseContext chatParseContext) {
|
||||
Set<Long> matchedDataSets = chatParseContext.getMapInfo().getMatchedDataSetInfos();
|
||||
private static Set<Long> getPluginMatchedDataSet(ChatPlugin plugin, ParseContext parseContext) {
|
||||
Set<Long> matchedDataSets = parseContext.getMapInfo().getMatchedDataSetInfos();
|
||||
if (plugin.isContainsAllDataSet()) {
|
||||
return Sets.newHashSet(plugin.getDefaultMode());
|
||||
}
|
||||
|
||||
@@ -7,7 +7,7 @@ import com.tencent.supersonic.headless.api.pojo.SchemaElementType;
|
||||
import com.tencent.supersonic.headless.api.pojo.SemanticParseInfo;
|
||||
import com.tencent.supersonic.headless.api.pojo.request.QueryFilter;
|
||||
import com.tencent.supersonic.headless.api.pojo.request.QueryFilters;
|
||||
import com.tencent.supersonic.headless.api.pojo.response.QueryResult;
|
||||
import com.tencent.supersonic.chat.api.pojo.response.QueryResult;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.springframework.util.CollectionUtils;
|
||||
|
||||
|
||||
@@ -7,7 +7,7 @@ import com.tencent.supersonic.chat.server.plugin.build.PluginSemanticQuery;
|
||||
import com.tencent.supersonic.chat.server.plugin.build.WebBase;
|
||||
import com.tencent.supersonic.common.pojo.Constants;
|
||||
import com.tencent.supersonic.common.util.JsonUtil;
|
||||
import com.tencent.supersonic.headless.api.pojo.response.QueryResult;
|
||||
import com.tencent.supersonic.chat.api.pojo.response.QueryResult;
|
||||
import com.tencent.supersonic.headless.api.pojo.response.QueryState;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.springframework.stereotype.Component;
|
||||
|
||||
@@ -11,7 +11,7 @@ import com.tencent.supersonic.common.pojo.Constants;
|
||||
import com.tencent.supersonic.common.pojo.QueryColumn;
|
||||
import com.tencent.supersonic.common.util.ContextUtils;
|
||||
import com.tencent.supersonic.common.util.JsonUtil;
|
||||
import com.tencent.supersonic.headless.api.pojo.response.QueryResult;
|
||||
import com.tencent.supersonic.chat.api.pojo.response.QueryResult;
|
||||
import com.tencent.supersonic.headless.api.pojo.response.QueryState;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.springframework.http.HttpEntity;
|
||||
|
||||
@@ -6,7 +6,7 @@ import com.tencent.supersonic.chat.server.plugin.ChatPlugin;
|
||||
import com.tencent.supersonic.chat.server.plugin.PluginManager;
|
||||
import com.tencent.supersonic.chat.server.plugin.PluginParseResult;
|
||||
import com.tencent.supersonic.chat.server.plugin.PluginRecallResult;
|
||||
import com.tencent.supersonic.chat.server.pojo.ChatParseContext;
|
||||
import com.tencent.supersonic.chat.server.pojo.ParseContext;
|
||||
import com.tencent.supersonic.common.pojo.Constants;
|
||||
import com.tencent.supersonic.common.pojo.enums.FilterOperatorEnum;
|
||||
import com.tencent.supersonic.headless.api.pojo.SchemaElement;
|
||||
@@ -28,22 +28,22 @@ import java.util.Set;
|
||||
*/
|
||||
public abstract class PluginRecognizer {
|
||||
|
||||
public void recognize(ChatParseContext chatParseContext, ParseResp parseResp) {
|
||||
if (!checkPreCondition(chatParseContext)) {
|
||||
public void recognize(ParseContext parseContext, ParseResp parseResp) {
|
||||
if (!checkPreCondition(parseContext)) {
|
||||
return;
|
||||
}
|
||||
PluginRecallResult pluginRecallResult = recallPlugin(chatParseContext);
|
||||
PluginRecallResult pluginRecallResult = recallPlugin(parseContext);
|
||||
if (pluginRecallResult == null) {
|
||||
return;
|
||||
}
|
||||
buildQuery(chatParseContext, parseResp, pluginRecallResult);
|
||||
buildQuery(parseContext, parseResp, pluginRecallResult);
|
||||
}
|
||||
|
||||
public abstract boolean checkPreCondition(ChatParseContext chatParseContext);
|
||||
public abstract boolean checkPreCondition(ParseContext parseContext);
|
||||
|
||||
public abstract PluginRecallResult recallPlugin(ChatParseContext chatParseContext);
|
||||
public abstract PluginRecallResult recallPlugin(ParseContext parseContext);
|
||||
|
||||
public void buildQuery(ChatParseContext chatParseContext, ParseResp parseResp,
|
||||
public void buildQuery(ParseContext parseContext, ParseResp parseResp,
|
||||
PluginRecallResult pluginRecallResult) {
|
||||
ChatPlugin plugin = pluginRecallResult.getPlugin();
|
||||
Set<Long> dataSetIds = pluginRecallResult.getDataSetIds();
|
||||
@@ -52,35 +52,35 @@ public abstract class PluginRecognizer {
|
||||
}
|
||||
for (Long dataSetId : dataSetIds) {
|
||||
SemanticParseInfo semanticParseInfo = buildSemanticParseInfo(dataSetId, plugin,
|
||||
chatParseContext, pluginRecallResult.getDistance());
|
||||
parseContext, pluginRecallResult.getDistance());
|
||||
semanticParseInfo.setQueryMode(plugin.getType());
|
||||
semanticParseInfo.setScore(pluginRecallResult.getScore());
|
||||
parseResp.getSelectedParses().add(semanticParseInfo);
|
||||
}
|
||||
}
|
||||
|
||||
protected List<ChatPlugin> getPluginList(ChatParseContext chatParseContext) {
|
||||
return PluginManager.getPluginAgentCanSupport(chatParseContext);
|
||||
protected List<ChatPlugin> getPluginList(ParseContext parseContext) {
|
||||
return PluginManager.getPluginAgentCanSupport(parseContext);
|
||||
}
|
||||
|
||||
protected SemanticParseInfo buildSemanticParseInfo(Long dataSetId, ChatPlugin plugin,
|
||||
ChatParseContext chatParseContext, double distance) {
|
||||
List<SchemaElementMatch> schemaElementMatches = chatParseContext.getMapInfo().getMatchedElements(dataSetId);
|
||||
QueryFilters queryFilters = chatParseContext.getQueryFilters();
|
||||
ParseContext parseContext, double distance) {
|
||||
List<SchemaElementMatch> schemaElementMatches = parseContext.getMapInfo().getMatchedElements(dataSetId);
|
||||
QueryFilters queryFilters = parseContext.getQueryFilters();
|
||||
if (schemaElementMatches == null) {
|
||||
schemaElementMatches = Lists.newArrayList();
|
||||
}
|
||||
SemanticParseInfo semanticParseInfo = new SemanticParseInfo();
|
||||
semanticParseInfo.setElementMatches(schemaElementMatches);
|
||||
SchemaElement schemaElement = new SchemaElement();
|
||||
schemaElement.setDataSet(dataSetId);
|
||||
schemaElement.setDataSetId(dataSetId);
|
||||
semanticParseInfo.setDataSet(schemaElement);
|
||||
Map<String, Object> properties = new HashMap<>();
|
||||
PluginParseResult pluginParseResult = new PluginParseResult();
|
||||
pluginParseResult.setPlugin(plugin);
|
||||
pluginParseResult.setQueryFilters(queryFilters);
|
||||
pluginParseResult.setDistance(distance);
|
||||
pluginParseResult.setQueryText(chatParseContext.getQueryText());
|
||||
pluginParseResult.setQueryText(parseContext.getQueryText());
|
||||
properties.put(Constants.CONTEXT, pluginParseResult);
|
||||
properties.put("type", "plugin");
|
||||
properties.put("name", plugin.getName());
|
||||
|
||||
@@ -6,7 +6,7 @@ import com.tencent.supersonic.chat.server.plugin.ChatPlugin;
|
||||
import com.tencent.supersonic.chat.server.plugin.PluginManager;
|
||||
import com.tencent.supersonic.chat.server.plugin.PluginRecallResult;
|
||||
import com.tencent.supersonic.chat.server.plugin.recognize.PluginRecognizer;
|
||||
import com.tencent.supersonic.chat.server.pojo.ChatParseContext;
|
||||
import com.tencent.supersonic.chat.server.pojo.ParseContext;
|
||||
import com.tencent.supersonic.common.util.ContextUtils;
|
||||
import dev.langchain4j.store.embedding.Retrieval;
|
||||
import dev.langchain4j.store.embedding.RetrieveQueryResult;
|
||||
@@ -26,25 +26,25 @@ import java.util.stream.Collectors;
|
||||
@Slf4j
|
||||
public class EmbeddingRecallRecognizer extends PluginRecognizer {
|
||||
|
||||
public boolean checkPreCondition(ChatParseContext chatParseContext) {
|
||||
List<ChatPlugin> plugins = getPluginList(chatParseContext);
|
||||
public boolean checkPreCondition(ParseContext parseContext) {
|
||||
List<ChatPlugin> plugins = getPluginList(parseContext);
|
||||
return !CollectionUtils.isEmpty(plugins);
|
||||
}
|
||||
|
||||
public PluginRecallResult recallPlugin(ChatParseContext chatParseContext) {
|
||||
String text = chatParseContext.getQueryText();
|
||||
public PluginRecallResult recallPlugin(ParseContext parseContext) {
|
||||
String text = parseContext.getQueryText();
|
||||
List<Retrieval> embeddingRetrievals = embeddingRecall(text);
|
||||
if (CollectionUtils.isEmpty(embeddingRetrievals)) {
|
||||
return null;
|
||||
}
|
||||
List<ChatPlugin> plugins = getPluginList(chatParseContext);
|
||||
List<ChatPlugin> plugins = getPluginList(parseContext);
|
||||
Map<Long, ChatPlugin> pluginMap = plugins.stream().collect(Collectors.toMap(ChatPlugin::getId, p -> p));
|
||||
for (Retrieval embeddingRetrieval : embeddingRetrievals) {
|
||||
ChatPlugin plugin = pluginMap.get(Long.parseLong(embeddingRetrieval.getId()));
|
||||
if (plugin == null) {
|
||||
continue;
|
||||
}
|
||||
Pair<Boolean, Set<Long>> pair = PluginManager.resolve(plugin, chatParseContext);
|
||||
Pair<Boolean, Set<Long>> pair = PluginManager.resolve(plugin, parseContext);
|
||||
log.info("embedding plugin resolve: {}", pair);
|
||||
if (pair.getLeft()) {
|
||||
Set<Long> dataSetList = pair.getRight();
|
||||
@@ -53,7 +53,7 @@ public class EmbeddingRecallRecognizer extends PluginRecognizer {
|
||||
}
|
||||
plugin.setParseMode(ParseMode.EMBEDDING_RECALL);
|
||||
double distance = embeddingRetrieval.getDistance();
|
||||
double score = chatParseContext.getQueryText().length() * (1 - distance);
|
||||
double score = parseContext.getQueryText().length() * (1 - distance);
|
||||
return PluginRecallResult.builder()
|
||||
.plugin(plugin).dataSetIds(dataSetList).score(score).distance(distance).build();
|
||||
}
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
package com.tencent.supersonic.headless.chat;
|
||||
package com.tencent.supersonic.chat.server.pojo;
|
||||
|
||||
|
||||
import com.tencent.supersonic.headless.api.pojo.SemanticParseInfo;
|
||||
@@ -6,7 +6,6 @@ import lombok.Data;
|
||||
|
||||
@Data
|
||||
public class ChatContext {
|
||||
|
||||
private Integer chatId;
|
||||
private String queryText;
|
||||
private SemanticParseInfo parseInfo = new SemanticParseInfo();
|
||||
@@ -1,17 +1,17 @@
|
||||
package com.tencent.supersonic.chat.server.pojo;
|
||||
|
||||
import com.tencent.supersonic.auth.api.authentication.pojo.User;
|
||||
import com.tencent.supersonic.chat.server.agent.Agent;
|
||||
import com.tencent.supersonic.headless.api.pojo.SemanticParseInfo;
|
||||
import lombok.Data;
|
||||
|
||||
@Data
|
||||
public class ChatExecuteContext {
|
||||
public class ExecuteContext {
|
||||
private User user;
|
||||
private Integer agentId;
|
||||
private Long queryId;
|
||||
private Integer chatId;
|
||||
private int parseId;
|
||||
private String queryText;
|
||||
private Agent agent;
|
||||
private Integer chatId;
|
||||
private Long queryId;
|
||||
private boolean saveAnswer;
|
||||
private SemanticParseInfo parseInfo;
|
||||
}
|
||||
@@ -7,14 +7,14 @@ import com.tencent.supersonic.headless.api.pojo.request.QueryFilters;
|
||||
import lombok.Data;
|
||||
|
||||
@Data
|
||||
public class ChatParseContext {
|
||||
private String queryText;
|
||||
private Integer chatId;
|
||||
private Agent agent;
|
||||
public class ParseContext {
|
||||
private User user;
|
||||
private String queryText;
|
||||
private Agent agent;
|
||||
private Integer chatId;
|
||||
private QueryFilters queryFilters;
|
||||
private boolean saveAnswer = true;
|
||||
private SchemaMapInfo mapInfo = new SchemaMapInfo();
|
||||
private SchemaMapInfo mapInfo;
|
||||
|
||||
public boolean enableNL2SQL() {
|
||||
if (agent == null) {
|
||||
@@ -5,5 +5,4 @@ package com.tencent.supersonic.chat.server.processor;
|
||||
*/
|
||||
public interface ResultProcessor {
|
||||
|
||||
|
||||
}
|
||||
|
||||
@@ -1,14 +1,14 @@
|
||||
package com.tencent.supersonic.chat.server.processor.execute;
|
||||
|
||||
import com.google.common.collect.Lists;
|
||||
import com.tencent.supersonic.chat.server.pojo.ChatExecuteContext;
|
||||
import com.tencent.supersonic.chat.server.pojo.ExecuteContext;
|
||||
import com.tencent.supersonic.common.pojo.enums.QueryType;
|
||||
import com.tencent.supersonic.common.util.ContextUtils;
|
||||
import com.tencent.supersonic.headless.api.pojo.DataSetSchema;
|
||||
import com.tencent.supersonic.headless.api.pojo.RelatedSchemaElement;
|
||||
import com.tencent.supersonic.headless.api.pojo.SchemaElement;
|
||||
import com.tencent.supersonic.headless.api.pojo.SemanticParseInfo;
|
||||
import com.tencent.supersonic.headless.api.pojo.response.QueryResult;
|
||||
import com.tencent.supersonic.chat.api.pojo.response.QueryResult;
|
||||
import com.tencent.supersonic.headless.server.facade.service.SemanticLayerService;
|
||||
import org.springframework.util.CollectionUtils;
|
||||
|
||||
@@ -28,14 +28,14 @@ public class DimensionRecommendProcessor implements ExecuteResultProcessor {
|
||||
private static final int recommend_dimension_size = 5;
|
||||
|
||||
@Override
|
||||
public void process(ChatExecuteContext chatExecuteContext, QueryResult queryResult) {
|
||||
SemanticParseInfo semanticParseInfo = chatExecuteContext.getParseInfo();
|
||||
public void process(ExecuteContext executeContext, QueryResult queryResult) {
|
||||
SemanticParseInfo semanticParseInfo = executeContext.getParseInfo();
|
||||
if (!QueryType.METRIC.equals(semanticParseInfo.getQueryType())
|
||||
|| CollectionUtils.isEmpty(semanticParseInfo.getMetrics())) {
|
||||
return;
|
||||
}
|
||||
SchemaElement element = semanticParseInfo.getMetrics().iterator().next();
|
||||
List<SchemaElement> dimensionRecommended = getDimensions(element.getId(), element.getDataSet());
|
||||
List<SchemaElement> dimensionRecommended = getDimensions(element.getId(), element.getDataSetId());
|
||||
queryResult.setRecommendedDimensions(dimensionRecommended);
|
||||
}
|
||||
|
||||
|
||||
@@ -1,14 +1,15 @@
|
||||
package com.tencent.supersonic.chat.server.processor.execute;
|
||||
|
||||
import com.tencent.supersonic.chat.server.pojo.ChatExecuteContext;
|
||||
import com.tencent.supersonic.chat.server.pojo.ExecuteContext;
|
||||
import com.tencent.supersonic.chat.server.processor.ResultProcessor;
|
||||
import com.tencent.supersonic.headless.api.pojo.response.QueryResult;
|
||||
import com.tencent.supersonic.chat.api.pojo.response.QueryResult;
|
||||
|
||||
/**
|
||||
* A ExecuteResultProcessor wraps things up before returning results to users in execute stage.
|
||||
* A ExecuteResultProcessor wraps things up before returning
|
||||
* execution results to the users.
|
||||
*/
|
||||
public interface ExecuteResultProcessor extends ResultProcessor {
|
||||
|
||||
void process(ChatExecuteContext chatExecuteContext, QueryResult queryResult);
|
||||
void process(ExecuteContext executeContext, QueryResult queryResult);
|
||||
|
||||
}
|
||||
@@ -11,7 +11,7 @@ import static com.tencent.supersonic.common.pojo.Constants.TIME_FORMAT;
|
||||
import static com.tencent.supersonic.common.pojo.Constants.WEEK;
|
||||
|
||||
import com.tencent.supersonic.auth.api.authentication.pojo.User;
|
||||
import com.tencent.supersonic.chat.server.pojo.ChatExecuteContext;
|
||||
import com.tencent.supersonic.chat.server.pojo.ExecuteContext;
|
||||
import com.tencent.supersonic.common.pojo.DateConf;
|
||||
import com.tencent.supersonic.common.pojo.DateConf.DateMode;
|
||||
import com.tencent.supersonic.common.pojo.QueryColumn;
|
||||
@@ -26,7 +26,7 @@ import com.tencent.supersonic.headless.api.pojo.MetricInfo;
|
||||
import com.tencent.supersonic.headless.api.pojo.SchemaElement;
|
||||
import com.tencent.supersonic.headless.api.pojo.SemanticParseInfo;
|
||||
import com.tencent.supersonic.headless.api.pojo.request.QueryStructReq;
|
||||
import com.tencent.supersonic.headless.api.pojo.response.QueryResult;
|
||||
import com.tencent.supersonic.chat.api.pojo.response.QueryResult;
|
||||
import com.tencent.supersonic.headless.api.pojo.response.SemanticQueryResp;
|
||||
import com.tencent.supersonic.headless.core.config.AggregatorConfig;
|
||||
import com.tencent.supersonic.headless.chat.utils.QueryReqBuilder;
|
||||
@@ -60,15 +60,15 @@ import org.springframework.util.CollectionUtils;
|
||||
public class MetricRatioProcessor implements ExecuteResultProcessor {
|
||||
|
||||
@Override
|
||||
public void process(ChatExecuteContext chatExecuteContext, QueryResult queryResult) {
|
||||
SemanticParseInfo semanticParseInfo = chatExecuteContext.getParseInfo();
|
||||
public void process(ExecuteContext executeContext, QueryResult queryResult) {
|
||||
SemanticParseInfo semanticParseInfo = executeContext.getParseInfo();
|
||||
AggregatorConfig aggregatorConfig = ContextUtils.getBean(AggregatorConfig.class);
|
||||
if (CollectionUtils.isEmpty(semanticParseInfo.getMetrics())
|
||||
|| !aggregatorConfig.getEnableRatio()
|
||||
|| !QueryType.METRIC.equals(semanticParseInfo.getQueryType())) {
|
||||
return;
|
||||
}
|
||||
AggregateInfo aggregateInfo = getAggregateInfo(chatExecuteContext.getUser(),
|
||||
AggregateInfo aggregateInfo = getAggregateInfo(executeContext.getUser(),
|
||||
semanticParseInfo, queryResult);
|
||||
queryResult.setAggregateInfo(aggregateInfo);
|
||||
}
|
||||
|
||||
@@ -1,20 +1,19 @@
|
||||
package com.tencent.supersonic.chat.server.processor.execute;
|
||||
|
||||
import com.alibaba.fastjson.JSONObject;
|
||||
import com.tencent.supersonic.chat.server.pojo.ChatExecuteContext;
|
||||
import com.tencent.supersonic.chat.api.pojo.response.QueryResult;
|
||||
import com.tencent.supersonic.chat.server.pojo.ExecuteContext;
|
||||
import com.tencent.supersonic.common.pojo.Constants;
|
||||
import com.tencent.supersonic.common.pojo.enums.DictWordType;
|
||||
import com.tencent.supersonic.common.pojo.enums.QueryType;
|
||||
import com.tencent.supersonic.common.util.ContextUtils;
|
||||
import dev.langchain4j.store.embedding.Retrieval;
|
||||
import dev.langchain4j.store.embedding.RetrieveQuery;
|
||||
import dev.langchain4j.store.embedding.RetrieveQueryResult;
|
||||
import com.tencent.supersonic.headless.api.pojo.SchemaElement;
|
||||
import com.tencent.supersonic.headless.api.pojo.SchemaElementType;
|
||||
import com.tencent.supersonic.headless.api.pojo.SemanticParseInfo;
|
||||
import com.tencent.supersonic.headless.api.pojo.response.QueryResult;
|
||||
import com.tencent.supersonic.headless.chat.knowledge.MetaEmbeddingService;
|
||||
import java.util.Objects;
|
||||
import dev.langchain4j.store.embedding.Retrieval;
|
||||
import dev.langchain4j.store.embedding.RetrieveQuery;
|
||||
import dev.langchain4j.store.embedding.RetrieveQueryResult;
|
||||
import org.springframework.util.CollectionUtils;
|
||||
|
||||
import java.util.Collections;
|
||||
@@ -23,6 +22,7 @@ import java.util.HashMap;
|
||||
import java.util.HashSet;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
import java.util.Objects;
|
||||
import java.util.Set;
|
||||
import java.util.stream.Collectors;
|
||||
|
||||
@@ -34,8 +34,8 @@ public class MetricRecommendProcessor implements ExecuteResultProcessor {
|
||||
private static final int METRIC_RECOMMEND_SIZE = 5;
|
||||
|
||||
@Override
|
||||
public void process(ChatExecuteContext chatExecuteContext, QueryResult queryResult) {
|
||||
fillSimilarMetric(chatExecuteContext.getParseInfo());
|
||||
public void process(ExecuteContext executeContext, QueryResult queryResult) {
|
||||
fillSimilarMetric(executeContext.getParseInfo());
|
||||
}
|
||||
|
||||
private void fillSimilarMetric(SemanticParseInfo parseInfo) {
|
||||
@@ -45,8 +45,8 @@ public class MetricRecommendProcessor implements ExecuteResultProcessor {
|
||||
return;
|
||||
}
|
||||
List<String> metricNames = Collections.singletonList(parseInfo.getMetrics().iterator().next().getName());
|
||||
Map<String, String> filterCondition = new HashMap<>();
|
||||
filterCondition.put("modelId", parseInfo.getMetrics().iterator().next().getDataSet().toString());
|
||||
Map<String, Object> filterCondition = new HashMap<>();
|
||||
filterCondition.put("modelId", parseInfo.getMetrics().iterator().next().getDataSetId().toString());
|
||||
filterCondition.put("type", SchemaElementType.METRIC.name());
|
||||
RetrieveQuery retrieveQuery = RetrieveQuery.builder().queryTextsList(metricNames)
|
||||
.filterCondition(filterCondition).queryEmbeddings(null).build();
|
||||
@@ -78,7 +78,7 @@ public class MetricRecommendProcessor implements ExecuteResultProcessor {
|
||||
if (retrieval.getMetadata().containsKey("dataSetId")) {
|
||||
String dataSetId = retrieval.getMetadata().get("dataSetId").toString()
|
||||
.replace(Constants.UNDERLINE, "");
|
||||
schemaElement.setDataSet(Long.parseLong(dataSetId));
|
||||
schemaElement.setDataSetId(Long.parseLong(dataSetId));
|
||||
}
|
||||
schemaElement.setOrder(++metricOrder);
|
||||
parseInfo.getMetrics().add(schemaElement);
|
||||
|
||||
@@ -1,43 +0,0 @@
|
||||
package com.tencent.supersonic.chat.server.processor.parse;
|
||||
|
||||
import com.tencent.supersonic.chat.server.pojo.ChatParseContext;
|
||||
import com.tencent.supersonic.common.util.ContextUtils;
|
||||
import com.tencent.supersonic.headless.api.pojo.DataSetSchema;
|
||||
import com.tencent.supersonic.headless.api.pojo.EntityInfo;
|
||||
import com.tencent.supersonic.headless.api.pojo.SemanticParseInfo;
|
||||
import com.tencent.supersonic.headless.api.pojo.response.ParseResp;
|
||||
import com.tencent.supersonic.headless.chat.query.QueryManager;
|
||||
import com.tencent.supersonic.headless.server.facade.service.SemanticLayerService;
|
||||
import org.springframework.util.CollectionUtils;
|
||||
|
||||
import java.util.List;
|
||||
|
||||
/**
|
||||
* EntityInfoProcessor fills core attributes of an entity so that
|
||||
* users get to know which entity is parsed out.
|
||||
*/
|
||||
public class EntityInfoProcessor implements ParseResultProcessor {
|
||||
|
||||
@Override
|
||||
public void process(ChatParseContext chatParseContext, ParseResp parseResp) {
|
||||
List<SemanticParseInfo> selectedParses = parseResp.getSelectedParses();
|
||||
if (CollectionUtils.isEmpty(selectedParses)) {
|
||||
return;
|
||||
}
|
||||
selectedParses.forEach(parseInfo -> {
|
||||
String queryMode = parseInfo.getQueryMode();
|
||||
if (QueryManager.containsRuleQuery(queryMode) || "PLAIN".equals(queryMode)) {
|
||||
return;
|
||||
}
|
||||
|
||||
//1. set entity info
|
||||
SemanticLayerService semanticService = ContextUtils.getBean(SemanticLayerService.class);
|
||||
DataSetSchema dataSetSchema = semanticService.getDataSetSchema(parseInfo.getDataSetId());
|
||||
EntityInfo entityInfo = semanticService.getEntityInfo(parseInfo, dataSetSchema, chatParseContext.getUser());
|
||||
if (QueryManager.isTagQuery(queryMode)
|
||||
|| QueryManager.isMetricQuery(queryMode)) {
|
||||
parseInfo.setEntityInfo(entityInfo);
|
||||
}
|
||||
});
|
||||
}
|
||||
}
|
||||
@@ -1,10 +1,15 @@
|
||||
package com.tencent.supersonic.chat.server.processor.parse;
|
||||
|
||||
import com.tencent.supersonic.chat.server.pojo.ChatParseContext;
|
||||
import com.tencent.supersonic.chat.server.pojo.ParseContext;
|
||||
import com.tencent.supersonic.chat.server.processor.ResultProcessor;
|
||||
import com.tencent.supersonic.headless.api.pojo.response.ParseResp;
|
||||
|
||||
public interface ParseResultProcessor {
|
||||
/**
|
||||
* A ParseResultProcessor wraps things up before returning
|
||||
* parsing results to the users.
|
||||
*/
|
||||
public interface ParseResultProcessor extends ResultProcessor {
|
||||
|
||||
void process(ChatParseContext chatParseContext, ParseResp parseResp);
|
||||
void process(ParseContext parseContext, ParseResp parseResp);
|
||||
|
||||
}
|
||||
|
||||
@@ -5,9 +5,9 @@ import com.baomidou.mybatisplus.core.conditions.update.UpdateWrapper;
|
||||
import com.tencent.supersonic.chat.api.pojo.response.SimilarQueryRecallResp;
|
||||
import com.tencent.supersonic.chat.server.persistence.dataobject.ChatQueryDO;
|
||||
import com.tencent.supersonic.chat.server.persistence.repository.ChatQueryRepository;
|
||||
import com.tencent.supersonic.chat.server.pojo.ChatParseContext;
|
||||
import com.tencent.supersonic.chat.server.pojo.ParseContext;
|
||||
import com.tencent.supersonic.common.config.EmbeddingConfig;
|
||||
import com.tencent.supersonic.common.pojo.SqlExemplar;
|
||||
import com.tencent.supersonic.common.pojo.Text2SQLExemplar;
|
||||
import com.tencent.supersonic.common.service.ExemplarService;
|
||||
import com.tencent.supersonic.common.util.ContextUtils;
|
||||
import com.tencent.supersonic.headless.api.pojo.response.ParseResp;
|
||||
@@ -25,15 +25,15 @@ import java.util.stream.Collectors;
|
||||
public class QueryRecommendProcessor implements ParseResultProcessor {
|
||||
|
||||
@Override
|
||||
public void process(ChatParseContext chatParseContext, ParseResp parseResp) {
|
||||
CompletableFuture.runAsync(() -> doProcess(parseResp, chatParseContext));
|
||||
public void process(ParseContext parseContext, ParseResp parseResp) {
|
||||
CompletableFuture.runAsync(() -> doProcess(parseResp, parseContext));
|
||||
}
|
||||
|
||||
@SneakyThrows
|
||||
private void doProcess(ParseResp parseResp, ChatParseContext chatParseContext) {
|
||||
private void doProcess(ParseResp parseResp, ParseContext parseContext) {
|
||||
Long queryId = parseResp.getQueryId();
|
||||
List<SimilarQueryRecallResp> solvedQueries = getSimilarQueries(chatParseContext.getQueryText(),
|
||||
chatParseContext.getAgent().getId());
|
||||
List<SimilarQueryRecallResp> solvedQueries = getSimilarQueries(parseContext.getQueryText(),
|
||||
parseContext.getAgent().getId());
|
||||
ChatQueryDO chatQueryDO = getChatQuery(queryId);
|
||||
chatQueryDO.setSimilarQueries(JSONObject.toJSONString(solvedQueries));
|
||||
updateChatQuery(chatQueryDO);
|
||||
@@ -43,7 +43,7 @@ public class QueryRecommendProcessor implements ParseResultProcessor {
|
||||
ExemplarService exemplarService = ContextUtils.getBean(ExemplarService.class);
|
||||
EmbeddingConfig embeddingConfig = ContextUtils.getBean(EmbeddingConfig.class);
|
||||
String memoryCollectionName = embeddingConfig.getMemoryCollectionName(agentId);
|
||||
List<SqlExemplar> exemplars = exemplarService.recallExemplars(memoryCollectionName, queryText, 5);
|
||||
List<Text2SQLExemplar> exemplars = exemplarService.recallExemplars(memoryCollectionName, queryText, 5);
|
||||
return exemplars.stream().map(sqlExemplar ->
|
||||
SimilarQueryRecallResp.builder().queryText(sqlExemplar.getQuestion()).build())
|
||||
.collect(Collectors.toList());
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
package com.tencent.supersonic.chat.server.processor.parse;
|
||||
|
||||
|
||||
import com.tencent.supersonic.chat.server.pojo.ChatParseContext;
|
||||
import com.tencent.supersonic.chat.server.pojo.ParseContext;
|
||||
import com.tencent.supersonic.headless.api.pojo.response.ParseResp;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
|
||||
@@ -12,7 +12,7 @@ import lombok.extern.slf4j.Slf4j;
|
||||
public class TimeCostProcessor implements ParseResultProcessor {
|
||||
|
||||
@Override
|
||||
public void process(ChatParseContext chatParseContext, ParseResp parseResp) {
|
||||
public void process(ParseContext parseContext, ParseResp parseResp) {
|
||||
long parseStartTime = parseResp.getParseTimeCost().getParseStartTime();
|
||||
parseResp.getParseTimeCost().setParseTime(
|
||||
System.currentTimeMillis() - parseStartTime - parseResp.getParseTimeCost().getSqlTime());
|
||||
|
||||
@@ -6,7 +6,7 @@ import com.tencent.supersonic.chat.server.agent.Agent;
|
||||
import com.tencent.supersonic.chat.server.agent.AgentToolType;
|
||||
import com.tencent.supersonic.chat.server.service.AgentService;
|
||||
import com.tencent.supersonic.chat.server.util.LLMConnHelper;
|
||||
import com.tencent.supersonic.common.config.LLMConfig;
|
||||
import com.tencent.supersonic.common.pojo.ChatModelConfig;
|
||||
import org.springframework.beans.factory.annotation.Autowired;
|
||||
import org.springframework.web.bind.annotation.DeleteMapping;
|
||||
import org.springframework.web.bind.annotation.PathVariable;
|
||||
@@ -15,6 +15,7 @@ import org.springframework.web.bind.annotation.PutMapping;
|
||||
import org.springframework.web.bind.annotation.RequestBody;
|
||||
import org.springframework.web.bind.annotation.RequestMapping;
|
||||
import org.springframework.web.bind.annotation.RestController;
|
||||
|
||||
import javax.servlet.http.HttpServletRequest;
|
||||
import javax.servlet.http.HttpServletResponse;
|
||||
import java.util.List;
|
||||
@@ -29,16 +30,16 @@ public class AgentController {
|
||||
|
||||
@PostMapping
|
||||
public Agent createAgent(@RequestBody Agent agent,
|
||||
HttpServletRequest httpServletRequest,
|
||||
HttpServletResponse httpServletResponse) {
|
||||
HttpServletRequest httpServletRequest,
|
||||
HttpServletResponse httpServletResponse) {
|
||||
User user = UserHolder.findUser(httpServletRequest, httpServletResponse);
|
||||
return agentService.createAgent(agent, user);
|
||||
}
|
||||
|
||||
@PutMapping
|
||||
public Agent updateAgent(@RequestBody Agent agent,
|
||||
HttpServletRequest httpServletRequest,
|
||||
HttpServletResponse httpServletResponse) {
|
||||
HttpServletRequest httpServletRequest,
|
||||
HttpServletResponse httpServletResponse) {
|
||||
User user = UserHolder.findUser(httpServletRequest, httpServletResponse);
|
||||
return agentService.updateAgent(agent, user);
|
||||
}
|
||||
@@ -50,8 +51,8 @@ public class AgentController {
|
||||
}
|
||||
|
||||
@PostMapping("/testLLMConn")
|
||||
public boolean testLLMConn(@RequestBody LLMConfig llmConfig) {
|
||||
return LLMConnHelper.testConnection(llmConfig);
|
||||
public boolean testLLMConn(@RequestBody ChatModelConfig modelConfig) {
|
||||
return LLMConnHelper.testConnection(modelConfig);
|
||||
}
|
||||
|
||||
@RequestMapping("/getAgentList")
|
||||
|
||||
@@ -6,11 +6,10 @@ import com.tencent.supersonic.auth.api.authentication.utils.UserHolder;
|
||||
import com.tencent.supersonic.chat.api.pojo.request.ChatExecuteReq;
|
||||
import com.tencent.supersonic.chat.api.pojo.request.ChatParseReq;
|
||||
import com.tencent.supersonic.chat.api.pojo.request.ChatQueryDataReq;
|
||||
import com.tencent.supersonic.chat.server.service.ChatService;
|
||||
import com.tencent.supersonic.chat.server.service.ChatQueryService;
|
||||
import com.tencent.supersonic.common.pojo.exception.InvalidArgumentException;
|
||||
import com.tencent.supersonic.headless.api.pojo.SemanticParseInfo;
|
||||
import com.tencent.supersonic.headless.api.pojo.request.DimensionValueReq;
|
||||
import com.tencent.supersonic.headless.api.pojo.request.QueryReq;
|
||||
import com.tencent.supersonic.headless.api.pojo.response.ParseResp;
|
||||
import org.apache.commons.collections.CollectionUtils;
|
||||
import org.springframework.beans.BeanUtils;
|
||||
@@ -32,20 +31,20 @@ import javax.validation.Valid;
|
||||
public class ChatQueryController {
|
||||
|
||||
@Autowired
|
||||
private ChatService chatService;
|
||||
private ChatQueryService chatQueryService;
|
||||
|
||||
@PostMapping("search")
|
||||
public Object search(@RequestBody ChatParseReq chatParseReq, HttpServletRequest request,
|
||||
HttpServletResponse response) {
|
||||
chatParseReq.setUser(UserHolder.findUser(request, response));
|
||||
return chatService.search(chatParseReq);
|
||||
return chatQueryService.search(chatParseReq);
|
||||
}
|
||||
|
||||
@PostMapping("parse")
|
||||
public Object parse(@RequestBody ChatParseReq chatParseReq,
|
||||
HttpServletRequest request, HttpServletResponse response) throws Exception {
|
||||
chatParseReq.setUser(UserHolder.findUser(request, response));
|
||||
return chatService.performParsing(chatParseReq);
|
||||
return chatQueryService.performParsing(chatParseReq);
|
||||
}
|
||||
|
||||
@PostMapping("execute")
|
||||
@@ -53,7 +52,7 @@ public class ChatQueryController {
|
||||
HttpServletRequest request, HttpServletResponse response)
|
||||
throws Exception {
|
||||
chatExecuteReq.setUser(UserHolder.findUser(request, response));
|
||||
return chatService.performExecution(chatExecuteReq);
|
||||
return chatQueryService.performExecution(chatExecuteReq);
|
||||
}
|
||||
|
||||
@PostMapping("/")
|
||||
@@ -62,7 +61,7 @@ public class ChatQueryController {
|
||||
throws Exception {
|
||||
User user = UserHolder.findUser(request, response);
|
||||
chatParseReq.setUser(user);
|
||||
ParseResp parseResp = chatService.performParsing(chatParseReq);
|
||||
ParseResp parseResp = chatQueryService.performParsing(chatParseReq);
|
||||
|
||||
if (CollectionUtils.isEmpty(parseResp.getSelectedParses())) {
|
||||
throw new InvalidArgumentException("parser error,no selectedParses");
|
||||
@@ -72,27 +71,20 @@ public class ChatQueryController {
|
||||
BeanUtils.copyProperties(chatParseReq, chatExecuteReq);
|
||||
chatExecuteReq.setQueryId(parseResp.getQueryId());
|
||||
chatExecuteReq.setParseId(semanticParseInfo.getId());
|
||||
return chatService.performExecution(chatExecuteReq);
|
||||
}
|
||||
|
||||
@PostMapping("queryContext")
|
||||
public Object queryContext(@RequestBody QueryReq queryCtx,
|
||||
HttpServletRequest request, HttpServletResponse response) {
|
||||
queryCtx.setUser(UserHolder.findUser(request, response));
|
||||
return chatService.queryContext(queryCtx.getChatId());
|
||||
return chatQueryService.performExecution(chatExecuteReq);
|
||||
}
|
||||
|
||||
@PostMapping("queryData")
|
||||
public Object queryData(@RequestBody ChatQueryDataReq chatQueryDataReq,
|
||||
HttpServletRequest request, HttpServletResponse response) throws Exception {
|
||||
chatQueryDataReq.setUser(UserHolder.findUser(request, response));
|
||||
return chatService.queryData(chatQueryDataReq, UserHolder.findUser(request, response));
|
||||
return chatQueryService.queryData(chatQueryDataReq, UserHolder.findUser(request, response));
|
||||
}
|
||||
|
||||
@PostMapping("queryDimensionValue")
|
||||
public Object queryDimensionValue(@RequestBody @Valid DimensionValueReq dimensionValueReq,
|
||||
HttpServletRequest request, HttpServletResponse response) throws Exception {
|
||||
return chatService.queryDimensionValue(dimensionValueReq, UserHolder.findUser(request, response));
|
||||
return chatQueryService.queryDimensionValue(dimensionValueReq, UserHolder.findUser(request, response));
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
@@ -0,0 +1,11 @@
|
||||
package com.tencent.supersonic.chat.server.service;
|
||||
|
||||
import com.tencent.supersonic.chat.server.pojo.ChatContext;
|
||||
|
||||
public interface ChatContextService {
|
||||
|
||||
ChatContext getOrCreateContext(Integer chatId);
|
||||
|
||||
void updateContext(ChatContext chatCtx);
|
||||
|
||||
}
|
||||
@@ -12,7 +12,7 @@ import com.tencent.supersonic.chat.server.persistence.dataobject.ChatParseDO;
|
||||
import com.tencent.supersonic.chat.server.persistence.dataobject.ChatQueryDO;
|
||||
import com.tencent.supersonic.headless.api.pojo.SemanticParseInfo;
|
||||
import com.tencent.supersonic.headless.api.pojo.response.ParseResp;
|
||||
import com.tencent.supersonic.headless.api.pojo.response.QueryResult;
|
||||
import com.tencent.supersonic.chat.api.pojo.response.QueryResult;
|
||||
|
||||
import java.util.List;
|
||||
|
||||
@@ -35,6 +35,8 @@ public interface ChatManageService {
|
||||
|
||||
QueryResp getChatQuery(Long queryId);
|
||||
|
||||
List<QueryResp> getChatQueries(Integer chatId);
|
||||
|
||||
ShowCaseResp queryShowCase(PageQueryInfoReq pageQueryInfoReq, int agentId);
|
||||
|
||||
ChatQueryDO saveQueryResult(ChatExecuteReq chatExecuteReq, QueryResult queryResult);
|
||||
|
||||
@@ -4,15 +4,14 @@ import com.tencent.supersonic.auth.api.authentication.pojo.User;
|
||||
import com.tencent.supersonic.chat.api.pojo.request.ChatExecuteReq;
|
||||
import com.tencent.supersonic.chat.api.pojo.request.ChatParseReq;
|
||||
import com.tencent.supersonic.chat.api.pojo.request.ChatQueryDataReq;
|
||||
import com.tencent.supersonic.headless.api.pojo.SemanticParseInfo;
|
||||
import com.tencent.supersonic.headless.api.pojo.request.DimensionValueReq;
|
||||
import com.tencent.supersonic.headless.api.pojo.response.ParseResp;
|
||||
import com.tencent.supersonic.headless.api.pojo.response.QueryResult;
|
||||
import com.tencent.supersonic.chat.api.pojo.response.QueryResult;
|
||||
import com.tencent.supersonic.headless.api.pojo.response.SearchResult;
|
||||
|
||||
import java.util.List;
|
||||
|
||||
public interface ChatService {
|
||||
public interface ChatQueryService {
|
||||
|
||||
List<SearchResult> search(ChatParseReq chatParseReq);
|
||||
|
||||
@@ -24,8 +23,6 @@ public interface ChatService {
|
||||
|
||||
Object queryData(ChatQueryDataReq chatQueryDataReq, User user) throws Exception;
|
||||
|
||||
SemanticParseInfo queryContext(Integer chatId);
|
||||
|
||||
Object queryDimensionValue(DimensionValueReq dimensionValueReq, User user) throws Exception;
|
||||
|
||||
}
|
||||
@@ -16,6 +16,10 @@ public interface MemoryService {
|
||||
|
||||
void updateMemory(ChatMemoryDO memory);
|
||||
|
||||
void enableMemory(ChatMemoryDO memory);
|
||||
|
||||
void disableMemory(ChatMemoryDO memory);
|
||||
|
||||
PageInfo<ChatMemoryDO> pageMemories(PageMemoryReq pageMemoryReq);
|
||||
|
||||
List<ChatMemoryDO> getMemories(ChatMemoryFilter chatMemoryFilter);
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
package com.tencent.supersonic.chat.server.service;
|
||||
|
||||
import com.tencent.supersonic.headless.server.persistence.dataobject.StatisticsDO;
|
||||
import com.tencent.supersonic.chat.server.persistence.dataobject.StatisticsDO;
|
||||
|
||||
import java.util.List;
|
||||
|
||||
|
||||
@@ -9,16 +9,18 @@ import com.tencent.supersonic.chat.server.persistence.dataobject.AgentDO;
|
||||
import com.tencent.supersonic.chat.server.persistence.dataobject.ChatMemoryDO;
|
||||
import com.tencent.supersonic.chat.server.persistence.mapper.AgentDOMapper;
|
||||
import com.tencent.supersonic.chat.server.service.AgentService;
|
||||
import com.tencent.supersonic.chat.server.service.ChatService;
|
||||
import com.tencent.supersonic.chat.server.service.ChatQueryService;
|
||||
import com.tencent.supersonic.chat.server.service.MemoryService;
|
||||
import com.tencent.supersonic.chat.server.util.LLMConnHelper;
|
||||
import com.tencent.supersonic.common.config.LLMConfig;
|
||||
import com.tencent.supersonic.common.config.PromptConfig;
|
||||
import com.tencent.supersonic.common.config.VisualConfig;
|
||||
import com.tencent.supersonic.common.pojo.ChatModelConfig;
|
||||
import com.tencent.supersonic.common.util.JsonUtil;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.springframework.beans.BeanUtils;
|
||||
import org.springframework.beans.factory.annotation.Autowired;
|
||||
import org.springframework.stereotype.Service;
|
||||
import org.springframework.util.CollectionUtils;
|
||||
|
||||
import java.util.List;
|
||||
import java.util.concurrent.ExecutorService;
|
||||
@@ -34,7 +36,7 @@ public class AgentServiceImpl extends ServiceImpl<AgentDOMapper, AgentDO>
|
||||
private MemoryService memoryService;
|
||||
|
||||
@Autowired
|
||||
private ChatService chatService;
|
||||
private ChatQueryService chatQueryService;
|
||||
|
||||
private ExecutorService executorService = Executors.newFixedThreadPool(1);
|
||||
|
||||
@@ -78,6 +80,7 @@ public class AgentServiceImpl extends ServiceImpl<AgentDOMapper, AgentDO>
|
||||
/**
|
||||
* the example in the agent will be executed by default,
|
||||
* if the result is correct, it will be put into memory as a reference for LLM
|
||||
*
|
||||
* @param agent
|
||||
*/
|
||||
private void executeAgentExamplesAsync(Agent agent) {
|
||||
@@ -85,9 +88,11 @@ public class AgentServiceImpl extends ServiceImpl<AgentDOMapper, AgentDO>
|
||||
}
|
||||
|
||||
private synchronized void doExecuteAgentExamples(Agent agent) {
|
||||
if (!agent.containsLLMParserTool() || !LLMConnHelper.testConnection(agent.getLlmConfig())) {
|
||||
if (!agent.containsLLMParserTool() || !LLMConnHelper.testConnection(agent.getModelConfig())
|
||||
|| CollectionUtils.isEmpty(agent.getExamples())) {
|
||||
return;
|
||||
}
|
||||
|
||||
List<String> examples = agent.getExamples();
|
||||
ChatMemoryFilter chatMemoryFilter = ChatMemoryFilter.builder().agentId(agent.getId())
|
||||
.questions(examples).build();
|
||||
@@ -98,7 +103,7 @@ public class AgentServiceImpl extends ServiceImpl<AgentDOMapper, AgentDO>
|
||||
continue;
|
||||
}
|
||||
try {
|
||||
chatService.parseAndExecute(-1, agent.getId(), example);
|
||||
chatQueryService.parseAndExecute(-1, agent.getId(), example);
|
||||
} catch (Exception e) {
|
||||
log.warn("agent:{} example execute failed:{}", agent.getName(), example);
|
||||
}
|
||||
@@ -117,7 +122,8 @@ public class AgentServiceImpl extends ServiceImpl<AgentDOMapper, AgentDO>
|
||||
BeanUtils.copyProperties(agentDO, agent);
|
||||
agent.setAgentConfig(agentDO.getConfig());
|
||||
agent.setExamples(JsonUtil.toList(agentDO.getExamples(), String.class));
|
||||
agent.setLlmConfig(JsonUtil.toObject(agentDO.getLlmConfig(), LLMConfig.class));
|
||||
agent.setModelConfig(JsonUtil.toObject(agentDO.getModelConfig(), ChatModelConfig.class));
|
||||
agent.setPromptConfig(JsonUtil.toObject(agentDO.getPromptConfig(), PromptConfig.class));
|
||||
agent.setMultiTurnConfig(JsonUtil.toObject(agentDO.getMultiTurnConfig(), MultiTurnConfig.class));
|
||||
agent.setVisualConfig(JsonUtil.toObject(agentDO.getVisualConfig(), VisualConfig.class));
|
||||
return agent;
|
||||
@@ -128,9 +134,10 @@ public class AgentServiceImpl extends ServiceImpl<AgentDOMapper, AgentDO>
|
||||
BeanUtils.copyProperties(agent, agentDO);
|
||||
agentDO.setConfig(agent.getAgentConfig());
|
||||
agentDO.setExamples(JsonUtil.toString(agent.getExamples()));
|
||||
agentDO.setLlmConfig(JsonUtil.toString(agent.getLlmConfig()));
|
||||
agentDO.setModelConfig(JsonUtil.toString(agent.getModelConfig()));
|
||||
agentDO.setMultiTurnConfig(JsonUtil.toString(agent.getMultiTurnConfig()));
|
||||
agentDO.setVisualConfig(JsonUtil.toString(agent.getVisualConfig()));
|
||||
agentDO.setPromptConfig(JsonUtil.toString(agent.getPromptConfig()));
|
||||
if (agentDO.getStatus() == null) {
|
||||
agentDO.setStatus(1);
|
||||
}
|
||||
|
||||
@@ -0,0 +1,30 @@
|
||||
package com.tencent.supersonic.chat.server.service.impl;
|
||||
|
||||
import com.tencent.supersonic.chat.server.service.ChatContextService;
|
||||
import com.tencent.supersonic.chat.server.pojo.ChatContext;
|
||||
import com.tencent.supersonic.chat.server.persistence.repository.ChatContextRepository;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.springframework.stereotype.Service;
|
||||
|
||||
@Slf4j
|
||||
@Service
|
||||
public class ChatContextServiceImpl implements ChatContextService {
|
||||
|
||||
private ChatContextRepository chatContextRepository;
|
||||
|
||||
public ChatContextServiceImpl(ChatContextRepository chatContextRepository) {
|
||||
this.chatContextRepository = chatContextRepository;
|
||||
}
|
||||
|
||||
@Override
|
||||
public ChatContext getOrCreateContext(Integer chatId) {
|
||||
return chatContextRepository.getOrCreateContext(chatId);
|
||||
}
|
||||
|
||||
@Override
|
||||
public void updateContext(ChatContext chatCtx) {
|
||||
log.debug("save ChatContext {}", chatCtx);
|
||||
chatContextRepository.updateContext(chatCtx);
|
||||
}
|
||||
|
||||
}
|
||||
@@ -18,7 +18,7 @@ import com.tencent.supersonic.chat.server.service.ChatManageService;
|
||||
import com.tencent.supersonic.common.util.JsonUtil;
|
||||
import com.tencent.supersonic.headless.api.pojo.SemanticParseInfo;
|
||||
import com.tencent.supersonic.headless.api.pojo.response.ParseResp;
|
||||
import com.tencent.supersonic.headless.api.pojo.response.QueryResult;
|
||||
import com.tencent.supersonic.chat.api.pojo.response.QueryResult;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.springframework.beans.factory.annotation.Autowired;
|
||||
import org.springframework.stereotype.Service;
|
||||
@@ -107,6 +107,13 @@ public class ChatManageServiceImpl implements ChatManageService {
|
||||
return chatQueryRepository.getChatQuery(queryId);
|
||||
}
|
||||
|
||||
@Override
|
||||
public List<QueryResp> getChatQueries(Integer chatId) {
|
||||
List<QueryResp> queries = chatQueryRepository.getChatQueries(chatId);
|
||||
fillParseInfo(queries);
|
||||
return queries;
|
||||
}
|
||||
|
||||
@Override
|
||||
public ShowCaseResp queryShowCase(PageQueryInfoReq pageQueryInfoReq, int agentId) {
|
||||
ShowCaseResp showCaseResp = new ShowCaseResp();
|
||||
|
||||
@@ -0,0 +1,548 @@
|
||||
package com.tencent.supersonic.chat.server.service.impl;
|
||||
|
||||
import com.google.common.collect.Lists;
|
||||
import com.tencent.supersonic.auth.api.authentication.pojo.User;
|
||||
import com.tencent.supersonic.chat.api.pojo.request.ChatExecuteReq;
|
||||
import com.tencent.supersonic.chat.api.pojo.request.ChatParseReq;
|
||||
import com.tencent.supersonic.chat.api.pojo.request.ChatQueryDataReq;
|
||||
import com.tencent.supersonic.chat.api.pojo.response.QueryResult;
|
||||
import com.tencent.supersonic.chat.server.agent.Agent;
|
||||
import com.tencent.supersonic.chat.server.executor.ChatQueryExecutor;
|
||||
import com.tencent.supersonic.chat.server.parser.ChatQueryParser;
|
||||
import com.tencent.supersonic.chat.server.pojo.ExecuteContext;
|
||||
import com.tencent.supersonic.chat.server.pojo.ParseContext;
|
||||
import com.tencent.supersonic.chat.server.processor.execute.ExecuteResultProcessor;
|
||||
import com.tencent.supersonic.chat.server.processor.parse.ParseResultProcessor;
|
||||
import com.tencent.supersonic.chat.server.service.AgentService;
|
||||
import com.tencent.supersonic.chat.server.service.ChatManageService;
|
||||
import com.tencent.supersonic.chat.server.service.ChatQueryService;
|
||||
import com.tencent.supersonic.chat.server.util.ComponentFactory;
|
||||
import com.tencent.supersonic.chat.server.util.QueryReqConverter;
|
||||
import com.tencent.supersonic.common.jsqlparser.FieldExpression;
|
||||
import com.tencent.supersonic.common.jsqlparser.SqlAddHelper;
|
||||
import com.tencent.supersonic.common.jsqlparser.SqlRemoveHelper;
|
||||
import com.tencent.supersonic.common.jsqlparser.SqlReplaceHelper;
|
||||
import com.tencent.supersonic.common.jsqlparser.SqlSelectHelper;
|
||||
import com.tencent.supersonic.common.pojo.enums.FilterOperatorEnum;
|
||||
import com.tencent.supersonic.common.pojo.enums.TimeDimensionEnum;
|
||||
import com.tencent.supersonic.common.util.BeanMapper;
|
||||
import com.tencent.supersonic.common.util.ContextUtils;
|
||||
import com.tencent.supersonic.common.util.DateUtils;
|
||||
import com.tencent.supersonic.common.util.JsonUtil;
|
||||
import com.tencent.supersonic.headless.api.pojo.DataSetSchema;
|
||||
import com.tencent.supersonic.headless.api.pojo.EntityInfo;
|
||||
import com.tencent.supersonic.headless.api.pojo.SchemaElement;
|
||||
import com.tencent.supersonic.headless.api.pojo.SemanticParseInfo;
|
||||
import com.tencent.supersonic.headless.api.pojo.SqlInfo;
|
||||
import com.tencent.supersonic.headless.api.pojo.request.DimensionValueReq;
|
||||
import com.tencent.supersonic.headless.api.pojo.request.QueryFilter;
|
||||
import com.tencent.supersonic.headless.api.pojo.request.QueryNLReq;
|
||||
import com.tencent.supersonic.headless.api.pojo.request.SemanticQueryReq;
|
||||
import com.tencent.supersonic.headless.api.pojo.response.MapResp;
|
||||
import com.tencent.supersonic.headless.api.pojo.response.ParseResp;
|
||||
import com.tencent.supersonic.headless.api.pojo.response.QueryState;
|
||||
import com.tencent.supersonic.headless.api.pojo.response.SearchResult;
|
||||
import com.tencent.supersonic.headless.api.pojo.response.SemanticQueryResp;
|
||||
import com.tencent.supersonic.headless.api.pojo.response.SemanticTranslateResp;
|
||||
import com.tencent.supersonic.headless.chat.query.QueryManager;
|
||||
import com.tencent.supersonic.headless.chat.query.SemanticQuery;
|
||||
import com.tencent.supersonic.headless.chat.query.llm.s2sql.LLMSqlQuery;
|
||||
import com.tencent.supersonic.headless.server.facade.service.ChatLayerService;
|
||||
import com.tencent.supersonic.headless.server.facade.service.SemanticLayerService;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import net.sf.jsqlparser.expression.Expression;
|
||||
import net.sf.jsqlparser.expression.LongValue;
|
||||
import net.sf.jsqlparser.expression.StringValue;
|
||||
import net.sf.jsqlparser.expression.operators.relational.ComparisonOperator;
|
||||
import net.sf.jsqlparser.expression.operators.relational.GreaterThanEquals;
|
||||
import net.sf.jsqlparser.expression.operators.relational.InExpression;
|
||||
import net.sf.jsqlparser.expression.operators.relational.MinorThanEquals;
|
||||
import net.sf.jsqlparser.expression.operators.relational.ParenthesedExpressionList;
|
||||
import net.sf.jsqlparser.schema.Column;
|
||||
import org.apache.commons.lang3.StringUtils;
|
||||
import org.apache.commons.lang3.tuple.Pair;
|
||||
import org.springframework.beans.factory.annotation.Autowired;
|
||||
import org.springframework.stereotype.Service;
|
||||
import org.springframework.util.CollectionUtils;
|
||||
|
||||
import java.util.ArrayList;
|
||||
import java.util.HashMap;
|
||||
import java.util.HashSet;
|
||||
import java.util.Iterator;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
import java.util.Objects;
|
||||
import java.util.Set;
|
||||
import java.util.stream.Collectors;
|
||||
|
||||
|
||||
@Slf4j
|
||||
@Service
|
||||
public class ChatQueryServiceImpl implements ChatQueryService {
|
||||
|
||||
@Autowired
|
||||
private ChatManageService chatManageService;
|
||||
@Autowired
|
||||
private ChatLayerService chatLayerService;
|
||||
@Autowired
|
||||
private SemanticLayerService semanticLayerService;
|
||||
@Autowired
|
||||
private AgentService agentService;
|
||||
|
||||
private List<ChatQueryParser> chatQueryParsers = ComponentFactory.getChatParsers();
|
||||
private List<ChatQueryExecutor> chatQueryExecutors = ComponentFactory.getChatExecutors();
|
||||
private List<ParseResultProcessor> parseResultProcessors = ComponentFactory.getParseProcessors();
|
||||
private List<ExecuteResultProcessor> executeResultProcessors = ComponentFactory.getExecuteProcessors();
|
||||
|
||||
@Override
|
||||
public List<SearchResult> search(ChatParseReq chatParseReq) {
|
||||
ParseContext parseContext = buildParseContext(chatParseReq);
|
||||
Agent agent = parseContext.getAgent();
|
||||
if (!agent.enableSearch()) {
|
||||
return Lists.newArrayList();
|
||||
}
|
||||
QueryNLReq queryNLReq = QueryReqConverter.buildText2SqlQueryReq(parseContext);
|
||||
return chatLayerService.retrieve(queryNLReq);
|
||||
}
|
||||
|
||||
@Override
|
||||
public ParseResp performParsing(ChatParseReq chatParseReq) {
|
||||
ParseResp parseResp = new ParseResp(chatParseReq.getQueryText());
|
||||
chatManageService.createChatQuery(chatParseReq, parseResp);
|
||||
ParseContext parseContext = buildParseContext(chatParseReq);
|
||||
supplyMapInfo(parseContext);
|
||||
for (ChatQueryParser chatQueryParser : chatQueryParsers) {
|
||||
chatQueryParser.parse(parseContext, parseResp);
|
||||
}
|
||||
for (ParseResultProcessor processor : parseResultProcessors) {
|
||||
processor.process(parseContext, parseResp);
|
||||
}
|
||||
chatParseReq.setQueryText(parseContext.getQueryText());
|
||||
chatManageService.batchAddParse(chatParseReq, parseResp);
|
||||
chatManageService.updateParseCostTime(parseResp);
|
||||
return parseResp;
|
||||
}
|
||||
|
||||
@Override
|
||||
public QueryResult performExecution(ChatExecuteReq chatExecuteReq) {
|
||||
QueryResult queryResult = new QueryResult();
|
||||
ExecuteContext executeContext = buildExecuteContext(chatExecuteReq);
|
||||
for (ChatQueryExecutor chatQueryExecutor : chatQueryExecutors) {
|
||||
queryResult = chatQueryExecutor.execute(executeContext);
|
||||
if (queryResult != null) {
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
if (queryResult != null) {
|
||||
for (ExecuteResultProcessor processor : executeResultProcessors) {
|
||||
processor.process(executeContext, queryResult);
|
||||
}
|
||||
saveQueryResult(chatExecuteReq, queryResult);
|
||||
}
|
||||
|
||||
return queryResult;
|
||||
}
|
||||
|
||||
@Override
|
||||
public QueryResult parseAndExecute(int chatId, int agentId, String queryText) {
|
||||
ChatParseReq chatParseReq = new ChatParseReq();
|
||||
chatParseReq.setQueryText(queryText);
|
||||
chatParseReq.setChatId(chatId);
|
||||
chatParseReq.setAgentId(agentId);
|
||||
chatParseReq.setUser(User.getFakeUser());
|
||||
ParseResp parseResp = performParsing(chatParseReq);
|
||||
if (CollectionUtils.isEmpty(parseResp.getSelectedParses())) {
|
||||
log.debug("chatId:{}, agentId:{}, queryText:{}, parseResp.getSelectedParses() is empty",
|
||||
chatId, agentId, queryText);
|
||||
return null;
|
||||
}
|
||||
ChatExecuteReq executeReq = new ChatExecuteReq();
|
||||
executeReq.setQueryId(parseResp.getQueryId());
|
||||
executeReq.setParseId(parseResp.getSelectedParses().get(0).getId());
|
||||
executeReq.setQueryText(queryText);
|
||||
executeReq.setChatId(chatId);
|
||||
executeReq.setUser(User.getFakeUser());
|
||||
executeReq.setAgentId(agentId);
|
||||
executeReq.setSaveAnswer(true);
|
||||
return performExecution(executeReq);
|
||||
}
|
||||
|
||||
private ParseContext buildParseContext(ChatParseReq chatParseReq) {
|
||||
ParseContext parseContext = new ParseContext();
|
||||
BeanMapper.mapper(chatParseReq, parseContext);
|
||||
Agent agent = agentService.getAgent(chatParseReq.getAgentId());
|
||||
parseContext.setAgent(agent);
|
||||
return parseContext;
|
||||
}
|
||||
|
||||
private void supplyMapInfo(ParseContext parseContext) {
|
||||
QueryNLReq queryNLReq = QueryReqConverter.buildText2SqlQueryReq(parseContext);
|
||||
MapResp mapResp = chatLayerService.performMapping(queryNLReq);
|
||||
parseContext.setMapInfo(mapResp.getMapInfo());
|
||||
}
|
||||
|
||||
private ExecuteContext buildExecuteContext(ChatExecuteReq chatExecuteReq) {
|
||||
ExecuteContext executeContext = new ExecuteContext();
|
||||
BeanMapper.mapper(chatExecuteReq, executeContext);
|
||||
SemanticParseInfo parseInfo = chatManageService.getParseInfo(
|
||||
chatExecuteReq.getQueryId(), chatExecuteReq.getParseId());
|
||||
Agent agent = agentService.getAgent(chatExecuteReq.getAgentId());
|
||||
executeContext.setAgent(agent);
|
||||
executeContext.setParseInfo(parseInfo);
|
||||
return executeContext;
|
||||
}
|
||||
|
||||
@Override
|
||||
public Object queryData(ChatQueryDataReq chatQueryDataReq, User user) throws Exception {
|
||||
Integer parseId = chatQueryDataReq.getParseId();
|
||||
SemanticParseInfo parseInfo = chatManageService.getParseInfo(chatQueryDataReq.getQueryId(), parseId);
|
||||
parseInfo = mergeParseInfo(parseInfo, chatQueryDataReq);
|
||||
DataSetSchema dataSetSchema = semanticLayerService.getDataSetSchema(parseInfo.getDataSetId());
|
||||
|
||||
SemanticQuery semanticQuery = QueryManager.createQuery(parseInfo.getQueryMode());
|
||||
semanticQuery.setParseInfo(parseInfo);
|
||||
|
||||
if (LLMSqlQuery.QUERY_MODE.equalsIgnoreCase(parseInfo.getQueryMode())) {
|
||||
handleLLMQueryMode(chatQueryDataReq, semanticQuery, user);
|
||||
} else {
|
||||
handleRuleQueryMode(semanticQuery, dataSetSchema, user);
|
||||
}
|
||||
|
||||
return executeQuery(semanticQuery, user, dataSetSchema);
|
||||
}
|
||||
|
||||
private List<String> getFieldsFromSql(SemanticParseInfo parseInfo) {
|
||||
SqlInfo sqlInfo = parseInfo.getSqlInfo();
|
||||
if (Objects.isNull(sqlInfo) || StringUtils.isNotBlank(sqlInfo.getCorrectedS2SQL())) {
|
||||
return new ArrayList<>();
|
||||
}
|
||||
return SqlSelectHelper.getAllSelectFields(sqlInfo.getCorrectedS2SQL());
|
||||
}
|
||||
|
||||
private void handleLLMQueryMode(ChatQueryDataReq chatQueryDataReq,
|
||||
SemanticQuery semanticQuery,
|
||||
User user) throws Exception {
|
||||
SemanticParseInfo parseInfo = semanticQuery.getParseInfo();
|
||||
List<String> fields = getFieldsFromSql(parseInfo);
|
||||
if (checkMetricReplace(fields, chatQueryDataReq.getMetrics())) {
|
||||
log.info("llm begin replace metrics!");
|
||||
SchemaElement metricToReplace = chatQueryDataReq.getMetrics().iterator().next();
|
||||
replaceMetrics(parseInfo, metricToReplace);
|
||||
} else {
|
||||
log.info("llm begin revise filters!");
|
||||
String correctorSql = reviseCorrectS2SQL(chatQueryDataReq, parseInfo);
|
||||
parseInfo.getSqlInfo().setCorrectedS2SQL(correctorSql);
|
||||
semanticQuery.setParseInfo(parseInfo);
|
||||
SemanticQueryReq semanticQueryReq = semanticQuery.buildSemanticQueryReq();
|
||||
SemanticTranslateResp explain = semanticLayerService.translate(semanticQueryReq, user);
|
||||
parseInfo.getSqlInfo().setQuerySQL(explain.getQuerySQL());
|
||||
}
|
||||
}
|
||||
|
||||
private void handleRuleQueryMode(SemanticQuery semanticQuery,
|
||||
DataSetSchema dataSetSchema,
|
||||
User user) {
|
||||
log.info("rule begin replace metrics and revise filters!");
|
||||
validFilter(semanticQuery.getParseInfo().getDimensionFilters());
|
||||
validFilter(semanticQuery.getParseInfo().getMetricFilters());
|
||||
semanticQuery.initS2Sql(dataSetSchema, user);
|
||||
}
|
||||
|
||||
private QueryResult executeQuery(SemanticQuery semanticQuery,
|
||||
User user,
|
||||
DataSetSchema dataSetSchema) throws Exception {
|
||||
SemanticQueryReq semanticQueryReq = semanticQuery.buildSemanticQueryReq();
|
||||
SemanticParseInfo parseInfo = semanticQuery.getParseInfo();
|
||||
QueryResult queryResult = doExecution(semanticQueryReq, parseInfo.getQueryMode(), user);
|
||||
queryResult.setChatContext(semanticQuery.getParseInfo());
|
||||
SemanticLayerService semanticService = ContextUtils.getBean(SemanticLayerService.class);
|
||||
EntityInfo entityInfo = semanticService.getEntityInfo(parseInfo, dataSetSchema, user);
|
||||
queryResult.setEntityInfo(entityInfo);
|
||||
return queryResult;
|
||||
}
|
||||
|
||||
private boolean checkMetricReplace(List<String> oriFields, Set<SchemaElement> metrics) {
|
||||
if (CollectionUtils.isEmpty(oriFields) || CollectionUtils.isEmpty(metrics)) {
|
||||
return false;
|
||||
}
|
||||
List<String> metricNames = metrics.stream().map(SchemaElement::getName).collect(Collectors.toList());
|
||||
return !oriFields.containsAll(metricNames);
|
||||
}
|
||||
|
||||
private String reviseCorrectS2SQL(ChatQueryDataReq queryData, SemanticParseInfo parseInfo) {
|
||||
String correctorSql = parseInfo.getSqlInfo().getCorrectedS2SQL();
|
||||
log.info("correctorSql before replacing:{}", correctorSql);
|
||||
// get where filter and having filter
|
||||
List<FieldExpression> whereExpressionList = SqlSelectHelper.getWhereExpressions(correctorSql);
|
||||
|
||||
// replace where filter
|
||||
List<Expression> addWhereConditions = new ArrayList<>();
|
||||
Set<String> removeWhereFieldNames = updateFilters(whereExpressionList, queryData.getDimensionFilters(),
|
||||
parseInfo.getDimensionFilters(), addWhereConditions);
|
||||
|
||||
Map<String, Map<String, String>> filedNameToValueMap = new HashMap<>();
|
||||
Set<String> removeDataFieldNames = updateDateInfo(queryData, parseInfo, filedNameToValueMap,
|
||||
whereExpressionList, addWhereConditions);
|
||||
removeWhereFieldNames.addAll(removeDataFieldNames);
|
||||
|
||||
correctorSql = SqlReplaceHelper.replaceValue(correctorSql, filedNameToValueMap);
|
||||
correctorSql = SqlRemoveHelper.removeWhereCondition(correctorSql, removeWhereFieldNames);
|
||||
|
||||
// replace having filter
|
||||
List<FieldExpression> havingExpressionList = SqlSelectHelper.getHavingExpressions(correctorSql);
|
||||
List<Expression> addHavingConditions = new ArrayList<>();
|
||||
Set<String> removeHavingFieldNames = updateFilters(havingExpressionList,
|
||||
queryData.getDimensionFilters(), parseInfo.getDimensionFilters(), addHavingConditions);
|
||||
correctorSql = SqlReplaceHelper.replaceHavingValue(correctorSql, new HashMap<>());
|
||||
correctorSql = SqlRemoveHelper.removeHavingCondition(correctorSql, removeHavingFieldNames);
|
||||
|
||||
correctorSql = SqlAddHelper.addWhere(correctorSql, addWhereConditions);
|
||||
correctorSql = SqlAddHelper.addHaving(correctorSql, addHavingConditions);
|
||||
log.info("correctorSql after replacing:{}", correctorSql);
|
||||
return correctorSql;
|
||||
}
|
||||
|
||||
private void replaceMetrics(SemanticParseInfo parseInfo, SchemaElement metric) {
|
||||
List<String> oriMetrics = parseInfo.getMetrics().stream()
|
||||
.map(SchemaElement::getName).collect(Collectors.toList());
|
||||
String correctorSql = parseInfo.getSqlInfo().getCorrectedS2SQL();
|
||||
log.info("before replaceMetrics:{}", correctorSql);
|
||||
log.info("filteredMetrics:{},metrics:{}", oriMetrics, metric);
|
||||
Map<String, Pair<String, String>> fieldMap = new HashMap<>();
|
||||
if (!CollectionUtils.isEmpty(oriMetrics) && !oriMetrics.contains(metric.getName())) {
|
||||
fieldMap.put(oriMetrics.get(0), Pair.of(metric.getName(), metric.getDefaultAgg()));
|
||||
correctorSql = SqlReplaceHelper.replaceAggFields(correctorSql, fieldMap);
|
||||
}
|
||||
log.info("after replaceMetrics:{}", correctorSql);
|
||||
parseInfo.getSqlInfo().setCorrectedS2SQL(correctorSql);
|
||||
}
|
||||
|
||||
private QueryResult doExecution(SemanticQueryReq semanticQueryReq, String queryMode, User user) throws Exception {
|
||||
SemanticQueryResp queryResp = semanticLayerService.queryByReq(semanticQueryReq, user);
|
||||
QueryResult queryResult = new QueryResult();
|
||||
|
||||
if (queryResp != null) {
|
||||
queryResult.setQueryAuthorization(queryResp.getQueryAuthorization());
|
||||
queryResult.setQuerySql(queryResp.getSql());
|
||||
queryResult.setQueryResults(queryResp.getResultList());
|
||||
queryResult.setQueryColumns(queryResp.getColumns());
|
||||
} else {
|
||||
queryResult.setQueryResults(new ArrayList<>());
|
||||
queryResult.setQueryColumns(new ArrayList<>());
|
||||
}
|
||||
|
||||
queryResult.setQueryMode(queryMode);
|
||||
queryResult.setQueryState(QueryState.SUCCESS);
|
||||
return queryResult;
|
||||
}
|
||||
|
||||
private Set<String> updateDateInfo(ChatQueryDataReq queryData, SemanticParseInfo parseInfo,
|
||||
Map<String, Map<String, String>> filedNameToValueMap,
|
||||
List<FieldExpression> fieldExpressionList,
|
||||
List<Expression> addConditions) {
|
||||
Set<String> removeFieldNames = new HashSet<>();
|
||||
if (Objects.isNull(queryData.getDateInfo())) {
|
||||
return removeFieldNames;
|
||||
}
|
||||
if (queryData.getDateInfo().getUnit() > 1) {
|
||||
queryData.getDateInfo().setStartDate(DateUtils.getBeforeDate(queryData.getDateInfo().getUnit() + 1));
|
||||
queryData.getDateInfo().setEndDate(DateUtils.getBeforeDate(1));
|
||||
}
|
||||
// startDate equals to endDate
|
||||
for (FieldExpression fieldExpression : fieldExpressionList) {
|
||||
if (TimeDimensionEnum.DAY.getChName().equals(fieldExpression.getFieldName())) {
|
||||
// first remove,then add
|
||||
removeFieldNames.add(TimeDimensionEnum.DAY.getChName());
|
||||
GreaterThanEquals greaterThanEquals = new GreaterThanEquals();
|
||||
addTimeFilters(queryData.getDateInfo().getStartDate(), greaterThanEquals, addConditions);
|
||||
MinorThanEquals minorThanEquals = new MinorThanEquals();
|
||||
addTimeFilters(queryData.getDateInfo().getEndDate(), minorThanEquals, addConditions);
|
||||
break;
|
||||
}
|
||||
}
|
||||
for (FieldExpression fieldExpression : fieldExpressionList) {
|
||||
for (QueryFilter queryFilter : queryData.getDimensionFilters()) {
|
||||
if (queryFilter.getOperator().equals(FilterOperatorEnum.LIKE)
|
||||
&& FilterOperatorEnum.LIKE.getValue().equalsIgnoreCase(
|
||||
fieldExpression.getOperator())) {
|
||||
Map<String, String> replaceMap = new HashMap<>();
|
||||
String preValue = fieldExpression.getFieldValue().toString();
|
||||
String curValue = queryFilter.getValue().toString();
|
||||
if (preValue.startsWith("%")) {
|
||||
curValue = "%" + curValue;
|
||||
}
|
||||
if (preValue.endsWith("%")) {
|
||||
curValue = curValue + "%";
|
||||
}
|
||||
replaceMap.put(preValue, curValue);
|
||||
filedNameToValueMap.put(fieldExpression.getFieldName(), replaceMap);
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
parseInfo.setDateInfo(queryData.getDateInfo());
|
||||
return removeFieldNames;
|
||||
}
|
||||
|
||||
private <T extends ComparisonOperator> void addTimeFilters(String date,
|
||||
T comparisonExpression,
|
||||
List<Expression> addConditions) {
|
||||
Column column = new Column(TimeDimensionEnum.DAY.getChName());
|
||||
StringValue stringValue = new StringValue(date);
|
||||
comparisonExpression.setLeftExpression(column);
|
||||
comparisonExpression.setRightExpression(stringValue);
|
||||
addConditions.add(comparisonExpression);
|
||||
}
|
||||
|
||||
private Set<String> updateFilters(List<FieldExpression> fieldExpressionList,
|
||||
Set<QueryFilter> metricFilters,
|
||||
Set<QueryFilter> contextMetricFilters,
|
||||
List<Expression> addConditions) {
|
||||
Set<String> removeFieldNames = new HashSet<>();
|
||||
if (CollectionUtils.isEmpty(metricFilters)) {
|
||||
return removeFieldNames;
|
||||
}
|
||||
|
||||
for (QueryFilter dslQueryFilter : metricFilters) {
|
||||
for (FieldExpression fieldExpression : fieldExpressionList) {
|
||||
if (fieldExpression.getFieldName() != null
|
||||
&& fieldExpression.getFieldName().contains(dslQueryFilter.getName())) {
|
||||
removeFieldNames.add(dslQueryFilter.getName());
|
||||
handleFilter(dslQueryFilter, contextMetricFilters, addConditions);
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
return removeFieldNames;
|
||||
}
|
||||
|
||||
private void handleFilter(QueryFilter dslQueryFilter,
|
||||
Set<QueryFilter> contextMetricFilters,
|
||||
List<Expression> addConditions) {
|
||||
FilterOperatorEnum operator = dslQueryFilter.getOperator();
|
||||
|
||||
if (operator == FilterOperatorEnum.IN) {
|
||||
addWhereInFilters(dslQueryFilter, new InExpression(), contextMetricFilters, addConditions);
|
||||
} else {
|
||||
ComparisonOperator expression = FilterOperatorEnum.createExpression(operator);
|
||||
if (Objects.nonNull(expression)) {
|
||||
addWhereFilters(dslQueryFilter, expression, contextMetricFilters, addConditions);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// add in condition to sql where condition
|
||||
private void addWhereInFilters(QueryFilter dslQueryFilter,
|
||||
InExpression inExpression,
|
||||
Set<QueryFilter> contextMetricFilters,
|
||||
List<Expression> addConditions) {
|
||||
Column column = new Column(dslQueryFilter.getName());
|
||||
ParenthesedExpressionList parenthesedExpressionList = new ParenthesedExpressionList<>();
|
||||
List<String> valueList = JsonUtil.toList(
|
||||
JsonUtil.toString(dslQueryFilter.getValue()), String.class);
|
||||
if (CollectionUtils.isEmpty(valueList)) {
|
||||
return;
|
||||
}
|
||||
valueList.stream().forEach(o -> {
|
||||
StringValue stringValue = new StringValue(o);
|
||||
parenthesedExpressionList.add(stringValue);
|
||||
});
|
||||
inExpression.setLeftExpression(column);
|
||||
inExpression.setRightExpression(parenthesedExpressionList);
|
||||
addConditions.add(inExpression);
|
||||
contextMetricFilters.stream().forEach(o -> {
|
||||
if (o.getName().equals(dslQueryFilter.getName())) {
|
||||
o.setValue(dslQueryFilter.getValue());
|
||||
o.setOperator(dslQueryFilter.getOperator());
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
// add where filter
|
||||
private void addWhereFilters(QueryFilter dslQueryFilter,
|
||||
ComparisonOperator comparisonExpression,
|
||||
Set<QueryFilter> contextMetricFilters,
|
||||
List<Expression> addConditions) {
|
||||
String columnName = dslQueryFilter.getName();
|
||||
if (StringUtils.isNotBlank(dslQueryFilter.getFunction())) {
|
||||
columnName = dslQueryFilter.getFunction() + "(" + dslQueryFilter.getName() + ")";
|
||||
}
|
||||
if (Objects.isNull(dslQueryFilter.getValue())) {
|
||||
return;
|
||||
}
|
||||
Column column = new Column(columnName);
|
||||
comparisonExpression.setLeftExpression(column);
|
||||
if (StringUtils.isNumeric(dslQueryFilter.getValue().toString())) {
|
||||
LongValue longValue = new LongValue(Long.parseLong(dslQueryFilter.getValue().toString()));
|
||||
comparisonExpression.setRightExpression(longValue);
|
||||
} else {
|
||||
StringValue stringValue = new StringValue(dslQueryFilter.getValue().toString());
|
||||
comparisonExpression.setRightExpression(stringValue);
|
||||
}
|
||||
addConditions.add(comparisonExpression);
|
||||
contextMetricFilters.stream().forEach(o -> {
|
||||
if (o.getName().equals(dslQueryFilter.getName())) {
|
||||
o.setValue(dslQueryFilter.getValue());
|
||||
o.setOperator(dslQueryFilter.getOperator());
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
private SemanticParseInfo mergeParseInfo(SemanticParseInfo parseInfo,
|
||||
ChatQueryDataReq queryData) {
|
||||
if (LLMSqlQuery.QUERY_MODE.equals(parseInfo.getQueryMode())) {
|
||||
return parseInfo;
|
||||
}
|
||||
if (!CollectionUtils.isEmpty(queryData.getDimensions())) {
|
||||
parseInfo.setDimensions(queryData.getDimensions());
|
||||
}
|
||||
if (!CollectionUtils.isEmpty(queryData.getMetrics())) {
|
||||
parseInfo.setMetrics(queryData.getMetrics());
|
||||
}
|
||||
if (!CollectionUtils.isEmpty(queryData.getDimensionFilters())) {
|
||||
parseInfo.setDimensionFilters(queryData.getDimensionFilters());
|
||||
}
|
||||
if (!CollectionUtils.isEmpty(queryData.getMetricFilters())) {
|
||||
parseInfo.setMetricFilters(queryData.getMetricFilters());
|
||||
}
|
||||
if (Objects.nonNull(queryData.getDateInfo())) {
|
||||
parseInfo.setDateInfo(queryData.getDateInfo());
|
||||
}
|
||||
return parseInfo;
|
||||
}
|
||||
|
||||
private void validFilter(Set<QueryFilter> filters) {
|
||||
Iterator<QueryFilter> iterator = filters.iterator();
|
||||
while (iterator.hasNext()) {
|
||||
QueryFilter queryFilter = iterator.next();
|
||||
Object queryFilterValue = queryFilter.getValue();
|
||||
if (Objects.isNull(queryFilterValue)) {
|
||||
iterator.remove();
|
||||
continue;
|
||||
}
|
||||
List<String> collection = JsonUtil.toList(JsonUtil.toString(queryFilterValue), String.class);
|
||||
if (FilterOperatorEnum.IN.equals(queryFilter.getOperator())
|
||||
&& CollectionUtils.isEmpty(collection)) {
|
||||
iterator.remove();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@Override
|
||||
public Object queryDimensionValue(DimensionValueReq dimensionValueReq, User user) {
|
||||
Integer agentId = dimensionValueReq.getAgentId();
|
||||
Agent agent = agentService.getAgent(agentId);
|
||||
dimensionValueReq.setDataSetIds(agent.getDataSetIds());
|
||||
return semanticLayerService.queryDimensionValue(dimensionValueReq, user);
|
||||
}
|
||||
|
||||
public void saveQueryResult(ChatExecuteReq chatExecuteReq, QueryResult queryResult) {
|
||||
//The history record only retains the query result of the first parse
|
||||
if (chatExecuteReq.getParseId() > 1) {
|
||||
return;
|
||||
}
|
||||
chatManageService.saveQueryResult(chatExecuteReq, queryResult);
|
||||
}
|
||||
|
||||
}
|
||||
@@ -1,186 +0,0 @@
|
||||
package com.tencent.supersonic.chat.server.service.impl;
|
||||
|
||||
import com.google.common.collect.Lists;
|
||||
import com.tencent.supersonic.auth.api.authentication.pojo.User;
|
||||
import com.tencent.supersonic.chat.api.pojo.request.ChatExecuteReq;
|
||||
import com.tencent.supersonic.chat.api.pojo.request.ChatParseReq;
|
||||
import com.tencent.supersonic.chat.api.pojo.request.ChatQueryDataReq;
|
||||
import com.tencent.supersonic.chat.server.agent.Agent;
|
||||
import com.tencent.supersonic.chat.server.executor.ChatExecutor;
|
||||
import com.tencent.supersonic.chat.server.parser.ChatParser;
|
||||
import com.tencent.supersonic.chat.server.pojo.ChatExecuteContext;
|
||||
import com.tencent.supersonic.chat.server.pojo.ChatParseContext;
|
||||
import com.tencent.supersonic.chat.server.processor.execute.ExecuteResultProcessor;
|
||||
import com.tencent.supersonic.chat.server.processor.parse.ParseResultProcessor;
|
||||
import com.tencent.supersonic.chat.server.service.AgentService;
|
||||
import com.tencent.supersonic.chat.server.service.ChatManageService;
|
||||
import com.tencent.supersonic.chat.server.service.ChatService;
|
||||
import com.tencent.supersonic.chat.server.util.ComponentFactory;
|
||||
import com.tencent.supersonic.chat.server.util.QueryReqConverter;
|
||||
import com.tencent.supersonic.common.util.BeanMapper;
|
||||
import com.tencent.supersonic.headless.api.pojo.SemanticParseInfo;
|
||||
import com.tencent.supersonic.headless.api.pojo.request.DimensionValueReq;
|
||||
import com.tencent.supersonic.headless.api.pojo.request.QueryDataReq;
|
||||
import com.tencent.supersonic.headless.api.pojo.request.QueryReq;
|
||||
import com.tencent.supersonic.headless.api.pojo.response.MapResp;
|
||||
import com.tencent.supersonic.headless.api.pojo.response.ParseResp;
|
||||
import com.tencent.supersonic.headless.api.pojo.response.QueryResult;
|
||||
import com.tencent.supersonic.headless.api.pojo.response.SearchResult;
|
||||
import com.tencent.supersonic.headless.server.facade.service.ChatQueryService;
|
||||
import com.tencent.supersonic.headless.server.facade.service.RetrieveService;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.springframework.beans.factory.annotation.Autowired;
|
||||
import org.springframework.stereotype.Service;
|
||||
import org.springframework.util.CollectionUtils;
|
||||
|
||||
import java.util.List;
|
||||
|
||||
|
||||
@Slf4j
|
||||
@Service
|
||||
public class ChatServiceImpl implements ChatService {
|
||||
|
||||
@Autowired
|
||||
private ChatManageService chatManageService;
|
||||
@Autowired
|
||||
private ChatQueryService chatQueryService;
|
||||
@Autowired
|
||||
private RetrieveService retrieveService;
|
||||
@Autowired
|
||||
private AgentService agentService;
|
||||
private List<ChatParser> chatParsers = ComponentFactory.getChatParsers();
|
||||
private List<ChatExecutor> chatExecutors = ComponentFactory.getChatExecutors();
|
||||
private List<ParseResultProcessor> parseResultProcessors = ComponentFactory.getParseProcessors();
|
||||
private List<ExecuteResultProcessor> executeResultProcessors = ComponentFactory.getExecuteProcessors();
|
||||
|
||||
@Override
|
||||
public List<SearchResult> search(ChatParseReq chatParseReq) {
|
||||
ChatParseContext chatParseContext = buildParseContext(chatParseReq);
|
||||
Agent agent = chatParseContext.getAgent();
|
||||
if (!agent.enableSearch()) {
|
||||
return Lists.newArrayList();
|
||||
}
|
||||
QueryReq queryReq = QueryReqConverter.buildText2SqlQueryReq(chatParseContext);
|
||||
return retrieveService.retrieve(queryReq);
|
||||
}
|
||||
|
||||
@Override
|
||||
public ParseResp performParsing(ChatParseReq chatParseReq) {
|
||||
ParseResp parseResp = new ParseResp(chatParseReq.getChatId(), chatParseReq.getQueryText());
|
||||
chatManageService.createChatQuery(chatParseReq, parseResp);
|
||||
ChatParseContext chatParseContext = buildParseContext(chatParseReq);
|
||||
supplyMapInfo(chatParseContext);
|
||||
for (ChatParser chatParser : chatParsers) {
|
||||
chatParser.parse(chatParseContext, parseResp);
|
||||
}
|
||||
for (ParseResultProcessor processor : parseResultProcessors) {
|
||||
processor.process(chatParseContext, parseResp);
|
||||
}
|
||||
chatParseReq.setQueryText(chatParseContext.getQueryText());
|
||||
parseResp.setQueryText(chatParseContext.getQueryText());
|
||||
chatManageService.batchAddParse(chatParseReq, parseResp);
|
||||
chatManageService.updateParseCostTime(parseResp);
|
||||
return parseResp;
|
||||
}
|
||||
|
||||
@Override
|
||||
public QueryResult performExecution(ChatExecuteReq chatExecuteReq) {
|
||||
QueryResult queryResult = new QueryResult();
|
||||
ChatExecuteContext chatExecuteContext = buildExecuteContext(chatExecuteReq);
|
||||
for (ChatExecutor chatExecutor : chatExecutors) {
|
||||
queryResult = chatExecutor.execute(chatExecuteContext);
|
||||
if (queryResult != null) {
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
if (queryResult != null) {
|
||||
for (ExecuteResultProcessor processor : executeResultProcessors) {
|
||||
processor.process(chatExecuteContext, queryResult);
|
||||
}
|
||||
saveQueryResult(chatExecuteReq, queryResult);
|
||||
}
|
||||
|
||||
return queryResult;
|
||||
}
|
||||
|
||||
@Override
|
||||
public QueryResult parseAndExecute(int chatId, int agentId, String queryText) {
|
||||
ChatParseReq chatParseReq = new ChatParseReq();
|
||||
chatParseReq.setQueryText(queryText);
|
||||
chatParseReq.setChatId(chatId);
|
||||
chatParseReq.setAgentId(agentId);
|
||||
chatParseReq.setUser(User.getFakeUser());
|
||||
ParseResp parseResp = performParsing(chatParseReq);
|
||||
if (CollectionUtils.isEmpty(parseResp.getSelectedParses())) {
|
||||
log.debug("chatId:{}, agentId:{}, queryText:{}, parseResp.getSelectedParses() is empty",
|
||||
chatId, agentId, queryText);
|
||||
return null;
|
||||
}
|
||||
ChatExecuteReq executeReq = new ChatExecuteReq();
|
||||
executeReq.setQueryId(parseResp.getQueryId());
|
||||
executeReq.setParseId(parseResp.getSelectedParses().get(0).getId());
|
||||
executeReq.setQueryText(queryText);
|
||||
executeReq.setChatId(parseResp.getChatId());
|
||||
executeReq.setUser(User.getFakeUser());
|
||||
executeReq.setAgentId(agentId);
|
||||
executeReq.setSaveAnswer(true);
|
||||
return performExecution(executeReq);
|
||||
}
|
||||
|
||||
private ChatParseContext buildParseContext(ChatParseReq chatParseReq) {
|
||||
ChatParseContext chatParseContext = new ChatParseContext();
|
||||
BeanMapper.mapper(chatParseReq, chatParseContext);
|
||||
Agent agent = agentService.getAgent(chatParseReq.getAgentId());
|
||||
chatParseContext.setAgent(agent);
|
||||
return chatParseContext;
|
||||
}
|
||||
|
||||
private void supplyMapInfo(ChatParseContext chatParseContext) {
|
||||
QueryReq queryReq = QueryReqConverter.buildText2SqlQueryReq(chatParseContext);
|
||||
MapResp mapResp = chatQueryService.performMapping(queryReq);
|
||||
chatParseContext.setMapInfo(mapResp.getMapInfo());
|
||||
}
|
||||
|
||||
private ChatExecuteContext buildExecuteContext(ChatExecuteReq chatExecuteReq) {
|
||||
ChatExecuteContext chatExecuteContext = new ChatExecuteContext();
|
||||
BeanMapper.mapper(chatExecuteReq, chatExecuteContext);
|
||||
SemanticParseInfo parseInfo = chatManageService.getParseInfo(
|
||||
chatExecuteReq.getQueryId(), chatExecuteReq.getParseId());
|
||||
chatExecuteContext.setParseInfo(parseInfo);
|
||||
return chatExecuteContext;
|
||||
}
|
||||
|
||||
@Override
|
||||
public Object queryData(ChatQueryDataReq chatQueryDataReq, User user) throws Exception {
|
||||
Integer parseId = chatQueryDataReq.getParseId();
|
||||
SemanticParseInfo parseInfo = chatManageService.getParseInfo(
|
||||
chatQueryDataReq.getQueryId(), parseId);
|
||||
QueryDataReq queryData = new QueryDataReq();
|
||||
BeanMapper.mapper(chatQueryDataReq, queryData);
|
||||
queryData.setParseInfo(parseInfo);
|
||||
return chatQueryService.executeDirectQuery(queryData, user);
|
||||
}
|
||||
|
||||
@Override
|
||||
public SemanticParseInfo queryContext(Integer chatId) {
|
||||
return chatQueryService.queryContext(chatId);
|
||||
}
|
||||
|
||||
@Override
|
||||
public Object queryDimensionValue(DimensionValueReq dimensionValueReq, User user) throws Exception {
|
||||
Integer agentId = dimensionValueReq.getAgentId();
|
||||
Agent agent = agentService.getAgent(agentId);
|
||||
dimensionValueReq.setDataSetIds(agent.getDataSetIds());
|
||||
return chatQueryService.queryDimensionValue(dimensionValueReq, user);
|
||||
}
|
||||
|
||||
public void saveQueryResult(ChatExecuteReq chatExecuteReq, QueryResult queryResult) {
|
||||
//The history record only retains the query result of the first parse
|
||||
if (chatExecuteReq.getParseId() > 1) {
|
||||
return;
|
||||
}
|
||||
chatManageService.saveQueryResult(chatExecuteReq, queryResult);
|
||||
}
|
||||
|
||||
}
|
||||
@@ -30,9 +30,7 @@ import com.tencent.supersonic.headless.api.pojo.SchemaElement;
|
||||
import com.tencent.supersonic.headless.api.pojo.SchemaItem;
|
||||
import com.tencent.supersonic.headless.api.pojo.response.DimensionResp;
|
||||
import com.tencent.supersonic.headless.api.pojo.response.MetricResp;
|
||||
import com.tencent.supersonic.headless.server.pojo.MetaFilter;
|
||||
import com.tencent.supersonic.headless.server.web.service.DimensionService;
|
||||
import com.tencent.supersonic.headless.server.web.service.MetricService;
|
||||
import com.tencent.supersonic.headless.api.pojo.MetaFilter;
|
||||
import com.tencent.supersonic.headless.server.facade.service.SemanticLayerService;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.springframework.beans.BeanUtils;
|
||||
@@ -53,18 +51,13 @@ public class ConfigServiceImpl implements ConfigService {
|
||||
|
||||
private final ChatConfigRepository chatConfigRepository;
|
||||
private final ChatConfigHelper chatConfigHelper;
|
||||
private final DimensionService dimensionService;
|
||||
private final MetricService metricService;
|
||||
private final SemanticLayerService semanticLayerService;
|
||||
|
||||
|
||||
public ConfigServiceImpl(ChatConfigRepository chatConfigRepository,
|
||||
ChatConfigHelper chatConfigHelper, DimensionService dimensionService,
|
||||
MetricService metricService, SemanticLayerService semanticLayerService) {
|
||||
ChatConfigHelper chatConfigHelper, SemanticLayerService semanticLayerService) {
|
||||
this.chatConfigRepository = chatConfigRepository;
|
||||
this.chatConfigHelper = chatConfigHelper;
|
||||
this.dimensionService = dimensionService;
|
||||
this.metricService = metricService;
|
||||
this.semanticLayerService = semanticLayerService;
|
||||
}
|
||||
|
||||
@@ -136,14 +129,14 @@ public class ConfigServiceImpl implements ConfigService {
|
||||
MetaFilter metaFilter = new MetaFilter();
|
||||
metaFilter.setModelIds(Lists.newArrayList(modelId));
|
||||
if (!CollectionUtils.isEmpty(blackDimIdList)) {
|
||||
List<DimensionResp> dimensionRespList = dimensionService.getDimensions(metaFilter);
|
||||
List<DimensionResp> dimensionRespList = semanticLayerService.getDimensions(metaFilter);
|
||||
List<String> blackDimNameList = dimensionRespList.stream().filter(o -> filterDimIdList.contains(o.getId()))
|
||||
.map(SchemaItem::getName).collect(Collectors.toList());
|
||||
itemNameVisibility.setBlackDimNameList(blackDimNameList);
|
||||
}
|
||||
if (!CollectionUtils.isEmpty(blackMetricIdList)) {
|
||||
|
||||
List<MetricResp> metricRespList = metricService.getMetrics(metaFilter);
|
||||
List<MetricResp> metricRespList = semanticLayerService.getMetrics(metaFilter);
|
||||
List<String> blackMetricList = metricRespList.stream().filter(o -> filterMetricIdList.contains(o.getId()))
|
||||
.map(SchemaItem::getName).collect(Collectors.toList());
|
||||
itemNameVisibility.setBlackMetricNameList(blackMetricList);
|
||||
|
||||
@@ -12,7 +12,7 @@ import com.tencent.supersonic.chat.server.persistence.dataobject.ChatMemoryDO;
|
||||
import com.tencent.supersonic.chat.server.persistence.repository.ChatMemoryRepository;
|
||||
import com.tencent.supersonic.chat.server.service.MemoryService;
|
||||
import com.tencent.supersonic.common.config.EmbeddingConfig;
|
||||
import com.tencent.supersonic.common.pojo.SqlExemplar;
|
||||
import com.tencent.supersonic.common.pojo.Text2SQLExemplar;
|
||||
import com.tencent.supersonic.common.service.ExemplarService;
|
||||
import com.tencent.supersonic.common.util.BeanMapper;
|
||||
import java.util.List;
|
||||
@@ -96,19 +96,25 @@ public class MemoryServiceImpl implements MemoryService {
|
||||
return chatMemoryRepository.getMemories(queryWrapper);
|
||||
}
|
||||
|
||||
private void enableMemory(ChatMemoryDO memory) {
|
||||
@Override
|
||||
public void enableMemory(ChatMemoryDO memory) {
|
||||
memory.setStatus(MemoryStatus.ENABLED);
|
||||
exemplarService.storeExemplar(embeddingConfig.getMemoryCollectionName(memory.getAgentId()),
|
||||
SqlExemplar.builder()
|
||||
Text2SQLExemplar.builder()
|
||||
.question(memory.getQuestion())
|
||||
.sideInfo(memory.getSideInfo())
|
||||
.dbSchema(memory.getDbSchema())
|
||||
.sql(memory.getS2sql())
|
||||
.build());
|
||||
}
|
||||
|
||||
private void disableMemory(ChatMemoryDO memory) {
|
||||
@Override
|
||||
public void disableMemory(ChatMemoryDO memory) {
|
||||
memory.setStatus(MemoryStatus.DISABLED);
|
||||
exemplarService.removeExemplar(embeddingConfig.getMemoryCollectionName(memory.getAgentId()),
|
||||
SqlExemplar.builder()
|
||||
Text2SQLExemplar.builder()
|
||||
.question(memory.getQuestion())
|
||||
.sideInfo(memory.getSideInfo())
|
||||
.dbSchema(memory.getDbSchema())
|
||||
.sql(memory.getS2sql())
|
||||
.build());
|
||||
|
||||
@@ -1,8 +1,8 @@
|
||||
package com.tencent.supersonic.chat.server.service.impl;
|
||||
|
||||
import com.tencent.supersonic.headless.server.persistence.dataobject.StatisticsDO;
|
||||
import com.tencent.supersonic.chat.server.persistence.dataobject.StatisticsDO;
|
||||
import com.tencent.supersonic.chat.server.service.StatisticsService;
|
||||
import com.tencent.supersonic.headless.server.persistence.mapper.StatisticsMapper;
|
||||
import com.tencent.supersonic.chat.server.persistence.mapper.StatisticsMapper;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.springframework.beans.factory.annotation.Autowired;
|
||||
import org.springframework.scheduling.annotation.Async;
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
package com.tencent.supersonic.chat.server.util;
|
||||
|
||||
import com.tencent.supersonic.chat.server.executor.ChatExecutor;
|
||||
import com.tencent.supersonic.chat.server.parser.ChatParser;
|
||||
import com.tencent.supersonic.chat.server.executor.ChatQueryExecutor;
|
||||
import com.tencent.supersonic.chat.server.parser.ChatQueryParser;
|
||||
import com.tencent.supersonic.chat.server.plugin.recognize.PluginRecognizer;
|
||||
import com.tencent.supersonic.chat.server.processor.execute.ExecuteResultProcessor;
|
||||
import com.tencent.supersonic.chat.server.processor.parse.ParseResultProcessor;
|
||||
@@ -16,8 +16,8 @@ import java.util.List;
|
||||
public class ComponentFactory {
|
||||
private static List<ParseResultProcessor> parseProcessors = new ArrayList<>();
|
||||
private static List<ExecuteResultProcessor> executeProcessors = new ArrayList<>();
|
||||
private static List<ChatParser> chatParsers = new ArrayList<>();
|
||||
private static List<ChatExecutor> chatExecutors = new ArrayList<>();
|
||||
private static List<ChatQueryParser> chatQueryParsers = new ArrayList<>();
|
||||
private static List<ChatQueryExecutor> chatQueryExecutors = new ArrayList<>();
|
||||
private static List<PluginRecognizer> pluginRecognizers = new ArrayList<>();
|
||||
|
||||
public static List<ParseResultProcessor> getParseProcessors() {
|
||||
@@ -30,14 +30,14 @@ public class ComponentFactory {
|
||||
? init(ExecuteResultProcessor.class, executeProcessors) : executeProcessors;
|
||||
}
|
||||
|
||||
public static List<ChatParser> getChatParsers() {
|
||||
return CollectionUtils.isEmpty(chatParsers)
|
||||
? init(ChatParser.class, chatParsers) : chatParsers;
|
||||
public static List<ChatQueryParser> getChatParsers() {
|
||||
return CollectionUtils.isEmpty(chatQueryParsers)
|
||||
? init(ChatQueryParser.class, chatQueryParsers) : chatQueryParsers;
|
||||
}
|
||||
|
||||
public static List<ChatExecutor> getChatExecutors() {
|
||||
return CollectionUtils.isEmpty(chatExecutors)
|
||||
? init(ChatExecutor.class, chatExecutors) : chatExecutors;
|
||||
public static List<ChatQueryExecutor> getChatExecutors() {
|
||||
return CollectionUtils.isEmpty(chatQueryExecutors)
|
||||
? init(ChatQueryExecutor.class, chatQueryExecutors) : chatQueryExecutors;
|
||||
}
|
||||
|
||||
public static List<PluginRecognizer> getPluginRecognizers() {
|
||||
|
||||
@@ -1,20 +1,20 @@
|
||||
package com.tencent.supersonic.chat.server.util;
|
||||
|
||||
import com.tencent.supersonic.common.config.LLMConfig;
|
||||
import com.tencent.supersonic.common.pojo.ChatModelConfig;
|
||||
import com.tencent.supersonic.common.pojo.exception.InvalidArgumentException;
|
||||
import dev.langchain4j.model.chat.ChatLanguageModel;
|
||||
import dev.langchain4j.model.provider.ChatLanguageModelProvider;
|
||||
import dev.langchain4j.provider.ModelProvider;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.apache.commons.lang3.StringUtils;
|
||||
|
||||
@Slf4j
|
||||
public class LLMConnHelper {
|
||||
public static boolean testConnection(LLMConfig llmConfig) {
|
||||
public static boolean testConnection(ChatModelConfig modelConfig) {
|
||||
try {
|
||||
if (llmConfig == null || StringUtils.isBlank(llmConfig.getBaseUrl())) {
|
||||
if (modelConfig == null || StringUtils.isBlank(modelConfig.getBaseUrl())) {
|
||||
return false;
|
||||
}
|
||||
ChatLanguageModel chatLanguageModel = ChatLanguageModelProvider.provide(llmConfig);
|
||||
ChatLanguageModel chatLanguageModel = ModelProvider.getChatModel(modelConfig);
|
||||
String response = chatLanguageModel.generate("Hi there");
|
||||
return StringUtils.isNotEmpty(response) ? true : false;
|
||||
} catch (Exception e) {
|
||||
|
||||
@@ -1,37 +1,53 @@
|
||||
package com.tencent.supersonic.chat.server.util;
|
||||
|
||||
import com.tencent.supersonic.chat.server.agent.Agent;
|
||||
import com.tencent.supersonic.chat.server.pojo.ChatParseContext;
|
||||
import com.tencent.supersonic.chat.server.pojo.ParseContext;
|
||||
import com.tencent.supersonic.common.pojo.enums.Text2SQLType;
|
||||
import com.tencent.supersonic.common.util.BeanMapper;
|
||||
import com.tencent.supersonic.headless.api.pojo.request.QueryReq;
|
||||
import com.tencent.supersonic.headless.api.pojo.request.QueryNLReq;
|
||||
import com.tencent.supersonic.chat.server.pojo.ChatContext;
|
||||
import org.apache.commons.collections.MapUtils;
|
||||
|
||||
import java.util.Objects;
|
||||
|
||||
public class QueryReqConverter {
|
||||
|
||||
public static QueryReq buildText2SqlQueryReq(ChatParseContext chatParseContext) {
|
||||
QueryReq queryReq = new QueryReq();
|
||||
BeanMapper.mapper(chatParseContext, queryReq);
|
||||
Agent agent = chatParseContext.getAgent();
|
||||
public static QueryNLReq buildText2SqlQueryReq(ParseContext parseContext) {
|
||||
return buildText2SqlQueryReq(parseContext, null);
|
||||
}
|
||||
|
||||
public static QueryNLReq buildText2SqlQueryReq(ParseContext parseContext, ChatContext chatCtx) {
|
||||
QueryNLReq queryNLReq = new QueryNLReq();
|
||||
BeanMapper.mapper(parseContext, queryNLReq);
|
||||
Agent agent = parseContext.getAgent();
|
||||
if (agent == null) {
|
||||
return queryReq;
|
||||
return queryNLReq;
|
||||
}
|
||||
if (agent.containsLLMParserTool() && agent.containsRuleTool()) {
|
||||
queryReq.setText2SQLType(Text2SQLType.RULE_AND_LLM);
|
||||
} else if (agent.containsLLMParserTool()) {
|
||||
queryReq.setText2SQLType(Text2SQLType.ONLY_LLM);
|
||||
} else if (agent.containsRuleTool()) {
|
||||
queryReq.setText2SQLType(Text2SQLType.ONLY_RULE);
|
||||
|
||||
boolean hasLLMTool = agent.containsLLMParserTool();
|
||||
boolean hasRuleTool = agent.containsRuleTool();
|
||||
boolean hasLLMConfig = Objects.nonNull(agent.getModelConfig());
|
||||
|
||||
if (hasLLMTool && hasLLMConfig) {
|
||||
queryNLReq.setText2SQLType(Text2SQLType.ONLY_LLM);
|
||||
} else if (hasLLMTool && hasRuleTool) {
|
||||
queryNLReq.setText2SQLType(Text2SQLType.RULE_AND_LLM);
|
||||
} else if (hasLLMTool) {
|
||||
queryNLReq.setText2SQLType(Text2SQLType.ONLY_LLM);
|
||||
} else if (hasRuleTool) {
|
||||
queryNLReq.setText2SQLType(Text2SQLType.ONLY_RULE);
|
||||
}
|
||||
queryReq.setDataSetIds(agent.getDataSetIds());
|
||||
if (Objects.nonNull(queryReq.getMapInfo())
|
||||
&& MapUtils.isNotEmpty(queryReq.getMapInfo().getDataSetElementMatches())) {
|
||||
queryReq.setMapInfo(queryReq.getMapInfo());
|
||||
queryNLReq.setDataSetIds(agent.getDataSetIds());
|
||||
if (Objects.nonNull(queryNLReq.getMapInfo())
|
||||
&& MapUtils.isNotEmpty(queryNLReq.getMapInfo().getDataSetElementMatches())) {
|
||||
queryNLReq.setMapInfo(queryNLReq.getMapInfo());
|
||||
}
|
||||
queryReq.setLlmConfig(agent.getLlmConfig());
|
||||
return queryReq;
|
||||
queryNLReq.setModelConfig(agent.getModelConfig());
|
||||
queryNLReq.setPromptConfig(agent.getPromptConfig());
|
||||
if (chatCtx != null) {
|
||||
queryNLReq.setContextParseInfo(chatCtx.getParseInfo());
|
||||
}
|
||||
return queryNLReq;
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
@@ -3,10 +3,10 @@
|
||||
"http://mybatis.org/dtd/mybatis-3-mapper.dtd">
|
||||
|
||||
|
||||
<mapper namespace="com.tencent.supersonic.headless.server.persistence.mapper.ChatContextMapper">
|
||||
<mapper namespace="com.tencent.supersonic.chat.server.persistence.mapper.ChatContextMapper">
|
||||
|
||||
<resultMap id="ChatContextDO"
|
||||
type="com.tencent.supersonic.headless.server.persistence.dataobject.ChatContextDO">
|
||||
type="com.tencent.supersonic.chat.server.persistence.dataobject.ChatContextDO">
|
||||
<id column="chat_id" property="chatId"/>
|
||||
<result column="modified_at" property="modifiedAt"/>
|
||||
<result column="user" property="user"/>
|
||||
@@ -20,7 +20,7 @@
|
||||
from s2_chat_context where chat_id=#{chatId} limit 1
|
||||
</select>
|
||||
|
||||
<insert id="addContext" parameterType="com.tencent.supersonic.headless.server.persistence.dataobject.ChatContextDO" >
|
||||
<insert id="addContext" parameterType="com.tencent.supersonic.chat.server.persistence.dataobject.ChatContextDO" >
|
||||
insert into s2_chat_context (chat_id,user,query_text,semantic_parse) values (#{chatId}, #{user},#{queryText}, #{semanticParse})
|
||||
</insert>
|
||||
<update id="updateContext">
|
||||
@@ -3,9 +3,9 @@
|
||||
"http://mybatis.org/dtd/mybatis-3-mapper.dtd">
|
||||
|
||||
|
||||
<mapper namespace="com.tencent.supersonic.headless.server.persistence.mapper.StatisticsMapper">
|
||||
<mapper namespace="com.tencent.supersonic.chat.server.persistence.mapper.StatisticsMapper">
|
||||
|
||||
<resultMap id="Statistics" type="com.tencent.supersonic.headless.server.persistence.dataobject.StatisticsDO">
|
||||
<resultMap id="Statistics" type="com.tencent.supersonic.chat.server.persistence.dataobject.StatisticsDO">
|
||||
<id column="question_id" property="questionId"/>
|
||||
<result column="chat_id" property="chatId"/>
|
||||
<result column="user_name" property="userName"/>
|
||||
@@ -16,7 +16,7 @@
|
||||
<result column="create_time" property="createTime"/>
|
||||
</resultMap>
|
||||
|
||||
<insert id="batchSaveStatistics" parameterType="com.tencent.supersonic.headless.server.persistence.dataobject.StatisticsDO">
|
||||
<insert id="batchSaveStatistics" parameterType="com.tencent.supersonic.chat.server.persistence.dataobject.StatisticsDO">
|
||||
insert into s2_chat_statistics
|
||||
(question_id,chat_id, user_name, query_text, interface_name,cost,type ,create_time)
|
||||
values
|
||||
@@ -46,11 +46,6 @@
|
||||
</exclusion>
|
||||
</exclusions>
|
||||
</dependency>
|
||||
<dependency>
|
||||
<groupId>commons-lang</groupId>
|
||||
<artifactId>commons-lang</artifactId>
|
||||
<version>${commons.lang.version}</version>
|
||||
</dependency>
|
||||
|
||||
<dependency>
|
||||
<groupId>com.alibaba</groupId>
|
||||
@@ -67,6 +62,11 @@
|
||||
<groupId>org.apache.commons</groupId>
|
||||
<artifactId>commons-lang3</artifactId>
|
||||
</dependency>
|
||||
<dependency>
|
||||
<groupId>org.apache.commons</groupId>
|
||||
<artifactId>commons-compress</artifactId>
|
||||
<version>${commons.compress.version}</version>
|
||||
</dependency>
|
||||
|
||||
<dependency>
|
||||
<groupId>org.junit.jupiter</groupId>
|
||||
|
||||
@@ -0,0 +1,179 @@
|
||||
package com.tencent.supersonic.common.config;
|
||||
|
||||
import com.google.common.collect.ImmutableMap;
|
||||
import com.google.common.collect.Lists;
|
||||
import com.tencent.supersonic.common.pojo.ChatModelConfig;
|
||||
import com.tencent.supersonic.common.pojo.Parameter;
|
||||
import dev.langchain4j.provider.AzureModelFactory;
|
||||
import dev.langchain4j.provider.DashscopeModelFactory;
|
||||
import dev.langchain4j.provider.LocalAiModelFactory;
|
||||
import dev.langchain4j.provider.OllamaModelFactory;
|
||||
import dev.langchain4j.provider.OpenAiModelFactory;
|
||||
import dev.langchain4j.provider.QianfanModelFactory;
|
||||
import dev.langchain4j.provider.ZhipuModelFactory;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.springframework.stereotype.Service;
|
||||
|
||||
import java.util.List;
|
||||
|
||||
@Service("ChatModelParameterConfig")
|
||||
@Slf4j
|
||||
public class ChatModelParameterConfig extends ParameterConfig {
|
||||
|
||||
public static final Parameter CHAT_MODEL_PROVIDER =
|
||||
new Parameter("s2.chat.model.provider", OpenAiModelFactory.PROVIDER,
|
||||
"接口协议", "", "list",
|
||||
"对话模型配置", getCandidateValues());
|
||||
|
||||
public static final Parameter CHAT_MODEL_BASE_URL =
|
||||
new Parameter("s2.chat.model.base.url", OpenAiModelFactory.DEFAULT_BASE_URL,
|
||||
"BaseUrl", "", "string",
|
||||
"对话模型配置", null, getBaseUrlDependency());
|
||||
public static final Parameter CHAT_MODEL_ENDPOINT =
|
||||
new Parameter("s2.chat.model.endpoint", "llama_2_70b",
|
||||
"Endpoint", "", "string",
|
||||
"对话模型配置", null, getEndpointDependency());
|
||||
public static final Parameter CHAT_MODEL_API_KEY =
|
||||
new Parameter("s2.chat.model.api.key", DEMO,
|
||||
"ApiKey", "", "password",
|
||||
"对话模型配置", null, getApiKeyDependency()
|
||||
);
|
||||
public static final Parameter CHAT_MODEL_SECRET_KEY =
|
||||
new Parameter("s2.chat.model.secretKey", "demo",
|
||||
"SecretKey", "", "password",
|
||||
"对话模型配置", null, getSecretKeyDependency());
|
||||
|
||||
public static final Parameter CHAT_MODEL_NAME =
|
||||
new Parameter("s2.chat.model.name", "gpt-3.5-turbo",
|
||||
"ModelName", "", "string",
|
||||
"对话模型配置", null, getModelNameDependency());
|
||||
|
||||
public static final Parameter CHAT_MODEL_ENABLE_SEARCH =
|
||||
new Parameter("s2.chat.model.enableSearch", "false",
|
||||
"是否启用搜索增强功能,设为false表示不启用", "", "bool",
|
||||
"对话模型配置", null, getEnableSearchDependency());
|
||||
|
||||
public static final Parameter CHAT_MODEL_TEMPERATURE =
|
||||
new Parameter("s2.chat.model.temperature", "0.0",
|
||||
"Temperature", "",
|
||||
"slider", "对话模型配置");
|
||||
|
||||
public static final Parameter CHAT_MODEL_TIMEOUT =
|
||||
new Parameter("s2.chat.model.timeout", "60",
|
||||
"超时时间(秒)", "",
|
||||
"number", "对话模型配置");
|
||||
|
||||
@Override
|
||||
public List<Parameter> getSysParameters() {
|
||||
return Lists.newArrayList(
|
||||
CHAT_MODEL_PROVIDER, CHAT_MODEL_BASE_URL, CHAT_MODEL_ENDPOINT,
|
||||
CHAT_MODEL_API_KEY, CHAT_MODEL_SECRET_KEY, CHAT_MODEL_NAME,
|
||||
CHAT_MODEL_ENABLE_SEARCH, CHAT_MODEL_TEMPERATURE, CHAT_MODEL_TIMEOUT
|
||||
);
|
||||
}
|
||||
|
||||
public ChatModelConfig convert() {
|
||||
String chatModelProvider = getParameterValue(CHAT_MODEL_PROVIDER);
|
||||
String chatModelBaseUrl = getParameterValue(CHAT_MODEL_BASE_URL);
|
||||
String chatModelApiKey = getParameterValue(CHAT_MODEL_API_KEY);
|
||||
String chatModelName = getParameterValue(CHAT_MODEL_NAME);
|
||||
String chatModelTemperature = getParameterValue(CHAT_MODEL_TEMPERATURE);
|
||||
String chatModelTimeout = getParameterValue(CHAT_MODEL_TIMEOUT);
|
||||
String endpoint = getParameterValue(CHAT_MODEL_ENDPOINT);
|
||||
String secretKey = getParameterValue(CHAT_MODEL_SECRET_KEY);
|
||||
String enableSearch = getParameterValue(CHAT_MODEL_ENABLE_SEARCH);
|
||||
|
||||
return ChatModelConfig.builder()
|
||||
.provider(chatModelProvider)
|
||||
.baseUrl(chatModelBaseUrl)
|
||||
.apiKey(chatModelApiKey)
|
||||
.modelName(chatModelName)
|
||||
.enableSearch(Boolean.valueOf(enableSearch))
|
||||
.temperature(Double.valueOf(chatModelTemperature))
|
||||
.timeOut(Long.valueOf(chatModelTimeout))
|
||||
.endpoint(endpoint)
|
||||
.secretKey(secretKey)
|
||||
.build();
|
||||
}
|
||||
|
||||
private static List<String> getCandidateValues() {
|
||||
return Lists.newArrayList(
|
||||
OpenAiModelFactory.PROVIDER,
|
||||
AzureModelFactory.PROVIDER,
|
||||
OllamaModelFactory.PROVIDER,
|
||||
QianfanModelFactory.PROVIDER,
|
||||
ZhipuModelFactory.PROVIDER,
|
||||
LocalAiModelFactory.PROVIDER,
|
||||
DashscopeModelFactory.PROVIDER);
|
||||
}
|
||||
|
||||
private static List<Parameter.Dependency> getBaseUrlDependency() {
|
||||
return getDependency(CHAT_MODEL_PROVIDER.getName(),
|
||||
getCandidateValues(),
|
||||
ImmutableMap.of(
|
||||
OpenAiModelFactory.PROVIDER, OpenAiModelFactory.DEFAULT_BASE_URL,
|
||||
AzureModelFactory.PROVIDER, AzureModelFactory.DEFAULT_BASE_URL,
|
||||
OllamaModelFactory.PROVIDER, OllamaModelFactory.DEFAULT_BASE_URL,
|
||||
QianfanModelFactory.PROVIDER, QianfanModelFactory.DEFAULT_BASE_URL,
|
||||
ZhipuModelFactory.PROVIDER, ZhipuModelFactory.DEFAULT_BASE_URL,
|
||||
LocalAiModelFactory.PROVIDER, LocalAiModelFactory.DEFAULT_BASE_URL,
|
||||
DashscopeModelFactory.PROVIDER, DashscopeModelFactory.DEFAULT_BASE_URL)
|
||||
);
|
||||
}
|
||||
|
||||
private static List<Parameter.Dependency> getApiKeyDependency() {
|
||||
return getDependency(CHAT_MODEL_PROVIDER.getName(),
|
||||
Lists.newArrayList(
|
||||
OpenAiModelFactory.PROVIDER,
|
||||
QianfanModelFactory.PROVIDER,
|
||||
ZhipuModelFactory.PROVIDER,
|
||||
LocalAiModelFactory.PROVIDER,
|
||||
AzureModelFactory.PROVIDER,
|
||||
DashscopeModelFactory.PROVIDER
|
||||
),
|
||||
ImmutableMap.of(
|
||||
OpenAiModelFactory.PROVIDER, DEMO,
|
||||
QianfanModelFactory.PROVIDER, DEMO,
|
||||
ZhipuModelFactory.PROVIDER, DEMO,
|
||||
LocalAiModelFactory.PROVIDER, DEMO,
|
||||
AzureModelFactory.PROVIDER, DEMO,
|
||||
DashscopeModelFactory.PROVIDER, DEMO
|
||||
));
|
||||
}
|
||||
|
||||
private static List<Parameter.Dependency> getModelNameDependency() {
|
||||
return getDependency(CHAT_MODEL_PROVIDER.getName(),
|
||||
getCandidateValues(),
|
||||
ImmutableMap.of(
|
||||
OpenAiModelFactory.PROVIDER, OpenAiModelFactory.DEFAULT_MODEL_NAME,
|
||||
OllamaModelFactory.PROVIDER, OllamaModelFactory.DEFAULT_MODEL_NAME,
|
||||
QianfanModelFactory.PROVIDER, QianfanModelFactory.DEFAULT_MODEL_NAME,
|
||||
ZhipuModelFactory.PROVIDER, ZhipuModelFactory.DEFAULT_MODEL_NAME,
|
||||
LocalAiModelFactory.PROVIDER, LocalAiModelFactory.DEFAULT_MODEL_NAME,
|
||||
AzureModelFactory.PROVIDER, AzureModelFactory.DEFAULT_MODEL_NAME,
|
||||
DashscopeModelFactory.PROVIDER, DashscopeModelFactory.DEFAULT_MODEL_NAME
|
||||
)
|
||||
);
|
||||
}
|
||||
|
||||
private static List<Parameter.Dependency> getEndpointDependency() {
|
||||
return getDependency(CHAT_MODEL_PROVIDER.getName(),
|
||||
Lists.newArrayList(QianfanModelFactory.PROVIDER),
|
||||
ImmutableMap.of(QianfanModelFactory.PROVIDER, QianfanModelFactory.DEFAULT_ENDPOINT)
|
||||
);
|
||||
}
|
||||
|
||||
private static List<Parameter.Dependency> getEnableSearchDependency() {
|
||||
return getDependency(CHAT_MODEL_PROVIDER.getName(),
|
||||
Lists.newArrayList(DashscopeModelFactory.PROVIDER),
|
||||
ImmutableMap.of(DashscopeModelFactory.PROVIDER, "false")
|
||||
);
|
||||
}
|
||||
|
||||
private static List<Parameter.Dependency> getSecretKeyDependency() {
|
||||
return getDependency(CHAT_MODEL_PROVIDER.getName(),
|
||||
Lists.newArrayList(QianfanModelFactory.PROVIDER),
|
||||
ImmutableMap.of(QianfanModelFactory.PROVIDER, DEMO)
|
||||
);
|
||||
}
|
||||
}
|
||||
@@ -16,7 +16,10 @@ public class DataBaseConfig {
|
||||
@Primary
|
||||
@ConfigurationProperties("spring.datasource")
|
||||
public DataSource dataSource() {
|
||||
return new DruidDataSource();
|
||||
DruidDataSource druidDataSource = new DruidDataSource();
|
||||
druidDataSource.setTestWhileIdle(true);
|
||||
druidDataSource.setValidationQuery("select 1");
|
||||
return druidDataSource;
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
@@ -0,0 +1,168 @@
|
||||
package com.tencent.supersonic.common.config;
|
||||
|
||||
import com.google.common.collect.ImmutableMap;
|
||||
import com.google.common.collect.Lists;
|
||||
import com.tencent.supersonic.common.pojo.EmbeddingModelConfig;
|
||||
import com.tencent.supersonic.common.pojo.Parameter;
|
||||
import dev.langchain4j.provider.AzureModelFactory;
|
||||
import dev.langchain4j.provider.DashscopeModelFactory;
|
||||
import dev.langchain4j.provider.EmbeddingModelConstant;
|
||||
import dev.langchain4j.provider.InMemoryModelFactory;
|
||||
import dev.langchain4j.provider.OllamaModelFactory;
|
||||
import dev.langchain4j.provider.OpenAiModelFactory;
|
||||
import dev.langchain4j.provider.QianfanModelFactory;
|
||||
import dev.langchain4j.provider.ZhipuModelFactory;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.springframework.stereotype.Service;
|
||||
|
||||
import java.util.ArrayList;
|
||||
import java.util.List;
|
||||
|
||||
@Service("EmbeddingModelParameterConfig")
|
||||
@Slf4j
|
||||
public class EmbeddingModelParameterConfig extends ParameterConfig {
|
||||
public static final Parameter EMBEDDING_MODEL_PROVIDER =
|
||||
new Parameter("s2.embedding.model.provider", InMemoryModelFactory.PROVIDER,
|
||||
"接口协议", "", "list",
|
||||
"向量模型配置", getCandidateValues());
|
||||
public static final Parameter EMBEDDING_MODEL_BASE_URL =
|
||||
new Parameter("s2.embedding.model.base.url", "",
|
||||
"BaseUrl", "", "string",
|
||||
"向量模型配置", null, getBaseUrlDependency()
|
||||
);
|
||||
|
||||
public static final Parameter EMBEDDING_MODEL_API_KEY =
|
||||
new Parameter("s2.embedding.model.api.key", "",
|
||||
"ApiKey", "", "password",
|
||||
"向量模型配置", null, getApiKeyDependency());
|
||||
|
||||
public static final Parameter EMBEDDING_MODEL_SECRET_KEY =
|
||||
new Parameter("s2.embedding.model.secretKey", "demo",
|
||||
"SecretKey", "", "password",
|
||||
"向量模型配置", null, getSecretKeyDependency());
|
||||
|
||||
public static final Parameter EMBEDDING_MODEL_NAME =
|
||||
new Parameter("s2.embedding.model.name", EmbeddingModelConstant.BGE_SMALL_ZH,
|
||||
"ModelName", "", "string",
|
||||
"向量模型配置", null, getModelNameDependency());
|
||||
|
||||
public static final Parameter EMBEDDING_MODEL_PATH =
|
||||
new Parameter("s2.embedding.model.path", "",
|
||||
"模型路径", "", "string",
|
||||
"向量模型配置", null, getModelPathDependency());
|
||||
public static final Parameter EMBEDDING_MODEL_VOCABULARY_PATH =
|
||||
new Parameter("s2.embedding.model.vocabulary.path", "",
|
||||
"词汇表路径", "", "string",
|
||||
"向量模型配置", null, getModelPathDependency());
|
||||
|
||||
@Override
|
||||
public List<Parameter> getSysParameters() {
|
||||
return Lists.newArrayList(
|
||||
EMBEDDING_MODEL_PROVIDER, EMBEDDING_MODEL_BASE_URL, EMBEDDING_MODEL_API_KEY,
|
||||
EMBEDDING_MODEL_SECRET_KEY, EMBEDDING_MODEL_NAME, EMBEDDING_MODEL_PATH,
|
||||
EMBEDDING_MODEL_VOCABULARY_PATH
|
||||
);
|
||||
}
|
||||
|
||||
public EmbeddingModelConfig convert() {
|
||||
String provider = getParameterValue(EMBEDDING_MODEL_PROVIDER);
|
||||
String baseUrl = getParameterValue(EMBEDDING_MODEL_BASE_URL);
|
||||
String apiKey = getParameterValue(EMBEDDING_MODEL_API_KEY);
|
||||
String modelName = getParameterValue(EMBEDDING_MODEL_NAME);
|
||||
String modelPath = getParameterValue(EMBEDDING_MODEL_PATH);
|
||||
String vocabularyPath = getParameterValue(EMBEDDING_MODEL_VOCABULARY_PATH);
|
||||
String secretKey = getParameterValue(EMBEDDING_MODEL_SECRET_KEY);
|
||||
return EmbeddingModelConfig.builder()
|
||||
.provider(provider)
|
||||
.baseUrl(baseUrl)
|
||||
.apiKey(apiKey)
|
||||
.secretKey(secretKey)
|
||||
.modelName(modelName)
|
||||
.modelPath(modelPath)
|
||||
.vocabularyPath(vocabularyPath)
|
||||
.build();
|
||||
}
|
||||
|
||||
private static ArrayList<String> getCandidateValues() {
|
||||
return Lists.newArrayList(
|
||||
InMemoryModelFactory.PROVIDER,
|
||||
OpenAiModelFactory.PROVIDER,
|
||||
OllamaModelFactory.PROVIDER,
|
||||
AzureModelFactory.PROVIDER,
|
||||
DashscopeModelFactory.PROVIDER,
|
||||
QianfanModelFactory.PROVIDER,
|
||||
ZhipuModelFactory.PROVIDER
|
||||
);
|
||||
}
|
||||
|
||||
private static List<Parameter.Dependency> getBaseUrlDependency() {
|
||||
return getDependency(EMBEDDING_MODEL_PROVIDER.getName(),
|
||||
Lists.newArrayList(OpenAiModelFactory.PROVIDER,
|
||||
OllamaModelFactory.PROVIDER,
|
||||
AzureModelFactory.PROVIDER,
|
||||
DashscopeModelFactory.PROVIDER,
|
||||
QianfanModelFactory.PROVIDER,
|
||||
ZhipuModelFactory.PROVIDER),
|
||||
ImmutableMap.of(
|
||||
OpenAiModelFactory.PROVIDER, OpenAiModelFactory.DEFAULT_BASE_URL,
|
||||
OllamaModelFactory.PROVIDER, OllamaModelFactory.DEFAULT_BASE_URL,
|
||||
AzureModelFactory.PROVIDER, AzureModelFactory.DEFAULT_BASE_URL,
|
||||
DashscopeModelFactory.PROVIDER, DashscopeModelFactory.DEFAULT_BASE_URL,
|
||||
QianfanModelFactory.PROVIDER, QianfanModelFactory.DEFAULT_BASE_URL,
|
||||
ZhipuModelFactory.PROVIDER, ZhipuModelFactory.DEFAULT_BASE_URL
|
||||
)
|
||||
);
|
||||
}
|
||||
|
||||
private static List<Parameter.Dependency> getApiKeyDependency() {
|
||||
return getDependency(EMBEDDING_MODEL_PROVIDER.getName(),
|
||||
Lists.newArrayList(OpenAiModelFactory.PROVIDER,
|
||||
AzureModelFactory.PROVIDER,
|
||||
DashscopeModelFactory.PROVIDER,
|
||||
QianfanModelFactory.PROVIDER,
|
||||
ZhipuModelFactory.PROVIDER),
|
||||
ImmutableMap.of(OpenAiModelFactory.PROVIDER, DEMO,
|
||||
AzureModelFactory.PROVIDER, DEMO,
|
||||
DashscopeModelFactory.PROVIDER, DEMO,
|
||||
QianfanModelFactory.PROVIDER, DEMO,
|
||||
ZhipuModelFactory.PROVIDER, DEMO)
|
||||
);
|
||||
}
|
||||
|
||||
private static List<Parameter.Dependency> getModelNameDependency() {
|
||||
return getDependency(EMBEDDING_MODEL_PROVIDER.getName(),
|
||||
Lists.newArrayList(
|
||||
InMemoryModelFactory.PROVIDER,
|
||||
OpenAiModelFactory.PROVIDER,
|
||||
OllamaModelFactory.PROVIDER,
|
||||
AzureModelFactory.PROVIDER,
|
||||
DashscopeModelFactory.PROVIDER,
|
||||
QianfanModelFactory.PROVIDER,
|
||||
ZhipuModelFactory.PROVIDER
|
||||
),
|
||||
ImmutableMap.of(
|
||||
InMemoryModelFactory.PROVIDER, EmbeddingModelConstant.BGE_SMALL_ZH,
|
||||
OpenAiModelFactory.PROVIDER, OpenAiModelFactory.DEFAULT_EMBEDDING_MODEL_NAME,
|
||||
OllamaModelFactory.PROVIDER, OllamaModelFactory.DEFAULT_EMBEDDING_MODEL_NAME,
|
||||
AzureModelFactory.PROVIDER, AzureModelFactory.DEFAULT_EMBEDDING_MODEL_NAME,
|
||||
DashscopeModelFactory.PROVIDER, DashscopeModelFactory.DEFAULT_EMBEDDING_MODEL_NAME,
|
||||
QianfanModelFactory.PROVIDER, QianfanModelFactory.DEFAULT_EMBEDDING_MODEL_NAME,
|
||||
ZhipuModelFactory.PROVIDER, ZhipuModelFactory.DEFAULT_EMBEDDING_MODEL_NAME
|
||||
)
|
||||
);
|
||||
}
|
||||
|
||||
private static List<Parameter.Dependency> getModelPathDependency() {
|
||||
return getDependency(EMBEDDING_MODEL_PROVIDER.getName(),
|
||||
Lists.newArrayList(InMemoryModelFactory.PROVIDER),
|
||||
ImmutableMap.of(InMemoryModelFactory.PROVIDER, "")
|
||||
);
|
||||
}
|
||||
|
||||
private static List<Parameter.Dependency> getSecretKeyDependency() {
|
||||
return getDependency(EMBEDDING_MODEL_PROVIDER.getName(),
|
||||
Lists.newArrayList(QianfanModelFactory.PROVIDER),
|
||||
ImmutableMap.of(QianfanModelFactory.PROVIDER, DEMO)
|
||||
);
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,122 @@
|
||||
package com.tencent.supersonic.common.config;
|
||||
|
||||
import com.google.common.collect.ImmutableMap;
|
||||
import com.google.common.collect.Lists;
|
||||
import com.tencent.supersonic.common.pojo.EmbeddingStoreConfig;
|
||||
import com.tencent.supersonic.common.pojo.Parameter;
|
||||
import dev.langchain4j.store.embedding.EmbeddingStoreType;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.apache.commons.lang3.StringUtils;
|
||||
import org.springframework.stereotype.Service;
|
||||
|
||||
import java.util.ArrayList;
|
||||
import java.util.List;
|
||||
|
||||
@Service("EmbeddingStoreParameterConfig")
|
||||
@Slf4j
|
||||
public class EmbeddingStoreParameterConfig extends ParameterConfig {
|
||||
public static final Parameter EMBEDDING_STORE_PROVIDER =
|
||||
new Parameter("s2.embedding.store.provider", EmbeddingStoreType.IN_MEMORY.name(),
|
||||
"向量库类型", "目前支持三种类型:IN_MEMORY、MILVUS、CHROMA", "list",
|
||||
"向量库配置", getCandidateValues());
|
||||
|
||||
public static final Parameter EMBEDDING_STORE_BASE_URL =
|
||||
new Parameter("s2.embedding.store.base.url", "",
|
||||
"BaseUrl", "", "string",
|
||||
"向量库配置", null, getBaseUrlDependency());
|
||||
|
||||
public static final Parameter EMBEDDING_STORE_API_KEY =
|
||||
new Parameter("s2.embedding.store.api.key", "",
|
||||
"ApiKey", "", "password",
|
||||
"向量库配置", null, getApiKeyDependency());
|
||||
|
||||
public static final Parameter EMBEDDING_STORE_PERSIST_PATH =
|
||||
new Parameter("s2.embedding.store.persist.path", "",
|
||||
"持久化路径", "默认不持久化,如需持久化请填写持久化路径。"
|
||||
+ "注意:如果变更了向量模型需删除该路径下已保存的文件或修改持久化路径", "string",
|
||||
"向量库配置", null, getPathDependency());
|
||||
|
||||
public static final Parameter EMBEDDING_STORE_TIMEOUT =
|
||||
new Parameter("s2.embedding.store.timeout", "60",
|
||||
"超时时间(秒)", "",
|
||||
"number", "向量库配置");
|
||||
|
||||
public static final Parameter EMBEDDING_STORE_DIMENSION =
|
||||
new Parameter("s2.embedding.store.dimension", "",
|
||||
"纬度", "", "number",
|
||||
"向量库配置", null, getDimensionDependency());
|
||||
public static final Parameter EMBEDDING_STORE_DATABASE_NAME =
|
||||
new Parameter("s2.embedding.store.databaseName", "",
|
||||
"DatabaseName", "", "string",
|
||||
"向量库配置", null, getDatabaseNameDependency());
|
||||
|
||||
@Override
|
||||
public List<Parameter> getSysParameters() {
|
||||
return Lists.newArrayList(
|
||||
EMBEDDING_STORE_PROVIDER, EMBEDDING_STORE_BASE_URL, EMBEDDING_STORE_API_KEY,
|
||||
EMBEDDING_STORE_DATABASE_NAME, EMBEDDING_STORE_PERSIST_PATH,
|
||||
EMBEDDING_STORE_TIMEOUT, EMBEDDING_STORE_DIMENSION
|
||||
);
|
||||
}
|
||||
|
||||
public EmbeddingStoreConfig convert() {
|
||||
String provider = getParameterValue(EMBEDDING_STORE_PROVIDER);
|
||||
String baseUrl = getParameterValue(EMBEDDING_STORE_BASE_URL);
|
||||
String apiKey = getParameterValue(EMBEDDING_STORE_API_KEY);
|
||||
String persistPath = getParameterValue(EMBEDDING_STORE_PERSIST_PATH);
|
||||
String timeOut = getParameterValue(EMBEDDING_STORE_TIMEOUT);
|
||||
String databaseName = getParameterValue(EMBEDDING_STORE_DATABASE_NAME);
|
||||
Integer dimension = null;
|
||||
if (StringUtils.isNumeric(getParameterValue(EMBEDDING_STORE_DIMENSION))) {
|
||||
dimension = Integer.valueOf(getParameterValue(EMBEDDING_STORE_DIMENSION));
|
||||
}
|
||||
return EmbeddingStoreConfig.builder().provider(provider).baseUrl(baseUrl)
|
||||
.apiKey(apiKey).persistPath(persistPath).databaseName(databaseName)
|
||||
.timeOut(Long.valueOf(timeOut)).dimension(dimension).build();
|
||||
}
|
||||
|
||||
private static ArrayList<String> getCandidateValues() {
|
||||
return Lists.newArrayList(
|
||||
EmbeddingStoreType.IN_MEMORY.name(),
|
||||
EmbeddingStoreType.MILVUS.name(),
|
||||
EmbeddingStoreType.CHROMA.name());
|
||||
}
|
||||
|
||||
private static List<Parameter.Dependency> getBaseUrlDependency() {
|
||||
return getDependency(EMBEDDING_STORE_PROVIDER.getName(),
|
||||
Lists.newArrayList(
|
||||
EmbeddingStoreType.MILVUS.name(), EmbeddingStoreType.CHROMA.name()),
|
||||
ImmutableMap.of(
|
||||
EmbeddingStoreType.MILVUS.name(), "http://localhost:19530",
|
||||
EmbeddingStoreType.CHROMA.name(), "http://localhost:8000"
|
||||
)
|
||||
);
|
||||
}
|
||||
|
||||
private static List<Parameter.Dependency> getApiKeyDependency() {
|
||||
return getDependency(EMBEDDING_STORE_PROVIDER.getName(),
|
||||
Lists.newArrayList(EmbeddingStoreType.MILVUS.name()),
|
||||
ImmutableMap.of(EmbeddingStoreType.MILVUS.name(), DEMO)
|
||||
);
|
||||
}
|
||||
|
||||
private static List<Parameter.Dependency> getPathDependency() {
|
||||
return getDependency(EMBEDDING_STORE_PROVIDER.getName(),
|
||||
Lists.newArrayList(EmbeddingStoreType.IN_MEMORY.name()),
|
||||
ImmutableMap.of(EmbeddingStoreType.IN_MEMORY.name(), ""));
|
||||
}
|
||||
|
||||
private static List<Parameter.Dependency> getDimensionDependency() {
|
||||
return getDependency(EMBEDDING_STORE_PROVIDER.getName(),
|
||||
Lists.newArrayList(EmbeddingStoreType.MILVUS.name()),
|
||||
ImmutableMap.of(EmbeddingStoreType.MILVUS.name(), "384")
|
||||
);
|
||||
}
|
||||
|
||||
private static List<Parameter.Dependency> getDatabaseNameDependency() {
|
||||
return getDependency(EMBEDDING_STORE_PROVIDER.getName(),
|
||||
Lists.newArrayList(EmbeddingStoreType.MILVUS.name()),
|
||||
ImmutableMap.of(EmbeddingStoreType.MILVUS.name(), "")
|
||||
);
|
||||
}
|
||||
}
|
||||
@@ -1,45 +0,0 @@
|
||||
package com.tencent.supersonic.common.config;
|
||||
|
||||
import com.tencent.supersonic.common.util.AESEncryptionUtil;
|
||||
import lombok.AllArgsConstructor;
|
||||
import lombok.Data;
|
||||
import lombok.NoArgsConstructor;
|
||||
|
||||
@Data
|
||||
@AllArgsConstructor
|
||||
@NoArgsConstructor
|
||||
public class LLMConfig {
|
||||
|
||||
private String provider;
|
||||
|
||||
private String baseUrl;
|
||||
|
||||
private String apiKey;
|
||||
|
||||
private String modelName;
|
||||
|
||||
private Double temperature = 0.0d;
|
||||
|
||||
private Long timeOut = 60L;
|
||||
|
||||
public LLMConfig(String provider, String baseUrl, String apiKey, String modelName) {
|
||||
this.provider = provider;
|
||||
this.baseUrl = baseUrl;
|
||||
this.apiKey = apiKey;
|
||||
this.modelName = modelName;
|
||||
}
|
||||
|
||||
public LLMConfig(String provider, String baseUrl, String apiKey, String modelName,
|
||||
double temperature) {
|
||||
this.provider = provider;
|
||||
this.baseUrl = baseUrl;
|
||||
this.apiKey = apiKey;
|
||||
this.modelName = modelName;
|
||||
this.temperature = temperature;
|
||||
}
|
||||
|
||||
public String keyDecrypt() {
|
||||
return AESEncryptionUtil.aesDecryptECB(apiKey);
|
||||
}
|
||||
|
||||
}
|
||||
@@ -7,11 +7,14 @@ import org.springframework.beans.factory.annotation.Autowired;
|
||||
import org.springframework.core.env.Environment;
|
||||
import org.springframework.stereotype.Service;
|
||||
|
||||
import java.util.ArrayList;
|
||||
import java.util.Collections;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
|
||||
@Service
|
||||
public abstract class ParameterConfig {
|
||||
|
||||
public static final String DEMO = "demo";
|
||||
@Autowired
|
||||
private SystemConfigService sysConfigService;
|
||||
|
||||
@@ -21,13 +24,16 @@ public abstract class ParameterConfig {
|
||||
/**
|
||||
* @return system parameters to be set with user interface
|
||||
*/
|
||||
protected abstract List<Parameter> getSysParameters();
|
||||
protected List<Parameter> getSysParameters() {
|
||||
return Collections.EMPTY_LIST;
|
||||
}
|
||||
|
||||
/**
|
||||
* Parameter value will be derived in the following order:
|
||||
* 1. `system config` set with user interface
|
||||
* 2. `system property` set with application.yaml file
|
||||
* 3. `default value` set with parameter declaration
|
||||
*
|
||||
* @param parameter instance
|
||||
* @return parameter value
|
||||
*/
|
||||
@@ -44,4 +50,22 @@ public abstract class ParameterConfig {
|
||||
|
||||
return value;
|
||||
}
|
||||
|
||||
protected static List<Parameter.Dependency> getDependency(
|
||||
String dependencyParameterName,
|
||||
List<String> includesValue,
|
||||
Map<String, String> setDefaultValue) {
|
||||
|
||||
Parameter.Dependency.Show show = new Parameter.Dependency.Show();
|
||||
show.setIncludesValue(includesValue);
|
||||
|
||||
Parameter.Dependency dependency = new Parameter.Dependency();
|
||||
dependency.setName(dependencyParameterName);
|
||||
dependency.setShow(show);
|
||||
dependency.setSetDefaultValue(setDefaultValue);
|
||||
List<Parameter.Dependency> dependencies = new ArrayList<>();
|
||||
dependencies.add(dependency);
|
||||
return dependencies;
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
@@ -0,0 +1,14 @@
|
||||
package com.tencent.supersonic.common.config;
|
||||
|
||||
import lombok.AllArgsConstructor;
|
||||
import lombok.Data;
|
||||
import lombok.NoArgsConstructor;
|
||||
|
||||
@Data
|
||||
@AllArgsConstructor
|
||||
@NoArgsConstructor
|
||||
public class PromptConfig {
|
||||
|
||||
private String promptTemplate;
|
||||
|
||||
}
|
||||
@@ -29,8 +29,7 @@ public enum AggregateEnum {
|
||||
}
|
||||
|
||||
public static Map<String, String> getAggregateEnum() {
|
||||
Map<String, String> aggregateMap = Arrays.stream(AggregateEnum.values())
|
||||
return Arrays.stream(AggregateEnum.values())
|
||||
.collect(Collectors.toMap(AggregateEnum::getAggregateCh, AggregateEnum::getAggregateEN));
|
||||
return aggregateMap;
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,24 +1,36 @@
|
||||
package com.tencent.supersonic.common.jsqlparser;
|
||||
|
||||
import java.util.Map;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import net.sf.jsqlparser.expression.ExpressionVisitorAdapter;
|
||||
import net.sf.jsqlparser.expression.Function;
|
||||
import net.sf.jsqlparser.schema.Column;
|
||||
|
||||
import java.util.Map;
|
||||
|
||||
@Slf4j
|
||||
public class FieldReplaceVisitor extends ExpressionVisitorAdapter {
|
||||
|
||||
ParseVisitorHelper parseVisitorHelper = new ParseVisitorHelper();
|
||||
private Map<String, String> fieldNameMap;
|
||||
private boolean exactReplace;
|
||||
private ThreadLocal<Boolean> exactReplace = ThreadLocal.withInitial(() -> false);
|
||||
|
||||
public FieldReplaceVisitor(Map<String, String> fieldNameMap, boolean exactReplace) {
|
||||
this.fieldNameMap = fieldNameMap;
|
||||
this.exactReplace = exactReplace;
|
||||
this.exactReplace.set(exactReplace);
|
||||
}
|
||||
|
||||
@Override
|
||||
public void visit(Column column) {
|
||||
parseVisitorHelper.replaceColumn(column, fieldNameMap, exactReplace);
|
||||
parseVisitorHelper.replaceColumn(column, fieldNameMap, exactReplace.get());
|
||||
}
|
||||
|
||||
@Override
|
||||
public void visit(Function function) {
|
||||
boolean originalExactReplace = exactReplace.get();
|
||||
exactReplace.set(true);
|
||||
try {
|
||||
super.visit(function);
|
||||
} finally {
|
||||
exactReplace.set(originalExactReplace);
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -1,10 +1,5 @@
|
||||
package com.tencent.supersonic.common.jsqlparser;
|
||||
|
||||
import com.tencent.supersonic.common.util.JsonUtil;
|
||||
import java.util.ArrayList;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
import java.util.Objects;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import net.sf.jsqlparser.expression.DoubleValue;
|
||||
import net.sf.jsqlparser.expression.Expression;
|
||||
@@ -24,14 +19,19 @@ import net.sf.jsqlparser.schema.Column;
|
||||
import org.apache.commons.lang3.StringUtils;
|
||||
import org.springframework.util.CollectionUtils;
|
||||
|
||||
import java.util.ArrayList;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
import java.util.Objects;
|
||||
|
||||
@Slf4j
|
||||
public class FieldlValueReplaceVisitor extends ExpressionVisitorAdapter {
|
||||
public class FieldValueReplaceVisitor extends ExpressionVisitorAdapter {
|
||||
|
||||
ParseVisitorHelper parseVisitorHelper = new ParseVisitorHelper();
|
||||
private boolean exactReplace;
|
||||
private Map<String, Map<String, String>> filedNameToValueMap;
|
||||
|
||||
public FieldlValueReplaceVisitor(boolean exactReplace, Map<String, Map<String, String>> filedNameToValueMap) {
|
||||
public FieldValueReplaceVisitor(boolean exactReplace, Map<String, Map<String, String>> filedNameToValueMap) {
|
||||
this.exactReplace = exactReplace;
|
||||
this.filedNameToValueMap = filedNameToValueMap;
|
||||
}
|
||||
@@ -71,17 +71,13 @@ public class FieldlValueReplaceVisitor extends ExpressionVisitorAdapter {
|
||||
values.add(((StringValue) o).getValue());
|
||||
}
|
||||
});
|
||||
if (valueMap == null) {
|
||||
if (valueMap == null || CollectionUtils.isEmpty(values)) {
|
||||
return;
|
||||
}
|
||||
String value = valueMap.get(JsonUtil.toString(values));
|
||||
if (StringUtils.isBlank(value)) {
|
||||
return;
|
||||
}
|
||||
List<String> valueList = JsonUtil.toList(value, String.class);
|
||||
List<Expression> newExpressions = new ArrayList<>();
|
||||
valueList.stream().forEach(o -> {
|
||||
StringValue stringValue = new StringValue(o);
|
||||
values.stream().forEach(o -> {
|
||||
String replaceValue = valueMap.getOrDefault(o, o);
|
||||
StringValue stringValue = new StringValue(replaceValue);
|
||||
newExpressions.add(stringValue);
|
||||
});
|
||||
rightItemsList.setExpressions(newExpressions);
|
||||
@@ -1,9 +1,5 @@
|
||||
package com.tencent.supersonic.common.jsqlparser;
|
||||
|
||||
import java.util.ArrayList;
|
||||
import java.util.List;
|
||||
import java.util.Objects;
|
||||
import java.util.Set;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import net.sf.jsqlparser.JSQLParserException;
|
||||
import net.sf.jsqlparser.expression.Expression;
|
||||
@@ -20,6 +16,11 @@ import net.sf.jsqlparser.parser.CCJSqlParserUtil;
|
||||
import net.sf.jsqlparser.schema.Column;
|
||||
import org.apache.commons.collections.CollectionUtils;
|
||||
|
||||
import java.util.ArrayList;
|
||||
import java.util.List;
|
||||
import java.util.Objects;
|
||||
import java.util.Set;
|
||||
|
||||
@Slf4j
|
||||
public class FiledFilterReplaceVisitor extends ExpressionVisitorAdapter {
|
||||
|
||||
@@ -76,37 +77,39 @@ public class FiledFilterReplaceVisitor extends ExpressionVisitorAdapter {
|
||||
|
||||
public List<Expression> parserFilter(ComparisonOperator comparisonOperator, String condExpr) {
|
||||
List<Expression> result = new ArrayList<>();
|
||||
String toString = comparisonOperator.toString();
|
||||
String comparisonOperatorStr = comparisonOperator.toString();
|
||||
Expression leftExpression = comparisonOperator.getLeftExpression();
|
||||
|
||||
if (!(leftExpression instanceof Function)) {
|
||||
return result;
|
||||
}
|
||||
Function leftExpressionFunction = (Function) leftExpression;
|
||||
if (leftExpressionFunction.toString().contains(JsqlConstants.DATE_FUNCTION)) {
|
||||
|
||||
Function leftFunction = (Function) leftExpression;
|
||||
if (leftFunction.toString().contains(JsqlConstants.DATE_FUNCTION)) {
|
||||
return result;
|
||||
}
|
||||
|
||||
//List<Expression> leftExpressions = leftExpressionFunction.getParameters().getExpressions();
|
||||
ExpressionList<?> leftExpressions = leftExpressionFunction.getParameters();
|
||||
if (CollectionUtils.isEmpty(leftExpressions)) {
|
||||
ExpressionList<?> leftFunctionParams = leftFunction.getParameters();
|
||||
if (CollectionUtils.isEmpty(leftFunctionParams)) {
|
||||
return result;
|
||||
}
|
||||
Column field = (Column) leftExpressions.get(0);
|
||||
|
||||
Column field = (Column) leftFunctionParams.get(0);
|
||||
String columnName = field.getColumnName();
|
||||
if (!fieldNames.contains(columnName)) {
|
||||
return null;
|
||||
}
|
||||
|
||||
try {
|
||||
ComparisonOperator expression = (ComparisonOperator) CCJSqlParserUtil.parseCondExpression(condExpr);
|
||||
comparisonOperator.setLeftExpression(expression.getLeftExpression());
|
||||
comparisonOperator.setRightExpression(expression.getRightExpression());
|
||||
comparisonOperator.setASTNode(expression.getASTNode());
|
||||
result.add(CCJSqlParserUtil.parseCondExpression(toString));
|
||||
ComparisonOperator parsedExpression = (ComparisonOperator) CCJSqlParserUtil.parseCondExpression(condExpr);
|
||||
comparisonOperator.setLeftExpression(parsedExpression.getLeftExpression());
|
||||
comparisonOperator.setRightExpression(parsedExpression.getRightExpression());
|
||||
comparisonOperator.setASTNode(parsedExpression.getASTNode());
|
||||
result.add(CCJSqlParserUtil.parseCondExpression(comparisonOperatorStr));
|
||||
return result;
|
||||
} catch (JSQLParserException e) {
|
||||
log.error("JSQLParserException", e);
|
||||
}
|
||||
return null;
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
@@ -1,8 +1,5 @@
|
||||
package com.tencent.supersonic.common.jsqlparser;
|
||||
|
||||
import java.util.Map;
|
||||
import java.util.Objects;
|
||||
import java.util.Set;
|
||||
import net.sf.jsqlparser.expression.BinaryExpression;
|
||||
import net.sf.jsqlparser.expression.Expression;
|
||||
import net.sf.jsqlparser.expression.ExpressionVisitorAdapter;
|
||||
@@ -12,9 +9,11 @@ import net.sf.jsqlparser.expression.operators.relational.LikeExpression;
|
||||
import net.sf.jsqlparser.schema.Column;
|
||||
import org.springframework.util.CollectionUtils;
|
||||
|
||||
public class FiledNameReplaceVisitor extends ExpressionVisitorAdapter {
|
||||
import java.util.Map;
|
||||
import java.util.Objects;
|
||||
import java.util.Set;
|
||||
|
||||
public static final String PREFIX = "%";
|
||||
public class FiledNameReplaceVisitor extends ExpressionVisitorAdapter {
|
||||
private Map<String, Set<String>> fieldValueToFieldNames;
|
||||
|
||||
public FiledNameReplaceVisitor(Map<String, Set<String>> fieldValueToFieldNames) {
|
||||
@@ -34,40 +33,20 @@ public class FiledNameReplaceVisitor extends ExpressionVisitorAdapter {
|
||||
private void replaceFieldNameByFieldValue(BinaryExpression expr) {
|
||||
Expression leftExpression = expr.getLeftExpression();
|
||||
Expression rightExpression = expr.getRightExpression();
|
||||
if (!(rightExpression instanceof StringValue)) {
|
||||
|
||||
if (!(rightExpression instanceof StringValue) || !(leftExpression instanceof Column)
|
||||
|| CollectionUtils.isEmpty(fieldValueToFieldNames)
|
||||
|| Objects.isNull(rightExpression) || Objects.isNull(leftExpression)) {
|
||||
return;
|
||||
}
|
||||
if (!(leftExpression instanceof Column)) {
|
||||
return;
|
||||
}
|
||||
if (CollectionUtils.isEmpty(fieldValueToFieldNames)) {
|
||||
return;
|
||||
}
|
||||
if (Objects.isNull(rightExpression) || Objects.isNull(leftExpression)) {
|
||||
return;
|
||||
}
|
||||
Column leftColumnName = (Column) leftExpression;
|
||||
|
||||
Column leftColumn = (Column) leftExpression;
|
||||
StringValue rightStringValue = (StringValue) rightExpression;
|
||||
|
||||
if (expr instanceof LikeExpression) {
|
||||
String value = getValue(rightStringValue.getValue());
|
||||
rightStringValue.setValue(value);
|
||||
}
|
||||
|
||||
Set<String> fieldNames = fieldValueToFieldNames.get(rightStringValue.getValue());
|
||||
if (!CollectionUtils.isEmpty(fieldNames) && !fieldNames.contains(leftColumnName.getColumnName())) {
|
||||
leftColumnName.setColumnName(fieldNames.stream().findFirst().get());
|
||||
if (!CollectionUtils.isEmpty(fieldNames) && !fieldNames.contains(leftColumn.getColumnName())) {
|
||||
leftColumn.setColumnName(fieldNames.stream().findFirst().get());
|
||||
}
|
||||
}
|
||||
|
||||
private String getValue(String value) {
|
||||
if (value.startsWith(PREFIX)) {
|
||||
value = value.substring(1);
|
||||
}
|
||||
if (value.endsWith(PREFIX)) {
|
||||
value = value.substring(0, value.length() - 1);
|
||||
}
|
||||
return value;
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
@@ -1,94 +0,0 @@
|
||||
package com.tencent.supersonic.common.jsqlparser;
|
||||
|
||||
import java.util.List;
|
||||
import net.sf.jsqlparser.expression.Expression;
|
||||
import net.sf.jsqlparser.expression.ExpressionVisitorAdapter;
|
||||
import net.sf.jsqlparser.expression.LongValue;
|
||||
import net.sf.jsqlparser.expression.operators.relational.EqualsTo;
|
||||
import net.sf.jsqlparser.expression.operators.relational.GreaterThan;
|
||||
import net.sf.jsqlparser.expression.operators.relational.GreaterThanEquals;
|
||||
import net.sf.jsqlparser.expression.operators.relational.InExpression;
|
||||
import net.sf.jsqlparser.expression.operators.relational.MinorThan;
|
||||
import net.sf.jsqlparser.expression.operators.relational.MinorThanEquals;
|
||||
import net.sf.jsqlparser.schema.Column;
|
||||
import org.apache.commons.lang3.StringUtils;
|
||||
|
||||
public class FilterRemoveVisitor extends ExpressionVisitorAdapter {
|
||||
|
||||
private List<String> filedNames;
|
||||
|
||||
public FilterRemoveVisitor(List<String> filedNames) {
|
||||
this.filedNames = filedNames;
|
||||
}
|
||||
|
||||
private boolean isRemove(Expression leftExpression) {
|
||||
if (!(leftExpression instanceof Column)) {
|
||||
return false;
|
||||
}
|
||||
Column leftColumnName = (Column) leftExpression;
|
||||
String columnName = leftColumnName.getColumnName();
|
||||
if (StringUtils.isEmpty(columnName)) {
|
||||
return false;
|
||||
}
|
||||
if (!filedNames.contains(columnName)) {
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
@Override
|
||||
public void visit(EqualsTo expr) {
|
||||
if (!isRemove(expr.getLeftExpression())) {
|
||||
return;
|
||||
}
|
||||
expr.setRightExpression(new LongValue(1L));
|
||||
expr.setLeftExpression(new LongValue(1L));
|
||||
}
|
||||
|
||||
@Override
|
||||
public void visit(MinorThan expr) {
|
||||
if (!isRemove(expr.getLeftExpression())) {
|
||||
return;
|
||||
}
|
||||
expr.setRightExpression(new LongValue(1L));
|
||||
expr.setLeftExpression(new LongValue(0L));
|
||||
}
|
||||
|
||||
@Override
|
||||
public void visit(MinorThanEquals expr) {
|
||||
if (!isRemove(expr.getLeftExpression())) {
|
||||
return;
|
||||
}
|
||||
expr.setRightExpression(new LongValue(1L));
|
||||
expr.setLeftExpression(new LongValue(1L));
|
||||
}
|
||||
|
||||
@Override
|
||||
public void visit(GreaterThan expr) {
|
||||
if (!isRemove(expr.getLeftExpression())) {
|
||||
return;
|
||||
}
|
||||
expr.setRightExpression(new LongValue(0L));
|
||||
expr.setLeftExpression(new LongValue(1L));
|
||||
}
|
||||
|
||||
@Override
|
||||
public void visit(GreaterThanEquals expr) {
|
||||
if (!isRemove(expr.getLeftExpression())) {
|
||||
return;
|
||||
}
|
||||
expr.setRightExpression(new LongValue(1L));
|
||||
expr.setLeftExpression(new LongValue(1L));
|
||||
}
|
||||
|
||||
@Override
|
||||
public void visit(InExpression expr) {
|
||||
if (!isRemove(expr.getLeftExpression())) {
|
||||
return;
|
||||
}
|
||||
expr.setNot(false);
|
||||
expr.setRightExpression(new LongValue(1L));
|
||||
expr.setLeftExpression(new LongValue(1L));
|
||||
}
|
||||
|
||||
}
|
||||
@@ -1,24 +1,21 @@
|
||||
package com.tencent.supersonic.common.jsqlparser;
|
||||
|
||||
import java.util.Map;
|
||||
import java.util.Objects;
|
||||
import java.util.function.UnaryOperator;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import net.sf.jsqlparser.expression.ExpressionVisitorAdapter;
|
||||
import net.sf.jsqlparser.expression.Function;
|
||||
import net.sf.jsqlparser.expression.operators.relational.ExpressionList;
|
||||
import org.apache.commons.lang3.StringUtils;
|
||||
|
||||
import java.util.Map;
|
||||
import java.util.Objects;
|
||||
import java.util.function.UnaryOperator;
|
||||
|
||||
@Slf4j
|
||||
public class FunctionNameReplaceVisitor extends ExpressionVisitorAdapter {
|
||||
|
||||
private Map<String, String> functionMap;
|
||||
private Map<String, UnaryOperator> functionCallMap;
|
||||
|
||||
public FunctionNameReplaceVisitor(Map<String, String> functionMap) {
|
||||
this.functionMap = functionMap;
|
||||
}
|
||||
|
||||
public FunctionNameReplaceVisitor(Map<String, String> functionMap, Map<String, UnaryOperator> functionCallMap) {
|
||||
this.functionMap = functionMap;
|
||||
this.functionCallMap = functionCallMap;
|
||||
|
||||
@@ -1,123 +0,0 @@
|
||||
package com.tencent.supersonic.common.jsqlparser;
|
||||
|
||||
import com.tencent.supersonic.common.util.StringUtil;
|
||||
import java.util.ArrayList;
|
||||
import java.util.List;
|
||||
import java.util.Objects;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import net.sf.jsqlparser.JSQLParserException;
|
||||
import net.sf.jsqlparser.expression.Expression;
|
||||
import net.sf.jsqlparser.expression.ExpressionVisitorAdapter;
|
||||
import net.sf.jsqlparser.expression.Function;
|
||||
import net.sf.jsqlparser.expression.LongValue;
|
||||
import net.sf.jsqlparser.expression.operators.relational.ComparisonOperator;
|
||||
import net.sf.jsqlparser.expression.operators.relational.EqualsTo;
|
||||
import net.sf.jsqlparser.expression.operators.relational.ExpressionList;
|
||||
import net.sf.jsqlparser.expression.operators.relational.GreaterThan;
|
||||
import net.sf.jsqlparser.expression.operators.relational.GreaterThanEquals;
|
||||
import net.sf.jsqlparser.expression.operators.relational.MinorThan;
|
||||
import net.sf.jsqlparser.expression.operators.relational.MinorThanEquals;
|
||||
import net.sf.jsqlparser.parser.CCJSqlParserUtil;
|
||||
import net.sf.jsqlparser.schema.Column;
|
||||
import org.apache.commons.collections.CollectionUtils;
|
||||
|
||||
@Slf4j
|
||||
public class FunctionReplaceVisitor extends ExpressionVisitorAdapter {
|
||||
|
||||
private List<Expression> waitingForAdds = new ArrayList<>();
|
||||
|
||||
@Override
|
||||
public void visit(MinorThan expr) {
|
||||
List<Expression> expressions = reparseDate(expr, ">");
|
||||
if (Objects.nonNull(expressions)) {
|
||||
waitingForAdds.addAll(expressions);
|
||||
}
|
||||
}
|
||||
|
||||
@Override
|
||||
public void visit(EqualsTo expr) {
|
||||
List<Expression> expressions = reparseDate(expr, ">=");
|
||||
if (Objects.nonNull(expressions)) {
|
||||
waitingForAdds.addAll(expressions);
|
||||
}
|
||||
}
|
||||
|
||||
@Override
|
||||
public void visit(MinorThanEquals expr) {
|
||||
List<Expression> expressions = reparseDate(expr, ">=");
|
||||
if (Objects.nonNull(expressions)) {
|
||||
waitingForAdds.addAll(expressions);
|
||||
}
|
||||
}
|
||||
|
||||
@Override
|
||||
public void visit(GreaterThan expr) {
|
||||
List<Expression> expressions = reparseDate(expr, "<");
|
||||
if (Objects.nonNull(expressions)) {
|
||||
waitingForAdds.addAll(expressions);
|
||||
}
|
||||
}
|
||||
|
||||
@Override
|
||||
public void visit(GreaterThanEquals expr) {
|
||||
List<Expression> expressions = reparseDate(expr, "<=");
|
||||
if (Objects.nonNull(expressions)) {
|
||||
waitingForAdds.addAll(expressions);
|
||||
}
|
||||
}
|
||||
|
||||
public List<Expression> getWaitingForAdds() {
|
||||
return waitingForAdds;
|
||||
}
|
||||
|
||||
public List<Expression> reparseDate(ComparisonOperator comparisonOperator, String startDateOperator) {
|
||||
List<Expression> result = new ArrayList<>();
|
||||
Expression leftExpression = comparisonOperator.getLeftExpression();
|
||||
if (!(leftExpression instanceof Function)) {
|
||||
return result;
|
||||
}
|
||||
Function leftExpressionFunction = (Function) leftExpression;
|
||||
if (!leftExpressionFunction.toString().contains(JsqlConstants.DATE_FUNCTION)) {
|
||||
return result;
|
||||
}
|
||||
//List<Expression> leftExpressions = leftExpressionFunction.getParameters().getExpressions();
|
||||
ExpressionList<?> leftExpressions = leftExpressionFunction.getParameters();
|
||||
if (CollectionUtils.isEmpty(leftExpressions) || leftExpressions.size() < 3) {
|
||||
return result;
|
||||
}
|
||||
Column field = (Column) leftExpressions.get(1);
|
||||
String columnName = field.getColumnName();
|
||||
try {
|
||||
String startDateValue = DateFunctionHelper.getStartDateStr(comparisonOperator, leftExpressions);
|
||||
String endDateValue = DateFunctionHelper.getEndDateValue(leftExpressions);
|
||||
String endDateOperator = comparisonOperator.getStringExpression();
|
||||
String condExpr =
|
||||
columnName + StringUtil.getSpaceWrap(DateFunctionHelper.getEndDateOperator(comparisonOperator))
|
||||
+ StringUtil.getCommaWrap(endDateValue);
|
||||
ComparisonOperator expression = (ComparisonOperator) CCJSqlParserUtil.parseCondExpression(condExpr);
|
||||
|
||||
String startDataCondExpr =
|
||||
columnName + StringUtil.getSpaceWrap(startDateOperator) + StringUtil.getCommaWrap(startDateValue);
|
||||
if (JsqlConstants.EQUAL.equalsIgnoreCase(endDateOperator)) {
|
||||
result.add(CCJSqlParserUtil.parseCondExpression(condExpr));
|
||||
expression = (ComparisonOperator) CCJSqlParserUtil.parseCondExpression(JsqlConstants.EQUAL_CONSTANT);
|
||||
}
|
||||
if (startDateOperator.equals("<=") || startDateOperator.equals("<")) {
|
||||
comparisonOperator.setLeftExpression(new Column("1"));
|
||||
comparisonOperator.setRightExpression(new LongValue(1));
|
||||
comparisonOperator.setASTNode(null);
|
||||
} else {
|
||||
comparisonOperator.setLeftExpression(expression.getLeftExpression());
|
||||
comparisonOperator.setRightExpression(expression.getRightExpression());
|
||||
comparisonOperator.setASTNode(expression.getASTNode());
|
||||
}
|
||||
result.add(CCJSqlParserUtil.parseCondExpression(startDataCondExpr));
|
||||
return result;
|
||||
} catch (JSQLParserException e) {
|
||||
log.error("JSQLParserException", e);
|
||||
}
|
||||
|
||||
return null;
|
||||
}
|
||||
|
||||
}
|
||||
@@ -1,9 +1,5 @@
|
||||
package com.tencent.supersonic.common.jsqlparser;
|
||||
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
import java.util.Objects;
|
||||
import java.util.function.UnaryOperator;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import net.sf.jsqlparser.expression.Expression;
|
||||
import net.sf.jsqlparser.expression.Function;
|
||||
@@ -12,16 +8,17 @@ import net.sf.jsqlparser.statement.select.GroupByElement;
|
||||
import net.sf.jsqlparser.statement.select.GroupByVisitor;
|
||||
import org.apache.commons.lang3.StringUtils;
|
||||
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
import java.util.Objects;
|
||||
import java.util.function.UnaryOperator;
|
||||
|
||||
@Slf4j
|
||||
public class GroupByFunctionReplaceVisitor implements GroupByVisitor {
|
||||
|
||||
private Map<String, String> functionMap;
|
||||
private Map<String, UnaryOperator> functionCallMap;
|
||||
|
||||
public GroupByFunctionReplaceVisitor(Map<String, String> functionMap) {
|
||||
this.functionMap = functionMap;
|
||||
}
|
||||
|
||||
public GroupByFunctionReplaceVisitor(Map<String, String> functionMap, Map<String, UnaryOperator> functionCallMap) {
|
||||
this.functionMap = functionMap;
|
||||
this.functionCallMap = functionCallMap;
|
||||
@@ -31,22 +28,22 @@ public class GroupByFunctionReplaceVisitor implements GroupByVisitor {
|
||||
groupByElement.getGroupByExpressionList();
|
||||
ExpressionList groupByExpressionList = groupByElement.getGroupByExpressionList();
|
||||
List<Expression> groupByExpressions = groupByExpressionList.getExpressions();
|
||||
|
||||
for (int i = 0; i < groupByExpressions.size(); i++) {
|
||||
Expression expression = groupByExpressions.get(i);
|
||||
if (expression instanceof Function) {
|
||||
Function function = (Function) expression;
|
||||
String functionName = function.getName().toLowerCase();
|
||||
String replaceName = functionMap.get(functionName);
|
||||
if (StringUtils.isNotBlank(replaceName)) {
|
||||
function.setName(replaceName);
|
||||
if (Objects.nonNull(functionCallMap) && functionCallMap.containsKey(functionName)) {
|
||||
Object ret = functionCallMap.get(functionName).apply(function.getParameters());
|
||||
if (Objects.nonNull(ret) && ret instanceof ExpressionList) {
|
||||
ExpressionList expressionList = (ExpressionList) ret;
|
||||
function.setParameters(expressionList);
|
||||
}
|
||||
}
|
||||
for (Expression expression : groupByExpressions) {
|
||||
if (!(expression instanceof Function)) {
|
||||
continue;
|
||||
}
|
||||
Function function = (Function) expression;
|
||||
String functionName = function.getName().toLowerCase();
|
||||
String replaceName = functionMap.get(functionName);
|
||||
if (StringUtils.isBlank(replaceName)) {
|
||||
continue;
|
||||
}
|
||||
function.setName(replaceName);
|
||||
if (Objects.nonNull(functionCallMap) && functionCallMap.containsKey(functionName)) {
|
||||
Object ret = functionCallMap.get(functionName).apply(function.getParameters());
|
||||
if (Objects.nonNull(ret) && ret instanceof ExpressionList) {
|
||||
ExpressionList expressionList = (ExpressionList) ret;
|
||||
function.setParameters(expressionList);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,8 +1,5 @@
|
||||
package com.tencent.supersonic.common.jsqlparser;
|
||||
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
import java.util.Objects;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import net.sf.jsqlparser.JSQLParserException;
|
||||
import net.sf.jsqlparser.expression.Expression;
|
||||
@@ -14,6 +11,10 @@ import net.sf.jsqlparser.statement.select.GroupByElement;
|
||||
import net.sf.jsqlparser.statement.select.GroupByVisitor;
|
||||
import org.apache.commons.lang3.StringUtils;
|
||||
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
import java.util.Objects;
|
||||
|
||||
@Slf4j
|
||||
public class GroupByReplaceVisitor implements GroupByVisitor {
|
||||
|
||||
@@ -27,38 +28,51 @@ public class GroupByReplaceVisitor implements GroupByVisitor {
|
||||
}
|
||||
|
||||
public void visit(GroupByElement groupByElement) {
|
||||
groupByElement.getGroupByExpressionList();
|
||||
ExpressionList groupByExpressionList = groupByElement.getGroupByExpressionList();
|
||||
List<Expression> groupByExpressions = groupByExpressionList.getExpressions();
|
||||
|
||||
for (int i = 0; i < groupByExpressions.size(); i++) {
|
||||
Expression expression = groupByExpressions.get(i);
|
||||
String columnName = expression.toString();
|
||||
if (expression instanceof Function && Objects.nonNull(
|
||||
((Function) expression).getParameters().getExpressions().get(0))) {
|
||||
columnName = ((Function) expression).getParameters().getExpressions().get(0).toString();
|
||||
}
|
||||
String replaceColumn = parseVisitorHelper.getReplaceValue(columnName, fieldNameMap,
|
||||
exactReplace);
|
||||
String columnName = getColumnName(expression);
|
||||
|
||||
String replaceColumn = parseVisitorHelper.getReplaceValue(columnName, fieldNameMap, exactReplace);
|
||||
if (StringUtils.isNotEmpty(replaceColumn)) {
|
||||
if (expression instanceof Column) {
|
||||
groupByExpressions.set(i, new Column(replaceColumn));
|
||||
}
|
||||
if (expression instanceof Function) {
|
||||
try {
|
||||
Expression element = CCJSqlParserUtil.parseExpression(replaceColumn);
|
||||
ExpressionList<Expression> expressionList = new ExpressionList<Expression>();
|
||||
expressionList.add(element);
|
||||
if (((Function) expression).getParameters().size() > 1) {
|
||||
((Function) expression).getParameters().stream().skip(1).forEach(e -> {
|
||||
expressionList.add((Function) e);
|
||||
});
|
||||
}
|
||||
((Function) expression).setParameters(expressionList);
|
||||
} catch (JSQLParserException e) {
|
||||
log.error("e", e);
|
||||
}
|
||||
replaceExpression(groupByExpressions, i, expression, replaceColumn);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
private String getColumnName(Expression expression) {
|
||||
if (expression instanceof Function) {
|
||||
Function function = (Function) expression;
|
||||
if (Objects.nonNull(function.getParameters().getExpressions().get(0))) {
|
||||
return function.getParameters().getExpressions().get(0).toString();
|
||||
}
|
||||
}
|
||||
return expression.toString();
|
||||
}
|
||||
|
||||
private void replaceExpression(List<Expression> groupByExpressions,
|
||||
int index,
|
||||
Expression expression,
|
||||
String replaceColumn) {
|
||||
if (expression instanceof Column) {
|
||||
groupByExpressions.set(index, new Column(replaceColumn));
|
||||
} else if (expression instanceof Function) {
|
||||
try {
|
||||
Expression newExpression = CCJSqlParserUtil.parseExpression(replaceColumn);
|
||||
ExpressionList<Expression> newExpressionList = new ExpressionList<>();
|
||||
newExpressionList.add(newExpression);
|
||||
|
||||
Function function = (Function) expression;
|
||||
if (function.getParameters().size() > 1) {
|
||||
function.getParameters().stream().skip(1).forEach(
|
||||
e -> newExpressionList.add((Function) e)
|
||||
);
|
||||
}
|
||||
function.setParameters(newExpressionList);
|
||||
} catch (JSQLParserException e) {
|
||||
log.error("Error parsing expression: {}", replaceColumn, e);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -71,7 +71,6 @@ public class QueryExpressionReplaceVisitor extends ExpressionVisitorAdapter {
|
||||
}
|
||||
}
|
||||
}
|
||||
//selectExpressionItem.getExpression().accept(this);
|
||||
}
|
||||
|
||||
public static Expression replace(Expression expression, Map<String, String> fieldExprMap) {
|
||||
|
||||
@@ -1,10 +1,5 @@
|
||||
package com.tencent.supersonic.common.jsqlparser;
|
||||
|
||||
import java.util.ArrayList;
|
||||
import java.util.HashSet;
|
||||
import java.util.List;
|
||||
import java.util.Objects;
|
||||
import java.util.Set;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import net.sf.jsqlparser.JSQLParserException;
|
||||
import net.sf.jsqlparser.expression.BinaryExpression;
|
||||
@@ -33,6 +28,12 @@ import net.sf.jsqlparser.statement.select.SelectItem;
|
||||
import net.sf.jsqlparser.statement.select.SelectVisitorAdapter;
|
||||
import org.springframework.util.CollectionUtils;
|
||||
|
||||
import java.util.ArrayList;
|
||||
import java.util.HashSet;
|
||||
import java.util.List;
|
||||
import java.util.Objects;
|
||||
import java.util.Set;
|
||||
|
||||
/**
|
||||
* Sql Parser remove Helper
|
||||
*/
|
||||
@@ -228,7 +229,6 @@ public class SqlRemoveHelper {
|
||||
if (selectStatement == null) {
|
||||
return sql;
|
||||
}
|
||||
//SelectBody selectBody = selectStatement.getSelectBody();
|
||||
if (!(selectStatement instanceof PlainSelect)) {
|
||||
return sql;
|
||||
}
|
||||
|
||||
@@ -2,15 +2,6 @@ package com.tencent.supersonic.common.jsqlparser;
|
||||
|
||||
import com.tencent.supersonic.common.pojo.enums.AggOperatorEnum;
|
||||
import com.tencent.supersonic.common.util.StringUtil;
|
||||
|
||||
import java.util.ArrayList;
|
||||
import java.util.HashMap;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
import java.util.Objects;
|
||||
import java.util.Set;
|
||||
import java.util.function.UnaryOperator;
|
||||
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import net.sf.jsqlparser.JSQLParserException;
|
||||
import net.sf.jsqlparser.expression.Alias;
|
||||
@@ -30,6 +21,7 @@ import net.sf.jsqlparser.expression.operators.relational.NotEqualsTo;
|
||||
import net.sf.jsqlparser.parser.CCJSqlParserUtil;
|
||||
import net.sf.jsqlparser.schema.Column;
|
||||
import net.sf.jsqlparser.schema.Table;
|
||||
import net.sf.jsqlparser.statement.select.FromItem;
|
||||
import net.sf.jsqlparser.statement.select.GroupByElement;
|
||||
import net.sf.jsqlparser.statement.select.Join;
|
||||
import net.sf.jsqlparser.statement.select.OrderByElement;
|
||||
@@ -40,11 +32,18 @@ import net.sf.jsqlparser.statement.select.Select;
|
||||
import net.sf.jsqlparser.statement.select.SelectItem;
|
||||
import net.sf.jsqlparser.statement.select.SelectVisitorAdapter;
|
||||
import net.sf.jsqlparser.statement.select.SetOperationList;
|
||||
import net.sf.jsqlparser.statement.select.FromItem;
|
||||
import org.apache.commons.lang3.StringUtils;
|
||||
import org.apache.commons.lang3.tuple.Pair;
|
||||
import org.springframework.util.CollectionUtils;
|
||||
|
||||
import java.util.ArrayList;
|
||||
import java.util.HashMap;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
import java.util.Objects;
|
||||
import java.util.Set;
|
||||
import java.util.function.UnaryOperator;
|
||||
|
||||
/**
|
||||
* Sql Parser replace Helper
|
||||
*/
|
||||
@@ -127,12 +126,10 @@ public class SqlReplaceHelper {
|
||||
if (!(selectStatement instanceof PlainSelect)) {
|
||||
return sql;
|
||||
}
|
||||
//List<PlainSelect> plainSelectList = new ArrayList<>();
|
||||
//plainSelectList.add((PlainSelect) selectStatement);
|
||||
List<PlainSelect> plainSelects = SqlSelectHelper.getPlainSelect(selectStatement);
|
||||
for (PlainSelect plainSelect : plainSelects) {
|
||||
Expression where = plainSelect.getWhere();
|
||||
FieldlValueReplaceVisitor visitor = new FieldlValueReplaceVisitor(exactReplace, filedNameToValueMap);
|
||||
FieldValueReplaceVisitor visitor = new FieldValueReplaceVisitor(exactReplace, filedNameToValueMap);
|
||||
if (Objects.nonNull(where)) {
|
||||
where.accept(visitor);
|
||||
}
|
||||
@@ -187,18 +184,14 @@ public class SqlReplaceHelper {
|
||||
public static String replaceFields(String sql, Map<String, String> fieldNameMap, boolean exactReplace) {
|
||||
Select selectStatement = SqlSelectHelper.getSelect(sql);
|
||||
List<PlainSelect> plainSelectList = SqlSelectHelper.getWithItem(selectStatement);
|
||||
//plainSelectList.add(selectStatement.getPlainSelect());
|
||||
if (selectStatement instanceof PlainSelect) {
|
||||
PlainSelect plainSelect = (PlainSelect) selectStatement;
|
||||
plainSelectList.add(plainSelect);
|
||||
getFromSelect(plainSelect.getFromItem(), plainSelectList);
|
||||
//plainSelectList.add((PlainSelect) selectStatement);
|
||||
} else if (selectStatement instanceof SetOperationList) {
|
||||
SetOperationList setOperationList = (SetOperationList) selectStatement;
|
||||
if (!CollectionUtils.isEmpty(setOperationList.getSelects())) {
|
||||
setOperationList.getSelects().forEach(subSelectBody -> {
|
||||
//PlainSelect subPlainSelect = (PlainSelect) subSelectBody;
|
||||
//plainSelectList.add(subPlainSelect);
|
||||
PlainSelect subPlainSelect = (PlainSelect) subSelectBody;
|
||||
plainSelectList.add(subPlainSelect);
|
||||
getFromSelect(subPlainSelect.getFromItem(), plainSelectList);
|
||||
@@ -546,7 +539,7 @@ public class SqlReplaceHelper {
|
||||
}
|
||||
PlainSelect plainSelect = (PlainSelect) selectStatement;
|
||||
Expression having = plainSelect.getHaving();
|
||||
FieldlValueReplaceVisitor visitor = new FieldlValueReplaceVisitor(false, filedNameToValueMap);
|
||||
FieldValueReplaceVisitor visitor = new FieldValueReplaceVisitor(false, filedNameToValueMap);
|
||||
if (Objects.nonNull(having)) {
|
||||
having.accept(visitor);
|
||||
}
|
||||
|
||||
@@ -1,14 +1,6 @@
|
||||
package com.tencent.supersonic.common.jsqlparser;
|
||||
|
||||
import com.tencent.supersonic.common.util.StringUtil;
|
||||
import java.util.ArrayList;
|
||||
import java.util.HashMap;
|
||||
import java.util.HashSet;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
import java.util.Objects;
|
||||
import java.util.Set;
|
||||
import java.util.stream.Collectors;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import net.sf.jsqlparser.JSQLParserException;
|
||||
import net.sf.jsqlparser.expression.Alias;
|
||||
@@ -50,6 +42,15 @@ import net.sf.jsqlparser.statement.select.WithItem;
|
||||
import org.apache.commons.lang3.StringUtils;
|
||||
import org.springframework.util.CollectionUtils;
|
||||
|
||||
import java.util.ArrayList;
|
||||
import java.util.HashMap;
|
||||
import java.util.HashSet;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
import java.util.Objects;
|
||||
import java.util.Set;
|
||||
import java.util.stream.Collectors;
|
||||
|
||||
/**
|
||||
* Sql Parser Select Helper
|
||||
*/
|
||||
@@ -97,6 +98,22 @@ public class SqlSelectHelper {
|
||||
});
|
||||
}
|
||||
|
||||
public static List<String> gePureSelectFields(String sql) {
|
||||
List<PlainSelect> plainSelectList = getPlainSelect(sql);
|
||||
Set<String> result = new HashSet<>();
|
||||
plainSelectList.stream().forEach(plainSelect -> {
|
||||
List<SelectItem<?>> selectItems = plainSelect.getSelectItems();
|
||||
for (SelectItem selectItem : selectItems) {
|
||||
if (!(selectItem.getExpression() instanceof Column)) {
|
||||
continue;
|
||||
}
|
||||
Column column = (Column) selectItem.getExpression();
|
||||
result.add(column.getColumnName());
|
||||
}
|
||||
});
|
||||
return new ArrayList<>(result);
|
||||
}
|
||||
|
||||
public static List<String> getSelectFields(String sql) {
|
||||
List<PlainSelect> plainSelectList = getPlainSelect(sql);
|
||||
if (CollectionUtils.isEmpty(plainSelectList)) {
|
||||
@@ -244,7 +261,7 @@ public class SqlSelectHelper {
|
||||
return plainSelects;
|
||||
}
|
||||
|
||||
public static List<String> getAllFields(String sql) {
|
||||
public static List<String> getAllSelectFields(String sql) {
|
||||
List<PlainSelect> plainSelects = getPlainSelects(getPlainSelect(sql));
|
||||
Set<String> results = new HashSet<>();
|
||||
for (PlainSelect plainSelect : plainSelects) {
|
||||
@@ -632,22 +649,6 @@ public class SqlSelectHelper {
|
||||
return withNameList;
|
||||
}
|
||||
|
||||
public static Map<String, WithItem> getWith(String sql) {
|
||||
Select selectStatement = getSelect(sql);
|
||||
if (selectStatement == null) {
|
||||
return new HashMap<>();
|
||||
}
|
||||
Map<String, WithItem> withMap = new HashMap<>();
|
||||
List<WithItem> withItemList = selectStatement.getWithItemsList();
|
||||
if (!CollectionUtils.isEmpty(withItemList)) {
|
||||
for (int i = 0; i < withItemList.size(); i++) {
|
||||
WithItem withItem = withItemList.get(i);
|
||||
withMap.put(withItem.getAlias().getName(), withItem);
|
||||
}
|
||||
}
|
||||
return withMap;
|
||||
}
|
||||
|
||||
public static Table getTable(String sql) {
|
||||
Select selectStatement = getSelect(sql);
|
||||
if (selectStatement == null) {
|
||||
@@ -776,24 +777,25 @@ public class SqlSelectHelper {
|
||||
|
||||
private static void getFieldsWithSubQuery(PlainSelect plainSelect, Map<String, Set<String>> fields) {
|
||||
if (plainSelect.getFromItem() instanceof Table) {
|
||||
boolean isWith = false;
|
||||
List<String> withAlias = new ArrayList<>();
|
||||
if (!CollectionUtils.isEmpty(plainSelect.getWithItemsList())) {
|
||||
for (WithItem withItem : plainSelect.getWithItemsList()) {
|
||||
if (Objects.nonNull(withItem.getSelect())) {
|
||||
getFieldsWithSubQuery(withItem.getSelect().getPlainSelect(), fields);
|
||||
isWith = true;
|
||||
withAlias.add(withItem.getAlias().getName());
|
||||
}
|
||||
}
|
||||
}
|
||||
if (!isWith) {
|
||||
Table table = (Table) plainSelect.getFromItem();
|
||||
Table table = (Table) plainSelect.getFromItem();
|
||||
String tableName = table.getFullyQualifiedName();
|
||||
if (!withAlias.contains(tableName)) {
|
||||
if (!fields.containsKey(table.getFullyQualifiedName())) {
|
||||
fields.put(table.getFullyQualifiedName(), new HashSet<>());
|
||||
fields.put(tableName, new HashSet<>());
|
||||
}
|
||||
List<String> sqlFields = getFieldsByPlainSelect(plainSelect).stream().map(f -> f.replaceAll("`", ""))
|
||||
.collect(
|
||||
Collectors.toList());
|
||||
fields.get(table.getFullyQualifiedName()).addAll(sqlFields);
|
||||
fields.get(tableName).addAll(sqlFields);
|
||||
}
|
||||
}
|
||||
if (plainSelect.getFromItem() instanceof ParenthesedSelect) {
|
||||
|
||||
@@ -2,7 +2,6 @@ package com.tencent.supersonic.common.jsqlparser;
|
||||
|
||||
import java.util.List;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import net.sf.jsqlparser.JSQLParserException;
|
||||
import net.sf.jsqlparser.parser.CCJSqlParserUtil;
|
||||
import org.apache.commons.collections.CollectionUtils;
|
||||
|
||||
@@ -29,8 +28,8 @@ public class SqlValidHelper {
|
||||
}
|
||||
|
||||
//2. all fields
|
||||
List<String> thisAllFields = SqlSelectHelper.getAllFields(thisSql);
|
||||
List<String> otherAllFields = SqlSelectHelper.getAllFields(otherSql);
|
||||
List<String> thisAllFields = SqlSelectHelper.getAllSelectFields(thisSql);
|
||||
List<String> otherAllFields = SqlSelectHelper.getAllSelectFields(otherSql);
|
||||
|
||||
if (!CollectionUtils.isEqualCollection(thisAllFields, otherAllFields)) {
|
||||
return false;
|
||||
@@ -69,7 +68,7 @@ public class SqlValidHelper {
|
||||
try {
|
||||
CCJSqlParserUtil.parse(sql);
|
||||
return true;
|
||||
} catch (JSQLParserException e) {
|
||||
} catch (Exception e) {
|
||||
log.error("isValidSQL parse:{}", e);
|
||||
return false;
|
||||
}
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user