mirror of
https://github.com/tencentmusic/supersonic.git
synced 2025-12-10 11:07:06 +00:00
Compare commits
131 Commits
v0.9.10
...
62b9db6791
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
62b9db6791 | ||
|
|
6d907b6adf | ||
|
|
da172a030e | ||
|
|
47c2595fb8 | ||
|
|
9bddd4457e | ||
|
|
55ac3d1aa5 | ||
|
|
0427917624 | ||
|
|
d8fe2ed2b3 | ||
|
|
11d1264d38 | ||
|
|
32675387d7 | ||
|
|
e408204690 | ||
|
|
269f146c11 | ||
|
|
6f497b142e | ||
|
|
79a44b27ee | ||
|
|
76cc5ee111 | ||
|
|
320fcf04bd | ||
|
|
75fc83010c | ||
|
|
37673c82da | ||
|
|
3ae0d645a7 | ||
|
|
256a6bcb3f | ||
|
|
1faf84e372 | ||
|
|
7e6639df83 | ||
|
|
075ae4c0af | ||
|
|
08133ccbfb | ||
|
|
164d2a9e23 | ||
|
|
f899d23b63 | ||
|
|
944beddafc | ||
|
|
019d737f07 | ||
|
|
0721df2e66 | ||
|
|
303392f492 | ||
|
|
e5a41765b4 | ||
|
|
87355533b4 | ||
|
|
06fb6ba744 | ||
|
|
9ffdba956e | ||
|
|
df70a3cf15 | ||
|
|
2552e2ae4b | ||
|
|
4bfa10ba7c | ||
|
|
958aca945d | ||
|
|
fae9118c28 | ||
|
|
c24ba59bb5 | ||
|
|
90c4f66770 | ||
|
|
b9dd6bb7c5 | ||
|
|
dff64b62f4 | ||
|
|
8eeed87bac | ||
|
|
e171bdd97f | ||
|
|
0709575cd9 | ||
|
|
be0447ae15 | ||
|
|
1b8cd7f0d3 | ||
|
|
2fd82cc259 | ||
|
|
00814a3807 | ||
|
|
08705c9d3b | ||
|
|
1c9cf788cb | ||
|
|
1ab5d9c7e6 | ||
|
|
2b13866c0b | ||
|
|
e812884802 | ||
|
|
e2ae7e21ad | ||
|
|
3fc1ec42be | ||
|
|
c4992501bd | ||
|
|
acffc03c79 | ||
|
|
763def2de0 | ||
|
|
d0a67af684 | ||
|
|
be8b56bdde | ||
|
|
9f2c0c7699 | ||
|
|
c1fa9d7442 | ||
|
|
0d5da763b3 | ||
|
|
d1b4863a27 | ||
|
|
dce9a8a58c | ||
|
|
fbf048cb00 | ||
|
|
48a8f69cca | ||
|
|
ecdf65da3e | ||
|
|
5585b9e222 | ||
|
|
97710a90c4 | ||
|
|
0ab7643299 | ||
|
|
d2aa73b85e | ||
|
|
8828964e53 | ||
|
|
b188da8595 | ||
|
|
2e7ba468c9 | ||
|
|
d26c9180ed | ||
|
|
ca96aa725d | ||
|
|
614917ba76 | ||
|
|
1fed8ca4d9 | ||
|
|
232a202275 | ||
|
|
791c493a6a | ||
|
|
f9d4ce2128 | ||
|
|
8abfc923a0 | ||
|
|
e6598a79bb | ||
|
|
d2a43a99c8 | ||
|
|
db8f340e2d | ||
|
|
2e81b190a4 | ||
|
|
81cd60d2da | ||
|
|
3ffc8c3d9e | ||
|
|
18db24c011 | ||
|
|
cd698ac367 | ||
|
|
58b640b087 | ||
|
|
1f28aaeaed | ||
|
|
35b835172b | ||
|
|
1c85bcecc5 | ||
|
|
c3483ae340 | ||
|
|
a5051c7225 | ||
|
|
12f6cfa42d | ||
|
|
4c94f2b816 | ||
|
|
c81aa5859d | ||
|
|
21e213fb19 | ||
|
|
f67bf3eeac | ||
|
|
9d13038599 | ||
|
|
0c8c2d4804 | ||
|
|
f05a4b523c | ||
|
|
b7369abcca | ||
|
|
b40cb13740 | ||
|
|
6f8cf9853b | ||
|
|
75906037ac | ||
|
|
b58e041e8d | ||
|
|
93d585c0d5 | ||
|
|
0dbf56d357 | ||
|
|
a3293e6788 | ||
|
|
a99f5985f5 | ||
|
|
91243005bc | ||
|
|
a76b5a4300 | ||
|
|
c1f9df963c | ||
|
|
954aa4eea5 | ||
|
|
33bd0de604 | ||
|
|
881d891d70 | ||
|
|
d9db455dab | ||
|
|
e0dc3fbf1a | ||
|
|
efddf4cacf | ||
|
|
732222ab98 | ||
|
|
5b994c4f8f | ||
|
|
5d2ebdf680 | ||
|
|
f1bc18ef65 | ||
|
|
8f361f9932 | ||
|
|
f532088e38 |
11
.github/workflows/centos-ci.yml
vendored
11
.github/workflows/centos-ci.yml
vendored
@@ -11,7 +11,7 @@ jobs:
|
|||||||
build:
|
build:
|
||||||
runs-on: ubuntu-latest
|
runs-on: ubuntu-latest
|
||||||
container:
|
container:
|
||||||
image: quay.io/centos/centos:stream8 # 使用 CentOS Stream 8 容器
|
image: almalinux:9 # maven >=3.6.3
|
||||||
|
|
||||||
strategy:
|
strategy:
|
||||||
matrix:
|
matrix:
|
||||||
@@ -28,9 +28,10 @@ jobs:
|
|||||||
|
|
||||||
- name: Reset DNF repositories
|
- name: Reset DNF repositories
|
||||||
run: |
|
run: |
|
||||||
cd /etc/yum.repos.d/
|
sed -e 's|^mirrorlist=|#mirrorlist=|g' \
|
||||||
sed -i 's/mirrorlist/#mirrorlist/g' /etc/yum.repos.d/CentOS-*
|
-e 's|^# baseurl=https://repo.almalinux.org|baseurl=https://mirrors.aliyun.com|g' \
|
||||||
sed -i 's|#baseurl=http://mirror.centos.org|baseurl=http://vault.centos.org|g' /etc/yum.repos.d/CentOS-*
|
/etc/yum.repos.d/almalinux*.repo
|
||||||
|
|
||||||
|
|
||||||
- name: Update DNF package index
|
- name: Update DNF package index
|
||||||
run: dnf makecache
|
run: dnf makecache
|
||||||
@@ -47,7 +48,7 @@ jobs:
|
|||||||
mvn -version
|
mvn -version
|
||||||
|
|
||||||
- name: Cache Maven packages
|
- name: Cache Maven packages
|
||||||
uses: actions/cache@v2
|
uses: actions/cache@v4
|
||||||
with:
|
with:
|
||||||
path: ~/.m2
|
path: ~/.m2
|
||||||
key: ${{ runner.os }}-m2-${{ hashFiles('**/pom.xml') }}
|
key: ${{ runner.os }}-m2-${{ hashFiles('**/pom.xml') }}
|
||||||
|
|||||||
14
.github/workflows/mac-ci.yml
vendored
14
.github/workflows/mac-ci.yml
vendored
@@ -17,21 +17,27 @@ jobs:
|
|||||||
java-version: [21] # Define the JDK versions to test
|
java-version: [21] # Define the JDK versions to test
|
||||||
|
|
||||||
steps:
|
steps:
|
||||||
- uses: actions/checkout@v2
|
- uses: actions/checkout@v4
|
||||||
|
|
||||||
- name: Set up JDK ${{ matrix.java-version }}
|
- name: Set up JDK ${{ matrix.java-version }}
|
||||||
uses: actions/setup-java@v2
|
uses: actions/setup-java@v3
|
||||||
with:
|
with:
|
||||||
java-version: ${{ matrix.java-version }}
|
java-version: ${{ matrix.java-version }}
|
||||||
distribution: 'adopt'
|
distribution: 'temurin'
|
||||||
|
|
||||||
- name: Cache Maven packages
|
- name: Cache Maven packages
|
||||||
uses: actions/cache@v2
|
uses: actions/cache@v3
|
||||||
with:
|
with:
|
||||||
path: ~/Library/Caches/Maven # macOS Maven cache path
|
path: ~/Library/Caches/Maven # macOS Maven cache path
|
||||||
key: ${{ runner.os }}-m2-${{ hashFiles('**/pom.xml') }}
|
key: ${{ runner.os }}-m2-${{ hashFiles('**/pom.xml') }}
|
||||||
restore-keys: ${{ runner.os }}-m2
|
restore-keys: ${{ runner.os }}-m2
|
||||||
|
|
||||||
|
- name: Install system dependencies
|
||||||
|
run: |
|
||||||
|
brew update
|
||||||
|
brew install cmake
|
||||||
|
brew install gcc
|
||||||
|
|
||||||
- name: Build with Maven
|
- name: Build with Maven
|
||||||
run: mvn -B package --file pom.xml
|
run: mvn -B package --file pom.xml
|
||||||
|
|
||||||
|
|||||||
2
.github/workflows/ubuntu-ci.yml
vendored
2
.github/workflows/ubuntu-ci.yml
vendored
@@ -26,7 +26,7 @@ jobs:
|
|||||||
distribution: 'adopt'
|
distribution: 'adopt'
|
||||||
|
|
||||||
- name: Cache Maven packages
|
- name: Cache Maven packages
|
||||||
uses: actions/cache@v2
|
uses: actions/cache@v4
|
||||||
with:
|
with:
|
||||||
path: ~/.m2
|
path: ~/.m2
|
||||||
key: ${{ runner.os }}-m2-${{ hashFiles('**/pom.xml') }}
|
key: ${{ runner.os }}-m2-${{ hashFiles('**/pom.xml') }}
|
||||||
|
|||||||
2
.github/workflows/windows-ci.yml
vendored
2
.github/workflows/windows-ci.yml
vendored
@@ -26,7 +26,7 @@ jobs:
|
|||||||
distribution: 'adopt' # You might need to change this if 'adopt' doesn't support JDK 21
|
distribution: 'adopt' # You might need to change this if 'adopt' doesn't support JDK 21
|
||||||
|
|
||||||
- name: Cache Maven packages
|
- name: Cache Maven packages
|
||||||
uses: actions/cache@v2
|
uses: actions/cache@v4
|
||||||
with:
|
with:
|
||||||
path: ~\.m2 # Windows uses a backslash for paths
|
path: ~\.m2 # Windows uses a backslash for paths
|
||||||
key: ${{ runner.os }}-m2-${{ hashFiles('**/pom.xml') }}
|
key: ${{ runner.os }}-m2-${{ hashFiles('**/pom.xml') }}
|
||||||
|
|||||||
1
.gitignore
vendored
1
.gitignore
vendored
@@ -21,3 +21,4 @@ __pycache__/
|
|||||||
/dict
|
/dict
|
||||||
assembly/build/*-SNAPSHOT
|
assembly/build/*-SNAPSHOT
|
||||||
**/node_modules/
|
**/node_modules/
|
||||||
|
benchmark/res/
|
||||||
2
LICENSE
2
LICENSE
@@ -14,7 +14,7 @@ code and logo.
|
|||||||
b. a commercial license must be obtained from the author if you want to develop and distribute a derivative work based
|
b. a commercial license must be obtained from the author if you want to develop and distribute a derivative work based
|
||||||
on SuperSonic.
|
on SuperSonic.
|
||||||
|
|
||||||
Please contact zhangjun2915@163.com by email to inquire about licensing matters.
|
Please contact supersonicbi@qq.com by email to inquire about licensing matters.
|
||||||
|
|
||||||
|
|
||||||
2. As a contributor, you should agree that:
|
2. As a contributor, you should agree that:
|
||||||
|
|||||||
@@ -75,7 +75,7 @@ SuperSonic comes with sample semantic models as well as chat conversations that
|
|||||||
|
|
||||||
## Build and Development
|
## Build and Development
|
||||||
|
|
||||||
Please refer to project [Docs](https://supersonicbi.github.io/docs/%E7%B3%BB%E7%BB%9F%E9%83%A8%E7%BD%B2/%E7%BC%96%E8%AF%91%E6%9E%84%E5%BB%BA/).
|
Please refer to project [Docs](https://supersonicbi.github.io/docs/%E7%B3%BB%E7%BB%9F%E9%83%A8%E7%BD%B2/%E6%BA%90%E7%A0%81%E7%BC%96%E8%AF%91%E9%83%A8%E7%BD%B2/).
|
||||||
|
|
||||||
## WeChat Contact
|
## WeChat Contact
|
||||||
|
|
||||||
|
|||||||
@@ -75,7 +75,7 @@ SuperSonic自带样例的语义模型和问答对话,只需以下三步即可
|
|||||||
|
|
||||||
## 如何构建和部署
|
## 如何构建和部署
|
||||||
|
|
||||||
请参考项目[文档](https://supersonicbi.github.io/docs/%E7%B3%BB%E7%BB%9F%E9%83%A8%E7%BD%B2/%E7%BC%96%E8%AF%91%E6%9E%84%E5%BB%BA/)。
|
请参考项目[文档](https://supersonicbi.github.io/docs/%E7%B3%BB%E7%BB%9F%E9%83%A8%E7%BD%B2/%E6%BA%90%E7%A0%81%E7%BC%96%E8%AF%91%E9%83%A8%E7%BD%B2/)。
|
||||||
|
|
||||||
## 微信联系方式
|
## 微信联系方式
|
||||||
|
|
||||||
|
|||||||
@@ -71,7 +71,7 @@ SuperSonicには、サンプルのセマンティックモデルとチャット
|
|||||||
|
|
||||||
## ビルドと開発
|
## ビルドと開発
|
||||||
|
|
||||||
プロジェクト[ドキュメント](https://supersonicbi.github.io/docs/%E7%B3%BB%E7%BB%9F%E9%83%A8%E7%BD%B2/%E7%BC%96%E8%AF%91%E6%9E%84%E5%BB%BA/)を参照してください。
|
プロジェクト[ドキュメント](https://supersonicbi.github.io/docs/%E7%B3%BB%E7%BB%9F%E9%83%A8%E7%BD%B2/%E6%BA%90%E7%A0%81%E7%BC%96%E8%AF%91%E9%83%A8%E7%BD%B2/)を参照してください。
|
||||||
|
|
||||||
## WeChat連絡先
|
## WeChat連絡先
|
||||||
|
|
||||||
|
|||||||
@@ -43,10 +43,26 @@ if "%service%"=="webapp" (
|
|||||||
call mvn -f %projectDir% clean package -DskipTests -Dspotless.skip=true
|
call mvn -f %projectDir% clean package -DskipTests -Dspotless.skip=true
|
||||||
IF ERRORLEVEL 1 (
|
IF ERRORLEVEL 1 (
|
||||||
ECHO Failed to build backend Java modules.
|
ECHO Failed to build backend Java modules.
|
||||||
|
ECHO Please check Maven and Java versions are compatible.
|
||||||
|
ECHO Current Java: %JAVA_HOME%
|
||||||
|
ECHO Current Maven: %MAVEN_HOME%
|
||||||
EXIT /B 1
|
EXIT /B 1
|
||||||
)
|
)
|
||||||
|
|
||||||
|
REM extract and copy files to deployment directory
|
||||||
|
cd %projectDir%\launchers\%model_name%\target
|
||||||
|
if exist "launchers-%model_name%-%MVN_VERSION%-bin.tar.gz" (
|
||||||
|
echo "Extracting launchers-%model_name%-%MVN_VERSION%-bin.tar.gz..."
|
||||||
|
tar -xf "launchers-%model_name%-%MVN_VERSION%-bin.tar.gz"
|
||||||
|
if exist "launchers-%model_name%-%MVN_VERSION%" (
|
||||||
|
echo "Copying files to deployment directory..."
|
||||||
|
xcopy /E /Y "launchers-%model_name%-%MVN_VERSION%\*" "%buildDir%\supersonic-%model_name%-%MVN_VERSION%\"
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
copy /y %projectDir%\launchers\%model_name%\target\*.tar.gz %buildDir%\
|
copy /y %projectDir%\launchers\%model_name%\target\*.tar.gz %buildDir%\
|
||||||
echo "finished building supersonic-%model_name% service"
|
echo "finished building supersonic-%model_name% service"
|
||||||
|
cd %baseDir%
|
||||||
goto :EOF
|
goto :EOF
|
||||||
|
|
||||||
|
|
||||||
@@ -72,22 +88,55 @@ if "%service%"=="webapp" (
|
|||||||
cd %buildDir%
|
cd %buildDir%
|
||||||
if exist %release_dir% rmdir /s /q %release_dir%
|
if exist %release_dir% rmdir /s /q %release_dir%
|
||||||
if exist %release_dir%.zip del %release_dir%.zip
|
if exist %release_dir%.zip del %release_dir%.zip
|
||||||
|
|
||||||
|
rem check if release directory already exists from buildJavaService
|
||||||
|
if exist %release_dir% (
|
||||||
|
echo "Release directory already prepared by buildJavaService"
|
||||||
|
) else (
|
||||||
mkdir %release_dir%
|
mkdir %release_dir%
|
||||||
rem package webapp
|
|
||||||
tar xvf supersonic-webapp.tar.gz
|
|
||||||
move /y supersonic-webapp webapp
|
|
||||||
echo {"env": ""} > webapp\supersonic.config.json
|
|
||||||
move /y webapp %release_dir%
|
|
||||||
rem package java service
|
rem package java service
|
||||||
tar xvf %service_name%-bin.tar.gz
|
tar xvf %service_name%-bin.tar.gz 2>nul
|
||||||
|
if errorlevel 1 (
|
||||||
|
echo "Warning: tar command failed, trying PowerShell extraction..."
|
||||||
|
powershell -Command "Expand-Archive -Path '%service_name%-bin.tar.gz' -DestinationPath '.' -Force"
|
||||||
|
)
|
||||||
for /d %%D in ("%service_name%\*") do (
|
for /d %%D in ("%service_name%\*") do (
|
||||||
move "%%D" "%release_dir%"
|
move "%%D" "%release_dir%"
|
||||||
)
|
)
|
||||||
|
rmdir /s /q %service_name% 2>nul
|
||||||
|
)
|
||||||
|
|
||||||
|
rem package webapp
|
||||||
|
if exist supersonic-webapp.tar.gz (
|
||||||
|
tar xvf supersonic-webapp.tar.gz 2>nul
|
||||||
|
if errorlevel 1 (
|
||||||
|
echo "Warning: tar command failed, trying PowerShell extraction..."
|
||||||
|
powershell -Command "Expand-Archive -Path 'supersonic-webapp.tar.gz' -DestinationPath '.' -Force"
|
||||||
|
)
|
||||||
|
move /y supersonic-webapp webapp
|
||||||
|
echo {"env": ""} > webapp\supersonic.config.json
|
||||||
|
move /y webapp %release_dir%
|
||||||
|
del supersonic-webapp.tar.gz 2>nul
|
||||||
|
)
|
||||||
|
|
||||||
|
rem verify deployment structure
|
||||||
|
if exist "%release_dir%\lib\launchers-%model_name%-%MVN_VERSION%.jar" (
|
||||||
|
echo "Deployment structure verified successfully"
|
||||||
|
) else (
|
||||||
|
echo "Warning: Main jar file not found in deployment structure"
|
||||||
|
echo "Expected: %release_dir%\lib\launchers-%model_name%-%MVN_VERSION%.jar"
|
||||||
|
)
|
||||||
|
|
||||||
rem generate zip file
|
rem generate zip file
|
||||||
powershell Compress-Archive -Path %release_dir% -DestinationPath %release_dir%.zip
|
powershell -Command "Compress-Archive -Path '%release_dir%' -DestinationPath '%release_dir%.zip' -Force"
|
||||||
del %service_name%-bin.tar.gz
|
if errorlevel 1 (
|
||||||
del supersonic-webapp.tar.gz
|
echo "Warning: PowerShell compression failed, release directory still available: %release_dir%"
|
||||||
rmdir /s /q %service_name%
|
) else (
|
||||||
|
echo "Successfully created release package: %release_dir%.zip"
|
||||||
|
)
|
||||||
|
|
||||||
|
del %service_name%-bin.tar.gz 2>nul
|
||||||
echo "finished packaging supersonic release"
|
echo "finished packaging supersonic release"
|
||||||
goto :EOF
|
goto :EOF
|
||||||
|
|
||||||
|
|||||||
@@ -20,7 +20,9 @@ if "%profile%"=="" (
|
|||||||
|
|
||||||
set "model_name=%service%"
|
set "model_name=%service%"
|
||||||
|
|
||||||
cd %baseDir%
|
REM fix path configuration - point to the correct release package directory
|
||||||
|
set "releaseDir=%buildDir%\supersonic-%service%-1.0.0-SNAPSHOT"
|
||||||
|
cd %releaseDir%
|
||||||
|
|
||||||
if "%command%"=="restart" (
|
if "%command%"=="restart" (
|
||||||
call :stop
|
call :stop
|
||||||
@@ -50,20 +52,58 @@ if "%command%"=="restart" (
|
|||||||
|
|
||||||
:runJavaService
|
:runJavaService
|
||||||
echo 'java service starting, see logs in logs/'
|
echo 'java service starting, see logs in logs/'
|
||||||
set "libDir=%baseDir%\lib"
|
echo 'Using release directory: %releaseDir%'
|
||||||
set "confDir=%baseDir%\conf"
|
|
||||||
set "webDir=%baseDir%\webapp"
|
REM use release package directory as base path
|
||||||
set "logDir=%baseDir%\logs"
|
set "libDir=%releaseDir%\lib"
|
||||||
set "classpath=%baseDir%;%webDir%;%libDir%\*;%confDir%"
|
set "confDir=%releaseDir%\conf"
|
||||||
set "property=-Dfile.encoding=UTF-8 -Duser.language=Zh -Duser.region=CN -Duser.timezone=GMT+08 -Dspring.profiles.active=%profile%"
|
set "webDir=%releaseDir%\webapp"
|
||||||
set "java-command=%property% -Xms1024m -Xmx1024m -cp %CLASSPATH% %MAIN_CLASS%"
|
set "logDir=%releaseDir%\logs"
|
||||||
|
|
||||||
|
REM fix variable name matching problem
|
||||||
|
set "CLASSPATH=%releaseDir%;%webDir%;%libDir%\*;%confDir%"
|
||||||
|
set "MAIN_CLASS=%main_class%"
|
||||||
|
|
||||||
|
REM add port configuration
|
||||||
|
set "property=-Dfile.encoding=UTF-8 -Duser.language=Zh -Duser.region=CN -Duser.timezone=GMT+08 -Dspring.profiles.active=%profile% -Dserver.port=9080"
|
||||||
|
set "java_command=%property% -Xms1024m -Xmx2048m -cp "%CLASSPATH%" %MAIN_CLASS%"
|
||||||
|
|
||||||
if not exist %logDir% mkdir %logDir%
|
if not exist %logDir% mkdir %logDir%
|
||||||
start /B java %java-command% >nul 2>&1
|
|
||||||
timeout /t 10 >nul
|
REM check if the main jar file exists
|
||||||
|
if not exist "%libDir%\launchers-standalone-1.0.0-SNAPSHOT.jar" (
|
||||||
|
echo "Error: Main jar file not found in %libDir%"
|
||||||
|
echo "Please make sure the application has been built and packaged correctly."
|
||||||
|
goto :EOF
|
||||||
|
)
|
||||||
|
|
||||||
|
echo 'Main Class: %MAIN_CLASS%'
|
||||||
|
echo 'Profile: %profile%'
|
||||||
|
echo 'Starting Java service...'
|
||||||
|
|
||||||
|
REM start service and save logs
|
||||||
|
start /B java %java_command% > "%logDir%\supersonic.log" 2>&1
|
||||||
|
timeout /t 15 >nul
|
||||||
|
|
||||||
|
REM check service status
|
||||||
|
netstat -an | findstr ":9080" >nul
|
||||||
|
if errorlevel 1 (
|
||||||
|
echo "Warning: Port 9080 is not listening"
|
||||||
|
echo "Please check the log file: %logDir%\supersonic.log"
|
||||||
|
if exist "%logDir%\supersonic.log" (
|
||||||
|
echo "Recent log entries:"
|
||||||
|
powershell -Command "Get-Content '%logDir%\supersonic.log' | Select-Object -Last 10"
|
||||||
|
)
|
||||||
|
) else (
|
||||||
|
echo "Service started successfully on port 9080"
|
||||||
|
echo "You can access the application at: http://localhost:9080"
|
||||||
|
)
|
||||||
|
|
||||||
echo 'java service started'
|
echo 'java service started'
|
||||||
goto :EOF
|
goto :EOF
|
||||||
|
|
||||||
:stopJavaService
|
:stopJavaService
|
||||||
|
echo 'Stopping Java service...'
|
||||||
for /f "tokens=2" %%i in ('tasklist ^| findstr /i "java"') do (
|
for /f "tokens=2" %%i in ('tasklist ^| findstr /i "java"') do (
|
||||||
taskkill /PID %%i /F
|
taskkill /PID %%i /F
|
||||||
echo "java service (PID = %%i) is killed."
|
echo "java service (PID = %%i) is killed."
|
||||||
|
|||||||
@@ -60,7 +60,8 @@ function runJavaService {
|
|||||||
JAVA_HOME=$(ls /usr/jdk64/jdk* -d 2>/dev/null | xargs | awk '{print "'$local_app_name'"}')
|
JAVA_HOME=$(ls /usr/jdk64/jdk* -d 2>/dev/null | xargs | awk '{print "'$local_app_name'"}')
|
||||||
fi
|
fi
|
||||||
export PATH=$JAVA_HOME/bin:$PATH
|
export PATH=$JAVA_HOME/bin:$PATH
|
||||||
command="-Dfile.encoding=UTF-8 -Duser.language=Zh -Duser.region=CN -Duser.timezone=GMT+08 -Dapp_name=${local_app_name} -Xms1024m -Xmx1024m $main_class"
|
command="-Dfile.encoding=UTF-8 -Duser.language=Zh -Duser.region=CN -Duser.timezone=GMT+08
|
||||||
|
-Dapp_name=${local_app_name} -Xms1024m -Xmx2048m -XX:+UseZGC -XX:+ZGenerational $main_class"
|
||||||
|
|
||||||
mkdir -p $javaRunDir/logs
|
mkdir -p $javaRunDir/logs
|
||||||
java -Dspring.profiles.active="$profile" $command >/dev/null 2>$javaRunDir/logs/error.log &
|
java -Dspring.profiles.active="$profile" $command >/dev/null 2>$javaRunDir/logs/error.log &
|
||||||
|
|||||||
5
assembly/bin/supersonic-docker-compose.sh
Normal file
5
assembly/bin/supersonic-docker-compose.sh
Normal file
@@ -0,0 +1,5 @@
|
|||||||
|
#!/usr/bin/env sh
|
||||||
|
|
||||||
|
export SUPERSONIC_VERSION=latest
|
||||||
|
|
||||||
|
docker-compose -f docker-compose.yml -p supersonic up
|
||||||
23
assembly/bin/supersonic-docker-run.sh
Normal file
23
assembly/bin/supersonic-docker-run.sh
Normal file
@@ -0,0 +1,23 @@
|
|||||||
|
#!/usr/bin/env sh
|
||||||
|
|
||||||
|
export SUPERSONIC_VERSION=latest
|
||||||
|
|
||||||
|
#### Set below DB configs to connect to your own database
|
||||||
|
# Supported DB_TYPE: h2, mysql, postgres
|
||||||
|
export S2_DB_TYPE=h2
|
||||||
|
export S2_DB_HOST=
|
||||||
|
export S2_DB_PORT=
|
||||||
|
export S2_DB_USER=
|
||||||
|
export S2_DB_PASSWORD=
|
||||||
|
export S2_DB_DATABASE=
|
||||||
|
|
||||||
|
docker run --rm -it -d \
|
||||||
|
--name supersonic_standalone \
|
||||||
|
-p 9080:9080 \
|
||||||
|
-e S2_DB_TYPE=${S2_DB_TYPE} \
|
||||||
|
-e S2_DB_HOST=${S2_DB_HOST} \
|
||||||
|
-e S2_DB_PORT=${S2_DB_PORT} \
|
||||||
|
-e S2_DB_USER=${S2_DB_USER} \
|
||||||
|
-e S2_DB_PASSWORD=${S2_DB_PASSWORD} \
|
||||||
|
-e S2_DB_DATABASE=${S2_DB_DATABASE} \
|
||||||
|
supersonicbi/supersonic:${SUPERSONIC_VERSION}
|
||||||
@@ -1,10 +1,11 @@
|
|||||||
#!/usr/bin/env bash
|
#!/usr/bin/env bash
|
||||||
|
|
||||||
#### Set below DB configs to connect to your own database
|
#### Set below DB configs to connect to your own database
|
||||||
|
# Comment out below exports to config your DB connection
|
||||||
# Supported DB_TYPE: h2, mysql, postgres
|
# Supported DB_TYPE: h2, mysql, postgres
|
||||||
export S2_DB_TYPE=h2
|
#export S2_DB_TYPE=h2
|
||||||
export S2_DB_HOST=
|
#export S2_DB_HOST=
|
||||||
export S2_DB_PORT=
|
#export S2_DB_PORT=
|
||||||
export S2_DB_USER=
|
#export S2_DB_USER=
|
||||||
export S2_DB_PASSWORD=
|
#export S2_DB_PASSWORD=
|
||||||
export S2_DB_DATABASE=
|
#export S2_DB_DATABASE=
|
||||||
@@ -34,8 +34,8 @@
|
|||||||
</dependencies>
|
</dependencies>
|
||||||
|
|
||||||
<properties>
|
<properties>
|
||||||
<maven.compiler.source>8</maven.compiler.source>
|
<maven.compiler.source>21</maven.compiler.source>
|
||||||
<maven.compiler.target>8</maven.compiler.target>
|
<maven.compiler.target>21</maven.compiler.target>
|
||||||
</properties>
|
</properties>
|
||||||
|
|
||||||
</project>
|
</project>
|
||||||
@@ -21,6 +21,8 @@ public interface UserAdaptor {
|
|||||||
|
|
||||||
void register(UserReq userReq);
|
void register(UserReq userReq);
|
||||||
|
|
||||||
|
void deleteUser(long userId);
|
||||||
|
|
||||||
String login(UserReq userReq, HttpServletRequest request);
|
String login(UserReq userReq, HttpServletRequest request);
|
||||||
|
|
||||||
String login(UserReq userReq, String appKey);
|
String login(UserReq userReq, String appKey);
|
||||||
|
|||||||
@@ -24,7 +24,7 @@ public class UserWithPassword extends User {
|
|||||||
|
|
||||||
public UserWithPassword(Long id, String name, String displayName, String email, String password,
|
public UserWithPassword(Long id, String name, String displayName, String email, String password,
|
||||||
Integer isAdmin) {
|
Integer isAdmin) {
|
||||||
super(id, name, displayName, email, isAdmin);
|
super(id, name, displayName, email, isAdmin, null);
|
||||||
this.password = password;
|
this.password = password;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -23,6 +23,8 @@ public interface UserService {
|
|||||||
|
|
||||||
void register(UserReq userCmd);
|
void register(UserReq userCmd);
|
||||||
|
|
||||||
|
void deleteUser(long userId);
|
||||||
|
|
||||||
String login(UserReq userCmd, HttpServletRequest request);
|
String login(UserReq userCmd, HttpServletRequest request);
|
||||||
|
|
||||||
String login(UserReq userCmd, String appKey);
|
String login(UserReq userCmd, String appKey);
|
||||||
|
|||||||
@@ -18,6 +18,8 @@ import jakarta.servlet.http.HttpServletRequest;
|
|||||||
import lombok.extern.slf4j.Slf4j;
|
import lombok.extern.slf4j.Slf4j;
|
||||||
import org.springframework.beans.BeanUtils;
|
import org.springframework.beans.BeanUtils;
|
||||||
|
|
||||||
|
import java.sql.Timestamp;
|
||||||
|
import java.util.Date;
|
||||||
import java.util.List;
|
import java.util.List;
|
||||||
import java.util.Optional;
|
import java.util.Optional;
|
||||||
import java.util.Set;
|
import java.util.Set;
|
||||||
@@ -90,6 +92,12 @@ public class DefaultUserAdaptor implements UserAdaptor {
|
|||||||
userRepository.addUser(userDO);
|
userRepository.addUser(userDO);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public void deleteUser(long userId) {
|
||||||
|
UserRepository userRepository = ContextUtils.getBean(UserRepository.class);
|
||||||
|
userRepository.deleteUser(userId);
|
||||||
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public String login(UserReq userReq, HttpServletRequest request) {
|
public String login(UserReq userReq, HttpServletRequest request) {
|
||||||
TokenService tokenService = ContextUtils.getBean(TokenService.class);
|
TokenService tokenService = ContextUtils.getBean(TokenService.class);
|
||||||
@@ -102,7 +110,9 @@ public class DefaultUserAdaptor implements UserAdaptor {
|
|||||||
TokenService tokenService = ContextUtils.getBean(TokenService.class);
|
TokenService tokenService = ContextUtils.getBean(TokenService.class);
|
||||||
try {
|
try {
|
||||||
UserWithPassword user = getUserWithPassword(userReq);
|
UserWithPassword user = getUserWithPassword(userReq);
|
||||||
return tokenService.generateToken(UserWithPassword.convert(user), appKey);
|
String token = tokenService.generateToken(UserWithPassword.convert(user), appKey);
|
||||||
|
updateLastLogin(userReq.getName());
|
||||||
|
return token;
|
||||||
} catch (Exception e) {
|
} catch (Exception e) {
|
||||||
log.error("", e);
|
log.error("", e);
|
||||||
throw new RuntimeException("password encrypt error, please try again");
|
throw new RuntimeException("password encrypt error, please try again");
|
||||||
@@ -213,8 +223,9 @@ public class DefaultUserAdaptor implements UserAdaptor {
|
|||||||
new UserWithPassword(userDO.getId(), userDO.getName(), userDO.getDisplayName(),
|
new UserWithPassword(userDO.getId(), userDO.getName(), userDO.getDisplayName(),
|
||||||
userDO.getEmail(), userDO.getPassword(), userDO.getIsAdmin());
|
userDO.getEmail(), userDO.getPassword(), userDO.getIsAdmin());
|
||||||
|
|
||||||
String token =
|
// 使用令牌名称作为生成key ,这样可以区分正常请求和api 请求,api 的令牌失效时间很长,需考虑令牌泄露的情况
|
||||||
tokenService.generateToken(UserWithPassword.convert(userWithPassword), expireTime);
|
String token = tokenService.generateToken(UserWithPassword.convert(userWithPassword),
|
||||||
|
"SysDbToken:" + name, (new Date().getTime() + expireTime));
|
||||||
UserTokenDO userTokenDO = saveUserToken(name, userName, token, expireTime);
|
UserTokenDO userTokenDO = saveUserToken(name, userName, token, expireTime);
|
||||||
return convertUserToken(userTokenDO);
|
return convertUserToken(userTokenDO);
|
||||||
}
|
}
|
||||||
@@ -267,4 +278,11 @@ public class DefaultUserAdaptor implements UserAdaptor {
|
|||||||
userToken.setExpireDate(userTokenDO.getExpireDateTime());
|
userToken.setExpireDate(userTokenDO.getExpireDateTime());
|
||||||
return userToken;
|
return userToken;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
private void updateLastLogin(String userName) {
|
||||||
|
UserRepository userRepository = ContextUtils.getBean(UserRepository.class);
|
||||||
|
UserDO userDO = userRepository.getUser(userName);
|
||||||
|
userDO.setLastLogin(new Timestamp(System.currentTimeMillis()));
|
||||||
|
userRepository.updateUser(userDO);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,7 +1,7 @@
|
|||||||
package com.tencent.supersonic.auth.authentication.interceptor;
|
package com.tencent.supersonic.auth.authentication.interceptor;
|
||||||
|
|
||||||
import com.tencent.supersonic.auth.api.authentication.config.AuthenticationConfig;
|
import com.tencent.supersonic.auth.api.authentication.config.AuthenticationConfig;
|
||||||
import com.tencent.supersonic.auth.authentication.service.UserServiceImpl;
|
import com.tencent.supersonic.auth.api.authentication.service.UserService;
|
||||||
import com.tencent.supersonic.auth.authentication.utils.TokenService;
|
import com.tencent.supersonic.auth.authentication.utils.TokenService;
|
||||||
import lombok.extern.slf4j.Slf4j;
|
import lombok.extern.slf4j.Slf4j;
|
||||||
import org.apache.commons.lang3.StringUtils;
|
import org.apache.commons.lang3.StringUtils;
|
||||||
@@ -16,7 +16,7 @@ public abstract class AuthenticationInterceptor implements HandlerInterceptor {
|
|||||||
|
|
||||||
protected AuthenticationConfig authenticationConfig;
|
protected AuthenticationConfig authenticationConfig;
|
||||||
|
|
||||||
protected UserServiceImpl userServiceImpl;
|
protected UserService userService;
|
||||||
|
|
||||||
protected TokenService tokenService;
|
protected TokenService tokenService;
|
||||||
|
|
||||||
|
|||||||
@@ -3,7 +3,7 @@ package com.tencent.supersonic.auth.authentication.interceptor;
|
|||||||
import com.tencent.supersonic.auth.api.authentication.annotation.AuthenticationIgnore;
|
import com.tencent.supersonic.auth.api.authentication.annotation.AuthenticationIgnore;
|
||||||
import com.tencent.supersonic.auth.api.authentication.config.AuthenticationConfig;
|
import com.tencent.supersonic.auth.api.authentication.config.AuthenticationConfig;
|
||||||
import com.tencent.supersonic.auth.api.authentication.pojo.UserWithPassword;
|
import com.tencent.supersonic.auth.api.authentication.pojo.UserWithPassword;
|
||||||
import com.tencent.supersonic.auth.authentication.service.UserServiceImpl;
|
import com.tencent.supersonic.auth.api.authentication.service.UserService;
|
||||||
import com.tencent.supersonic.auth.authentication.utils.TokenService;
|
import com.tencent.supersonic.auth.authentication.utils.TokenService;
|
||||||
import com.tencent.supersonic.common.pojo.exception.AccessException;
|
import com.tencent.supersonic.common.pojo.exception.AccessException;
|
||||||
import com.tencent.supersonic.common.util.ContextUtils;
|
import com.tencent.supersonic.common.util.ContextUtils;
|
||||||
@@ -16,12 +16,7 @@ import org.springframework.web.method.HandlerMethod;
|
|||||||
import java.lang.reflect.Method;
|
import java.lang.reflect.Method;
|
||||||
import java.util.Optional;
|
import java.util.Optional;
|
||||||
|
|
||||||
import static com.tencent.supersonic.auth.api.authentication.constant.UserConstants.TOKEN_IS_ADMIN;
|
import static com.tencent.supersonic.auth.api.authentication.constant.UserConstants.*;
|
||||||
import static com.tencent.supersonic.auth.api.authentication.constant.UserConstants.TOKEN_USER_DISPLAY_NAME;
|
|
||||||
import static com.tencent.supersonic.auth.api.authentication.constant.UserConstants.TOKEN_USER_EMAIL;
|
|
||||||
import static com.tencent.supersonic.auth.api.authentication.constant.UserConstants.TOKEN_USER_ID;
|
|
||||||
import static com.tencent.supersonic.auth.api.authentication.constant.UserConstants.TOKEN_USER_NAME;
|
|
||||||
import static com.tencent.supersonic.auth.api.authentication.constant.UserConstants.TOKEN_USER_PASSWORD;
|
|
||||||
|
|
||||||
@Slf4j
|
@Slf4j
|
||||||
public class DefaultAuthenticationInterceptor extends AuthenticationInterceptor {
|
public class DefaultAuthenticationInterceptor extends AuthenticationInterceptor {
|
||||||
@@ -30,7 +25,7 @@ public class DefaultAuthenticationInterceptor extends AuthenticationInterceptor
|
|||||||
public boolean preHandle(HttpServletRequest request, HttpServletResponse response,
|
public boolean preHandle(HttpServletRequest request, HttpServletResponse response,
|
||||||
Object handler) throws AccessException {
|
Object handler) throws AccessException {
|
||||||
authenticationConfig = ContextUtils.getBean(AuthenticationConfig.class);
|
authenticationConfig = ContextUtils.getBean(AuthenticationConfig.class);
|
||||||
userServiceImpl = ContextUtils.getBean(UserServiceImpl.class);
|
userService = ContextUtils.getBean(UserService.class);
|
||||||
tokenService = ContextUtils.getBean(TokenService.class);
|
tokenService = ContextUtils.getBean(TokenService.class);
|
||||||
if (!authenticationConfig.isEnabled()) {
|
if (!authenticationConfig.isEnabled()) {
|
||||||
return true;
|
return true;
|
||||||
|
|||||||
@@ -3,7 +3,11 @@ package com.tencent.supersonic.auth.authentication.persistence.dataobject;
|
|||||||
import com.baomidou.mybatisplus.annotation.IdType;
|
import com.baomidou.mybatisplus.annotation.IdType;
|
||||||
import com.baomidou.mybatisplus.annotation.TableId;
|
import com.baomidou.mybatisplus.annotation.TableId;
|
||||||
import com.baomidou.mybatisplus.annotation.TableName;
|
import com.baomidou.mybatisplus.annotation.TableName;
|
||||||
|
import lombok.Data;
|
||||||
|
|
||||||
|
import java.sql.Timestamp;
|
||||||
|
|
||||||
|
@Data
|
||||||
@TableName("s2_user")
|
@TableName("s2_user")
|
||||||
public class UserDO {
|
public class UserDO {
|
||||||
|
|
||||||
@@ -27,71 +31,25 @@ public class UserDO {
|
|||||||
/** */
|
/** */
|
||||||
private Integer isAdmin;
|
private Integer isAdmin;
|
||||||
|
|
||||||
/** @return id */
|
private Timestamp lastLogin;
|
||||||
public Long getId() {
|
|
||||||
return id;
|
|
||||||
}
|
|
||||||
|
|
||||||
/** @param id */
|
|
||||||
public void setId(Long id) {
|
|
||||||
this.id = id;
|
|
||||||
}
|
|
||||||
|
|
||||||
/** @return name */
|
|
||||||
public String getName() {
|
|
||||||
return name;
|
|
||||||
}
|
|
||||||
|
|
||||||
/** @param name */
|
/** @param name */
|
||||||
public void setName(String name) {
|
public void setName(String name) {
|
||||||
this.name = name == null ? null : name.trim();
|
this.name = name == null ? null : name.trim();
|
||||||
}
|
}
|
||||||
|
|
||||||
/** @return password */
|
|
||||||
public String getPassword() {
|
|
||||||
return password;
|
|
||||||
}
|
|
||||||
|
|
||||||
/** @param password */
|
/** @param password */
|
||||||
public void setPassword(String password) {
|
public void setPassword(String password) {
|
||||||
this.password = password == null ? null : password.trim();
|
this.password = password == null ? null : password.trim();
|
||||||
}
|
}
|
||||||
|
|
||||||
public String getSalt() {
|
|
||||||
return salt;
|
|
||||||
}
|
|
||||||
|
|
||||||
public void setSalt(String salt) {
|
public void setSalt(String salt) {
|
||||||
this.salt = salt == null ? null : salt.trim();
|
this.salt = salt == null ? null : salt.trim();
|
||||||
}
|
}
|
||||||
|
|
||||||
/** @return display_name */
|
|
||||||
public String getDisplayName() {
|
|
||||||
return displayName;
|
|
||||||
}
|
|
||||||
|
|
||||||
/** @param displayName */
|
|
||||||
public void setDisplayName(String displayName) {
|
|
||||||
this.displayName = displayName == null ? null : displayName.trim();
|
|
||||||
}
|
|
||||||
|
|
||||||
/** @return email */
|
|
||||||
public String getEmail() {
|
|
||||||
return email;
|
|
||||||
}
|
|
||||||
|
|
||||||
/** @param email */
|
/** @param email */
|
||||||
public void setEmail(String email) {
|
public void setEmail(String email) {
|
||||||
this.email = email == null ? null : email.trim();
|
this.email = email == null ? null : email.trim();
|
||||||
}
|
}
|
||||||
|
|
||||||
/** @return is_admin */
|
|
||||||
public Integer getIsAdmin() {
|
|
||||||
return isAdmin;
|
|
||||||
}
|
|
||||||
|
|
||||||
/** @param isAdmin */
|
|
||||||
public void setIsAdmin(Integer isAdmin) {
|
|
||||||
this.isAdmin = isAdmin;
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -21,7 +21,11 @@ public interface UserRepository {
|
|||||||
|
|
||||||
UserTokenDO getUserToken(Long tokenId);
|
UserTokenDO getUserToken(Long tokenId);
|
||||||
|
|
||||||
|
UserTokenDO getUserTokenByName(String tokenName);
|
||||||
|
|
||||||
void deleteUserTokenByName(String userName);
|
void deleteUserTokenByName(String userName);
|
||||||
|
|
||||||
void deleteUserToken(Long tokenId);
|
void deleteUserToken(Long tokenId);
|
||||||
|
|
||||||
|
void deleteUser(long userId);
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -65,6 +65,13 @@ public class UserRepositoryImpl implements UserRepository {
|
|||||||
return userTokenDOMapper.selectById(tokenId);
|
return userTokenDOMapper.selectById(tokenId);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public UserTokenDO getUserTokenByName(String tokenName) {
|
||||||
|
QueryWrapper<UserTokenDO> queryWrapper = new QueryWrapper<>();
|
||||||
|
queryWrapper.lambda().eq(UserTokenDO::getName, tokenName);
|
||||||
|
return userTokenDOMapper.selectOne(queryWrapper);
|
||||||
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public void deleteUserTokenByName(String userName) {
|
public void deleteUserTokenByName(String userName) {
|
||||||
QueryWrapper<UserTokenDO> queryWrapper = new QueryWrapper<>();
|
QueryWrapper<UserTokenDO> queryWrapper = new QueryWrapper<>();
|
||||||
@@ -76,4 +83,9 @@ public class UserRepositoryImpl implements UserRepository {
|
|||||||
public void deleteUserToken(Long tokenId) {
|
public void deleteUserToken(Long tokenId) {
|
||||||
userTokenDOMapper.deleteById(tokenId);
|
userTokenDOMapper.deleteById(tokenId);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public void deleteUser(long userId) {
|
||||||
|
userDOMapper.deleteById(userId);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -9,13 +9,7 @@ import com.tencent.supersonic.common.pojo.User;
|
|||||||
import jakarta.servlet.http.HttpServletRequest;
|
import jakarta.servlet.http.HttpServletRequest;
|
||||||
import jakarta.servlet.http.HttpServletResponse;
|
import jakarta.servlet.http.HttpServletResponse;
|
||||||
import lombok.extern.slf4j.Slf4j;
|
import lombok.extern.slf4j.Slf4j;
|
||||||
import org.springframework.web.bind.annotation.GetMapping;
|
import org.springframework.web.bind.annotation.*;
|
||||||
import org.springframework.web.bind.annotation.PathVariable;
|
|
||||||
import org.springframework.web.bind.annotation.PostMapping;
|
|
||||||
import org.springframework.web.bind.annotation.RequestBody;
|
|
||||||
import org.springframework.web.bind.annotation.RequestMapping;
|
|
||||||
import org.springframework.web.bind.annotation.RequestParam;
|
|
||||||
import org.springframework.web.bind.annotation.RestController;
|
|
||||||
|
|
||||||
import java.util.List;
|
import java.util.List;
|
||||||
import java.util.Set;
|
import java.util.Set;
|
||||||
@@ -67,6 +61,16 @@ public class UserController {
|
|||||||
userService.register(userCmd);
|
userService.register(userCmd);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@DeleteMapping("/delete/{userId}")
|
||||||
|
public void delete(@PathVariable("userId") long userId, HttpServletRequest httpServletRequest,
|
||||||
|
HttpServletResponse httpServletResponse) throws IllegalAccessException {
|
||||||
|
User user = userService.getCurrentUser(httpServletRequest, httpServletResponse);
|
||||||
|
if (user.getIsAdmin() != 1) {
|
||||||
|
throw new IllegalAccessException("only admin can delete user");
|
||||||
|
}
|
||||||
|
userService.deleteUser(userId);
|
||||||
|
}
|
||||||
|
|
||||||
@PostMapping("/login")
|
@PostMapping("/login")
|
||||||
public String login(@RequestBody UserReq userCmd, HttpServletRequest request) {
|
public String login(@RequestBody UserReq userCmd, HttpServletRequest request) {
|
||||||
return userService.login(userCmd, request);
|
return userService.login(userCmd, request);
|
||||||
|
|||||||
@@ -70,6 +70,11 @@ public class UserServiceImpl implements UserService {
|
|||||||
ComponentFactory.getUserAdaptor().register(userReq);
|
ComponentFactory.getUserAdaptor().register(userReq);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public void deleteUser(long userId) {
|
||||||
|
ComponentFactory.getUserAdaptor().deleteUser(userId);
|
||||||
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public String login(UserReq userReq, HttpServletRequest request) {
|
public String login(UserReq userReq, HttpServletRequest request) {
|
||||||
return ComponentFactory.getUserAdaptor().login(userReq, request);
|
return ComponentFactory.getUserAdaptor().login(userReq, request);
|
||||||
|
|||||||
@@ -6,7 +6,10 @@ import javax.crypto.spec.SecretKeySpec;
|
|||||||
|
|
||||||
import com.tencent.supersonic.auth.api.authentication.config.AuthenticationConfig;
|
import com.tencent.supersonic.auth.api.authentication.config.AuthenticationConfig;
|
||||||
import com.tencent.supersonic.auth.api.authentication.pojo.UserWithPassword;
|
import com.tencent.supersonic.auth.api.authentication.pojo.UserWithPassword;
|
||||||
|
import com.tencent.supersonic.auth.authentication.persistence.dataobject.UserTokenDO;
|
||||||
|
import com.tencent.supersonic.auth.authentication.persistence.repository.UserRepository;
|
||||||
import com.tencent.supersonic.common.pojo.exception.AccessException;
|
import com.tencent.supersonic.common.pojo.exception.AccessException;
|
||||||
|
import com.tencent.supersonic.common.util.ContextUtils;
|
||||||
import io.jsonwebtoken.Claims;
|
import io.jsonwebtoken.Claims;
|
||||||
import io.jsonwebtoken.Jwts;
|
import io.jsonwebtoken.Jwts;
|
||||||
import io.jsonwebtoken.SignatureAlgorithm;
|
import io.jsonwebtoken.SignatureAlgorithm;
|
||||||
@@ -71,6 +74,7 @@ public class TokenService {
|
|||||||
return generateToken(UserWithPassword.convert(appUser), request);
|
return generateToken(UserWithPassword.convert(appUser), request);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
public Optional<Claims> getClaims(HttpServletRequest request) {
|
public Optional<Claims> getClaims(HttpServletRequest request) {
|
||||||
String token = request.getHeader(authenticationConfig.getTokenHttpHeaderKey());
|
String token = request.getHeader(authenticationConfig.getTokenHttpHeaderKey());
|
||||||
String appKey = getAppKey(request);
|
String appKey = getAppKey(request);
|
||||||
@@ -90,6 +94,14 @@ public class TokenService {
|
|||||||
|
|
||||||
public Optional<Claims> getClaims(String token, String appKey) {
|
public Optional<Claims> getClaims(String token, String appKey) {
|
||||||
try {
|
try {
|
||||||
|
if (StringUtils.isNotBlank(appKey) && appKey.startsWith("SysDbToken:")) {// 如果是配置的长期令牌,需校验数据库是否存在该配置
|
||||||
|
UserRepository userRepository = ContextUtils.getBean(UserRepository.class);
|
||||||
|
UserTokenDO dbToken =
|
||||||
|
userRepository.getUserTokenByName(appKey.substring("SysDbToken:".length()));
|
||||||
|
if (dbToken == null || !dbToken.getToken().equals(token.replace("Bearer ", ""))) {
|
||||||
|
throw new AccessException("Token does not exist :" + appKey);
|
||||||
|
}
|
||||||
|
}
|
||||||
String tokenSecret = getTokenSecret(appKey);
|
String tokenSecret = getTokenSecret(appKey);
|
||||||
Claims claims =
|
Claims claims =
|
||||||
Jwts.parser().setSigningKey(tokenSecret.getBytes(StandardCharsets.UTF_8))
|
Jwts.parser().setSigningKey(tokenSecret.getBytes(StandardCharsets.UTF_8))
|
||||||
@@ -122,6 +134,16 @@ public class TokenService {
|
|||||||
Map<String, String> appKeyToSecretMap = authenticationConfig.getAppKeyToSecretMap();
|
Map<String, String> appKeyToSecretMap = authenticationConfig.getAppKeyToSecretMap();
|
||||||
String secret = appKeyToSecretMap.get(appKey);
|
String secret = appKeyToSecretMap.get(appKey);
|
||||||
if (StringUtils.isBlank(secret)) {
|
if (StringUtils.isBlank(secret)) {
|
||||||
|
if (StringUtils.isNotBlank(appKey) && appKey.startsWith("SysDbToken:")) { // 是配置的长期令牌
|
||||||
|
String realAppKey = appKey.substring("SysDbToken:".length());
|
||||||
|
String tmp =
|
||||||
|
"WIaO9YRRVt+7QtpPvyWsARFngnEcbaKBk783uGFwMrbJBaochsqCH62L4Kijcb0sZCYoSsiKGV/zPml5MnZ3uQ==";
|
||||||
|
if (tmp.length() <= realAppKey.length()) {
|
||||||
|
return realAppKey;
|
||||||
|
} else {
|
||||||
|
return realAppKey + tmp.substring(realAppKey.length());
|
||||||
|
}
|
||||||
|
}
|
||||||
throw new AccessException("get secret from appKey failed :" + appKey);
|
throw new AccessException("get secret from appKey failed :" + appKey);
|
||||||
}
|
}
|
||||||
return secret;
|
return secret;
|
||||||
|
|||||||
@@ -9,6 +9,7 @@
|
|||||||
<result column="display_name" jdbcType="VARCHAR" property="displayName" />
|
<result column="display_name" jdbcType="VARCHAR" property="displayName" />
|
||||||
<result column="email" jdbcType="VARCHAR" property="email" />
|
<result column="email" jdbcType="VARCHAR" property="email" />
|
||||||
<result column="is_admin" jdbcType="INTEGER" property="isAdmin" />
|
<result column="is_admin" jdbcType="INTEGER" property="isAdmin" />
|
||||||
|
<result column="last_login" jdbcType="TIMESTAMP" property="lastLogin" />
|
||||||
</resultMap>
|
</resultMap>
|
||||||
<sql id="Example_Where_Clause">
|
<sql id="Example_Where_Clause">
|
||||||
<where>
|
<where>
|
||||||
@@ -40,7 +41,7 @@
|
|||||||
</where>
|
</where>
|
||||||
</sql>
|
</sql>
|
||||||
<sql id="Base_Column_List">
|
<sql id="Base_Column_List">
|
||||||
id, name, password, salt, display_name, email, is_admin
|
id, name, password, salt, display_name, email, is_admin, last_login
|
||||||
</sql>
|
</sql>
|
||||||
<select id="selectByExample" parameterType="com.tencent.supersonic.auth.authentication.persistence.dataobject.UserDOExample" resultMap="BaseResultMap">
|
<select id="selectByExample" parameterType="com.tencent.supersonic.auth.authentication.persistence.dataobject.UserDOExample" resultMap="BaseResultMap">
|
||||||
select
|
select
|
||||||
@@ -136,6 +137,9 @@
|
|||||||
<if test="isAdmin != null">
|
<if test="isAdmin != null">
|
||||||
is_admin = #{isAdmin,jdbcType=INTEGER},
|
is_admin = #{isAdmin,jdbcType=INTEGER},
|
||||||
</if>
|
</if>
|
||||||
|
<if test="lastLogin != null">
|
||||||
|
last_login = #{lastLogin,jdbcType=TIMESTAMP},
|
||||||
|
</if>
|
||||||
</set>
|
</set>
|
||||||
where id = #{id,jdbcType=BIGINT}
|
where id = #{id,jdbcType=BIGINT}
|
||||||
</update>
|
</update>
|
||||||
|
|||||||
@@ -15,6 +15,68 @@ import requests
|
|||||||
import time
|
import time
|
||||||
import jwt
|
import jwt
|
||||||
import traceback
|
import traceback
|
||||||
|
import os
|
||||||
|
from datetime import datetime
|
||||||
|
|
||||||
|
|
||||||
|
class DataFrameAppender:
|
||||||
|
def __init__(self,file_name = "output"):
|
||||||
|
# 定义表头
|
||||||
|
columns = ['问题', '解析状态', '解析耗时', '执行状态', '执行耗时', '总耗时']
|
||||||
|
# 创建只有表头的 DataFrame
|
||||||
|
self.df = pd.DataFrame(columns=columns)
|
||||||
|
self.file_name = file_name
|
||||||
|
|
||||||
|
def append_data(self, new_data):
|
||||||
|
# 假设 new_data 是一维数组,将其转换为字典
|
||||||
|
columns = ['问题', '解析状态', '解析耗时', '执行状态', '执行耗时', '总耗时']
|
||||||
|
new_dict = dict(zip(columns, new_data))
|
||||||
|
# 使用 loc 方法追加数据
|
||||||
|
self.df.loc[len(self.df)] = new_dict
|
||||||
|
def print_analysis_result(self):
|
||||||
|
# 测试样例总数
|
||||||
|
total_samples = len(self.df)
|
||||||
|
|
||||||
|
# 解析成功数量
|
||||||
|
parse_success_count = (self.df['解析状态'] == '解析成功').sum()
|
||||||
|
|
||||||
|
# 执行成功数量
|
||||||
|
execute_success_count = (self.df['执行状态'] == '执行成功').sum()
|
||||||
|
|
||||||
|
# 解析平均耗时,保留两位小数
|
||||||
|
avg_parse_time = round(self.df['解析耗时'].mean(), 2)
|
||||||
|
|
||||||
|
# 执行平均耗时,保留两位小数
|
||||||
|
avg_execute_time = round(self.df['执行耗时'].mean(), 2)
|
||||||
|
|
||||||
|
# 总平均耗时,保留两位小数
|
||||||
|
avg_total_time = round(self.df['总耗时'].mean(), 2)
|
||||||
|
|
||||||
|
# 最长耗时,保留两位小数
|
||||||
|
max_time = round(self.df['总耗时'].max(), 2)
|
||||||
|
|
||||||
|
# 最短耗时,保留两位小数
|
||||||
|
min_time = round(self.df['总耗时'].min(), 2)
|
||||||
|
|
||||||
|
print(f"测试样例总数 : {total_samples}")
|
||||||
|
print(f"解析成功数量 : {parse_success_count}")
|
||||||
|
print(f"执行成功数量 : {execute_success_count}")
|
||||||
|
print(f"解析平均耗时 : {avg_parse_time} 秒")
|
||||||
|
print(f"执行平均耗时 : {avg_execute_time} 秒")
|
||||||
|
print(f"总平均耗时 : {avg_total_time} 秒")
|
||||||
|
print(f"最长耗时 : {max_time} 秒")
|
||||||
|
print(f"最短耗时 : {min_time} 秒")
|
||||||
|
|
||||||
|
def write_to_csv(self):
|
||||||
|
# 检查 data 文件夹是否存在,如果不存在则创建
|
||||||
|
if not os.path.exists('res'):
|
||||||
|
os.makedirs('res')
|
||||||
|
# 获取当前时间戳
|
||||||
|
timestamp = datetime.now().strftime("%Y%m%d%H%M%S")
|
||||||
|
# 生成带时间戳的文件名
|
||||||
|
file_path = os.path.join('res', f'{self.file_name}_{timestamp}.csv')
|
||||||
|
self.df.to_csv(file_path, index=False)
|
||||||
|
print(f"测试结果已保存到 {file_path}")
|
||||||
|
|
||||||
class BatchTest:
|
class BatchTest:
|
||||||
def __init__(self, url, agentId, chatId, userName):
|
def __init__(self, url, agentId, chatId, userName):
|
||||||
@@ -70,18 +132,35 @@ class BatchTest:
|
|||||||
def benchmark(url:str, agentId:str, chatId:str, filePath:str, userName:str):
|
def benchmark(url:str, agentId:str, chatId:str, filePath:str, userName:str):
|
||||||
batch_test = BatchTest(url, agentId, chatId, userName)
|
batch_test = BatchTest(url, agentId, chatId, userName)
|
||||||
df = batch_test.read_question_from_csv(filePath)
|
df = batch_test.read_question_from_csv(filePath)
|
||||||
|
appender = DataFrameAppender(os.path.basename(filePath))
|
||||||
for index, row in df.iterrows():
|
for index, row in df.iterrows():
|
||||||
question = row['question']
|
question = row['question']
|
||||||
print('start to ask question:', question)
|
print('start to ask question:', question)
|
||||||
# 捕获异常,防止程序中断
|
# 捕获异常,防止程序中断
|
||||||
try:
|
try:
|
||||||
parse_resp = batch_test.parse(question)
|
parse_resp = batch_test.parse(question)
|
||||||
batch_test.execute(agentId, question, parse_resp['data']['queryId'])
|
parse_status = '解析失败'
|
||||||
|
if parse_resp.get('data').get('errorMsg') is None:
|
||||||
|
parse_status = '解析成功'
|
||||||
|
parse_cost = parse_resp.get('data').get('parseTimeCost').get('parseTime')
|
||||||
|
execute_resp = batch_test.execute(agentId, question, parse_resp['data']['queryId'])
|
||||||
|
execute_status = '执行失败'
|
||||||
|
execute_cost = 0
|
||||||
|
if parse_status == '解析成功' and execute_resp.get('data').get('errorMsg') is None:
|
||||||
|
execute_status = '执行成功'
|
||||||
|
execute_cost = execute_resp.get('data').get('queryTimeCost')
|
||||||
|
res = [question.replace(',', '#'),parse_status,parse_cost/1000,execute_status,execute_cost/1000,(parse_cost+execute_cost)/1000]
|
||||||
|
appender.append_data(res)
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print('error:', e)
|
print('error:', e)
|
||||||
traceback.print_exc()
|
traceback.print_exc()
|
||||||
continue
|
continue
|
||||||
time.sleep(1)
|
time.sleep(1)
|
||||||
|
# 打印分析结果
|
||||||
|
appender.print_analysis_result()
|
||||||
|
# 分析明细输出
|
||||||
|
appender.write_to_csv()
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
|
|
||||||
|
|||||||
@@ -5,5 +5,7 @@ import com.tencent.supersonic.chat.server.pojo.ExecuteContext;
|
|||||||
|
|
||||||
public interface ChatQueryExecutor {
|
public interface ChatQueryExecutor {
|
||||||
|
|
||||||
|
boolean accept(ExecuteContext executeContext);
|
||||||
|
|
||||||
QueryResult execute(ExecuteContext executeContext);
|
QueryResult execute(ExecuteContext executeContext);
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -37,11 +37,12 @@ public class PlainTextExecutor implements ChatQueryExecutor {
|
|||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public QueryResult execute(ExecuteContext executeContext) {
|
public boolean accept(ExecuteContext executeContext) {
|
||||||
if (!"PLAIN_TEXT".equals(executeContext.getParseInfo().getQueryMode())) {
|
return "PLAIN_TEXT".equals(executeContext.getParseInfo().getQueryMode());
|
||||||
return null;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public QueryResult execute(ExecuteContext executeContext) {
|
||||||
AgentService agentService = ContextUtils.getBean(AgentService.class);
|
AgentService agentService = ContextUtils.getBean(AgentService.class);
|
||||||
Agent chatAgent = agentService.getAgent(executeContext.getAgent().getId());
|
Agent chatAgent = agentService.getAgent(executeContext.getAgent().getId());
|
||||||
ChatApp chatApp = chatAgent.getChatAppConfig().get(APP_KEY);
|
ChatApp chatApp = chatAgent.getChatAppConfig().get(APP_KEY);
|
||||||
|
|||||||
@@ -8,6 +8,11 @@ import com.tencent.supersonic.headless.api.pojo.SemanticParseInfo;
|
|||||||
|
|
||||||
public class PluginExecutor implements ChatQueryExecutor {
|
public class PluginExecutor implements ChatQueryExecutor {
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public boolean accept(ExecuteContext executeContext) {
|
||||||
|
return PluginQueryManager.isPluginQuery(executeContext.getParseInfo().getQueryMode());
|
||||||
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public QueryResult execute(ExecuteContext executeContext) {
|
public QueryResult execute(ExecuteContext executeContext) {
|
||||||
SemanticParseInfo parseInfo = executeContext.getParseInfo();
|
SemanticParseInfo parseInfo = executeContext.getParseInfo();
|
||||||
|
|||||||
@@ -25,6 +25,11 @@ import java.util.Objects;
|
|||||||
|
|
||||||
public class SqlExecutor implements ChatQueryExecutor {
|
public class SqlExecutor implements ChatQueryExecutor {
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public boolean accept(ExecuteContext executeContext) {
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
@SneakyThrows
|
@SneakyThrows
|
||||||
@Override
|
@Override
|
||||||
public QueryResult execute(ExecuteContext executeContext) {
|
public QueryResult execute(ExecuteContext executeContext) {
|
||||||
@@ -70,8 +75,12 @@ public class SqlExecutor implements ChatQueryExecutor {
|
|||||||
return null;
|
return null;
|
||||||
}
|
}
|
||||||
|
|
||||||
QuerySqlReq sqlReq =
|
// 使用querySQL,它已经包含了所有修正(包括物理SQL修正)
|
||||||
QuerySqlReq.builder().sql(parseInfo.getSqlInfo().getCorrectedS2SQL()).build();
|
String finalSql = StringUtils.isNotBlank(parseInfo.getSqlInfo().getQuerySQL())
|
||||||
|
? parseInfo.getSqlInfo().getQuerySQL()
|
||||||
|
: parseInfo.getSqlInfo().getCorrectedS2SQL();
|
||||||
|
|
||||||
|
QuerySqlReq sqlReq = QuerySqlReq.builder().sql(finalSql).build();
|
||||||
sqlReq.setSqlInfo(parseInfo.getSqlInfo());
|
sqlReq.setSqlInfo(parseInfo.getSqlInfo());
|
||||||
sqlReq.setDataSetId(parseInfo.getDataSetId());
|
sqlReq.setDataSetId(parseInfo.getDataSetId());
|
||||||
|
|
||||||
@@ -80,12 +89,12 @@ public class SqlExecutor implements ChatQueryExecutor {
|
|||||||
queryResult.setQueryId(executeContext.getRequest().getQueryId());
|
queryResult.setQueryId(executeContext.getRequest().getQueryId());
|
||||||
queryResult.setChatContext(parseInfo);
|
queryResult.setChatContext(parseInfo);
|
||||||
queryResult.setQueryMode(parseInfo.getQueryMode());
|
queryResult.setQueryMode(parseInfo.getQueryMode());
|
||||||
queryResult.setQueryTimeCost(System.currentTimeMillis() - startTime);
|
|
||||||
SemanticQueryResp queryResp =
|
SemanticQueryResp queryResp =
|
||||||
semanticLayer.queryByReq(sqlReq, executeContext.getRequest().getUser());
|
semanticLayer.queryByReq(sqlReq, executeContext.getRequest().getUser());
|
||||||
|
queryResult.setQueryTimeCost(System.currentTimeMillis() - startTime);
|
||||||
if (queryResp != null) {
|
if (queryResp != null) {
|
||||||
queryResult.setQueryAuthorization(queryResp.getQueryAuthorization());
|
queryResult.setQueryAuthorization(queryResp.getQueryAuthorization());
|
||||||
queryResult.setQuerySql(queryResp.getSql());
|
queryResult.setQuerySql(finalSql);
|
||||||
queryResult.setQueryResults(queryResp.getResultList());
|
queryResult.setQueryResults(queryResp.getResultList());
|
||||||
queryResult.setQueryColumns(queryResp.getColumns());
|
queryResult.setQueryColumns(queryResp.getColumns());
|
||||||
queryResult.setQueryState(QueryState.SUCCESS);
|
queryResult.setQueryState(QueryState.SUCCESS);
|
||||||
|
|||||||
@@ -4,5 +4,7 @@ import com.tencent.supersonic.chat.server.pojo.ParseContext;
|
|||||||
|
|
||||||
public interface ChatQueryParser {
|
public interface ChatQueryParser {
|
||||||
|
|
||||||
|
boolean accept(ParseContext parseContext);
|
||||||
|
|
||||||
void parse(ParseContext parseContext);
|
void parse(ParseContext parseContext);
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -14,12 +14,12 @@ public class NL2PluginParser implements ChatQueryParser {
|
|||||||
private final List<PluginRecognizer> pluginRecognizers =
|
private final List<PluginRecognizer> pluginRecognizers =
|
||||||
ComponentFactory.getPluginRecognizers();
|
ComponentFactory.getPluginRecognizers();
|
||||||
|
|
||||||
@Override
|
public boolean accept(ParseContext parseContext) {
|
||||||
public void parse(ParseContext parseContext) {
|
return parseContext.getAgent().containsPluginTool();
|
||||||
if (!parseContext.getAgent().containsPluginTool()) {
|
|
||||||
return;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public void parse(ParseContext parseContext) {
|
||||||
pluginRecognizers.forEach(pluginRecognizer -> {
|
pluginRecognizers.forEach(pluginRecognizer -> {
|
||||||
pluginRecognizer.recognize(parseContext);
|
pluginRecognizer.recognize(parseContext);
|
||||||
log.info("{} recallResult:{}", pluginRecognizer.getClass().getSimpleName(),
|
log.info("{} recallResult:{}", pluginRecognizer.getClass().getSimpleName(),
|
||||||
|
|||||||
@@ -73,12 +73,12 @@ public class NL2SQLParser implements ChatQueryParser {
|
|||||||
.build());
|
.build());
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
public boolean accept(ParseContext parseContext) {
|
||||||
public void parse(ParseContext parseContext) {
|
return parseContext.enableNL2SQL();
|
||||||
if (!parseContext.enableNL2SQL()) {
|
|
||||||
return;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public void parse(ParseContext parseContext) {
|
||||||
// first go with rule-based parsers unless the user has already selected one parse.
|
// first go with rule-based parsers unless the user has already selected one parse.
|
||||||
if (Objects.isNull(parseContext.getRequest().getSelectedParse())) {
|
if (Objects.isNull(parseContext.getRequest().getSelectedParse())) {
|
||||||
QueryNLReq queryNLReq = QueryReqConverter.buildQueryNLReq(parseContext);
|
QueryNLReq queryNLReq = QueryReqConverter.buildQueryNLReq(parseContext);
|
||||||
|
|||||||
@@ -6,12 +6,12 @@ import com.tencent.supersonic.headless.api.pojo.response.ParseResp;
|
|||||||
|
|
||||||
public class PlainTextParser implements ChatQueryParser {
|
public class PlainTextParser implements ChatQueryParser {
|
||||||
|
|
||||||
@Override
|
public boolean accept(ParseContext parseContext) {
|
||||||
public void parse(ParseContext parseContext) {
|
return !parseContext.getAgent().containsAnyTool();
|
||||||
if (parseContext.getAgent().containsAnyTool()) {
|
|
||||||
return;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public void parse(ParseContext parseContext) {
|
||||||
SemanticParseInfo parseInfo = new SemanticParseInfo();
|
SemanticParseInfo parseInfo = new SemanticParseInfo();
|
||||||
parseInfo.setQueryMode("PLAIN_TEXT");
|
parseInfo.setQueryMode("PLAIN_TEXT");
|
||||||
parseInfo.setId(1);
|
parseInfo.setId(1);
|
||||||
|
|||||||
@@ -1,11 +1,16 @@
|
|||||||
package com.tencent.supersonic.chat.server.persistence.dataobject;
|
package com.tencent.supersonic.chat.server.persistence.dataobject;
|
||||||
|
|
||||||
|
import com.baomidou.mybatisplus.annotation.IdType;
|
||||||
|
import com.baomidou.mybatisplus.annotation.TableId;
|
||||||
|
import com.baomidou.mybatisplus.annotation.TableName;
|
||||||
import lombok.Data;
|
import lombok.Data;
|
||||||
|
|
||||||
@Data
|
@Data
|
||||||
|
@TableName("s2_chat")
|
||||||
public class ChatDO {
|
public class ChatDO {
|
||||||
|
|
||||||
private long chatId;
|
@TableId(type = IdType.AUTO)
|
||||||
|
private Long chatId;
|
||||||
private Integer agentId;
|
private Integer agentId;
|
||||||
private String chatName;
|
private String chatName;
|
||||||
private String createTime;
|
private String createTime;
|
||||||
|
|||||||
@@ -1,20 +0,0 @@
|
|||||||
package com.tencent.supersonic.chat.server.persistence.dataobject;
|
|
||||||
|
|
||||||
import lombok.Data;
|
|
||||||
|
|
||||||
import java.util.Date;
|
|
||||||
|
|
||||||
@Data
|
|
||||||
public class DictConfDO {
|
|
||||||
|
|
||||||
private Long id;
|
|
||||||
|
|
||||||
private Long modelId;
|
|
||||||
|
|
||||||
private String dimValueInfos;
|
|
||||||
|
|
||||||
private String createdBy;
|
|
||||||
private String updatedBy;
|
|
||||||
private Date createdAt;
|
|
||||||
private Date updatedAt;
|
|
||||||
}
|
|
||||||
@@ -1,38 +0,0 @@
|
|||||||
package com.tencent.supersonic.chat.server.persistence.dataobject;
|
|
||||||
|
|
||||||
import lombok.Data;
|
|
||||||
import lombok.ToString;
|
|
||||||
import org.apache.commons.codec.digest.DigestUtils;
|
|
||||||
|
|
||||||
import java.util.Date;
|
|
||||||
|
|
||||||
@Data
|
|
||||||
@ToString
|
|
||||||
public class DictTaskDO {
|
|
||||||
|
|
||||||
private Long id;
|
|
||||||
|
|
||||||
private String name;
|
|
||||||
|
|
||||||
private String description;
|
|
||||||
|
|
||||||
private String command;
|
|
||||||
|
|
||||||
private String commandMd5;
|
|
||||||
|
|
||||||
private String dimIds;
|
|
||||||
|
|
||||||
private Integer status;
|
|
||||||
|
|
||||||
private String createdBy;
|
|
||||||
|
|
||||||
private Date createdAt;
|
|
||||||
|
|
||||||
private Double progress;
|
|
||||||
|
|
||||||
private Long elapsedMs;
|
|
||||||
|
|
||||||
public String getCommandMd5() {
|
|
||||||
return DigestUtils.md5Hex(command);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
@@ -148,7 +148,7 @@ public class ChatQueryRepositoryImpl implements ChatQueryRepository {
|
|||||||
chatQueryDO.setUserName(chatParseReq.getUser().getName());
|
chatQueryDO.setUserName(chatParseReq.getUser().getName());
|
||||||
chatQueryDO.setQueryText(chatParseReq.getQueryText());
|
chatQueryDO.setQueryText(chatParseReq.getQueryText());
|
||||||
chatQueryDO.setAgentId(chatParseReq.getAgentId());
|
chatQueryDO.setAgentId(chatParseReq.getAgentId());
|
||||||
chatQueryDO.setQueryResult("");
|
chatQueryDO.setQueryResult("{}");
|
||||||
chatQueryDO.setQueryState(1);
|
chatQueryDO.setQueryState(1);
|
||||||
try {
|
try {
|
||||||
chatQueryDOMapper.insert(chatQueryDO);
|
chatQueryDOMapper.insert(chatQueryDO);
|
||||||
|
|||||||
@@ -46,7 +46,9 @@ public class DataInterpretProcessor implements ExecuteResultProcessor {
|
|||||||
public boolean accept(ExecuteContext executeContext) {
|
public boolean accept(ExecuteContext executeContext) {
|
||||||
Agent agent = executeContext.getAgent();
|
Agent agent = executeContext.getAgent();
|
||||||
ChatApp chatApp = agent.getChatAppConfig().get(APP_KEY);
|
ChatApp chatApp = agent.getChatAppConfig().get(APP_KEY);
|
||||||
return Objects.nonNull(chatApp) && chatApp.isEnable();
|
return Objects.nonNull(chatApp) && chatApp.isEnable()
|
||||||
|
&& StringUtils.isNotBlank(executeContext.getResponse().getTextResult()) // 如果都没结果,则无法处理
|
||||||
|
&& StringUtils.isBlank(executeContext.getResponse().getTextSummary()); // 如果已经有汇总的结果了,无法再次处理
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
@@ -56,7 +58,16 @@ public class DataInterpretProcessor implements ExecuteResultProcessor {
|
|||||||
ChatApp chatApp = agent.getChatAppConfig().get(APP_KEY);
|
ChatApp chatApp = agent.getChatAppConfig().get(APP_KEY);
|
||||||
|
|
||||||
Map<String, Object> variable = new HashMap<>();
|
Map<String, Object> variable = new HashMap<>();
|
||||||
variable.put("question", executeContext.getRequest().getQueryText());
|
String question = executeContext.getResponse().getTextResult();// 结果解析应该用改写的问题,因为改写的内容信息量更大
|
||||||
|
if (executeContext.getParseInfo().getProperties() != null
|
||||||
|
&& executeContext.getParseInfo().getProperties().containsKey("CONTEXT")) {
|
||||||
|
Map<String, Object> context = (Map<String, Object>) executeContext.getParseInfo()
|
||||||
|
.getProperties().get("CONTEXT");
|
||||||
|
if (context.get("queryText") != null && "".equals(context.get("queryText"))) {
|
||||||
|
question = context.get("queryText").toString();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
variable.put("question", question);
|
||||||
variable.put("data", queryResult.getTextResult());
|
variable.put("data", queryResult.getTextResult());
|
||||||
|
|
||||||
Prompt prompt = PromptTemplate.from(chatApp.getPrompt()).apply(variable);
|
Prompt prompt = PromptTemplate.from(chatApp.getPrompt()).apply(variable);
|
||||||
|
|||||||
@@ -66,8 +66,10 @@ public class ChatConfigController {
|
|||||||
}
|
}
|
||||||
|
|
||||||
@GetMapping("/getDomainDataSetTree")
|
@GetMapping("/getDomainDataSetTree")
|
||||||
public List<ItemResp> getDomainDataSetTree() {
|
public List<ItemResp> getDomainDataSetTree(HttpServletRequest request,
|
||||||
return semanticLayerService.getDomainDataSetTree();
|
HttpServletResponse response) {
|
||||||
|
User user = UserHolder.findUser(request, response);
|
||||||
|
return semanticLayerService.getDomainDataSetTree(user);
|
||||||
}
|
}
|
||||||
|
|
||||||
@GetMapping("/getDataSetSchema/{id}")
|
@GetMapping("/getDataSetSchema/{id}")
|
||||||
|
|||||||
@@ -22,11 +22,10 @@ public class ChatController {
|
|||||||
private ChatManageService chatService;
|
private ChatManageService chatService;
|
||||||
|
|
||||||
@PostMapping("/save")
|
@PostMapping("/save")
|
||||||
public Boolean save(@RequestParam(value = "chatName") String chatName,
|
public Long save(@RequestParam(value = "chatName") String chatName,
|
||||||
@RequestParam(value = "agentId", required = false) Integer agentId,
|
@RequestParam(value = "agentId", required = false) Integer agentId,
|
||||||
HttpServletRequest request, HttpServletResponse response) {
|
HttpServletRequest request, HttpServletResponse response) {
|
||||||
chatService.addChat(UserHolder.findUser(request, response), chatName, agentId);
|
return chatService.addChat(UserHolder.findUser(request, response), chatName, agentId);
|
||||||
return true;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
@GetMapping("/getAll")
|
@GetMapping("/getAll")
|
||||||
|
|||||||
@@ -22,6 +22,7 @@ import lombok.extern.slf4j.Slf4j;
|
|||||||
import org.springframework.beans.BeanUtils;
|
import org.springframework.beans.BeanUtils;
|
||||||
import org.springframework.beans.factory.annotation.Autowired;
|
import org.springframework.beans.factory.annotation.Autowired;
|
||||||
import org.springframework.beans.factory.annotation.Qualifier;
|
import org.springframework.beans.factory.annotation.Qualifier;
|
||||||
|
import org.springframework.context.annotation.Lazy;
|
||||||
import org.springframework.stereotype.Service;
|
import org.springframework.stereotype.Service;
|
||||||
import org.springframework.util.CollectionUtils;
|
import org.springframework.util.CollectionUtils;
|
||||||
|
|
||||||
@@ -39,6 +40,7 @@ public class AgentServiceImpl extends ServiceImpl<AgentDOMapper, AgentDO> implem
|
|||||||
private MemoryService memoryService;
|
private MemoryService memoryService;
|
||||||
|
|
||||||
@Autowired
|
@Autowired
|
||||||
|
@Lazy
|
||||||
private ChatQueryService chatQueryService;
|
private ChatQueryService chatQueryService;
|
||||||
|
|
||||||
@Autowired
|
@Autowired
|
||||||
@@ -161,9 +163,11 @@ public class AgentServiceImpl extends ServiceImpl<AgentDOMapper, AgentDO> implem
|
|||||||
JsonUtil.toMap(agentDO.getChatModelConfig(), String.class, ChatApp.class));
|
JsonUtil.toMap(agentDO.getChatModelConfig(), String.class, ChatApp.class));
|
||||||
agent.setVisualConfig(JsonUtil.toObject(agentDO.getVisualConfig(), VisualConfig.class));
|
agent.setVisualConfig(JsonUtil.toObject(agentDO.getVisualConfig(), VisualConfig.class));
|
||||||
agent.getChatAppConfig().values().forEach(c -> {
|
agent.getChatAppConfig().values().forEach(c -> {
|
||||||
|
if (c.isEnable()) {// 优化,减少访问数据库的次数
|
||||||
ChatModel chatModel = chatModelService.getChatModel(c.getChatModelId());
|
ChatModel chatModel = chatModelService.getChatModel(c.getChatModelId());
|
||||||
if (Objects.nonNull(chatModel)) {
|
if (Objects.nonNull(chatModel)) {
|
||||||
c.setChatModelConfig(chatModelService.getChatModel(c.getChatModelId()).getConfig());
|
c.setChatModelConfig(chatModel.getConfig());
|
||||||
|
}
|
||||||
}
|
}
|
||||||
});
|
});
|
||||||
agent.setAdmins(JsonUtil.toList(agentDO.getAdmin(), String.class));
|
agent.setAdmins(JsonUtil.toList(agentDO.getAdmin(), String.class));
|
||||||
|
|||||||
@@ -233,6 +233,10 @@ public class ChatManageServiceImpl implements ChatManageService {
|
|||||||
@Override
|
@Override
|
||||||
public SemanticParseInfo getParseInfo(Long questionId, int parseId) {
|
public SemanticParseInfo getParseInfo(Long questionId, int parseId) {
|
||||||
ChatParseDO chatParseDO = chatQueryRepository.getParseInfo(questionId, parseId);
|
ChatParseDO chatParseDO = chatQueryRepository.getParseInfo(questionId, parseId);
|
||||||
|
if (chatParseDO == null) {
|
||||||
|
return null;
|
||||||
|
} else {
|
||||||
return JSONObject.parseObject(chatParseDO.getParseInfo(), SemanticParseInfo.class);
|
return JSONObject.parseObject(chatParseDO.getParseInfo(), SemanticParseInfo.class);
|
||||||
}
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -49,6 +49,7 @@ import net.sf.jsqlparser.schema.Column;
|
|||||||
import org.apache.commons.lang3.StringUtils;
|
import org.apache.commons.lang3.StringUtils;
|
||||||
import org.apache.commons.lang3.tuple.Pair;
|
import org.apache.commons.lang3.tuple.Pair;
|
||||||
import org.springframework.beans.factory.annotation.Autowired;
|
import org.springframework.beans.factory.annotation.Autowired;
|
||||||
|
import org.springframework.context.annotation.Lazy;
|
||||||
import org.springframework.stereotype.Service;
|
import org.springframework.stereotype.Service;
|
||||||
import org.springframework.util.CollectionUtils;
|
import org.springframework.util.CollectionUtils;
|
||||||
|
|
||||||
@@ -66,6 +67,7 @@ public class ChatQueryServiceImpl implements ChatQueryService {
|
|||||||
@Autowired
|
@Autowired
|
||||||
private SemanticLayerService semanticLayerService;
|
private SemanticLayerService semanticLayerService;
|
||||||
@Autowired
|
@Autowired
|
||||||
|
@Lazy
|
||||||
private AgentService agentService;
|
private AgentService agentService;
|
||||||
|
|
||||||
private final List<ChatQueryParser> chatQueryParsers = ComponentFactory.getChatParsers();
|
private final List<ChatQueryParser> chatQueryParsers = ComponentFactory.getChatParsers();
|
||||||
@@ -95,7 +97,11 @@ public class ChatQueryServiceImpl implements ChatQueryService {
|
|||||||
}
|
}
|
||||||
|
|
||||||
ParseContext parseContext = buildParseContext(chatParseReq, new ChatParseResp(queryId));
|
ParseContext parseContext = buildParseContext(chatParseReq, new ChatParseResp(queryId));
|
||||||
chatQueryParsers.forEach(p -> p.parse(parseContext));
|
for (ChatQueryParser parser : chatQueryParsers) {
|
||||||
|
if (parser.accept(parseContext)) {
|
||||||
|
parser.parse(parseContext);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
for (ParseResultProcessor processor : parseResultProcessors) {
|
for (ParseResultProcessor processor : parseResultProcessors) {
|
||||||
if (processor.accept(parseContext)) {
|
if (processor.accept(parseContext)) {
|
||||||
@@ -116,11 +122,13 @@ public class ChatQueryServiceImpl implements ChatQueryService {
|
|||||||
QueryResult queryResult = new QueryResult();
|
QueryResult queryResult = new QueryResult();
|
||||||
ExecuteContext executeContext = buildExecuteContext(chatExecuteReq);
|
ExecuteContext executeContext = buildExecuteContext(chatExecuteReq);
|
||||||
for (ChatQueryExecutor chatQueryExecutor : chatQueryExecutors) {
|
for (ChatQueryExecutor chatQueryExecutor : chatQueryExecutors) {
|
||||||
|
if (chatQueryExecutor.accept(executeContext)) {
|
||||||
queryResult = chatQueryExecutor.execute(executeContext);
|
queryResult = chatQueryExecutor.execute(executeContext);
|
||||||
if (queryResult != null) {
|
if (queryResult != null) {
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
|
||||||
executeContext.setResponse(queryResult);
|
executeContext.setResponse(queryResult);
|
||||||
if (queryResult != null) {
|
if (queryResult != null) {
|
||||||
|
|||||||
@@ -18,9 +18,11 @@ import com.tencent.supersonic.common.config.EmbeddingConfig;
|
|||||||
import com.tencent.supersonic.common.pojo.Text2SQLExemplar;
|
import com.tencent.supersonic.common.pojo.Text2SQLExemplar;
|
||||||
import com.tencent.supersonic.common.pojo.User;
|
import com.tencent.supersonic.common.pojo.User;
|
||||||
import com.tencent.supersonic.common.service.ExemplarService;
|
import com.tencent.supersonic.common.service.ExemplarService;
|
||||||
|
import lombok.extern.slf4j.Slf4j;
|
||||||
import org.apache.commons.lang3.StringUtils;
|
import org.apache.commons.lang3.StringUtils;
|
||||||
import org.springframework.beans.BeanUtils;
|
import org.springframework.beans.BeanUtils;
|
||||||
import org.springframework.beans.factory.annotation.Autowired;
|
import org.springframework.beans.factory.annotation.Autowired;
|
||||||
|
import org.springframework.boot.CommandLineRunner;
|
||||||
import org.springframework.stereotype.Service;
|
import org.springframework.stereotype.Service;
|
||||||
import org.springframework.util.CollectionUtils;
|
import org.springframework.util.CollectionUtils;
|
||||||
|
|
||||||
@@ -30,7 +32,8 @@ import java.util.Objects;
|
|||||||
import java.util.stream.Collectors;
|
import java.util.stream.Collectors;
|
||||||
|
|
||||||
@Service
|
@Service
|
||||||
public class MemoryServiceImpl implements MemoryService {
|
@Slf4j
|
||||||
|
public class MemoryServiceImpl implements MemoryService, CommandLineRunner {
|
||||||
|
|
||||||
@Autowired
|
@Autowired
|
||||||
private ChatMemoryRepository chatMemoryRepository;
|
private ChatMemoryRepository chatMemoryRepository;
|
||||||
@@ -61,12 +64,17 @@ public class MemoryServiceImpl implements MemoryService {
|
|||||||
ChatMemoryDO chatMemoryDO = chatMemoryRepository.getMemory(chatMemoryUpdateReq.getId());
|
ChatMemoryDO chatMemoryDO = chatMemoryRepository.getMemory(chatMemoryUpdateReq.getId());
|
||||||
boolean hadEnabled =
|
boolean hadEnabled =
|
||||||
MemoryStatus.ENABLED.toString().equals(chatMemoryDO.getStatus().trim());
|
MemoryStatus.ENABLED.toString().equals(chatMemoryDO.getStatus().trim());
|
||||||
if (MemoryStatus.ENABLED.equals(chatMemoryUpdateReq.getStatus()) && !hadEnabled) {
|
|
||||||
|
if (MemoryStatus.ENABLED.equals(chatMemoryUpdateReq.getStatus())) {
|
||||||
|
// Update the latest SQL/Schema to vector DB once memory is enabled
|
||||||
|
chatMemoryDO.setS2sql(chatMemoryUpdateReq.getS2sql());
|
||||||
|
chatMemoryDO.setDbSchema(chatMemoryUpdateReq.getDbSchema());
|
||||||
enableMemory(chatMemoryDO);
|
enableMemory(chatMemoryDO);
|
||||||
} else if (MemoryStatus.DISABLED.equals(chatMemoryUpdateReq.getStatus()) && hadEnabled) {
|
} else if ((MemoryStatus.DISABLED.equals(chatMemoryUpdateReq.getStatus())
|
||||||
|
|| MemoryStatus.PENDING.equals(chatMemoryUpdateReq.getStatus())) && hadEnabled) {
|
||||||
|
// Remove from vector DB when transitioning: launched→disabled OR enabled→pending
|
||||||
disableMemory(chatMemoryDO);
|
disableMemory(chatMemoryDO);
|
||||||
}
|
}
|
||||||
|
|
||||||
LambdaUpdateWrapper<ChatMemoryDO> updateWrapper = new LambdaUpdateWrapper<>();
|
LambdaUpdateWrapper<ChatMemoryDO> updateWrapper = new LambdaUpdateWrapper<>();
|
||||||
updateWrapper.eq(ChatMemoryDO::getId, chatMemoryDO.getId());
|
updateWrapper.eq(ChatMemoryDO::getId, chatMemoryDO.getId());
|
||||||
if (Objects.nonNull(chatMemoryUpdateReq.getStatus())) {
|
if (Objects.nonNull(chatMemoryUpdateReq.getStatus())) {
|
||||||
@@ -87,6 +95,12 @@ public class MemoryServiceImpl implements MemoryService {
|
|||||||
updateWrapper.set(ChatMemoryDO::getHumanReviewCmt,
|
updateWrapper.set(ChatMemoryDO::getHumanReviewCmt,
|
||||||
chatMemoryUpdateReq.getHumanReviewCmt());
|
chatMemoryUpdateReq.getHumanReviewCmt());
|
||||||
}
|
}
|
||||||
|
if (Objects.nonNull(chatMemoryUpdateReq.getDbSchema())) {
|
||||||
|
updateWrapper.set(ChatMemoryDO::getDbSchema, chatMemoryUpdateReq.getDbSchema());
|
||||||
|
}
|
||||||
|
if (Objects.nonNull(chatMemoryUpdateReq.getS2sql())) {
|
||||||
|
updateWrapper.set(ChatMemoryDO::getS2sql, chatMemoryUpdateReq.getS2sql());
|
||||||
|
}
|
||||||
updateWrapper.set(ChatMemoryDO::getUpdatedAt, new Date());
|
updateWrapper.set(ChatMemoryDO::getUpdatedAt, new Date());
|
||||||
updateWrapper.set(ChatMemoryDO::getUpdatedBy, user.getName());
|
updateWrapper.set(ChatMemoryDO::getUpdatedBy, user.getName());
|
||||||
|
|
||||||
@@ -95,6 +109,14 @@ public class MemoryServiceImpl implements MemoryService {
|
|||||||
|
|
||||||
@Override
|
@Override
|
||||||
public void batchDelete(List<Long> ids) {
|
public void batchDelete(List<Long> ids) {
|
||||||
|
QueryWrapper<ChatMemoryDO> queryWrapper = new QueryWrapper<>();
|
||||||
|
queryWrapper.lambda().in(ChatMemoryDO::getId, ids);
|
||||||
|
List<ChatMemoryDO> chatMemoryDOS = chatMemoryRepository.getMemories(queryWrapper);
|
||||||
|
chatMemoryDOS.forEach(chatMemoryDO -> {
|
||||||
|
if (MemoryStatus.ENABLED.toString().equals(chatMemoryDO.getStatus().trim())) {
|
||||||
|
disableMemory(chatMemoryDO);
|
||||||
|
}
|
||||||
|
});
|
||||||
chatMemoryRepository.batchDelete(ids);
|
chatMemoryRepository.batchDelete(ids);
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -187,4 +209,25 @@ public class MemoryServiceImpl implements MemoryService {
|
|||||||
return memory;
|
return memory;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public void run(String... args) { // 优化,启动时检查,向量数据,将记忆放到向量数据库
|
||||||
|
loadSysExemplars();
|
||||||
|
}
|
||||||
|
|
||||||
|
public void loadSysExemplars() {
|
||||||
|
try {
|
||||||
|
List<ChatMemory> memories = this
|
||||||
|
.getMemories(ChatMemoryFilter.builder().status(MemoryStatus.ENABLED).build());
|
||||||
|
for (ChatMemory memory : memories) {
|
||||||
|
exemplarService.storeExemplar(
|
||||||
|
embeddingConfig.getMemoryCollectionName(memory.getAgentId()),
|
||||||
|
Text2SQLExemplar.builder().question(memory.getQuestion())
|
||||||
|
.sideInfo(memory.getSideInfo()).dbSchema(memory.getDbSchema())
|
||||||
|
.sql(memory.getS2sql()).build());
|
||||||
|
}
|
||||||
|
|
||||||
|
} catch (Exception e) {
|
||||||
|
log.error("Failed to load system exemplars", e);
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -21,7 +21,10 @@
|
|||||||
<groupId>org.springframework.boot</groupId>
|
<groupId>org.springframework.boot</groupId>
|
||||||
<artifactId>spring-boot-starter-validation</artifactId>
|
<artifactId>spring-boot-starter-validation</artifactId>
|
||||||
</dependency>
|
</dependency>
|
||||||
|
<dependency>
|
||||||
|
<groupId>org.springframework.boot</groupId>
|
||||||
|
<artifactId>spring-boot-autoconfigure-processor</artifactId>
|
||||||
|
</dependency>
|
||||||
|
|
||||||
<dependency>
|
<dependency>
|
||||||
<groupId>org.slf4j</groupId>
|
<groupId>org.slf4j</groupId>
|
||||||
@@ -33,7 +36,7 @@
|
|||||||
<dependency>
|
<dependency>
|
||||||
<groupId>org.apache.httpcomponents.client5</groupId>
|
<groupId>org.apache.httpcomponents.client5</groupId>
|
||||||
<artifactId>httpclient5</artifactId>
|
<artifactId>httpclient5</artifactId>
|
||||||
<version>${httpclient5.version}</version> <!-- 请确认使用最新稳定版本 -->
|
<version>${httpclient5.version}</version>
|
||||||
</dependency>
|
</dependency>
|
||||||
<!-- <dependency>-->
|
<!-- <dependency>-->
|
||||||
<!-- <groupId>org.apache.httpcomponents</groupId>-->
|
<!-- <groupId>org.apache.httpcomponents</groupId>-->
|
||||||
@@ -182,10 +185,6 @@
|
|||||||
<groupId>dev.langchain4j</groupId>
|
<groupId>dev.langchain4j</groupId>
|
||||||
<artifactId>langchain4j-pgvector</artifactId>
|
<artifactId>langchain4j-pgvector</artifactId>
|
||||||
</dependency>
|
</dependency>
|
||||||
<dependency>
|
|
||||||
<groupId>dev.langchain4j</groupId>
|
|
||||||
<artifactId>langchain4j-azure-open-ai</artifactId>
|
|
||||||
</dependency>
|
|
||||||
<dependency>
|
<dependency>
|
||||||
<groupId>dev.langchain4j</groupId>
|
<groupId>dev.langchain4j</groupId>
|
||||||
<artifactId>langchain4j-embeddings-bge-small-zh</artifactId>
|
<artifactId>langchain4j-embeddings-bge-small-zh</artifactId>
|
||||||
@@ -198,34 +197,6 @@
|
|||||||
<groupId>dev.langchain4j</groupId>
|
<groupId>dev.langchain4j</groupId>
|
||||||
<artifactId>langchain4j-embeddings-all-minilm-l6-v2-q</artifactId>
|
<artifactId>langchain4j-embeddings-all-minilm-l6-v2-q</artifactId>
|
||||||
</dependency>
|
</dependency>
|
||||||
<dependency>
|
|
||||||
<groupId>dev.langchain4j</groupId>
|
|
||||||
<artifactId>langchain4j-qianfan</artifactId>
|
|
||||||
</dependency>
|
|
||||||
<dependency>
|
|
||||||
<groupId>dev.langchain4j</groupId>
|
|
||||||
<artifactId>langchain4j-zhipu-ai</artifactId>
|
|
||||||
</dependency>
|
|
||||||
<dependency>
|
|
||||||
<groupId>dev.langchain4j</groupId>
|
|
||||||
<artifactId>langchain4j-dashscope</artifactId>
|
|
||||||
<exclusions>
|
|
||||||
<exclusion>
|
|
||||||
<groupId>org.slf4j</groupId>
|
|
||||||
<artifactId>slf4j-simple</artifactId>
|
|
||||||
</exclusion>
|
|
||||||
</exclusions>
|
|
||||||
</dependency>
|
|
||||||
<dependency>
|
|
||||||
<groupId>dev.langchain4j</groupId>
|
|
||||||
<artifactId>langchain4j-chatglm</artifactId>
|
|
||||||
<exclusions>
|
|
||||||
<exclusion>
|
|
||||||
<groupId>org.slf4j</groupId>
|
|
||||||
<artifactId>slf4j-simple</artifactId>
|
|
||||||
</exclusion>
|
|
||||||
</exclusions>
|
|
||||||
</dependency>
|
|
||||||
<dependency>
|
<dependency>
|
||||||
<groupId>dev.langchain4j</groupId>
|
<groupId>dev.langchain4j</groupId>
|
||||||
<artifactId>langchain4j-ollama</artifactId>
|
<artifactId>langchain4j-ollama</artifactId>
|
||||||
@@ -237,11 +208,6 @@
|
|||||||
<version>${hanlp.version}</version>
|
<version>${hanlp.version}</version>
|
||||||
</dependency>
|
</dependency>
|
||||||
|
|
||||||
<dependency>
|
|
||||||
<groupId>org.springframework.boot</groupId>
|
|
||||||
<artifactId>spring-boot-autoconfigure-processor</artifactId>
|
|
||||||
</dependency>
|
|
||||||
|
|
||||||
<dependency>
|
<dependency>
|
||||||
<groupId>com.google.code.gson</groupId>
|
<groupId>com.google.code.gson</groupId>
|
||||||
<artifactId>gson</artifactId>
|
<artifactId>gson</artifactId>
|
||||||
|
|||||||
@@ -21,7 +21,8 @@ public class LoadRemoveService {
|
|||||||
List<String> resultList = new ArrayList<>(value);
|
List<String> resultList = new ArrayList<>(value);
|
||||||
if (!CollectionUtils.isEmpty(modelIdOrDataSetIds)) {
|
if (!CollectionUtils.isEmpty(modelIdOrDataSetIds)) {
|
||||||
resultList.removeIf(nature -> {
|
resultList.removeIf(nature -> {
|
||||||
if (Objects.isNull(nature)) {
|
if (Objects.isNull(nature) || !nature.startsWith("_")) { // 系统的字典是以 _ 开头的,
|
||||||
|
// 过滤因引用外部字典导致的异常
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
Long id = getId(nature);
|
Long id = getId(nature);
|
||||||
|
|||||||
@@ -77,11 +77,6 @@ public class SemanticSqlConformance implements SqlConformance {
|
|||||||
return SqlConformanceEnum.BIG_QUERY.isMinusAllowed();
|
return SqlConformanceEnum.BIG_QUERY.isMinusAllowed();
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
|
||||||
public boolean isRegexReplaceCaptureGroupDollarIndexed() {
|
|
||||||
return SqlConformanceEnum.BIG_QUERY.isRegexReplaceCaptureGroupDollarIndexed();
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public boolean isApplyAllowed() {
|
public boolean isApplyAllowed() {
|
||||||
return SqlConformanceEnum.BIG_QUERY.isApplyAllowed();
|
return SqlConformanceEnum.BIG_QUERY.isApplyAllowed();
|
||||||
|
|||||||
@@ -26,6 +26,16 @@ public class SqlDialectFactory {
|
|||||||
.withLiteralQuoteString("'").withIdentifierQuoteString("\"")
|
.withLiteralQuoteString("'").withIdentifierQuoteString("\"")
|
||||||
.withLiteralEscapedQuoteString("''").withUnquotedCasing(Casing.UNCHANGED)
|
.withLiteralEscapedQuoteString("''").withUnquotedCasing(Casing.UNCHANGED)
|
||||||
.withQuotedCasing(Casing.UNCHANGED).withCaseSensitive(true);
|
.withQuotedCasing(Casing.UNCHANGED).withCaseSensitive(true);
|
||||||
|
public static final Context PRESTO_CONTEXT =
|
||||||
|
SqlDialect.EMPTY_CONTEXT.withDatabaseProduct(DatabaseProduct.PRESTO)
|
||||||
|
.withLiteralQuoteString("'").withIdentifierQuoteString("\"")
|
||||||
|
.withLiteralEscapedQuoteString("''").withUnquotedCasing(Casing.UNCHANGED)
|
||||||
|
.withQuotedCasing(Casing.UNCHANGED).withCaseSensitive(true);
|
||||||
|
public static final Context KYUUBI_CONTEXT =
|
||||||
|
SqlDialect.EMPTY_CONTEXT.withDatabaseProduct(DatabaseProduct.BIG_QUERY)
|
||||||
|
.withLiteralQuoteString("'").withIdentifierQuoteString("`")
|
||||||
|
.withLiteralEscapedQuoteString("''").withUnquotedCasing(Casing.UNCHANGED)
|
||||||
|
.withQuotedCasing(Casing.UNCHANGED).withCaseSensitive(false);
|
||||||
private static Map<EngineType, SemanticSqlDialect> sqlDialectMap;
|
private static Map<EngineType, SemanticSqlDialect> sqlDialectMap;
|
||||||
|
|
||||||
static {
|
static {
|
||||||
@@ -35,6 +45,10 @@ public class SqlDialectFactory {
|
|||||||
sqlDialectMap.put(EngineType.H2, new SemanticSqlDialect(DEFAULT_CONTEXT));
|
sqlDialectMap.put(EngineType.H2, new SemanticSqlDialect(DEFAULT_CONTEXT));
|
||||||
sqlDialectMap.put(EngineType.POSTGRESQL, new SemanticSqlDialect(POSTGRESQL_CONTEXT));
|
sqlDialectMap.put(EngineType.POSTGRESQL, new SemanticSqlDialect(POSTGRESQL_CONTEXT));
|
||||||
sqlDialectMap.put(EngineType.HANADB, new SemanticSqlDialect(HANADB_CONTEXT));
|
sqlDialectMap.put(EngineType.HANADB, new SemanticSqlDialect(HANADB_CONTEXT));
|
||||||
|
sqlDialectMap.put(EngineType.STARROCKS, new SemanticSqlDialect(DEFAULT_CONTEXT));
|
||||||
|
sqlDialectMap.put(EngineType.KYUUBI, new SemanticSqlDialect(KYUUBI_CONTEXT));
|
||||||
|
sqlDialectMap.put(EngineType.PRESTO, new SemanticSqlDialect(PRESTO_CONTEXT));
|
||||||
|
sqlDialectMap.put(EngineType.TRINO, new SemanticSqlDialect(PRESTO_CONTEXT));
|
||||||
}
|
}
|
||||||
|
|
||||||
public static SemanticSqlDialect getSqlDialect(EngineType engineType) {
|
public static SemanticSqlDialect getSqlDialect(EngineType engineType) {
|
||||||
|
|||||||
@@ -2,19 +2,11 @@ package com.tencent.supersonic.common.calcite;
|
|||||||
|
|
||||||
import com.tencent.supersonic.common.pojo.enums.EngineType;
|
import com.tencent.supersonic.common.pojo.enums.EngineType;
|
||||||
import lombok.extern.slf4j.Slf4j;
|
import lombok.extern.slf4j.Slf4j;
|
||||||
import org.apache.calcite.sql.SqlIdentifier;
|
import net.sf.jsqlparser.expression.Alias;
|
||||||
import org.apache.calcite.sql.SqlLiteral;
|
import net.sf.jsqlparser.parser.CCJSqlParserUtil;
|
||||||
import org.apache.calcite.sql.SqlNode;
|
import net.sf.jsqlparser.statement.select.ParenthesedSelect;
|
||||||
import org.apache.calcite.sql.SqlNodeList;
|
import net.sf.jsqlparser.statement.select.Select;
|
||||||
import org.apache.calcite.sql.SqlOrderBy;
|
import net.sf.jsqlparser.statement.select.WithItem;
|
||||||
import org.apache.calcite.sql.SqlSelect;
|
|
||||||
import org.apache.calcite.sql.SqlWith;
|
|
||||||
import org.apache.calcite.sql.SqlWithItem;
|
|
||||||
import org.apache.calcite.sql.SqlWriterConfig;
|
|
||||||
import org.apache.calcite.sql.parser.SqlParseException;
|
|
||||||
import org.apache.calcite.sql.parser.SqlParser;
|
|
||||||
import org.apache.calcite.sql.parser.SqlParserPos;
|
|
||||||
import org.apache.calcite.sql.pretty.SqlPrettyWriter;
|
|
||||||
|
|
||||||
import java.util.ArrayList;
|
import java.util.ArrayList;
|
||||||
import java.util.List;
|
import java.util.List;
|
||||||
@@ -22,85 +14,37 @@ import java.util.List;
|
|||||||
@Slf4j
|
@Slf4j
|
||||||
public class SqlMergeWithUtils {
|
public class SqlMergeWithUtils {
|
||||||
public static String mergeWith(EngineType engineType, String sql, List<String> parentSqlList,
|
public static String mergeWith(EngineType engineType, String sql, List<String> parentSqlList,
|
||||||
List<String> parentWithNameList) throws SqlParseException {
|
List<String> parentWithNameList) throws Exception {
|
||||||
SqlParser.Config parserConfig = Configuration.getParserConfig(engineType);
|
|
||||||
|
|
||||||
// Parse the main SQL statement
|
Select selectStatement = (Select) CCJSqlParserUtil.parse(sql);
|
||||||
SqlParser parser = SqlParser.create(sql, parserConfig);
|
List<WithItem> withItemList = new ArrayList<>();
|
||||||
SqlNode sqlNode1 = parser.parseQuery();
|
|
||||||
|
|
||||||
// List to hold all WITH items
|
|
||||||
List<SqlNode> withItemList = new ArrayList<>();
|
|
||||||
|
|
||||||
// Iterate over each parentSql and parentWithName pair
|
|
||||||
for (int i = 0; i < parentSqlList.size(); i++) {
|
for (int i = 0; i < parentSqlList.size(); i++) {
|
||||||
String parentSql = parentSqlList.get(i);
|
String parentSql = parentSqlList.get(i);
|
||||||
String parentWithName = parentWithNameList.get(i);
|
String parentWithName = parentWithNameList.get(i);
|
||||||
|
|
||||||
// Parse the parent SQL statement
|
Select parentSelect = (Select) CCJSqlParserUtil.parse(parentSql);
|
||||||
parser = SqlParser.create(parentSql, parserConfig);
|
ParenthesedSelect select = new ParenthesedSelect();
|
||||||
SqlNode sqlNode2 = parser.parseQuery();
|
select.setSelect(parentSelect);
|
||||||
|
|
||||||
// Create a new WITH item for parentWithName without quotes
|
// Create a new WITH item for parentWithName without quotes
|
||||||
SqlWithItem withItem = new SqlWithItem(SqlParserPos.ZERO,
|
WithItem withItem = new WithItem();
|
||||||
new SqlIdentifier(parentWithName, SqlParserPos.ZERO), null, sqlNode2,
|
withItem.setAlias(new Alias(parentWithName));
|
||||||
SqlLiteral.createBoolean(false, SqlParserPos.ZERO));
|
withItem.setSelect(select);
|
||||||
|
|
||||||
// Add the new WITH item to the list
|
// Add the new WITH item to the list
|
||||||
withItemList.add(withItem);
|
withItemList.add(withItem);
|
||||||
}
|
}
|
||||||
|
|
||||||
// Check if the main SQL node contains an ORDER BY or LIMIT clause
|
// Extract existing WITH items from mainSelectBody if it has any
|
||||||
SqlNode limitNode = null;
|
if (selectStatement.getWithItemsList() != null) {
|
||||||
SqlNodeList orderByList = null;
|
withItemList.addAll(selectStatement.getWithItemsList());
|
||||||
if (sqlNode1 instanceof SqlOrderBy) {
|
|
||||||
SqlOrderBy sqlOrderBy = (SqlOrderBy) sqlNode1;
|
|
||||||
limitNode = sqlOrderBy.fetch;
|
|
||||||
orderByList = sqlOrderBy.orderList;
|
|
||||||
sqlNode1 = sqlOrderBy.query;
|
|
||||||
} else if (sqlNode1 instanceof SqlSelect) {
|
|
||||||
SqlSelect sqlSelect = (SqlSelect) sqlNode1;
|
|
||||||
limitNode = sqlSelect.getFetch();
|
|
||||||
sqlSelect.setFetch(null);
|
|
||||||
sqlNode1 = sqlSelect;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Extract existing WITH items from sqlNode1 if it is a SqlWith
|
// Set the new WITH items list to the main select body
|
||||||
if (sqlNode1 instanceof SqlWith) {
|
selectStatement.setWithItemsList(withItemList);
|
||||||
SqlWith sqlWith = (SqlWith) sqlNode1;
|
|
||||||
withItemList.addAll(sqlWith.withList.getList());
|
|
||||||
sqlNode1 = sqlWith.body;
|
|
||||||
}
|
|
||||||
|
|
||||||
// Create a new SqlWith node
|
|
||||||
SqlWith finalSqlNode = new SqlWith(SqlParserPos.ZERO,
|
|
||||||
new SqlNodeList(withItemList, SqlParserPos.ZERO), sqlNode1);
|
|
||||||
|
|
||||||
// If there was an ORDER BY or LIMIT clause, wrap the finalSqlNode in a SqlOrderBy
|
|
||||||
SqlNode resultNode = finalSqlNode;
|
|
||||||
if (orderByList != null || limitNode != null) {
|
|
||||||
resultNode = new SqlOrderBy(SqlParserPos.ZERO, finalSqlNode,
|
|
||||||
orderByList != null ? orderByList : SqlNodeList.EMPTY, null, limitNode);
|
|
||||||
}
|
|
||||||
|
|
||||||
// Custom SqlPrettyWriter configuration to avoid quoting identifiers
|
|
||||||
SqlWriterConfig config = Configuration.getSqlWriterConfig(engineType);
|
|
||||||
// Pretty print the final SQL
|
// Pretty print the final SQL
|
||||||
SqlPrettyWriter writer = new SqlPrettyWriter(config);
|
return selectStatement.toString();
|
||||||
return writer.format(resultNode);
|
|
||||||
}
|
|
||||||
|
|
||||||
public static boolean hasWith(EngineType engineType, String sql) throws SqlParseException {
|
|
||||||
SqlParser.Config parserConfig = Configuration.getParserConfig(engineType);
|
|
||||||
SqlParser parser = SqlParser.create(sql, parserConfig);
|
|
||||||
SqlNode sqlNode = parser.parseQuery();
|
|
||||||
SqlNode sqlSelect = sqlNode;
|
|
||||||
if (sqlNode instanceof SqlOrderBy) {
|
|
||||||
SqlOrderBy sqlOrderBy = (SqlOrderBy) sqlNode;
|
|
||||||
sqlSelect = sqlOrderBy.query;
|
|
||||||
} else if (sqlNode instanceof SqlSelect) {
|
|
||||||
sqlSelect = (SqlSelect) sqlNode;
|
|
||||||
}
|
|
||||||
return sqlSelect instanceof SqlWith;
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,9 +1,11 @@
|
|||||||
package com.tencent.supersonic.common.config;
|
package com.tencent.supersonic.common.config;
|
||||||
|
|
||||||
|
import com.google.common.collect.Lists;
|
||||||
import com.tencent.supersonic.common.pojo.ChatModelConfig;
|
import com.tencent.supersonic.common.pojo.ChatModelConfig;
|
||||||
import lombok.Data;
|
import lombok.Data;
|
||||||
|
|
||||||
import java.util.Date;
|
import java.util.Date;
|
||||||
|
import java.util.List;
|
||||||
|
|
||||||
@Data
|
@Data
|
||||||
public class ChatModel {
|
public class ChatModel {
|
||||||
@@ -25,5 +27,11 @@ public class ChatModel {
|
|||||||
|
|
||||||
private String admin;
|
private String admin;
|
||||||
|
|
||||||
private String viewer;
|
private List<String> viewers = Lists.newArrayList();
|
||||||
|
|
||||||
|
private Integer isOpen = 0;
|
||||||
|
|
||||||
|
public boolean isPublic() {
|
||||||
|
return isOpen != null && isOpen == 1;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -4,14 +4,10 @@ import com.google.common.collect.ImmutableMap;
|
|||||||
import com.google.common.collect.Lists;
|
import com.google.common.collect.Lists;
|
||||||
import com.tencent.supersonic.common.pojo.EmbeddingModelConfig;
|
import com.tencent.supersonic.common.pojo.EmbeddingModelConfig;
|
||||||
import com.tencent.supersonic.common.pojo.Parameter;
|
import com.tencent.supersonic.common.pojo.Parameter;
|
||||||
import dev.langchain4j.provider.AzureModelFactory;
|
|
||||||
import dev.langchain4j.provider.DashscopeModelFactory;
|
|
||||||
import dev.langchain4j.provider.EmbeddingModelConstant;
|
import dev.langchain4j.provider.EmbeddingModelConstant;
|
||||||
import dev.langchain4j.provider.InMemoryModelFactory;
|
import dev.langchain4j.provider.InMemoryModelFactory;
|
||||||
import dev.langchain4j.provider.OllamaModelFactory;
|
import dev.langchain4j.provider.OllamaModelFactory;
|
||||||
import dev.langchain4j.provider.OpenAiModelFactory;
|
import dev.langchain4j.provider.OpenAiModelFactory;
|
||||||
import dev.langchain4j.provider.QianfanModelFactory;
|
|
||||||
import dev.langchain4j.provider.ZhipuModelFactory;
|
|
||||||
import lombok.extern.slf4j.Slf4j;
|
import lombok.extern.slf4j.Slf4j;
|
||||||
import org.springframework.stereotype.Service;
|
import org.springframework.stereotype.Service;
|
||||||
|
|
||||||
@@ -70,52 +66,31 @@ public class EmbeddingModelParameterConfig extends ParameterConfig {
|
|||||||
|
|
||||||
private static ArrayList<String> getCandidateValues() {
|
private static ArrayList<String> getCandidateValues() {
|
||||||
return Lists.newArrayList(InMemoryModelFactory.PROVIDER, OpenAiModelFactory.PROVIDER,
|
return Lists.newArrayList(InMemoryModelFactory.PROVIDER, OpenAiModelFactory.PROVIDER,
|
||||||
OllamaModelFactory.PROVIDER, DashscopeModelFactory.PROVIDER,
|
OllamaModelFactory.PROVIDER);
|
||||||
QianfanModelFactory.PROVIDER, ZhipuModelFactory.PROVIDER,
|
|
||||||
AzureModelFactory.PROVIDER);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
private static List<Parameter.Dependency> getBaseUrlDependency() {
|
private static List<Parameter.Dependency> getBaseUrlDependency() {
|
||||||
return getDependency(EMBEDDING_MODEL_PROVIDER.getName(),
|
return getDependency(EMBEDDING_MODEL_PROVIDER.getName(),
|
||||||
Lists.newArrayList(OpenAiModelFactory.PROVIDER, OllamaModelFactory.PROVIDER,
|
Lists.newArrayList(OpenAiModelFactory.PROVIDER, OllamaModelFactory.PROVIDER),
|
||||||
AzureModelFactory.PROVIDER, DashscopeModelFactory.PROVIDER,
|
|
||||||
QianfanModelFactory.PROVIDER, ZhipuModelFactory.PROVIDER),
|
|
||||||
ImmutableMap.of(OpenAiModelFactory.PROVIDER, OpenAiModelFactory.DEFAULT_BASE_URL,
|
ImmutableMap.of(OpenAiModelFactory.PROVIDER, OpenAiModelFactory.DEFAULT_BASE_URL,
|
||||||
OllamaModelFactory.PROVIDER, OllamaModelFactory.DEFAULT_BASE_URL,
|
OllamaModelFactory.PROVIDER, OllamaModelFactory.DEFAULT_BASE_URL));
|
||||||
AzureModelFactory.PROVIDER, AzureModelFactory.DEFAULT_BASE_URL,
|
|
||||||
DashscopeModelFactory.PROVIDER, DashscopeModelFactory.DEFAULT_BASE_URL,
|
|
||||||
QianfanModelFactory.PROVIDER, QianfanModelFactory.DEFAULT_BASE_URL,
|
|
||||||
ZhipuModelFactory.PROVIDER, ZhipuModelFactory.DEFAULT_BASE_URL));
|
|
||||||
}
|
}
|
||||||
|
|
||||||
private static List<Parameter.Dependency> getApiKeyDependency() {
|
private static List<Parameter.Dependency> getApiKeyDependency() {
|
||||||
return getDependency(EMBEDDING_MODEL_PROVIDER.getName(),
|
return getDependency(EMBEDDING_MODEL_PROVIDER.getName(),
|
||||||
Lists.newArrayList(OpenAiModelFactory.PROVIDER, AzureModelFactory.PROVIDER,
|
Lists.newArrayList(OpenAiModelFactory.PROVIDER),
|
||||||
DashscopeModelFactory.PROVIDER, QianfanModelFactory.PROVIDER,
|
ImmutableMap.of(OpenAiModelFactory.PROVIDER, DEMO));
|
||||||
ZhipuModelFactory.PROVIDER),
|
|
||||||
ImmutableMap.of(OpenAiModelFactory.PROVIDER, DEMO, AzureModelFactory.PROVIDER, DEMO,
|
|
||||||
DashscopeModelFactory.PROVIDER, DEMO, QianfanModelFactory.PROVIDER, DEMO,
|
|
||||||
ZhipuModelFactory.PROVIDER, DEMO));
|
|
||||||
}
|
}
|
||||||
|
|
||||||
private static List<Parameter.Dependency> getModelNameDependency() {
|
private static List<Parameter.Dependency> getModelNameDependency() {
|
||||||
return getDependency(EMBEDDING_MODEL_PROVIDER.getName(),
|
return getDependency(EMBEDDING_MODEL_PROVIDER.getName(),
|
||||||
Lists.newArrayList(InMemoryModelFactory.PROVIDER, OpenAiModelFactory.PROVIDER,
|
Lists.newArrayList(InMemoryModelFactory.PROVIDER, OpenAiModelFactory.PROVIDER,
|
||||||
OllamaModelFactory.PROVIDER, AzureModelFactory.PROVIDER,
|
OllamaModelFactory.PROVIDER),
|
||||||
DashscopeModelFactory.PROVIDER, QianfanModelFactory.PROVIDER,
|
|
||||||
ZhipuModelFactory.PROVIDER),
|
|
||||||
ImmutableMap.of(InMemoryModelFactory.PROVIDER, EmbeddingModelConstant.BGE_SMALL_ZH,
|
ImmutableMap.of(InMemoryModelFactory.PROVIDER, EmbeddingModelConstant.BGE_SMALL_ZH,
|
||||||
OpenAiModelFactory.PROVIDER,
|
OpenAiModelFactory.PROVIDER,
|
||||||
OpenAiModelFactory.DEFAULT_EMBEDDING_MODEL_NAME,
|
OpenAiModelFactory.DEFAULT_EMBEDDING_MODEL_NAME,
|
||||||
OllamaModelFactory.PROVIDER,
|
OllamaModelFactory.PROVIDER,
|
||||||
OllamaModelFactory.DEFAULT_EMBEDDING_MODEL_NAME, AzureModelFactory.PROVIDER,
|
OllamaModelFactory.DEFAULT_EMBEDDING_MODEL_NAME));
|
||||||
AzureModelFactory.DEFAULT_EMBEDDING_MODEL_NAME,
|
|
||||||
DashscopeModelFactory.PROVIDER,
|
|
||||||
DashscopeModelFactory.DEFAULT_EMBEDDING_MODEL_NAME,
|
|
||||||
QianfanModelFactory.PROVIDER,
|
|
||||||
QianfanModelFactory.DEFAULT_EMBEDDING_MODEL_NAME,
|
|
||||||
ZhipuModelFactory.PROVIDER,
|
|
||||||
ZhipuModelFactory.DEFAULT_EMBEDDING_MODEL_NAME));
|
|
||||||
}
|
}
|
||||||
|
|
||||||
private static List<Parameter.Dependency> getModelPathDependency() {
|
private static List<Parameter.Dependency> getModelPathDependency() {
|
||||||
@@ -126,7 +101,7 @@ public class EmbeddingModelParameterConfig extends ParameterConfig {
|
|||||||
|
|
||||||
private static List<Parameter.Dependency> getSecretKeyDependency() {
|
private static List<Parameter.Dependency> getSecretKeyDependency() {
|
||||||
return getDependency(EMBEDDING_MODEL_PROVIDER.getName(),
|
return getDependency(EMBEDDING_MODEL_PROVIDER.getName(),
|
||||||
Lists.newArrayList(QianfanModelFactory.PROVIDER),
|
Lists.newArrayList(OpenAiModelFactory.PROVIDER),
|
||||||
ImmutableMap.of(QianfanModelFactory.PROVIDER, DEMO));
|
ImmutableMap.of(OpenAiModelFactory.PROVIDER, DEMO));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -91,7 +91,8 @@ public class FiledFilterReplaceVisitor extends ExpressionVisitorAdapter {
|
|||||||
}
|
}
|
||||||
|
|
||||||
ExpressionList<?> leftFunctionParams = leftFunction.getParameters();
|
ExpressionList<?> leftFunctionParams = leftFunction.getParameters();
|
||||||
if (CollectionUtils.isEmpty(leftFunctionParams)) {
|
if (CollectionUtils.isEmpty(leftFunctionParams)
|
||||||
|
|| !(leftFunctionParams.get(0) instanceof Column)) {
|
||||||
return result;
|
return result;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -146,6 +146,10 @@ public class SqlReplaceHelper {
|
|||||||
public static String replaceFields(String sql, Map<String, String> fieldNameMap,
|
public static String replaceFields(String sql, Map<String, String> fieldNameMap,
|
||||||
boolean exactReplace) {
|
boolean exactReplace) {
|
||||||
Select selectStatement = SqlSelectHelper.getSelect(sql);
|
Select selectStatement = SqlSelectHelper.getSelect(sql);
|
||||||
|
// alias field should not be replaced
|
||||||
|
Set<String> aliases = SqlSelectHelper.getAliasFields(sql);
|
||||||
|
aliases.forEach(alias -> fieldNameMap.put(alias, alias));
|
||||||
|
|
||||||
Set<Select> plainSelectList = SqlSelectHelper.getAllSelect(selectStatement);
|
Set<Select> plainSelectList = SqlSelectHelper.getAllSelect(selectStatement);
|
||||||
for (Select plainSelect : plainSelectList) {
|
for (Select plainSelect : plainSelectList) {
|
||||||
if (plainSelect instanceof PlainSelect) {
|
if (plainSelect instanceof PlainSelect) {
|
||||||
|
|||||||
@@ -989,6 +989,15 @@ public class SqlSelectHelper {
|
|||||||
for (SelectItem selectItem : selectItems) {
|
for (SelectItem selectItem : selectItems) {
|
||||||
selectItem.accept(visitor);
|
selectItem.accept(visitor);
|
||||||
}
|
}
|
||||||
|
if (plainSelect.getHaving() != null) {
|
||||||
|
plainSelect.getHaving().accept(visitor);
|
||||||
|
}
|
||||||
|
if (!CollectionUtils.isEmpty(plainSelect.getOrderByElements())) {
|
||||||
|
for (OrderByElement orderByElement : plainSelect.getOrderByElements()) {
|
||||||
|
orderByElement.getExpression().accept(visitor);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
return !visitor.getFunctionNames().isEmpty();
|
return !visitor.getFunctionNames().isEmpty();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -30,4 +30,6 @@ public class ChatModelDO {
|
|||||||
private String admin;
|
private String admin;
|
||||||
|
|
||||||
private String viewer;
|
private String viewer;
|
||||||
|
|
||||||
|
private Integer isOpen;
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -2,15 +2,7 @@ package com.tencent.supersonic.common.pojo;
|
|||||||
|
|
||||||
import com.google.common.collect.ImmutableMap;
|
import com.google.common.collect.ImmutableMap;
|
||||||
import com.google.common.collect.Lists;
|
import com.google.common.collect.Lists;
|
||||||
import dev.langchain4j.provider.AzureModelFactory;
|
import dev.langchain4j.provider.*;
|
||||||
import dev.langchain4j.provider.DashscopeModelFactory;
|
|
||||||
import dev.langchain4j.provider.DifyModelFactory;
|
|
||||||
import dev.langchain4j.provider.LocalAiModelFactory;
|
|
||||||
import dev.langchain4j.provider.ModelProvider;
|
|
||||||
import dev.langchain4j.provider.OllamaModelFactory;
|
|
||||||
import dev.langchain4j.provider.OpenAiModelFactory;
|
|
||||||
import dev.langchain4j.provider.QianfanModelFactory;
|
|
||||||
import dev.langchain4j.provider.ZhipuModelFactory;
|
|
||||||
|
|
||||||
import java.util.ArrayList;
|
import java.util.ArrayList;
|
||||||
import java.util.List;
|
import java.util.List;
|
||||||
@@ -21,7 +13,7 @@ public class ChatModelParameters {
|
|||||||
|
|
||||||
public static final Parameter CHAT_MODEL_PROVIDER =
|
public static final Parameter CHAT_MODEL_PROVIDER =
|
||||||
new Parameter("provider", ModelProvider.DEMO_CHAT_MODEL.getProvider(), "接口协议", "",
|
new Parameter("provider", ModelProvider.DEMO_CHAT_MODEL.getProvider(), "接口协议", "",
|
||||||
"list", MODULE_NAME, getCandidateValues());
|
"list", MODULE_NAME, getCandidateProviders());
|
||||||
|
|
||||||
public static final Parameter CHAT_MODEL_BASE_URL =
|
public static final Parameter CHAT_MODEL_BASE_URL =
|
||||||
new Parameter("baseUrl", ModelProvider.DEMO_CHAT_MODEL.getBaseUrl(), "BaseUrl", "",
|
new Parameter("baseUrl", ModelProvider.DEMO_CHAT_MODEL.getBaseUrl(), "BaseUrl", "",
|
||||||
@@ -37,15 +29,6 @@ public class ChatModelParameters {
|
|||||||
public static final Parameter CHAT_MODEL_API_VERSION = new Parameter("apiVersion", "2024-02-01",
|
public static final Parameter CHAT_MODEL_API_VERSION = new Parameter("apiVersion", "2024-02-01",
|
||||||
"ApiVersion", "", "string", MODULE_NAME, null, getApiVersionDependency());
|
"ApiVersion", "", "string", MODULE_NAME, null, getApiVersionDependency());
|
||||||
|
|
||||||
public static final Parameter CHAT_MODEL_ENDPOINT = new Parameter("endpoint", "llama_2_70b",
|
|
||||||
"Endpoint", "", "string", MODULE_NAME, null, getEndpointDependency());
|
|
||||||
|
|
||||||
public static final Parameter CHAT_MODEL_SECRET_KEY = new Parameter("secretKey", "demo",
|
|
||||||
"SecretKey", "", "password", MODULE_NAME, null, getSecretKeyDependency());
|
|
||||||
|
|
||||||
public static final Parameter CHAT_MODEL_ENABLE_SEARCH = new Parameter("enableSearch", "false",
|
|
||||||
"是否启用搜索增强功能,设为false表示不启用", "", "bool", MODULE_NAME, null, getEnableSearchDependency());
|
|
||||||
|
|
||||||
public static final Parameter CHAT_MODEL_TEMPERATURE =
|
public static final Parameter CHAT_MODEL_TEMPERATURE =
|
||||||
new Parameter("temperature", "0.0", "Temperature", "", "slider", MODULE_NAME);
|
new Parameter("temperature", "0.0", "Temperature", "", "slider", MODULE_NAME);
|
||||||
|
|
||||||
@@ -53,42 +36,27 @@ public class ChatModelParameters {
|
|||||||
new Parameter("timeOut", "60", "超时时间(秒)", "", "number", MODULE_NAME);
|
new Parameter("timeOut", "60", "超时时间(秒)", "", "number", MODULE_NAME);
|
||||||
|
|
||||||
public static List<Parameter> getParameters() {
|
public static List<Parameter> getParameters() {
|
||||||
return Lists.newArrayList(CHAT_MODEL_PROVIDER, CHAT_MODEL_BASE_URL, CHAT_MODEL_ENDPOINT,
|
return Lists.newArrayList(CHAT_MODEL_PROVIDER, CHAT_MODEL_BASE_URL, CHAT_MODEL_API_KEY,
|
||||||
CHAT_MODEL_API_KEY, CHAT_MODEL_SECRET_KEY, CHAT_MODEL_NAME, CHAT_MODEL_API_VERSION,
|
CHAT_MODEL_NAME, CHAT_MODEL_API_VERSION, CHAT_MODEL_TEMPERATURE,
|
||||||
CHAT_MODEL_ENABLE_SEARCH, CHAT_MODEL_TEMPERATURE, CHAT_MODEL_TIMEOUT);
|
CHAT_MODEL_TIMEOUT);
|
||||||
}
|
}
|
||||||
|
|
||||||
private static List<String> getCandidateValues() {
|
private static List<String> getCandidateProviders() {
|
||||||
return Lists.newArrayList(OpenAiModelFactory.PROVIDER, OllamaModelFactory.PROVIDER,
|
return Lists.newArrayList(OpenAiModelFactory.PROVIDER, OllamaModelFactory.PROVIDER,
|
||||||
QianfanModelFactory.PROVIDER, ZhipuModelFactory.PROVIDER,
|
DifyModelFactory.PROVIDER);
|
||||||
LocalAiModelFactory.PROVIDER, DashscopeModelFactory.PROVIDER,
|
|
||||||
AzureModelFactory.PROVIDER, DifyModelFactory.PROVIDER);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
private static List<Parameter.Dependency> getBaseUrlDependency() {
|
private static List<Parameter.Dependency> getBaseUrlDependency() {
|
||||||
return getDependency(CHAT_MODEL_PROVIDER.getName(), getCandidateValues(),
|
return getDependency(CHAT_MODEL_PROVIDER.getName(), getCandidateProviders(),
|
||||||
ImmutableMap.of(OpenAiModelFactory.PROVIDER, OpenAiModelFactory.DEFAULT_BASE_URL,
|
ImmutableMap.of(OpenAiModelFactory.PROVIDER, OpenAiModelFactory.DEFAULT_BASE_URL,
|
||||||
AzureModelFactory.PROVIDER, AzureModelFactory.DEFAULT_BASE_URL,
|
|
||||||
OllamaModelFactory.PROVIDER, OllamaModelFactory.DEFAULT_BASE_URL,
|
OllamaModelFactory.PROVIDER, OllamaModelFactory.DEFAULT_BASE_URL,
|
||||||
QianfanModelFactory.PROVIDER, QianfanModelFactory.DEFAULT_BASE_URL,
|
|
||||||
ZhipuModelFactory.PROVIDER, ZhipuModelFactory.DEFAULT_BASE_URL,
|
|
||||||
LocalAiModelFactory.PROVIDER, LocalAiModelFactory.DEFAULT_BASE_URL,
|
|
||||||
DashscopeModelFactory.PROVIDER, DashscopeModelFactory.DEFAULT_BASE_URL,
|
|
||||||
DifyModelFactory.PROVIDER, DifyModelFactory.DEFAULT_BASE_URL));
|
DifyModelFactory.PROVIDER, DifyModelFactory.DEFAULT_BASE_URL));
|
||||||
}
|
}
|
||||||
|
|
||||||
private static List<Parameter.Dependency> getApiKeyDependency() {
|
private static List<Parameter.Dependency> getApiKeyDependency() {
|
||||||
return getDependency(CHAT_MODEL_PROVIDER.getName(),
|
return getDependency(CHAT_MODEL_PROVIDER.getName(),
|
||||||
Lists.newArrayList(OpenAiModelFactory.PROVIDER, QianfanModelFactory.PROVIDER,
|
Lists.newArrayList(OpenAiModelFactory.PROVIDER, DifyModelFactory.PROVIDER),
|
||||||
ZhipuModelFactory.PROVIDER, LocalAiModelFactory.PROVIDER,
|
|
||||||
AzureModelFactory.PROVIDER, DashscopeModelFactory.PROVIDER,
|
|
||||||
DifyModelFactory.PROVIDER),
|
|
||||||
ImmutableMap.of(OpenAiModelFactory.PROVIDER,
|
ImmutableMap.of(OpenAiModelFactory.PROVIDER,
|
||||||
ModelProvider.DEMO_CHAT_MODEL.getApiKey(), QianfanModelFactory.PROVIDER,
|
|
||||||
ModelProvider.DEMO_CHAT_MODEL.getApiKey(), ZhipuModelFactory.PROVIDER,
|
|
||||||
ModelProvider.DEMO_CHAT_MODEL.getApiKey(), LocalAiModelFactory.PROVIDER,
|
|
||||||
ModelProvider.DEMO_CHAT_MODEL.getApiKey(), AzureModelFactory.PROVIDER,
|
|
||||||
ModelProvider.DEMO_CHAT_MODEL.getApiKey(), DashscopeModelFactory.PROVIDER,
|
|
||||||
ModelProvider.DEMO_CHAT_MODEL.getApiKey(), DifyModelFactory.PROVIDER,
|
ModelProvider.DEMO_CHAT_MODEL.getApiKey(), DifyModelFactory.PROVIDER,
|
||||||
ModelProvider.DEMO_CHAT_MODEL.getApiKey()));
|
ModelProvider.DEMO_CHAT_MODEL.getApiKey()));
|
||||||
}
|
}
|
||||||
@@ -100,33 +68,28 @@ public class ChatModelParameters {
|
|||||||
}
|
}
|
||||||
|
|
||||||
private static List<Parameter.Dependency> getModelNameDependency() {
|
private static List<Parameter.Dependency> getModelNameDependency() {
|
||||||
return getDependency(CHAT_MODEL_PROVIDER.getName(), getCandidateValues(),
|
return getDependency(CHAT_MODEL_PROVIDER.getName(), getCandidateProviders(),
|
||||||
ImmutableMap.of(OpenAiModelFactory.PROVIDER, OpenAiModelFactory.DEFAULT_MODEL_NAME,
|
ImmutableMap.of(OpenAiModelFactory.PROVIDER, OpenAiModelFactory.DEFAULT_MODEL_NAME,
|
||||||
OllamaModelFactory.PROVIDER, OllamaModelFactory.DEFAULT_MODEL_NAME,
|
OllamaModelFactory.PROVIDER, OllamaModelFactory.DEFAULT_MODEL_NAME,
|
||||||
QianfanModelFactory.PROVIDER, QianfanModelFactory.DEFAULT_MODEL_NAME,
|
|
||||||
ZhipuModelFactory.PROVIDER, ZhipuModelFactory.DEFAULT_MODEL_NAME,
|
|
||||||
LocalAiModelFactory.PROVIDER, LocalAiModelFactory.DEFAULT_MODEL_NAME,
|
|
||||||
AzureModelFactory.PROVIDER, AzureModelFactory.DEFAULT_MODEL_NAME,
|
|
||||||
DashscopeModelFactory.PROVIDER, DashscopeModelFactory.DEFAULT_MODEL_NAME,
|
|
||||||
DifyModelFactory.PROVIDER, DifyModelFactory.DEFAULT_MODEL_NAME));
|
DifyModelFactory.PROVIDER, DifyModelFactory.DEFAULT_MODEL_NAME));
|
||||||
}
|
}
|
||||||
|
|
||||||
private static List<Parameter.Dependency> getEndpointDependency() {
|
private static List<Parameter.Dependency> getEndpointDependency() {
|
||||||
return getDependency(CHAT_MODEL_PROVIDER.getName(),
|
return getDependency(CHAT_MODEL_PROVIDER.getName(),
|
||||||
Lists.newArrayList(QianfanModelFactory.PROVIDER), ImmutableMap
|
Lists.newArrayList(OpenAiModelFactory.PROVIDER), ImmutableMap
|
||||||
.of(QianfanModelFactory.PROVIDER, QianfanModelFactory.DEFAULT_ENDPOINT));
|
.of(OpenAiModelFactory.PROVIDER, OpenAiModelFactory.DEFAULT_MODEL_NAME));
|
||||||
}
|
}
|
||||||
|
|
||||||
private static List<Parameter.Dependency> getEnableSearchDependency() {
|
private static List<Parameter.Dependency> getEnableSearchDependency() {
|
||||||
return getDependency(CHAT_MODEL_PROVIDER.getName(),
|
return getDependency(CHAT_MODEL_PROVIDER.getName(),
|
||||||
Lists.newArrayList(DashscopeModelFactory.PROVIDER),
|
Lists.newArrayList(OpenAiModelFactory.PROVIDER),
|
||||||
ImmutableMap.of(DashscopeModelFactory.PROVIDER, "false"));
|
ImmutableMap.of(OpenAiModelFactory.PROVIDER, "false"));
|
||||||
}
|
}
|
||||||
|
|
||||||
private static List<Parameter.Dependency> getSecretKeyDependency() {
|
private static List<Parameter.Dependency> getSecretKeyDependency() {
|
||||||
return getDependency(CHAT_MODEL_PROVIDER.getName(),
|
return getDependency(CHAT_MODEL_PROVIDER.getName(),
|
||||||
Lists.newArrayList(QianfanModelFactory.PROVIDER), ImmutableMap.of(
|
Lists.newArrayList(OpenAiModelFactory.PROVIDER), ImmutableMap.of(
|
||||||
QianfanModelFactory.PROVIDER, ModelProvider.DEMO_CHAT_MODEL.getApiKey()));
|
OpenAiModelFactory.PROVIDER, ModelProvider.DEMO_CHAT_MODEL.getApiKey()));
|
||||||
}
|
}
|
||||||
|
|
||||||
private static List<Parameter.Dependency> getDependency(String dependencyParameterName,
|
private static List<Parameter.Dependency> getDependency(String dependencyParameterName,
|
||||||
|
|||||||
@@ -3,6 +3,7 @@ package com.tencent.supersonic.common.pojo;
|
|||||||
import com.google.common.base.Objects;
|
import com.google.common.base.Objects;
|
||||||
import jakarta.validation.constraints.NotBlank;
|
import jakarta.validation.constraints.NotBlank;
|
||||||
import lombok.Data;
|
import lombok.Data;
|
||||||
|
|
||||||
import java.io.Serializable;
|
import java.io.Serializable;
|
||||||
|
|
||||||
import static com.tencent.supersonic.common.pojo.Constants.ASC_UPPER;
|
import static com.tencent.supersonic.common.pojo.Constants.ASC_UPPER;
|
||||||
|
|||||||
@@ -22,4 +22,6 @@ public class Text2SQLExemplar implements Serializable {
|
|||||||
private String dbSchema;
|
private String dbSchema;
|
||||||
|
|
||||||
private String sql;
|
private String sql;
|
||||||
|
|
||||||
|
protected double similarity; // 传递相似度,可以作为样本筛选的依据
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -6,6 +6,7 @@ import lombok.NoArgsConstructor;
|
|||||||
import org.apache.commons.lang3.StringUtils;
|
import org.apache.commons.lang3.StringUtils;
|
||||||
|
|
||||||
import java.io.Serializable;
|
import java.io.Serializable;
|
||||||
|
import java.sql.Timestamp;
|
||||||
|
|
||||||
@Data
|
@Data
|
||||||
@NoArgsConstructor
|
@NoArgsConstructor
|
||||||
@@ -22,26 +23,28 @@ public class User implements Serializable {
|
|||||||
|
|
||||||
private Integer isAdmin;
|
private Integer isAdmin;
|
||||||
|
|
||||||
|
private Timestamp lastLogin;
|
||||||
|
|
||||||
public static User get(Long id, String name, String displayName, String email,
|
public static User get(Long id, String name, String displayName, String email,
|
||||||
Integer isAdmin) {
|
Integer isAdmin) {
|
||||||
return new User(id, name, displayName, email, isAdmin);
|
return new User(id, name, displayName, email, isAdmin, null);
|
||||||
}
|
}
|
||||||
|
|
||||||
public static User get(Long id, String name) {
|
public static User get(Long id, String name) {
|
||||||
return new User(id, name, name, name, 0);
|
return new User(id, name, name, name, 0, null);
|
||||||
}
|
}
|
||||||
|
|
||||||
public static User getDefaultUser() {
|
public static User getDefaultUser() {
|
||||||
return new User(1L, "admin", "admin", "admin@email", 1);
|
return new User(1L, "admin", "admin", "admin@email", 1, null);
|
||||||
}
|
}
|
||||||
|
|
||||||
public static User getVisitUser() {
|
public static User getVisitUser() {
|
||||||
return new User(1L, "visit", "visit", "visit@email", 0);
|
return new User(1L, "visit", "visit", "visit@email", 0, null);
|
||||||
}
|
}
|
||||||
|
|
||||||
public static User getAppUser(int appId) {
|
public static User getAppUser(int appId) {
|
||||||
String name = String.format("app_%s", appId);
|
String name = String.format("app_%s", appId);
|
||||||
return new User(1L, name, name, "", 1);
|
return new User(1L, name, name, "", 1, null);
|
||||||
}
|
}
|
||||||
|
|
||||||
public String getDisplayName() {
|
public String getDisplayName() {
|
||||||
|
|||||||
@@ -13,7 +13,8 @@ public enum EngineType {
|
|||||||
STARROCKS(10, "STARROCKS"),
|
STARROCKS(10, "STARROCKS"),
|
||||||
KYUUBI(11, "KYUUBI"),
|
KYUUBI(11, "KYUUBI"),
|
||||||
PRESTO(12, "PRESTO"),
|
PRESTO(12, "PRESTO"),
|
||||||
TRINO(13, "TRINO"),;
|
TRINO(13, "TRINO"),
|
||||||
|
ORACLE(14, "ORACLE");
|
||||||
|
|
||||||
private Integer code;
|
private Integer code;
|
||||||
|
|
||||||
|
|||||||
@@ -7,7 +7,7 @@ import com.tencent.supersonic.common.pojo.User;
|
|||||||
import java.util.List;
|
import java.util.List;
|
||||||
|
|
||||||
public interface ChatModelService {
|
public interface ChatModelService {
|
||||||
List<ChatModel> getChatModels();
|
List<ChatModel> getChatModels(User user);
|
||||||
|
|
||||||
ChatModel getChatModel(Integer id);
|
ChatModel getChatModel(Integer id);
|
||||||
|
|
||||||
|
|||||||
@@ -14,6 +14,7 @@ import org.apache.commons.lang3.StringUtils;
|
|||||||
import org.springframework.beans.BeanUtils;
|
import org.springframework.beans.BeanUtils;
|
||||||
import org.springframework.stereotype.Service;
|
import org.springframework.stereotype.Service;
|
||||||
|
|
||||||
|
import java.util.Comparator;
|
||||||
import java.util.Date;
|
import java.util.Date;
|
||||||
import java.util.List;
|
import java.util.List;
|
||||||
import java.util.stream.Collectors;
|
import java.util.stream.Collectors;
|
||||||
@@ -23,8 +24,15 @@ import java.util.stream.Collectors;
|
|||||||
public class ChatModelServiceImpl extends ServiceImpl<ChatModelMapper, ChatModelDO>
|
public class ChatModelServiceImpl extends ServiceImpl<ChatModelMapper, ChatModelDO>
|
||||||
implements ChatModelService {
|
implements ChatModelService {
|
||||||
@Override
|
@Override
|
||||||
public List<ChatModel> getChatModels() {
|
public List<ChatModel> getChatModels(User user) {
|
||||||
return list().stream().map(this::convert).collect(Collectors.toList());
|
return list().stream().map(this::convert).filter(chatModel -> {
|
||||||
|
if (chatModel.isPublic() || user.isSuperAdmin()
|
||||||
|
|| chatModel.getCreatedBy().equals(user.getName())
|
||||||
|
|| chatModel.getViewers().contains(user.getName())) {
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
return false;
|
||||||
|
}).sorted(Comparator.comparingLong(ChatModel::getId)).collect(Collectors.toList());
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
@@ -41,10 +49,14 @@ public class ChatModelServiceImpl extends ServiceImpl<ChatModelMapper, ChatModel
|
|||||||
chatModelDO.setCreatedBy(user.getName());
|
chatModelDO.setCreatedBy(user.getName());
|
||||||
chatModelDO.setCreatedAt(new Date());
|
chatModelDO.setCreatedAt(new Date());
|
||||||
chatModelDO.setUpdatedBy(user.getName());
|
chatModelDO.setUpdatedBy(user.getName());
|
||||||
chatModelDO.setUpdatedAt(new Date());
|
chatModelDO.setUpdatedAt(chatModelDO.getCreatedAt());
|
||||||
|
chatModelDO.setIsOpen(chatModel.getIsOpen());
|
||||||
if (StringUtils.isBlank(chatModel.getAdmin())) {
|
if (StringUtils.isBlank(chatModel.getAdmin())) {
|
||||||
chatModelDO.setAdmin(user.getName());
|
chatModelDO.setAdmin(user.getName());
|
||||||
}
|
}
|
||||||
|
if (!chatModel.getViewers().isEmpty()) {
|
||||||
|
chatModelDO.setViewer(JsonUtil.toString(chatModel.getViewers()));
|
||||||
|
}
|
||||||
save(chatModelDO);
|
save(chatModelDO);
|
||||||
chatModel.setId(chatModelDO.getId());
|
chatModel.setId(chatModelDO.getId());
|
||||||
return chatModel;
|
return chatModel;
|
||||||
@@ -55,9 +67,13 @@ public class ChatModelServiceImpl extends ServiceImpl<ChatModelMapper, ChatModel
|
|||||||
ChatModelDO chatModelDO = convert(chatModel);
|
ChatModelDO chatModelDO = convert(chatModel);
|
||||||
chatModelDO.setUpdatedBy(user.getName());
|
chatModelDO.setUpdatedBy(user.getName());
|
||||||
chatModelDO.setUpdatedAt(new Date());
|
chatModelDO.setUpdatedAt(new Date());
|
||||||
|
chatModelDO.setIsOpen(chatModel.getIsOpen());
|
||||||
if (StringUtils.isBlank(chatModel.getAdmin())) {
|
if (StringUtils.isBlank(chatModel.getAdmin())) {
|
||||||
chatModel.setAdmin(user.getName());
|
chatModel.setAdmin(user.getName());
|
||||||
}
|
}
|
||||||
|
if (!chatModel.getViewers().isEmpty()) {
|
||||||
|
chatModelDO.setViewer(JsonUtil.toString(chatModel.getViewers()));
|
||||||
|
}
|
||||||
updateById(chatModelDO);
|
updateById(chatModelDO);
|
||||||
return chatModel;
|
return chatModel;
|
||||||
}
|
}
|
||||||
@@ -74,6 +90,7 @@ public class ChatModelServiceImpl extends ServiceImpl<ChatModelMapper, ChatModel
|
|||||||
ChatModel chatModel = new ChatModel();
|
ChatModel chatModel = new ChatModel();
|
||||||
BeanUtils.copyProperties(chatModelDO, chatModel);
|
BeanUtils.copyProperties(chatModelDO, chatModel);
|
||||||
chatModel.setConfig(JsonUtil.toObject(chatModelDO.getConfig(), ChatModelConfig.class));
|
chatModel.setConfig(JsonUtil.toObject(chatModelDO.getConfig(), ChatModelConfig.class));
|
||||||
|
chatModel.setViewers(JsonUtil.toList(chatModelDO.getViewer(), String.class));
|
||||||
return chatModel;
|
return chatModel;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -49,10 +49,10 @@ public class EmbeddingServiceImpl implements EmbeddingService {
|
|||||||
try {
|
try {
|
||||||
EmbeddingModel embeddingModel = ModelProvider.getEmbeddingModel();
|
EmbeddingModel embeddingModel = ModelProvider.getEmbeddingModel();
|
||||||
Embedding embedding = embeddingModel.embed(question).content();
|
Embedding embedding = embeddingModel.embed(question).content();
|
||||||
boolean existSegment = existSegment(embeddingStore, query, embedding);
|
MetadataFilterBuilder filterBuilder =
|
||||||
if (existSegment) {
|
new MetadataFilterBuilder(TextSegmentConvert.QUERY_ID);
|
||||||
continue;
|
Filter filter = filterBuilder.isEqualTo(TextSegmentConvert.getQueryId(query));
|
||||||
}
|
embeddingStore.removeAll(filter);
|
||||||
embeddingStore.add(embedding, query);
|
embeddingStore.add(embedding, query);
|
||||||
cache.put(TextSegmentConvert.getQueryId(query), true);
|
cache.put(TextSegmentConvert.getQueryId(query), true);
|
||||||
} catch (Exception e) {
|
} catch (Exception e) {
|
||||||
@@ -62,14 +62,14 @@ public class EmbeddingServiceImpl implements EmbeddingService {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
private boolean existSegment(EmbeddingStore embeddingStore, TextSegment query,
|
private boolean existSegment(String collectionName, EmbeddingStore embeddingStore,
|
||||||
Embedding embedding) {
|
TextSegment query, Embedding embedding) {
|
||||||
String queryId = TextSegmentConvert.getQueryId(query);
|
String queryId = TextSegmentConvert.getQueryId(query);
|
||||||
if (queryId == null) {
|
if (queryId == null) {
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
// Check cache first
|
// Check cache first
|
||||||
Boolean cachedResult = cache.getIfPresent(queryId);
|
Boolean cachedResult = cache.getIfPresent(collectionName + queryId);
|
||||||
if (cachedResult != null) {
|
if (cachedResult != null) {
|
||||||
return cachedResult;
|
return cachedResult;
|
||||||
}
|
}
|
||||||
@@ -82,7 +82,7 @@ public class EmbeddingServiceImpl implements EmbeddingService {
|
|||||||
EmbeddingSearchResult result = embeddingStore.search(request);
|
EmbeddingSearchResult result = embeddingStore.search(request);
|
||||||
List<EmbeddingMatch<TextSegment>> relevant = result.matches();
|
List<EmbeddingMatch<TextSegment>> relevant = result.matches();
|
||||||
boolean exists = CollectionUtils.isNotEmpty(relevant);
|
boolean exists = CollectionUtils.isNotEmpty(relevant);
|
||||||
cache.put(queryId, exists);
|
cache.put(collectionName + queryId, exists);
|
||||||
return exists;
|
return exists;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -72,7 +72,10 @@ public class ExemplarServiceImpl implements ExemplarService, CommandLineRunner {
|
|||||||
embeddingService.retrieveQuery(collection, retrieveQuery, num);
|
embeddingService.retrieveQuery(collection, retrieveQuery, num);
|
||||||
results.forEach(ret -> {
|
results.forEach(ret -> {
|
||||||
ret.getRetrieval().forEach(r -> {
|
ret.getRetrieval().forEach(r -> {
|
||||||
exemplars.add(JsonUtil.mapToObject(r.getMetadata(), Text2SQLExemplar.class));
|
Text2SQLExemplar tmp = // 传递相似度,可以作为样本筛选的依据
|
||||||
|
JsonUtil.mapToObject(r.getMetadata(), Text2SQLExemplar.class);
|
||||||
|
tmp.setSimilarity(r.getSimilarity());
|
||||||
|
exemplars.add(tmp);
|
||||||
});
|
});
|
||||||
});
|
});
|
||||||
|
|
||||||
|
|||||||
@@ -242,10 +242,8 @@ public class DateModeUtils {
|
|||||||
return String.format("%s >= '%s' and %s <= '%s'", dateField,
|
return String.format("%s >= '%s' and %s <= '%s'", dateField,
|
||||||
dateInfo.getStartDate(), dateField, dateInfo.getEndDate());
|
dateInfo.getStartDate(), dateField, dateInfo.getEndDate());
|
||||||
}
|
}
|
||||||
LocalDate endData =
|
LocalDate endData = DateUtils.parseDate(dateInfo.getEndDate());
|
||||||
LocalDate.parse(dateInfo.getEndDate(), DateTimeFormatter.ofPattern(DAY_FORMAT));
|
LocalDate startData = DateUtils.parseDate(dateInfo.getStartDate());
|
||||||
LocalDate startData = LocalDate.parse(dateInfo.getStartDate(),
|
|
||||||
DateTimeFormatter.ofPattern(DAY_FORMAT));
|
|
||||||
DateTimeFormatter formatter = DateTimeFormatter.ofPattern(MONTH_FORMAT);
|
DateTimeFormatter formatter = DateTimeFormatter.ofPattern(MONTH_FORMAT);
|
||||||
return String.format("%s >= '%s' and %s <= '%s'", dateField,
|
return String.format("%s >= '%s' and %s <= '%s'", dateField,
|
||||||
startData.format(formatter), dateField, endData.format(formatter));
|
startData.format(formatter), dateField, endData.format(formatter));
|
||||||
@@ -320,7 +318,7 @@ public class DateModeUtils {
|
|||||||
}
|
}
|
||||||
|
|
||||||
public String getDateWhereStr(DateConf dateInfo, ItemDateResp dateDate) {
|
public String getDateWhereStr(DateConf dateInfo, ItemDateResp dateDate) {
|
||||||
if (Objects.isNull(dateInfo)) {
|
if (Objects.isNull(dateInfo) || Objects.isNull(dateInfo.getDateField())) {
|
||||||
return "";
|
return "";
|
||||||
}
|
}
|
||||||
String dateStr = "";
|
String dateStr = "";
|
||||||
|
|||||||
@@ -75,7 +75,7 @@ public class DateUtils {
|
|||||||
}
|
}
|
||||||
|
|
||||||
public static String getBeforeDate(String currentDate, DatePeriodEnum datePeriodEnum) {
|
public static String getBeforeDate(String currentDate, DatePeriodEnum datePeriodEnum) {
|
||||||
LocalDate specifiedDate = LocalDate.parse(currentDate, DEFAULT_DATE_FORMATTER2);
|
LocalDate specifiedDate = parseDate(currentDate);
|
||||||
LocalDate startDate;
|
LocalDate startDate;
|
||||||
switch (datePeriodEnum) {
|
switch (datePeriodEnum) {
|
||||||
case MONTH:
|
case MONTH:
|
||||||
@@ -93,7 +93,7 @@ public class DateUtils {
|
|||||||
|
|
||||||
public static String getBeforeDate(String currentDate, int intervalDay,
|
public static String getBeforeDate(String currentDate, int intervalDay,
|
||||||
DatePeriodEnum datePeriodEnum) {
|
DatePeriodEnum datePeriodEnum) {
|
||||||
LocalDate specifiedDate = LocalDate.parse(currentDate, DEFAULT_DATE_FORMATTER2);
|
LocalDate specifiedDate = parseDate(currentDate);
|
||||||
LocalDate result = null;
|
LocalDate result = null;
|
||||||
switch (datePeriodEnum) {
|
switch (datePeriodEnum) {
|
||||||
case DAY:
|
case DAY:
|
||||||
@@ -161,11 +161,25 @@ public class DateUtils {
|
|||||||
return !timeString.equals("00:00:00");
|
return !timeString.equals("00:00:00");
|
||||||
}
|
}
|
||||||
|
|
||||||
|
public static LocalDate parseDate(String timeString) {
|
||||||
|
DateTimeFormatter[] dateFormatters =
|
||||||
|
{DateTimeFormatter.ofPattern("yyyyMMdd"), DateTimeFormatter.ofPattern("yyyy-MM-dd"),
|
||||||
|
DateTimeFormatter.ofPattern("yyyy/MM/dd"),
|
||||||
|
DateTimeFormatter.ofPattern("yyyy-MM")};
|
||||||
|
for (DateTimeFormatter formatter : dateFormatters) {
|
||||||
|
try {
|
||||||
|
return LocalDate.parse(timeString, formatter);
|
||||||
|
} catch (DateTimeParseException ignored) {
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return null;
|
||||||
|
}
|
||||||
|
|
||||||
public static List<String> getDateList(String startDateStr, String endDateStr,
|
public static List<String> getDateList(String startDateStr, String endDateStr,
|
||||||
DatePeriodEnum period) {
|
DatePeriodEnum period) {
|
||||||
try {
|
try {
|
||||||
LocalDate startDate = LocalDate.parse(startDateStr);
|
LocalDate startDate = parseDate(startDateStr);
|
||||||
LocalDate endDate = LocalDate.parse(endDateStr);
|
LocalDate endDate = parseDate(endDateStr);
|
||||||
List<String> datesInRange = new ArrayList<>();
|
List<String> datesInRange = new ArrayList<>();
|
||||||
LocalDate currentDate = startDate;
|
LocalDate currentDate = startDate;
|
||||||
DateTimeFormatter formatter = DateTimeFormatter.ofPattern("yyyy-MM");
|
DateTimeFormatter formatter = DateTimeFormatter.ofPattern("yyyy-MM");
|
||||||
@@ -189,7 +203,7 @@ public class DateUtils {
|
|||||||
}
|
}
|
||||||
|
|
||||||
public static boolean isAnyDateString(String value) {
|
public static boolean isAnyDateString(String value) {
|
||||||
List<String> formats = Arrays.asList("yyyy-MM-dd", "yyyy-MM", "yyyy/MM/dd");
|
List<String> formats = Arrays.asList("yyyy-MM-dd", "yyyy-MM", "yyyy/MM/dd", "yyyyMMdd");
|
||||||
return isAnyDateString(value, formats);
|
return isAnyDateString(value, formats);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -1,23 +0,0 @@
|
|||||||
package dev.langchain4j.dashscope.spring;
|
|
||||||
|
|
||||||
import lombok.Getter;
|
|
||||||
import lombok.Setter;
|
|
||||||
|
|
||||||
import java.util.List;
|
|
||||||
|
|
||||||
@Getter
|
|
||||||
@Setter
|
|
||||||
class ChatModelProperties {
|
|
||||||
|
|
||||||
String baseUrl;
|
|
||||||
String apiKey;
|
|
||||||
String modelName;
|
|
||||||
Double topP;
|
|
||||||
Integer topK;
|
|
||||||
Boolean enableSearch;
|
|
||||||
Integer seed;
|
|
||||||
Float repetitionPenalty;
|
|
||||||
Float temperature;
|
|
||||||
List<String> stops;
|
|
||||||
Integer maxTokens;
|
|
||||||
}
|
|
||||||
@@ -1,84 +0,0 @@
|
|||||||
package dev.langchain4j.dashscope.spring;
|
|
||||||
|
|
||||||
import dev.langchain4j.model.dashscope.QwenChatModel;
|
|
||||||
import dev.langchain4j.model.dashscope.QwenEmbeddingModel;
|
|
||||||
import dev.langchain4j.model.dashscope.QwenLanguageModel;
|
|
||||||
import dev.langchain4j.model.dashscope.QwenStreamingChatModel;
|
|
||||||
import dev.langchain4j.model.dashscope.QwenStreamingLanguageModel;
|
|
||||||
import org.springframework.boot.autoconfigure.condition.ConditionalOnProperty;
|
|
||||||
import org.springframework.boot.context.properties.EnableConfigurationProperties;
|
|
||||||
import org.springframework.context.annotation.Bean;
|
|
||||||
import org.springframework.context.annotation.Configuration;
|
|
||||||
|
|
||||||
import static dev.langchain4j.dashscope.spring.Properties.PREFIX;
|
|
||||||
|
|
||||||
@Configuration
|
|
||||||
@EnableConfigurationProperties(Properties.class)
|
|
||||||
public class DashscopeAutoConfig {
|
|
||||||
|
|
||||||
@Bean
|
|
||||||
@ConditionalOnProperty(PREFIX + ".chat-model.api-key")
|
|
||||||
QwenChatModel qwenChatModel(Properties properties) {
|
|
||||||
ChatModelProperties chatModelProperties = properties.getChatModel();
|
|
||||||
return QwenChatModel.builder().baseUrl(chatModelProperties.getBaseUrl())
|
|
||||||
.apiKey(chatModelProperties.getApiKey())
|
|
||||||
.modelName(chatModelProperties.getModelName()).topP(chatModelProperties.getTopP())
|
|
||||||
.topK(chatModelProperties.getTopK())
|
|
||||||
.enableSearch(chatModelProperties.getEnableSearch())
|
|
||||||
.seed(chatModelProperties.getSeed())
|
|
||||||
.repetitionPenalty(chatModelProperties.getRepetitionPenalty())
|
|
||||||
.temperature(chatModelProperties.getTemperature())
|
|
||||||
.stops(chatModelProperties.getStops()).maxTokens(chatModelProperties.getMaxTokens())
|
|
||||||
.build();
|
|
||||||
}
|
|
||||||
|
|
||||||
@Bean
|
|
||||||
@ConditionalOnProperty(PREFIX + ".streaming-chat-model.api-key")
|
|
||||||
QwenStreamingChatModel qwenStreamingChatModel(Properties properties) {
|
|
||||||
ChatModelProperties chatModelProperties = properties.getStreamingChatModel();
|
|
||||||
return QwenStreamingChatModel.builder().baseUrl(chatModelProperties.getBaseUrl())
|
|
||||||
.apiKey(chatModelProperties.getApiKey())
|
|
||||||
.modelName(chatModelProperties.getModelName()).topP(chatModelProperties.getTopP())
|
|
||||||
.topK(chatModelProperties.getTopK())
|
|
||||||
.enableSearch(chatModelProperties.getEnableSearch())
|
|
||||||
.seed(chatModelProperties.getSeed())
|
|
||||||
.repetitionPenalty(chatModelProperties.getRepetitionPenalty())
|
|
||||||
.temperature(chatModelProperties.getTemperature())
|
|
||||||
.stops(chatModelProperties.getStops()).maxTokens(chatModelProperties.getMaxTokens())
|
|
||||||
.build();
|
|
||||||
}
|
|
||||||
|
|
||||||
@Bean
|
|
||||||
@ConditionalOnProperty(PREFIX + ".language-model.api-key")
|
|
||||||
QwenLanguageModel qwenLanguageModel(Properties properties) {
|
|
||||||
ChatModelProperties languageModel = properties.getLanguageModel();
|
|
||||||
return QwenLanguageModel.builder().baseUrl(languageModel.getBaseUrl())
|
|
||||||
.apiKey(languageModel.getApiKey()).modelName(languageModel.getModelName())
|
|
||||||
.topP(languageModel.getTopP()).topK(languageModel.getTopK())
|
|
||||||
.enableSearch(languageModel.getEnableSearch()).seed(languageModel.getSeed())
|
|
||||||
.repetitionPenalty(languageModel.getRepetitionPenalty())
|
|
||||||
.temperature(languageModel.getTemperature()).stops(languageModel.getStops())
|
|
||||||
.maxTokens(languageModel.getMaxTokens()).build();
|
|
||||||
}
|
|
||||||
|
|
||||||
@Bean
|
|
||||||
@ConditionalOnProperty(PREFIX + ".streaming-language-model.api-key")
|
|
||||||
QwenStreamingLanguageModel qwenStreamingLanguageModel(Properties properties) {
|
|
||||||
ChatModelProperties languageModel = properties.getStreamingLanguageModel();
|
|
||||||
return QwenStreamingLanguageModel.builder().baseUrl(languageModel.getBaseUrl())
|
|
||||||
.apiKey(languageModel.getApiKey()).modelName(languageModel.getModelName())
|
|
||||||
.topP(languageModel.getTopP()).topK(languageModel.getTopK())
|
|
||||||
.enableSearch(languageModel.getEnableSearch()).seed(languageModel.getSeed())
|
|
||||||
.repetitionPenalty(languageModel.getRepetitionPenalty())
|
|
||||||
.temperature(languageModel.getTemperature()).stops(languageModel.getStops())
|
|
||||||
.maxTokens(languageModel.getMaxTokens()).build();
|
|
||||||
}
|
|
||||||
|
|
||||||
@Bean
|
|
||||||
@ConditionalOnProperty(PREFIX + ".embedding-model.api-key")
|
|
||||||
QwenEmbeddingModel qwenEmbeddingModel(Properties properties) {
|
|
||||||
EmbeddingModelProperties embeddingModelProperties = properties.getEmbeddingModel();
|
|
||||||
return QwenEmbeddingModel.builder().apiKey(embeddingModelProperties.getApiKey())
|
|
||||||
.modelName(embeddingModelProperties.getModelName()).build();
|
|
||||||
}
|
|
||||||
}
|
|
||||||
@@ -1,12 +0,0 @@
|
|||||||
package dev.langchain4j.dashscope.spring;
|
|
||||||
|
|
||||||
import lombok.Getter;
|
|
||||||
import lombok.Setter;
|
|
||||||
|
|
||||||
@Getter
|
|
||||||
@Setter
|
|
||||||
class EmbeddingModelProperties {
|
|
||||||
|
|
||||||
private String apiKey;
|
|
||||||
private String modelName;
|
|
||||||
}
|
|
||||||
@@ -1,29 +0,0 @@
|
|||||||
package dev.langchain4j.dashscope.spring;
|
|
||||||
|
|
||||||
import lombok.Getter;
|
|
||||||
import lombok.Setter;
|
|
||||||
import org.springframework.boot.context.properties.ConfigurationProperties;
|
|
||||||
import org.springframework.boot.context.properties.NestedConfigurationProperty;
|
|
||||||
|
|
||||||
@Getter
|
|
||||||
@Setter
|
|
||||||
@ConfigurationProperties(prefix = Properties.PREFIX)
|
|
||||||
public class Properties {
|
|
||||||
|
|
||||||
static final String PREFIX = "langchain4j.dashscope";
|
|
||||||
|
|
||||||
@NestedConfigurationProperty
|
|
||||||
ChatModelProperties chatModel;
|
|
||||||
|
|
||||||
@NestedConfigurationProperty
|
|
||||||
ChatModelProperties streamingChatModel;
|
|
||||||
|
|
||||||
@NestedConfigurationProperty
|
|
||||||
ChatModelProperties languageModel;
|
|
||||||
|
|
||||||
@NestedConfigurationProperty
|
|
||||||
ChatModelProperties streamingLanguageModel;
|
|
||||||
|
|
||||||
@NestedConfigurationProperty
|
|
||||||
EmbeddingModelProperties embeddingModel;
|
|
||||||
}
|
|
||||||
@@ -1,9 +1,9 @@
|
|||||||
package dev.langchain4j.inmemory.spring;
|
package dev.langchain4j.inmemory.spring;
|
||||||
|
|
||||||
import dev.langchain4j.model.embedding.AllMiniLmL6V2QuantizedEmbeddingModel;
|
|
||||||
import dev.langchain4j.model.embedding.BgeSmallZhEmbeddingModel;
|
|
||||||
import dev.langchain4j.model.embedding.EmbeddingModel;
|
import dev.langchain4j.model.embedding.EmbeddingModel;
|
||||||
import dev.langchain4j.model.embedding.S2OnnxEmbeddingModel;
|
import dev.langchain4j.model.embedding.S2OnnxEmbeddingModel;
|
||||||
|
import dev.langchain4j.model.embedding.onnx.allminilml6v2q.AllMiniLmL6V2QuantizedEmbeddingModel;
|
||||||
|
import dev.langchain4j.model.embedding.onnx.bgesmallzh.BgeSmallZhEmbeddingModel;
|
||||||
import dev.langchain4j.provider.EmbeddingModelConstant;
|
import dev.langchain4j.provider.EmbeddingModelConstant;
|
||||||
import dev.langchain4j.store.embedding.EmbeddingStoreFactory;
|
import dev.langchain4j.store.embedding.EmbeddingStoreFactory;
|
||||||
import org.apache.commons.lang3.StringUtils;
|
import org.apache.commons.lang3.StringUtils;
|
||||||
|
|||||||
@@ -9,6 +9,7 @@ import dev.langchain4j.data.message.ChatMessage;
|
|||||||
import dev.langchain4j.model.chat.ChatLanguageModel;
|
import dev.langchain4j.model.chat.ChatLanguageModel;
|
||||||
import dev.langchain4j.model.output.Response;
|
import dev.langchain4j.model.output.Response;
|
||||||
import lombok.Builder;
|
import lombok.Builder;
|
||||||
|
import lombok.Setter;
|
||||||
|
|
||||||
import java.util.List;
|
import java.util.List;
|
||||||
|
|
||||||
@@ -32,6 +33,7 @@ public class DifyAiChatModel implements ChatLanguageModel {
|
|||||||
private final Double temperature;
|
private final Double temperature;
|
||||||
private final Long timeOut;
|
private final Long timeOut;
|
||||||
|
|
||||||
|
@Setter
|
||||||
private String userName;
|
private String userName;
|
||||||
|
|
||||||
@Builder
|
@Builder
|
||||||
@@ -54,7 +56,7 @@ public class DifyAiChatModel implements ChatLanguageModel {
|
|||||||
@Override
|
@Override
|
||||||
public String generate(String message) {
|
public String generate(String message) {
|
||||||
DifyResult difyResult = this.difyClient.generate(message, this.getUserName());
|
DifyResult difyResult = this.difyClient.generate(message, this.getUserName());
|
||||||
return difyResult.getAnswer().toString();
|
return difyResult.getAnswer();
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
@@ -67,7 +69,7 @@ public class DifyAiChatModel implements ChatLanguageModel {
|
|||||||
List<ToolSpecification> toolSpecifications) {
|
List<ToolSpecification> toolSpecifications) {
|
||||||
ensureNotEmpty(messages, "messages");
|
ensureNotEmpty(messages, "messages");
|
||||||
DifyResult difyResult =
|
DifyResult difyResult =
|
||||||
this.difyClient.generate(messages.get(0).text(), this.getUserName());
|
this.difyClient.generate(messages.get(0).toString(), this.getUserName());
|
||||||
System.out.println(difyResult.toString());
|
System.out.println(difyResult.toString());
|
||||||
|
|
||||||
if (!isNullOrEmpty(toolSpecifications)) {
|
if (!isNullOrEmpty(toolSpecifications)) {
|
||||||
@@ -84,12 +86,8 @@ public class DifyAiChatModel implements ChatLanguageModel {
|
|||||||
toolSpecification != null ? singletonList(toolSpecification) : null);
|
toolSpecification != null ? singletonList(toolSpecification) : null);
|
||||||
}
|
}
|
||||||
|
|
||||||
public void setUserName(String userName) {
|
|
||||||
this.userName = userName;
|
|
||||||
}
|
|
||||||
|
|
||||||
public String getUserName() {
|
public String getUserName() {
|
||||||
return null == userName ? "zhaodongsheng" : userName;
|
return null == userName ? "admin" : userName;
|
||||||
}
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,5 +1,8 @@
|
|||||||
package dev.langchain4j.model.embedding;
|
package dev.langchain4j.model.embedding;
|
||||||
|
|
||||||
|
import dev.langchain4j.model.embedding.onnx.AbstractInProcessEmbeddingModel;
|
||||||
|
import dev.langchain4j.model.embedding.onnx.OnnxBertBiEncoder;
|
||||||
|
import dev.langchain4j.model.embedding.onnx.PoolingMode;
|
||||||
import org.apache.commons.lang3.StringUtils;
|
import org.apache.commons.lang3.StringUtils;
|
||||||
|
|
||||||
import java.io.IOException;
|
import java.io.IOException;
|
||||||
@@ -9,6 +12,7 @@ import java.nio.file.Files;
|
|||||||
import java.nio.file.Path;
|
import java.nio.file.Path;
|
||||||
import java.nio.file.Paths;
|
import java.nio.file.Paths;
|
||||||
import java.util.Objects;
|
import java.util.Objects;
|
||||||
|
import java.util.concurrent.Executors;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* An embedding model that runs within your Java application's process. Any BERT-based model (e.g.,
|
* An embedding model that runs within your Java application's process. Any BERT-based model (e.g.,
|
||||||
@@ -25,6 +29,7 @@ public class S2OnnxEmbeddingModel extends AbstractInProcessEmbeddingModel {
|
|||||||
private static volatile String cachedVocabularyPath;
|
private static volatile String cachedVocabularyPath;
|
||||||
|
|
||||||
public S2OnnxEmbeddingModel(String pathToModel, String vocabularyPath) {
|
public S2OnnxEmbeddingModel(String pathToModel, String vocabularyPath) {
|
||||||
|
super(Executors.newSingleThreadExecutor());
|
||||||
if (shouldReloadModel(pathToModel, vocabularyPath)) {
|
if (shouldReloadModel(pathToModel, vocabularyPath)) {
|
||||||
synchronized (S2OnnxEmbeddingModel.class) {
|
synchronized (S2OnnxEmbeddingModel.class) {
|
||||||
if (shouldReloadModel(pathToModel, vocabularyPath)) {
|
if (shouldReloadModel(pathToModel, vocabularyPath)) {
|
||||||
@@ -61,8 +66,8 @@ public class S2OnnxEmbeddingModel extends AbstractInProcessEmbeddingModel {
|
|||||||
|
|
||||||
static OnnxBertBiEncoder loadFromFileSystem(Path pathToModel, URL vocabularyFile) {
|
static OnnxBertBiEncoder loadFromFileSystem(Path pathToModel, URL vocabularyFile) {
|
||||||
try {
|
try {
|
||||||
return new OnnxBertBiEncoder(Files.newInputStream(pathToModel), vocabularyFile,
|
return new OnnxBertBiEncoder(Files.newInputStream(pathToModel),
|
||||||
PoolingMode.MEAN);
|
vocabularyFile.openStream(), PoolingMode.MEAN);
|
||||||
} catch (IOException e) {
|
} catch (IOException e) {
|
||||||
throw new RuntimeException(e);
|
throw new RuntimeException(e);
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -52,7 +52,7 @@ import static dev.langchain4j.model.openai.InternalOpenAiHelper.toOpenAiMessages
|
|||||||
import static dev.langchain4j.model.openai.InternalOpenAiHelper.toOpenAiResponseFormat;
|
import static dev.langchain4j.model.openai.InternalOpenAiHelper.toOpenAiResponseFormat;
|
||||||
import static dev.langchain4j.model.openai.InternalOpenAiHelper.toTools;
|
import static dev.langchain4j.model.openai.InternalOpenAiHelper.toTools;
|
||||||
import static dev.langchain4j.model.openai.InternalOpenAiHelper.tokenUsageFrom;
|
import static dev.langchain4j.model.openai.InternalOpenAiHelper.tokenUsageFrom;
|
||||||
import static dev.langchain4j.model.openai.OpenAiModelName.GPT_3_5_TURBO;
|
import static dev.langchain4j.model.openai.OpenAiChatModelName.GPT_3_5_TURBO;
|
||||||
import static dev.langchain4j.spi.ServiceHelper.loadFactories;
|
import static dev.langchain4j.spi.ServiceHelper.loadFactories;
|
||||||
import static java.time.Duration.ofSeconds;
|
import static java.time.Duration.ofSeconds;
|
||||||
import static java.util.Collections.emptyList;
|
import static java.util.Collections.emptyList;
|
||||||
@@ -66,7 +66,6 @@ import static java.util.Collections.singletonList;
|
|||||||
@Slf4j
|
@Slf4j
|
||||||
public class OpenAiChatModel implements ChatLanguageModel, TokenCountEstimator {
|
public class OpenAiChatModel implements ChatLanguageModel, TokenCountEstimator {
|
||||||
|
|
||||||
public static final String ZHIPU = "bigmodel";
|
|
||||||
private final OpenAiClient client;
|
private final OpenAiClient client;
|
||||||
private final String baseUrl;
|
private final String baseUrl;
|
||||||
private final String modelName;
|
private final String modelName;
|
||||||
@@ -111,7 +110,7 @@ public class OpenAiChatModel implements ChatLanguageModel, TokenCountEstimator {
|
|||||||
.connectTimeout(timeout).readTimeout(timeout).writeTimeout(timeout).proxy(proxy)
|
.connectTimeout(timeout).readTimeout(timeout).writeTimeout(timeout).proxy(proxy)
|
||||||
.logRequests(logRequests).logResponses(logResponses).userAgent(DEFAULT_USER_AGENT)
|
.logRequests(logRequests).logResponses(logResponses).userAgent(DEFAULT_USER_AGENT)
|
||||||
.customHeaders(customHeaders).build();
|
.customHeaders(customHeaders).build();
|
||||||
this.modelName = getOrDefault(modelName, GPT_3_5_TURBO);
|
this.modelName = getOrDefault(modelName, GPT_3_5_TURBO.name());
|
||||||
this.apiVersion = apiVersion;
|
this.apiVersion = apiVersion;
|
||||||
this.temperature = getOrDefault(temperature, 0.7);
|
this.temperature = getOrDefault(temperature, 0.7);
|
||||||
this.topP = topP;
|
this.topP = topP;
|
||||||
@@ -130,7 +129,7 @@ public class OpenAiChatModel implements ChatLanguageModel, TokenCountEstimator {
|
|||||||
this.strictTools = getOrDefault(strictTools, false);
|
this.strictTools = getOrDefault(strictTools, false);
|
||||||
this.parallelToolCalls = parallelToolCalls;
|
this.parallelToolCalls = parallelToolCalls;
|
||||||
this.maxRetries = getOrDefault(maxRetries, 3);
|
this.maxRetries = getOrDefault(maxRetries, 3);
|
||||||
this.tokenizer = getOrDefault(tokenizer, OpenAiTokenizer::new);
|
this.tokenizer = getOrDefault(tokenizer, () -> new OpenAiTokenizer(this.modelName));
|
||||||
this.listeners = listeners == null ? emptyList() : new ArrayList<>(listeners);
|
this.listeners = listeners == null ? emptyList() : new ArrayList<>(listeners);
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -192,9 +191,7 @@ public class OpenAiChatModel implements ChatLanguageModel, TokenCountEstimator {
|
|||||||
.responseFormat(responseFormat).seed(seed).user(user)
|
.responseFormat(responseFormat).seed(seed).user(user)
|
||||||
.parallelToolCalls(parallelToolCalls);
|
.parallelToolCalls(parallelToolCalls);
|
||||||
|
|
||||||
if (!(baseUrl.contains(ZHIPU))) {
|
|
||||||
requestBuilder.temperature(temperature);
|
requestBuilder.temperature(temperature);
|
||||||
}
|
|
||||||
|
|
||||||
if (toolSpecifications != null && !toolSpecifications.isEmpty()) {
|
if (toolSpecifications != null && !toolSpecifications.isEmpty()) {
|
||||||
requestBuilder.tools(toTools(toolSpecifications, strictTools));
|
requestBuilder.tools(toTools(toolSpecifications, strictTools));
|
||||||
|
|||||||
@@ -1,16 +0,0 @@
|
|||||||
package dev.langchain4j.model.zhipu;
|
|
||||||
|
|
||||||
public enum ChatCompletionModel {
|
|
||||||
GLM_4("glm-4"), GLM_3_TURBO("glm-3-turbo"), CHATGLM_TURBO("chatglm_turbo");
|
|
||||||
|
|
||||||
private final String value;
|
|
||||||
|
|
||||||
ChatCompletionModel(String value) {
|
|
||||||
this.value = value;
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public String toString() {
|
|
||||||
return this.value;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
@@ -1,100 +0,0 @@
|
|||||||
package dev.langchain4j.model.zhipu;
|
|
||||||
|
|
||||||
import dev.langchain4j.agent.tool.ToolSpecification;
|
|
||||||
import dev.langchain4j.data.message.AiMessage;
|
|
||||||
import dev.langchain4j.data.message.ChatMessage;
|
|
||||||
import dev.langchain4j.model.chat.ChatLanguageModel;
|
|
||||||
import dev.langchain4j.model.output.Response;
|
|
||||||
import dev.langchain4j.model.zhipu.chat.ChatCompletionRequest;
|
|
||||||
import dev.langchain4j.model.zhipu.chat.ChatCompletionResponse;
|
|
||||||
import dev.langchain4j.model.zhipu.spi.ZhipuAiChatModelBuilderFactory;
|
|
||||||
import lombok.Builder;
|
|
||||||
|
|
||||||
import java.util.List;
|
|
||||||
|
|
||||||
import static dev.langchain4j.internal.RetryUtils.withRetry;
|
|
||||||
import static dev.langchain4j.internal.Utils.getOrDefault;
|
|
||||||
import static dev.langchain4j.internal.Utils.isNullOrEmpty;
|
|
||||||
import static dev.langchain4j.internal.ValidationUtils.ensureNotEmpty;
|
|
||||||
import static dev.langchain4j.model.zhipu.DefaultZhipuAiHelper.aiMessageFrom;
|
|
||||||
import static dev.langchain4j.model.zhipu.DefaultZhipuAiHelper.finishReasonFrom;
|
|
||||||
import static dev.langchain4j.model.zhipu.DefaultZhipuAiHelper.toTools;
|
|
||||||
import static dev.langchain4j.model.zhipu.DefaultZhipuAiHelper.toZhipuAiMessages;
|
|
||||||
import static dev.langchain4j.model.zhipu.DefaultZhipuAiHelper.tokenUsageFrom;
|
|
||||||
import static dev.langchain4j.model.zhipu.chat.ToolChoiceMode.AUTO;
|
|
||||||
import static dev.langchain4j.spi.ServiceHelper.loadFactories;
|
|
||||||
import static java.util.Collections.singletonList;
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Represents an ZhipuAi language model with a chat completion interface, such as glm-3-turbo and
|
|
||||||
* glm-4. You can find description of parameters
|
|
||||||
* <a href="https://open.bigmodel.cn/dev/api">here</a>.
|
|
||||||
*/
|
|
||||||
public class ZhipuAiChatModel implements ChatLanguageModel {
|
|
||||||
|
|
||||||
private final String baseUrl;
|
|
||||||
private final Double temperature;
|
|
||||||
private final Double topP;
|
|
||||||
private final String model;
|
|
||||||
private final Integer maxRetries;
|
|
||||||
private final Integer maxToken;
|
|
||||||
private final ZhipuAiClient client;
|
|
||||||
|
|
||||||
@Builder
|
|
||||||
public ZhipuAiChatModel(String baseUrl, String apiKey, Double temperature, Double topP,
|
|
||||||
String model, Integer maxRetries, Integer maxToken, Boolean logRequests,
|
|
||||||
Boolean logResponses) {
|
|
||||||
this.baseUrl = getOrDefault(baseUrl, "https://open.bigmodel.cn/");
|
|
||||||
this.temperature = getOrDefault(temperature, 0.7);
|
|
||||||
this.topP = topP;
|
|
||||||
this.model = getOrDefault(model, ChatCompletionModel.GLM_4.toString());
|
|
||||||
this.maxRetries = getOrDefault(maxRetries, 3);
|
|
||||||
this.maxToken = getOrDefault(maxToken, 512);
|
|
||||||
this.client = ZhipuAiClient.builder().baseUrl(this.baseUrl).apiKey(apiKey)
|
|
||||||
.logRequests(getOrDefault(logRequests, false))
|
|
||||||
.logResponses(getOrDefault(logResponses, false)).build();
|
|
||||||
}
|
|
||||||
|
|
||||||
public static ZhipuAiChatModelBuilder builder() {
|
|
||||||
for (ZhipuAiChatModelBuilderFactory factories : loadFactories(
|
|
||||||
ZhipuAiChatModelBuilderFactory.class)) {
|
|
||||||
return factories.get();
|
|
||||||
}
|
|
||||||
return new ZhipuAiChatModelBuilder();
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public Response<AiMessage> generate(List<ChatMessage> messages) {
|
|
||||||
return generate(messages, (ToolSpecification) null);
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public Response<AiMessage> generate(List<ChatMessage> messages,
|
|
||||||
List<ToolSpecification> toolSpecifications) {
|
|
||||||
ensureNotEmpty(messages, "messages");
|
|
||||||
|
|
||||||
ChatCompletionRequest.Builder requestBuilder =
|
|
||||||
ChatCompletionRequest.builder().model(this.model).maxTokens(maxToken).stream(false)
|
|
||||||
.topP(topP).toolChoice(AUTO).messages(toZhipuAiMessages(messages));
|
|
||||||
|
|
||||||
if (!isNullOrEmpty(toolSpecifications)) {
|
|
||||||
requestBuilder.tools(toTools(toolSpecifications));
|
|
||||||
}
|
|
||||||
|
|
||||||
ChatCompletionResponse response =
|
|
||||||
withRetry(() -> client.chatCompletion(requestBuilder.build()), maxRetries);
|
|
||||||
return Response.from(aiMessageFrom(response), tokenUsageFrom(response.getUsage()),
|
|
||||||
finishReasonFrom(response.getChoices().get(0).getFinishReason()));
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public Response<AiMessage> generate(List<ChatMessage> messages,
|
|
||||||
ToolSpecification toolSpecification) {
|
|
||||||
return generate(messages,
|
|
||||||
toolSpecification != null ? singletonList(toolSpecification) : null);
|
|
||||||
}
|
|
||||||
|
|
||||||
public static class ZhipuAiChatModelBuilder {
|
|
||||||
public ZhipuAiChatModelBuilder() {}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
@@ -1,51 +0,0 @@
|
|||||||
package dev.langchain4j.provider;
|
|
||||||
|
|
||||||
import com.tencent.supersonic.common.pojo.ChatModelConfig;
|
|
||||||
import com.tencent.supersonic.common.pojo.EmbeddingModelConfig;
|
|
||||||
import dev.langchain4j.model.azure.AzureOpenAiChatModel;
|
|
||||||
import dev.langchain4j.model.azure.AzureOpenAiEmbeddingModel;
|
|
||||||
import dev.langchain4j.model.chat.ChatLanguageModel;
|
|
||||||
import dev.langchain4j.model.embedding.EmbeddingModel;
|
|
||||||
import org.springframework.beans.factory.InitializingBean;
|
|
||||||
import org.springframework.stereotype.Service;
|
|
||||||
|
|
||||||
import java.time.Duration;
|
|
||||||
|
|
||||||
@Service
|
|
||||||
public class AzureModelFactory implements ModelFactory, InitializingBean {
|
|
||||||
public static final String PROVIDER = "AZURE";
|
|
||||||
public static final String DEFAULT_BASE_URL = "https://your-resource-name.openai.azure.com/";
|
|
||||||
public static final String DEFAULT_MODEL_NAME = "gpt-35-turbo";
|
|
||||||
public static final String DEFAULT_EMBEDDING_MODEL_NAME = "text-embedding-ada-002";
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public ChatLanguageModel createChatModel(ChatModelConfig modelConfig) {
|
|
||||||
AzureOpenAiChatModel.Builder builder = AzureOpenAiChatModel.builder()
|
|
||||||
.endpoint(modelConfig.getBaseUrl()).apiKey(modelConfig.getApiKey())
|
|
||||||
.deploymentName(modelConfig.getModelName())
|
|
||||||
.temperature(modelConfig.getTemperature()).maxRetries(modelConfig.getMaxRetries())
|
|
||||||
.topP(modelConfig.getTopP())
|
|
||||||
.timeout(Duration.ofSeconds(
|
|
||||||
modelConfig.getTimeOut() == null ? 0L : modelConfig.getTimeOut()))
|
|
||||||
.logRequestsAndResponses(
|
|
||||||
modelConfig.getLogRequests() != null && modelConfig.getLogResponses());
|
|
||||||
return builder.build();
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public EmbeddingModel createEmbeddingModel(EmbeddingModelConfig embeddingModelConfig) {
|
|
||||||
AzureOpenAiEmbeddingModel.Builder builder =
|
|
||||||
AzureOpenAiEmbeddingModel.builder().endpoint(embeddingModelConfig.getBaseUrl())
|
|
||||||
.apiKey(embeddingModelConfig.getApiKey())
|
|
||||||
.deploymentName(embeddingModelConfig.getModelName())
|
|
||||||
.maxRetries(embeddingModelConfig.getMaxRetries())
|
|
||||||
.logRequestsAndResponses(embeddingModelConfig.getLogRequests() != null
|
|
||||||
&& embeddingModelConfig.getLogResponses());
|
|
||||||
return builder.build();
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public void afterPropertiesSet() {
|
|
||||||
ModelProvider.add(PROVIDER, this);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
@@ -1,39 +0,0 @@
|
|||||||
package dev.langchain4j.provider;
|
|
||||||
|
|
||||||
import com.tencent.supersonic.common.pojo.ChatModelConfig;
|
|
||||||
import com.tencent.supersonic.common.pojo.EmbeddingModelConfig;
|
|
||||||
import dev.langchain4j.model.chat.ChatLanguageModel;
|
|
||||||
import dev.langchain4j.model.dashscope.QwenChatModel;
|
|
||||||
import dev.langchain4j.model.dashscope.QwenEmbeddingModel;
|
|
||||||
import dev.langchain4j.model.dashscope.QwenModelName;
|
|
||||||
import dev.langchain4j.model.embedding.EmbeddingModel;
|
|
||||||
import org.springframework.beans.factory.InitializingBean;
|
|
||||||
import org.springframework.stereotype.Service;
|
|
||||||
|
|
||||||
@Service
|
|
||||||
public class DashscopeModelFactory implements ModelFactory, InitializingBean {
|
|
||||||
public static final String PROVIDER = "DASHSCOPE";
|
|
||||||
public static final String DEFAULT_BASE_URL = "https://dashscope.aliyuncs.com/api/v1";
|
|
||||||
public static final String DEFAULT_MODEL_NAME = QwenModelName.QWEN_PLUS;
|
|
||||||
public static final String DEFAULT_EMBEDDING_MODEL_NAME = "text-embedding-v2";
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public ChatLanguageModel createChatModel(ChatModelConfig modelConfig) {
|
|
||||||
return QwenChatModel.builder().baseUrl(modelConfig.getBaseUrl())
|
|
||||||
.apiKey(modelConfig.getApiKey()).modelName(modelConfig.getModelName())
|
|
||||||
.temperature(modelConfig.getTemperature() == null ? 0L
|
|
||||||
: modelConfig.getTemperature().floatValue())
|
|
||||||
.topP(modelConfig.getTopP()).enableSearch(modelConfig.getEnableSearch()).build();
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public EmbeddingModel createEmbeddingModel(EmbeddingModelConfig embeddingModelConfig) {
|
|
||||||
return QwenEmbeddingModel.builder().apiKey(embeddingModelConfig.getApiKey())
|
|
||||||
.modelName(embeddingModelConfig.getModelName()).build();
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public void afterPropertiesSet() {
|
|
||||||
ModelProvider.add(PROVIDER, this);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
@@ -6,7 +6,7 @@ import com.tencent.supersonic.common.util.AESEncryptionUtil;
|
|||||||
import dev.langchain4j.model.chat.ChatLanguageModel;
|
import dev.langchain4j.model.chat.ChatLanguageModel;
|
||||||
import dev.langchain4j.model.dify.DifyAiChatModel;
|
import dev.langchain4j.model.dify.DifyAiChatModel;
|
||||||
import dev.langchain4j.model.embedding.EmbeddingModel;
|
import dev.langchain4j.model.embedding.EmbeddingModel;
|
||||||
import dev.langchain4j.model.zhipu.ZhipuAiEmbeddingModel;
|
import dev.langchain4j.model.openai.OpenAiEmbeddingModel;
|
||||||
import org.springframework.beans.factory.InitializingBean;
|
import org.springframework.beans.factory.InitializingBean;
|
||||||
import org.springframework.stereotype.Service;
|
import org.springframework.stereotype.Service;
|
||||||
|
|
||||||
@@ -27,8 +27,9 @@ public class DifyModelFactory implements ModelFactory, InitializingBean {
|
|||||||
|
|
||||||
@Override
|
@Override
|
||||||
public EmbeddingModel createEmbeddingModel(EmbeddingModelConfig embeddingModelConfig) {
|
public EmbeddingModel createEmbeddingModel(EmbeddingModelConfig embeddingModelConfig) {
|
||||||
return ZhipuAiEmbeddingModel.builder().baseUrl(embeddingModelConfig.getBaseUrl())
|
return OpenAiEmbeddingModel.builder().baseUrl(embeddingModelConfig.getBaseUrl())
|
||||||
.apiKey(embeddingModelConfig.getApiKey()).model(embeddingModelConfig.getModelName())
|
.apiKey(embeddingModelConfig.getApiKey())
|
||||||
|
.modelName(embeddingModelConfig.getModelName())
|
||||||
.maxRetries(embeddingModelConfig.getMaxRetries())
|
.maxRetries(embeddingModelConfig.getMaxRetries())
|
||||||
.logRequests(embeddingModelConfig.getLogRequests())
|
.logRequests(embeddingModelConfig.getLogRequests())
|
||||||
.logResponses(embeddingModelConfig.getLogResponses()).build();
|
.logResponses(embeddingModelConfig.getLogResponses()).build();
|
||||||
|
|||||||
@@ -1,8 +1,8 @@
|
|||||||
package dev.langchain4j.provider;
|
package dev.langchain4j.provider;
|
||||||
|
|
||||||
import dev.langchain4j.model.embedding.AllMiniLmL6V2QuantizedEmbeddingModel;
|
|
||||||
import dev.langchain4j.model.embedding.BgeSmallZhEmbeddingModel;
|
|
||||||
import dev.langchain4j.model.embedding.EmbeddingModel;
|
import dev.langchain4j.model.embedding.EmbeddingModel;
|
||||||
|
import dev.langchain4j.model.embedding.onnx.allminilml6v2q.AllMiniLmL6V2QuantizedEmbeddingModel;
|
||||||
|
import dev.langchain4j.model.embedding.onnx.bgesmallzh.BgeSmallZhEmbeddingModel;
|
||||||
import org.springframework.stereotype.Service;
|
import org.springframework.stereotype.Service;
|
||||||
|
|
||||||
@Service
|
@Service
|
||||||
|
|||||||
@@ -1,47 +0,0 @@
|
|||||||
package dev.langchain4j.provider;
|
|
||||||
|
|
||||||
import com.tencent.supersonic.common.pojo.ChatModelConfig;
|
|
||||||
import com.tencent.supersonic.common.pojo.EmbeddingModelConfig;
|
|
||||||
import dev.langchain4j.model.chat.ChatLanguageModel;
|
|
||||||
import dev.langchain4j.model.embedding.EmbeddingModel;
|
|
||||||
import dev.langchain4j.model.qianfan.QianfanChatModel;
|
|
||||||
import dev.langchain4j.model.qianfan.QianfanEmbeddingModel;
|
|
||||||
import org.springframework.beans.factory.InitializingBean;
|
|
||||||
import org.springframework.stereotype.Service;
|
|
||||||
|
|
||||||
@Service
|
|
||||||
public class QianfanModelFactory implements ModelFactory, InitializingBean {
|
|
||||||
|
|
||||||
public static final String PROVIDER = "QIANFAN";
|
|
||||||
public static final String DEFAULT_BASE_URL = "https://aip.baidubce.com";
|
|
||||||
public static final String DEFAULT_MODEL_NAME = "Llama-2-70b-chat";
|
|
||||||
|
|
||||||
public static final String DEFAULT_EMBEDDING_MODEL_NAME = "Embedding-V1";
|
|
||||||
public static final String DEFAULT_ENDPOINT = "llama_2_70b";
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public ChatLanguageModel createChatModel(ChatModelConfig modelConfig) {
|
|
||||||
return QianfanChatModel.builder().baseUrl(modelConfig.getBaseUrl())
|
|
||||||
.apiKey(modelConfig.getApiKey()).secretKey(modelConfig.getSecretKey())
|
|
||||||
.endpoint(modelConfig.getEndpoint()).modelName(modelConfig.getModelName())
|
|
||||||
.temperature(modelConfig.getTemperature()).topP(modelConfig.getTopP())
|
|
||||||
.maxRetries(modelConfig.getMaxRetries()).logRequests(modelConfig.getLogRequests())
|
|
||||||
.logResponses(modelConfig.getLogResponses()).build();
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public EmbeddingModel createEmbeddingModel(EmbeddingModelConfig embeddingModelConfig) {
|
|
||||||
return QianfanEmbeddingModel.builder().baseUrl(embeddingModelConfig.getBaseUrl())
|
|
||||||
.apiKey(embeddingModelConfig.getApiKey())
|
|
||||||
.secretKey(embeddingModelConfig.getSecretKey())
|
|
||||||
.modelName(embeddingModelConfig.getModelName())
|
|
||||||
.maxRetries(embeddingModelConfig.getMaxRetries())
|
|
||||||
.logRequests(embeddingModelConfig.getLogRequests())
|
|
||||||
.logResponses(embeddingModelConfig.getLogResponses()).build();
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public void afterPropertiesSet() {
|
|
||||||
ModelProvider.add(PROVIDER, this);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
@@ -1,45 +0,0 @@
|
|||||||
package dev.langchain4j.provider;
|
|
||||||
|
|
||||||
import com.tencent.supersonic.common.pojo.ChatModelConfig;
|
|
||||||
import com.tencent.supersonic.common.pojo.EmbeddingModelConfig;
|
|
||||||
import dev.langchain4j.model.chat.ChatLanguageModel;
|
|
||||||
import dev.langchain4j.model.embedding.EmbeddingModel;
|
|
||||||
import dev.langchain4j.model.zhipu.ChatCompletionModel;
|
|
||||||
import dev.langchain4j.model.zhipu.ZhipuAiChatModel;
|
|
||||||
import dev.langchain4j.model.zhipu.ZhipuAiEmbeddingModel;
|
|
||||||
import org.springframework.beans.factory.InitializingBean;
|
|
||||||
import org.springframework.stereotype.Service;
|
|
||||||
|
|
||||||
import static java.time.Duration.ofSeconds;
|
|
||||||
|
|
||||||
@Service
|
|
||||||
public class ZhipuModelFactory implements ModelFactory, InitializingBean {
|
|
||||||
public static final String PROVIDER = "ZHIPU";
|
|
||||||
public static final String DEFAULT_BASE_URL = "https://open.bigmodel.cn/";
|
|
||||||
public static final String DEFAULT_MODEL_NAME = ChatCompletionModel.GLM_4.toString();
|
|
||||||
public static final String DEFAULT_EMBEDDING_MODEL_NAME = "embedding-2";
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public ChatLanguageModel createChatModel(ChatModelConfig modelConfig) {
|
|
||||||
return ZhipuAiChatModel.builder().baseUrl(modelConfig.getBaseUrl())
|
|
||||||
.apiKey(modelConfig.getApiKey()).model(modelConfig.getModelName())
|
|
||||||
.temperature(modelConfig.getTemperature()).topP(modelConfig.getTopP())
|
|
||||||
.maxRetries(modelConfig.getMaxRetries()).logRequests(modelConfig.getLogRequests())
|
|
||||||
.logResponses(modelConfig.getLogResponses()).build();
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public EmbeddingModel createEmbeddingModel(EmbeddingModelConfig embeddingModelConfig) {
|
|
||||||
return ZhipuAiEmbeddingModel.builder().baseUrl(embeddingModelConfig.getBaseUrl())
|
|
||||||
.apiKey(embeddingModelConfig.getApiKey()).model(embeddingModelConfig.getModelName())
|
|
||||||
.maxRetries(embeddingModelConfig.getMaxRetries()).callTimeout(ofSeconds(60))
|
|
||||||
.connectTimeout(ofSeconds(60)).writeTimeout(ofSeconds(60))
|
|
||||||
.readTimeout(ofSeconds(60)).logRequests(embeddingModelConfig.getLogRequests())
|
|
||||||
.logResponses(embeddingModelConfig.getLogResponses()).build();
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public void afterPropertiesSet() {
|
|
||||||
ModelProvider.add(PROVIDER, this);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
@@ -1,21 +0,0 @@
|
|||||||
package dev.langchain4j.qianfan.spring;
|
|
||||||
|
|
||||||
import lombok.Getter;
|
|
||||||
import lombok.Setter;
|
|
||||||
|
|
||||||
@Getter
|
|
||||||
@Setter
|
|
||||||
class ChatModelProperties {
|
|
||||||
private String baseUrl;
|
|
||||||
private String apiKey;
|
|
||||||
private String secretKey;
|
|
||||||
private Double temperature;
|
|
||||||
private Integer maxRetries;
|
|
||||||
private Double topP;
|
|
||||||
private String modelName;
|
|
||||||
private String endpoint;
|
|
||||||
private String responseFormat;
|
|
||||||
private Double penaltyScore;
|
|
||||||
private Boolean logRequests;
|
|
||||||
private Boolean logResponses;
|
|
||||||
}
|
|
||||||
@@ -1,18 +0,0 @@
|
|||||||
package dev.langchain4j.qianfan.spring;
|
|
||||||
|
|
||||||
import lombok.Getter;
|
|
||||||
import lombok.Setter;
|
|
||||||
|
|
||||||
@Getter
|
|
||||||
@Setter
|
|
||||||
class EmbeddingModelProperties {
|
|
||||||
private String baseUrl;
|
|
||||||
private String apiKey;
|
|
||||||
private String secretKey;
|
|
||||||
private Integer maxRetries;
|
|
||||||
private String modelName;
|
|
||||||
private String endpoint;
|
|
||||||
private String user;
|
|
||||||
private Boolean logRequests;
|
|
||||||
private Boolean logResponses;
|
|
||||||
}
|
|
||||||
@@ -1,21 +0,0 @@
|
|||||||
package dev.langchain4j.qianfan.spring;
|
|
||||||
|
|
||||||
import lombok.Getter;
|
|
||||||
import lombok.Setter;
|
|
||||||
|
|
||||||
@Getter
|
|
||||||
@Setter
|
|
||||||
class LanguageModelProperties {
|
|
||||||
private String baseUrl;
|
|
||||||
private String apiKey;
|
|
||||||
private String secretKey;
|
|
||||||
private Double temperature;
|
|
||||||
private Integer maxRetries;
|
|
||||||
private Integer topK;
|
|
||||||
private Double topP;
|
|
||||||
private String modelName;
|
|
||||||
private String endpoint;
|
|
||||||
private Double penaltyScore;
|
|
||||||
private Boolean logRequests;
|
|
||||||
private Boolean logResponses;
|
|
||||||
}
|
|
||||||
@@ -1,29 +0,0 @@
|
|||||||
package dev.langchain4j.qianfan.spring;
|
|
||||||
|
|
||||||
import lombok.Getter;
|
|
||||||
import lombok.Setter;
|
|
||||||
import org.springframework.boot.context.properties.ConfigurationProperties;
|
|
||||||
import org.springframework.boot.context.properties.NestedConfigurationProperty;
|
|
||||||
|
|
||||||
@Getter
|
|
||||||
@Setter
|
|
||||||
@ConfigurationProperties(prefix = Properties.PREFIX)
|
|
||||||
public class Properties {
|
|
||||||
|
|
||||||
static final String PREFIX = "langchain4j.qianfan";
|
|
||||||
|
|
||||||
@NestedConfigurationProperty
|
|
||||||
ChatModelProperties chatModel;
|
|
||||||
|
|
||||||
@NestedConfigurationProperty
|
|
||||||
ChatModelProperties streamingChatModel;
|
|
||||||
|
|
||||||
@NestedConfigurationProperty
|
|
||||||
LanguageModelProperties languageModel;
|
|
||||||
|
|
||||||
@NestedConfigurationProperty
|
|
||||||
LanguageModelProperties streamingLanguageModel;
|
|
||||||
|
|
||||||
@NestedConfigurationProperty
|
|
||||||
EmbeddingModelProperties embeddingModel;
|
|
||||||
}
|
|
||||||
@@ -1,102 +0,0 @@
|
|||||||
package dev.langchain4j.qianfan.spring;
|
|
||||||
|
|
||||||
import dev.langchain4j.model.qianfan.QianfanChatModel;
|
|
||||||
import dev.langchain4j.model.qianfan.QianfanEmbeddingModel;
|
|
||||||
import dev.langchain4j.model.qianfan.QianfanLanguageModel;
|
|
||||||
import dev.langchain4j.model.qianfan.QianfanStreamingChatModel;
|
|
||||||
import dev.langchain4j.model.qianfan.QianfanStreamingLanguageModel;
|
|
||||||
import org.springframework.boot.autoconfigure.condition.ConditionalOnProperty;
|
|
||||||
import org.springframework.boot.context.properties.EnableConfigurationProperties;
|
|
||||||
import org.springframework.context.annotation.Bean;
|
|
||||||
import org.springframework.context.annotation.Configuration;
|
|
||||||
|
|
||||||
import static dev.langchain4j.qianfan.spring.Properties.PREFIX;
|
|
||||||
|
|
||||||
@Configuration
|
|
||||||
@EnableConfigurationProperties(Properties.class)
|
|
||||||
public class QianfanAutoConfig {
|
|
||||||
|
|
||||||
@Bean
|
|
||||||
@ConditionalOnProperty(PREFIX + ".chat-model.api-key")
|
|
||||||
QianfanChatModel qianfanChatModel(Properties properties) {
|
|
||||||
ChatModelProperties chatModelProperties = properties.getChatModel();
|
|
||||||
return QianfanChatModel.builder().baseUrl(chatModelProperties.getBaseUrl())
|
|
||||||
.apiKey(chatModelProperties.getApiKey())
|
|
||||||
.secretKey(chatModelProperties.getSecretKey())
|
|
||||||
.endpoint(chatModelProperties.getEndpoint())
|
|
||||||
.penaltyScore(chatModelProperties.getPenaltyScore())
|
|
||||||
.modelName(chatModelProperties.getModelName())
|
|
||||||
.temperature(chatModelProperties.getTemperature())
|
|
||||||
.topP(chatModelProperties.getTopP())
|
|
||||||
.responseFormat(chatModelProperties.getResponseFormat())
|
|
||||||
.maxRetries(chatModelProperties.getMaxRetries())
|
|
||||||
.logRequests(chatModelProperties.getLogRequests())
|
|
||||||
.logResponses(chatModelProperties.getLogResponses()).build();
|
|
||||||
}
|
|
||||||
|
|
||||||
@Bean
|
|
||||||
@ConditionalOnProperty(PREFIX + ".streaming-chat-model.api-key")
|
|
||||||
QianfanStreamingChatModel qianfanStreamingChatModel(Properties properties) {
|
|
||||||
ChatModelProperties chatModelProperties = properties.getStreamingChatModel();
|
|
||||||
return QianfanStreamingChatModel.builder().endpoint(chatModelProperties.getEndpoint())
|
|
||||||
.penaltyScore(chatModelProperties.getPenaltyScore())
|
|
||||||
.temperature(chatModelProperties.getTemperature())
|
|
||||||
.topP(chatModelProperties.getTopP()).baseUrl(chatModelProperties.getBaseUrl())
|
|
||||||
.apiKey(chatModelProperties.getApiKey())
|
|
||||||
.secretKey(chatModelProperties.getSecretKey())
|
|
||||||
.modelName(chatModelProperties.getModelName())
|
|
||||||
.responseFormat(chatModelProperties.getResponseFormat())
|
|
||||||
.logRequests(chatModelProperties.getLogRequests())
|
|
||||||
.logResponses(chatModelProperties.getLogResponses()).build();
|
|
||||||
}
|
|
||||||
|
|
||||||
@Bean
|
|
||||||
@ConditionalOnProperty(PREFIX + ".language-model.api-key")
|
|
||||||
QianfanLanguageModel qianfanLanguageModel(Properties properties) {
|
|
||||||
LanguageModelProperties languageModelProperties = properties.getLanguageModel();
|
|
||||||
return QianfanLanguageModel.builder().endpoint(languageModelProperties.getEndpoint())
|
|
||||||
.penaltyScore(languageModelProperties.getPenaltyScore())
|
|
||||||
.topK(languageModelProperties.getTopK()).topP(languageModelProperties.getTopP())
|
|
||||||
.baseUrl(languageModelProperties.getBaseUrl())
|
|
||||||
.apiKey(languageModelProperties.getApiKey())
|
|
||||||
.secretKey(languageModelProperties.getSecretKey())
|
|
||||||
.modelName(languageModelProperties.getModelName())
|
|
||||||
.temperature(languageModelProperties.getTemperature())
|
|
||||||
.maxRetries(languageModelProperties.getMaxRetries())
|
|
||||||
.logRequests(languageModelProperties.getLogRequests())
|
|
||||||
.logResponses(languageModelProperties.getLogResponses()).build();
|
|
||||||
}
|
|
||||||
|
|
||||||
@Bean
|
|
||||||
@ConditionalOnProperty(PREFIX + ".streaming-language-model.api-key")
|
|
||||||
QianfanStreamingLanguageModel qianfanStreamingLanguageModel(Properties properties) {
|
|
||||||
LanguageModelProperties languageModelProperties = properties.getStreamingLanguageModel();
|
|
||||||
return QianfanStreamingLanguageModel.builder()
|
|
||||||
.endpoint(languageModelProperties.getEndpoint())
|
|
||||||
.penaltyScore(languageModelProperties.getPenaltyScore())
|
|
||||||
.topK(languageModelProperties.getTopK()).topP(languageModelProperties.getTopP())
|
|
||||||
.baseUrl(languageModelProperties.getBaseUrl())
|
|
||||||
.apiKey(languageModelProperties.getApiKey())
|
|
||||||
.secretKey(languageModelProperties.getSecretKey())
|
|
||||||
.modelName(languageModelProperties.getModelName())
|
|
||||||
.temperature(languageModelProperties.getTemperature())
|
|
||||||
.maxRetries(languageModelProperties.getMaxRetries())
|
|
||||||
.logRequests(languageModelProperties.getLogRequests())
|
|
||||||
.logResponses(languageModelProperties.getLogResponses()).build();
|
|
||||||
}
|
|
||||||
|
|
||||||
@Bean
|
|
||||||
@ConditionalOnProperty(PREFIX + ".embedding-model.api-key")
|
|
||||||
QianfanEmbeddingModel qianfanEmbeddingModel(Properties properties) {
|
|
||||||
EmbeddingModelProperties embeddingModelProperties = properties.getEmbeddingModel();
|
|
||||||
return QianfanEmbeddingModel.builder().baseUrl(embeddingModelProperties.getBaseUrl())
|
|
||||||
.endpoint(embeddingModelProperties.getEndpoint())
|
|
||||||
.apiKey(embeddingModelProperties.getApiKey())
|
|
||||||
.secretKey(embeddingModelProperties.getSecretKey())
|
|
||||||
.modelName(embeddingModelProperties.getModelName())
|
|
||||||
.user(embeddingModelProperties.getUser())
|
|
||||||
.maxRetries(embeddingModelProperties.getMaxRetries())
|
|
||||||
.logRequests(embeddingModelProperties.getLogRequests())
|
|
||||||
.logResponses(embeddingModelProperties.getLogResponses()).build();
|
|
||||||
}
|
|
||||||
}
|
|
||||||
@@ -57,6 +57,7 @@ public class MilvusEmbeddingStore implements EmbeddingStore<TextSegment> {
|
|||||||
private final ConsistencyLevelEnum consistencyLevel;
|
private final ConsistencyLevelEnum consistencyLevel;
|
||||||
private final boolean retrieveEmbeddingsOnSearch;
|
private final boolean retrieveEmbeddingsOnSearch;
|
||||||
private final boolean autoFlushOnInsert;
|
private final boolean autoFlushOnInsert;
|
||||||
|
private final FieldDefinition fieldDefinition;
|
||||||
|
|
||||||
public MilvusEmbeddingStore(String host, Integer port, String collectionName, Integer dimension,
|
public MilvusEmbeddingStore(String host, Integer port, String collectionName, Integer dimension,
|
||||||
IndexType indexType, MetricType metricType, String uri, String token, String username,
|
IndexType indexType, MetricType metricType, String uri, String token, String username,
|
||||||
@@ -78,11 +79,15 @@ public class MilvusEmbeddingStore implements EmbeddingStore<TextSegment> {
|
|||||||
this.retrieveEmbeddingsOnSearch = getOrDefault(retrieveEmbeddingsOnSearch, false);
|
this.retrieveEmbeddingsOnSearch = getOrDefault(retrieveEmbeddingsOnSearch, false);
|
||||||
this.autoFlushOnInsert = getOrDefault(autoFlushOnInsert, false);
|
this.autoFlushOnInsert = getOrDefault(autoFlushOnInsert, false);
|
||||||
|
|
||||||
|
// Define the field structure for the collection
|
||||||
|
this.fieldDefinition = new FieldDefinition(ID_FIELD_NAME, TEXT_FIELD_NAME,
|
||||||
|
METADATA_FIELD_NAME, VECTOR_FIELD_NAME);
|
||||||
|
|
||||||
if (!hasCollection(this.milvusClient, this.collectionName)) {
|
if (!hasCollection(this.milvusClient, this.collectionName)) {
|
||||||
createCollection(this.milvusClient, this.collectionName,
|
createCollection(this.milvusClient, this.collectionName, fieldDefinition,
|
||||||
ensureNotNull(dimension, "dimension"));
|
ensureNotNull(dimension, "dimension"));
|
||||||
createIndex(this.milvusClient, this.collectionName, getOrDefault(indexType, FLAT),
|
createIndex(this.milvusClient, this.collectionName, VECTOR_FIELD_NAME,
|
||||||
this.metricType);
|
getOrDefault(indexType, FLAT), this.metricType);
|
||||||
}
|
}
|
||||||
|
|
||||||
loadCollectionInMemory(this.milvusClient, collectionName);
|
loadCollectionInMemory(this.milvusClient, collectionName);
|
||||||
@@ -128,7 +133,7 @@ public class MilvusEmbeddingStore implements EmbeddingStore<TextSegment> {
|
|||||||
public EmbeddingSearchResult<TextSegment> search(
|
public EmbeddingSearchResult<TextSegment> search(
|
||||||
EmbeddingSearchRequest embeddingSearchRequest) {
|
EmbeddingSearchRequest embeddingSearchRequest) {
|
||||||
|
|
||||||
SearchParam searchParam = buildSearchRequest(collectionName,
|
SearchParam searchParam = buildSearchRequest(collectionName, fieldDefinition,
|
||||||
embeddingSearchRequest.queryEmbedding().vectorAsList(),
|
embeddingSearchRequest.queryEmbedding().vectorAsList(),
|
||||||
embeddingSearchRequest.filter(), embeddingSearchRequest.maxResults(), metricType,
|
embeddingSearchRequest.filter(), embeddingSearchRequest.maxResults(), metricType,
|
||||||
consistencyLevel);
|
consistencyLevel);
|
||||||
@@ -137,7 +142,7 @@ public class MilvusEmbeddingStore implements EmbeddingStore<TextSegment> {
|
|||||||
CollectionOperationsExecutor.search(milvusClient, searchParam);
|
CollectionOperationsExecutor.search(milvusClient, searchParam);
|
||||||
|
|
||||||
List<EmbeddingMatch<TextSegment>> matches = toEmbeddingMatches(milvusClient, resultsWrapper,
|
List<EmbeddingMatch<TextSegment>> matches = toEmbeddingMatches(milvusClient, resultsWrapper,
|
||||||
collectionName, consistencyLevel, retrieveEmbeddingsOnSearch);
|
collectionName, fieldDefinition, consistencyLevel, retrieveEmbeddingsOnSearch);
|
||||||
|
|
||||||
List<EmbeddingMatch<TextSegment>> result =
|
List<EmbeddingMatch<TextSegment>> result =
|
||||||
matches.stream().filter(match -> match.score() >= embeddingSearchRequest.minScore())
|
matches.stream().filter(match -> match.score() >= embeddingSearchRequest.minScore())
|
||||||
@@ -226,7 +231,7 @@ public class MilvusEmbeddingStore implements EmbeddingStore<TextSegment> {
|
|||||||
@Override
|
@Override
|
||||||
public void removeAll(Filter filter) {
|
public void removeAll(Filter filter) {
|
||||||
ensureNotNull(filter, "filter");
|
ensureNotNull(filter, "filter");
|
||||||
removeForVector(this.milvusClient, this.collectionName, map(filter));
|
removeForVector(this.milvusClient, this.collectionName, map(filter, METADATA_FIELD_NAME));
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
|||||||
@@ -1,19 +0,0 @@
|
|||||||
package dev.langchain4j.zhipu.spring;
|
|
||||||
|
|
||||||
import lombok.Getter;
|
|
||||||
import lombok.Setter;
|
|
||||||
|
|
||||||
@Getter
|
|
||||||
@Setter
|
|
||||||
class ChatModelProperties {
|
|
||||||
|
|
||||||
String baseUrl;
|
|
||||||
String apiKey;
|
|
||||||
Double temperature;
|
|
||||||
Double topP;
|
|
||||||
String modelName;
|
|
||||||
Integer maxRetries;
|
|
||||||
Integer maxToken;
|
|
||||||
Boolean logRequests;
|
|
||||||
Boolean logResponses;
|
|
||||||
}
|
|
||||||
@@ -1,16 +0,0 @@
|
|||||||
package dev.langchain4j.zhipu.spring;
|
|
||||||
|
|
||||||
import lombok.Getter;
|
|
||||||
import lombok.Setter;
|
|
||||||
|
|
||||||
@Getter
|
|
||||||
@Setter
|
|
||||||
class EmbeddingModelProperties {
|
|
||||||
|
|
||||||
String baseUrl;
|
|
||||||
String apiKey;
|
|
||||||
String model;
|
|
||||||
Integer maxRetries;
|
|
||||||
Boolean logRequests;
|
|
||||||
Boolean logResponses;
|
|
||||||
}
|
|
||||||
@@ -1,23 +0,0 @@
|
|||||||
package dev.langchain4j.zhipu.spring;
|
|
||||||
|
|
||||||
import lombok.Getter;
|
|
||||||
import lombok.Setter;
|
|
||||||
import org.springframework.boot.context.properties.ConfigurationProperties;
|
|
||||||
import org.springframework.boot.context.properties.NestedConfigurationProperty;
|
|
||||||
|
|
||||||
@Getter
|
|
||||||
@Setter
|
|
||||||
@ConfigurationProperties(prefix = Properties.PREFIX)
|
|
||||||
public class Properties {
|
|
||||||
|
|
||||||
static final String PREFIX = "langchain4j.zhipu";
|
|
||||||
|
|
||||||
@NestedConfigurationProperty
|
|
||||||
ChatModelProperties chatModel;
|
|
||||||
|
|
||||||
@NestedConfigurationProperty
|
|
||||||
ChatModelProperties streamingChatModel;
|
|
||||||
|
|
||||||
@NestedConfigurationProperty
|
|
||||||
EmbeddingModelProperties embeddingModel;
|
|
||||||
}
|
|
||||||
@@ -1,53 +0,0 @@
|
|||||||
package dev.langchain4j.zhipu.spring;
|
|
||||||
|
|
||||||
import dev.langchain4j.model.zhipu.ZhipuAiChatModel;
|
|
||||||
import dev.langchain4j.model.zhipu.ZhipuAiEmbeddingModel;
|
|
||||||
import dev.langchain4j.model.zhipu.ZhipuAiStreamingChatModel;
|
|
||||||
import org.springframework.boot.autoconfigure.condition.ConditionalOnProperty;
|
|
||||||
import org.springframework.boot.context.properties.EnableConfigurationProperties;
|
|
||||||
import org.springframework.context.annotation.Bean;
|
|
||||||
import org.springframework.context.annotation.Configuration;
|
|
||||||
|
|
||||||
import static dev.langchain4j.zhipu.spring.Properties.PREFIX;
|
|
||||||
|
|
||||||
@Configuration
|
|
||||||
@EnableConfigurationProperties(Properties.class)
|
|
||||||
public class ZhipuAutoConfig {
|
|
||||||
|
|
||||||
@Bean
|
|
||||||
@ConditionalOnProperty(PREFIX + ".chat-model.api-key")
|
|
||||||
ZhipuAiChatModel zhipuAiChatModel(Properties properties) {
|
|
||||||
ChatModelProperties chatModelProperties = properties.getChatModel();
|
|
||||||
return ZhipuAiChatModel.builder().baseUrl(chatModelProperties.getBaseUrl())
|
|
||||||
.apiKey(chatModelProperties.getApiKey()).model(chatModelProperties.getModelName())
|
|
||||||
.temperature(chatModelProperties.getTemperature())
|
|
||||||
.topP(chatModelProperties.getTopP()).maxRetries(chatModelProperties.getMaxRetries())
|
|
||||||
.maxToken(chatModelProperties.getMaxToken())
|
|
||||||
.logRequests(chatModelProperties.getLogRequests())
|
|
||||||
.logResponses(chatModelProperties.getLogResponses()).build();
|
|
||||||
}
|
|
||||||
|
|
||||||
@Bean
|
|
||||||
@ConditionalOnProperty(PREFIX + ".streaming-chat-model.api-key")
|
|
||||||
ZhipuAiStreamingChatModel zhipuStreamingChatModel(Properties properties) {
|
|
||||||
ChatModelProperties chatModelProperties = properties.getStreamingChatModel();
|
|
||||||
return ZhipuAiStreamingChatModel.builder().baseUrl(chatModelProperties.getBaseUrl())
|
|
||||||
.apiKey(chatModelProperties.getApiKey()).model(chatModelProperties.getModelName())
|
|
||||||
.temperature(chatModelProperties.getTemperature())
|
|
||||||
.topP(chatModelProperties.getTopP()).maxToken(chatModelProperties.getMaxToken())
|
|
||||||
.logRequests(chatModelProperties.getLogRequests())
|
|
||||||
.logResponses(chatModelProperties.getLogResponses()).build();
|
|
||||||
}
|
|
||||||
|
|
||||||
@Bean
|
|
||||||
@ConditionalOnProperty(PREFIX + ".embedding-model.api-key")
|
|
||||||
ZhipuAiEmbeddingModel zhipuEmbeddingModel(Properties properties) {
|
|
||||||
EmbeddingModelProperties embeddingModelProperties = properties.getEmbeddingModel();
|
|
||||||
return ZhipuAiEmbeddingModel.builder().baseUrl(embeddingModelProperties.getBaseUrl())
|
|
||||||
.apiKey(embeddingModelProperties.getApiKey())
|
|
||||||
.model(embeddingModelProperties.getModel())
|
|
||||||
.maxRetries(embeddingModelProperties.getMaxRetries())
|
|
||||||
.logRequests(embeddingModelProperties.getLogRequests())
|
|
||||||
.logResponses(embeddingModelProperties.getLogResponses()).build();
|
|
||||||
}
|
|
||||||
}
|
|
||||||
@@ -2,7 +2,6 @@ package com.tencent.supersonic.common.calcite;
|
|||||||
|
|
||||||
import com.tencent.supersonic.common.pojo.enums.EngineType;
|
import com.tencent.supersonic.common.pojo.enums.EngineType;
|
||||||
import lombok.extern.slf4j.Slf4j;
|
import lombok.extern.slf4j.Slf4j;
|
||||||
import org.apache.calcite.sql.parser.SqlParseException;
|
|
||||||
import org.junit.Assert;
|
import org.junit.Assert;
|
||||||
import org.junit.jupiter.api.Test;
|
import org.junit.jupiter.api.Test;
|
||||||
|
|
||||||
@@ -12,7 +11,7 @@ import java.util.Collections;
|
|||||||
class SqlWithMergerTest {
|
class SqlWithMergerTest {
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
void test1() throws SqlParseException {
|
void test1() throws Exception {
|
||||||
String sql1 = "WITH DepartmentVisits AS (\n" + " SELECT department, SUM(pv) AS 总访问次数\n"
|
String sql1 = "WITH DepartmentVisits AS (\n" + " SELECT department, SUM(pv) AS 总访问次数\n"
|
||||||
+ " FROM t_1\n"
|
+ " FROM t_1\n"
|
||||||
+ " WHERE sys_imp_date >= '2024-09-01' AND sys_imp_date <= '2024-09-29'\n"
|
+ " WHERE sys_imp_date >= '2024-09-01' AND sys_imp_date <= '2024-09-29'\n"
|
||||||
@@ -38,7 +37,7 @@ class SqlWithMergerTest {
|
|||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
void test2() throws SqlParseException {
|
void test2() throws Exception {
|
||||||
|
|
||||||
String sql1 =
|
String sql1 =
|
||||||
"WITH DepartmentVisits AS (SELECT department, SUM(pv) AS 总访问次数 FROM t_1 WHERE sys_imp_date >= '2024-08-28' "
|
"WITH DepartmentVisits AS (SELECT department, SUM(pv) AS 总访问次数 FROM t_1 WHERE sys_imp_date >= '2024-08-28' "
|
||||||
@@ -65,7 +64,7 @@ class SqlWithMergerTest {
|
|||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
void test3() throws SqlParseException {
|
void test3() throws Exception {
|
||||||
|
|
||||||
String sql1 = "SELECT COUNT(*) FROM DepartmentVisits WHERE 总访问次数 > 100 LIMIT 1000";
|
String sql1 = "SELECT COUNT(*) FROM DepartmentVisits WHERE 总访问次数 > 100 LIMIT 1000";
|
||||||
|
|
||||||
@@ -89,7 +88,7 @@ class SqlWithMergerTest {
|
|||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
void test4() throws SqlParseException {
|
void test4() throws Exception {
|
||||||
String sql1 = "SELECT COUNT(*) FROM DepartmentVisits WHERE 总访问次数 > 100";
|
String sql1 = "SELECT COUNT(*) FROM DepartmentVisits WHERE 总访问次数 > 100";
|
||||||
|
|
||||||
String sql2 =
|
String sql2 =
|
||||||
@@ -112,7 +111,7 @@ class SqlWithMergerTest {
|
|||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
void test5() throws SqlParseException {
|
void test5() throws Exception {
|
||||||
|
|
||||||
String sql1 = "SELECT COUNT(*) FROM Department join Visits WHERE 总访问次数 > 100";
|
String sql1 = "SELECT COUNT(*) FROM Department join Visits WHERE 总访问次数 > 100";
|
||||||
|
|
||||||
@@ -132,13 +131,13 @@ class SqlWithMergerTest {
|
|||||||
"WITH t_1 AS (SELECT `t3`.`sys_imp_date`, `t2`.`department`, `t3`.`s2_pv_uv_statis_pv` AS `pv` "
|
"WITH t_1 AS (SELECT `t3`.`sys_imp_date`, `t2`.`department`, `t3`.`s2_pv_uv_statis_pv` AS `pv` "
|
||||||
+ "FROM (SELECT `user_name`, `department` FROM `s2_user_department`) AS `t2` LEFT JOIN "
|
+ "FROM (SELECT `user_name`, `department` FROM `s2_user_department`) AS `t2` LEFT JOIN "
|
||||||
+ "(SELECT 1 AS `s2_pv_uv_statis_pv`, `imp_date` AS `sys_imp_date`, `user_name` FROM `s2_pv_uv_statis`) "
|
+ "(SELECT 1 AS `s2_pv_uv_statis_pv`, `imp_date` AS `sys_imp_date`, `user_name` FROM `s2_pv_uv_statis`) "
|
||||||
+ "AS `t3` ON `t2`.`user_name` = `t3`.`user_name`) SELECT COUNT(*) FROM Department INNER JOIN Visits "
|
+ "AS `t3` ON `t2`.`user_name` = `t3`.`user_name`) SELECT COUNT(*) FROM Department JOIN Visits "
|
||||||
+ "WHERE 总访问次数 > 100");
|
+ "WHERE 总访问次数 > 100");
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
void test6() throws SqlParseException {
|
void test6() throws Exception {
|
||||||
|
|
||||||
String sql1 =
|
String sql1 =
|
||||||
"SELECT COUNT(*) FROM Department join Visits WHERE 总访问次数 > 100 ORDER BY 总访问次数 LIMIT 10";
|
"SELECT COUNT(*) FROM Department join Visits WHERE 总访问次数 > 100 ORDER BY 总访问次数 LIMIT 10";
|
||||||
@@ -159,7 +158,36 @@ class SqlWithMergerTest {
|
|||||||
"WITH t_1 AS (SELECT `t3`.`sys_imp_date`, `t2`.`department`, `t3`.`s2_pv_uv_statis_pv` AS `pv` FROM "
|
"WITH t_1 AS (SELECT `t3`.`sys_imp_date`, `t2`.`department`, `t3`.`s2_pv_uv_statis_pv` AS `pv` FROM "
|
||||||
+ "(SELECT `user_name`, `department` FROM `s2_user_department`) AS `t2` LEFT JOIN (SELECT 1 AS `s2_pv_uv_statis_pv`,"
|
+ "(SELECT `user_name`, `department` FROM `s2_user_department`) AS `t2` LEFT JOIN (SELECT 1 AS `s2_pv_uv_statis_pv`,"
|
||||||
+ " `imp_date` AS `sys_imp_date`, `user_name` FROM `s2_pv_uv_statis`) AS `t3` ON `t2`.`user_name` = `t3`.`user_name`) "
|
+ " `imp_date` AS `sys_imp_date`, `user_name` FROM `s2_pv_uv_statis`) AS `t3` ON `t2`.`user_name` = `t3`.`user_name`) "
|
||||||
+ "SELECT COUNT(*) FROM Department INNER JOIN Visits WHERE 总访问次数 > 100 ORDER BY 总访问次数 LIMIT 10");
|
+ "SELECT COUNT(*) FROM Department JOIN Visits WHERE 总访问次数 > 100 ORDER BY 总访问次数 LIMIT 10");
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
void test7() throws Exception {
|
||||||
|
|
||||||
|
String sql1 =
|
||||||
|
"SELECT COUNT(*) FROM Department join Visits WHERE 总访问次数 > 100 AND imp_date >= CURRENT_DATE - "
|
||||||
|
+ "INTERVAL '1 year' AND sys_imp_date < CURRENT_DATE ORDER"
|
||||||
|
+ " BY 总访问次数 LIMIT 10";
|
||||||
|
|
||||||
|
String sql2 =
|
||||||
|
"SELECT `t3`.`sys_imp_date`, `t2`.`department`, `t3`.`s2_pv_uv_statis_pv` AS `pv`\n"
|
||||||
|
+ "FROM\n" + "(SELECT `user_name`, `department`\n" + "FROM\n"
|
||||||
|
+ "`s2_user_department`) AS `t2`\n"
|
||||||
|
+ "LEFT JOIN (SELECT 1 AS `s2_pv_uv_statis_pv`, `imp_date` AS `sys_imp_date`, `user_name`\n"
|
||||||
|
+ "FROM\n"
|
||||||
|
+ "`s2_pv_uv_statis`) AS `t3` ON `t2`.`user_name` = `t3`.`user_name`";
|
||||||
|
|
||||||
|
String mergeSql = SqlMergeWithUtils.mergeWith(EngineType.MYSQL, sql1,
|
||||||
|
Collections.singletonList(sql2), Collections.singletonList("t_1"));
|
||||||
|
|
||||||
|
|
||||||
|
Assert.assertEquals(format(mergeSql),
|
||||||
|
"WITH t_1 AS (SELECT `t3`.`sys_imp_date`, `t2`.`department`, `t3`.`s2_pv_uv_statis_pv` AS `pv` FROM "
|
||||||
|
+ "(SELECT `user_name`, `department` FROM `s2_user_department`) AS `t2` LEFT JOIN (SELECT 1 AS `s2_pv_uv_statis_pv`,"
|
||||||
|
+ " `imp_date` AS `sys_imp_date`, `user_name` FROM `s2_pv_uv_statis`) AS `t3` ON `t2`.`user_name` = `t3`.`user_name`) "
|
||||||
|
+ "SELECT COUNT(*) FROM Department JOIN Visits WHERE 总访问次数 > 100 AND imp_date >= "
|
||||||
|
+ "CURRENT_DATE - INTERVAL '1 year' AND sys_imp_date < CURRENT_DATE ORDER BY 总访问次数 "
|
||||||
|
+ "LIMIT 10");
|
||||||
}
|
}
|
||||||
|
|
||||||
private static String format(String mergeSql) {
|
private static String format(String mergeSql) {
|
||||||
|
|||||||
@@ -1,12 +1,8 @@
|
|||||||
FROM supersonicbi/supersonic:0.9.10-SNAPSHOT
|
FROM openjdk:21-jdk-bullseye as base
|
||||||
|
|
||||||
# Set the working directory in the container
|
# Set the working directory in the container
|
||||||
WORKDIR /usr/src/app
|
WORKDIR /usr/src/app
|
||||||
|
|
||||||
# Delete old supersonic installation directory and the symbolic link
|
|
||||||
RUN rm -rf /usr/src/app/supersonic-standalone-0.9.10-SNAPSHOT
|
|
||||||
RUN rm -f /usr/src/app/supersonic-standalone-latest
|
|
||||||
|
|
||||||
# Argument to pass in the supersonic version at build time
|
# Argument to pass in the supersonic version at build time
|
||||||
ARG SUPERSONIC_VERSION
|
ARG SUPERSONIC_VERSION
|
||||||
|
|
||||||
@@ -17,6 +13,17 @@ COPY assembly/build/supersonic-standalone-${SUPERSONIC_VERSION}.zip .
|
|||||||
RUN unzip supersonic-standalone-${SUPERSONIC_VERSION}.zip && \
|
RUN unzip supersonic-standalone-${SUPERSONIC_VERSION}.zip && \
|
||||||
rm supersonic-standalone-${SUPERSONIC_VERSION}.zip
|
rm supersonic-standalone-${SUPERSONIC_VERSION}.zip
|
||||||
|
|
||||||
|
FROM openjdk:21-slim
|
||||||
|
|
||||||
|
# Set the working directory in the container
|
||||||
|
WORKDIR /usr/src/app
|
||||||
|
|
||||||
|
# Argument to pass in the supersonic version at build time
|
||||||
|
ARG SUPERSONIC_VERSION
|
||||||
|
|
||||||
|
# Copy the supersonic standalone folder into the container
|
||||||
|
COPY --from=base /usr/src/app/supersonic-standalone-${SUPERSONIC_VERSION} ./supersonic-standalone-${SUPERSONIC_VERSION}
|
||||||
|
|
||||||
# Create a symbolic link to the supersonic installation directory
|
# Create a symbolic link to the supersonic installation directory
|
||||||
RUN ln -s /usr/src/app/supersonic-standalone-${SUPERSONIC_VERSION} /usr/src/app/supersonic-standalone-latest
|
RUN ln -s /usr/src/app/supersonic-standalone-${SUPERSONIC_VERSION} /usr/src/app/supersonic-standalone-latest
|
||||||
|
|
||||||
@@ -27,4 +34,4 @@ WORKDIR /usr/src/app/supersonic-standalone-${SUPERSONIC_VERSION}
|
|||||||
EXPOSE 9080
|
EXPOSE 9080
|
||||||
# Command to run the supersonic daemon
|
# Command to run the supersonic daemon
|
||||||
RUN chmod +x bin/supersonic-daemon.sh
|
RUN chmod +x bin/supersonic-daemon.sh
|
||||||
CMD ["bash", "-c", "bin/supersonic-daemon.sh restart standalone docker && tail -f /dev/null"]
|
CMD ["bash", "-c", "bin/supersonic-daemon.sh restart standalone ${S2_DB_TYPE} && tail -f /dev/null"]
|
||||||
|
|||||||
@@ -2,8 +2,8 @@
|
|||||||
|
|
||||||
# Function to execute the build script
|
# Function to execute the build script
|
||||||
execute_build_script() {
|
execute_build_script() {
|
||||||
echo "Executing build script: assembly/bin/supersonic-build.sh"
|
echo "Executing build script: sh assembly/bin/supersonic-build.sh"
|
||||||
assembly/bin/supersonic-build.sh
|
sh assembly/bin/supersonic-build.sh
|
||||||
}
|
}
|
||||||
|
|
||||||
# Function to build the Docker image
|
# Function to build the Docker image
|
||||||
|
|||||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user