mirror of
https://github.com/tencentmusic/supersonic.git
synced 2026-04-20 13:44:19 +08:00
Compare commits
3 Commits
master
...
ecea348c44
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
ecea348c44 | ||
|
|
b1dadb4a1a | ||
|
|
158a0a802a |
9
.github/workflows/centos-ci.yml
vendored
9
.github/workflows/centos-ci.yml
vendored
@@ -11,7 +11,7 @@ jobs:
|
||||
build:
|
||||
runs-on: ubuntu-latest
|
||||
container:
|
||||
image: almalinux:9 # maven >=3.6.3
|
||||
image: quay.io/centos/centos:stream8 # 使用 CentOS Stream 8 容器
|
||||
|
||||
strategy:
|
||||
matrix:
|
||||
@@ -28,10 +28,9 @@ jobs:
|
||||
|
||||
- name: Reset DNF repositories
|
||||
run: |
|
||||
sed -e 's|^mirrorlist=|#mirrorlist=|g' \
|
||||
-e 's|^# baseurl=https://repo.almalinux.org|baseurl=https://mirrors.aliyun.com|g' \
|
||||
/etc/yum.repos.d/almalinux*.repo
|
||||
|
||||
cd /etc/yum.repos.d/
|
||||
sed -i 's/mirrorlist/#mirrorlist/g' /etc/yum.repos.d/CentOS-*
|
||||
sed -i 's|#baseurl=http://mirror.centos.org|baseurl=http://vault.centos.org|g' /etc/yum.repos.d/CentOS-*
|
||||
|
||||
- name: Update DNF package index
|
||||
run: dnf makecache
|
||||
|
||||
14
.github/workflows/mac-ci.yml
vendored
14
.github/workflows/mac-ci.yml
vendored
@@ -17,27 +17,21 @@ jobs:
|
||||
java-version: [21] # Define the JDK versions to test
|
||||
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
- uses: actions/checkout@v2
|
||||
|
||||
- name: Set up JDK ${{ matrix.java-version }}
|
||||
uses: actions/setup-java@v3
|
||||
uses: actions/setup-java@v2
|
||||
with:
|
||||
java-version: ${{ matrix.java-version }}
|
||||
distribution: 'temurin'
|
||||
distribution: 'adopt'
|
||||
|
||||
- name: Cache Maven packages
|
||||
uses: actions/cache@v3
|
||||
uses: actions/cache@v4
|
||||
with:
|
||||
path: ~/Library/Caches/Maven # macOS Maven cache path
|
||||
key: ${{ runner.os }}-m2-${{ hashFiles('**/pom.xml') }}
|
||||
restore-keys: ${{ runner.os }}-m2
|
||||
|
||||
- name: Install system dependencies
|
||||
run: |
|
||||
brew update
|
||||
brew install cmake
|
||||
brew install gcc
|
||||
|
||||
- name: Build with Maven
|
||||
run: mvn -B package --file pom.xml
|
||||
|
||||
|
||||
3
.gitignore
vendored
3
.gitignore
vendored
@@ -20,5 +20,4 @@ chm_db/
|
||||
__pycache__/
|
||||
/dict
|
||||
assembly/build/*-SNAPSHOT
|
||||
**/node_modules/
|
||||
benchmark/res/
|
||||
**/node_modules/
|
||||
199
CHANGELOG.md
199
CHANGELOG.md
@@ -3,205 +3,6 @@
|
||||
- All notable changes to this project will be documented in this file.
|
||||
- "Breaking Changes" describes any changes that may break existing functionality or cause
|
||||
compatibility issues with previous versions.
|
||||
## SuperSonic [1.0.0] - 2025-08-05
|
||||
|
||||
### 重大特性变更 / Major Features
|
||||
|
||||
#### 多数据库支持扩展 / Multi-Database Support
|
||||
- **Oracle数据库支持**: 新增Oracle数据库引擎类型及适配器 ([8eeed87ba](https://github.com/tencentmusic/supersonic/commit/8eeed87ba) by supersonicbi)
|
||||
- **StarRocks支持**: 支持StarRocks和多catalog功能 ([33268bf3d](https://github.com/tencentmusic/supersonic/commit/33268bf3d) by zyclove)
|
||||
- **SAP HANA支持**: 新增SAP HANA数据库适配支持 ([2e28a4c7a](https://github.com/tencentmusic/supersonic/commit/2e28a4c7a) by wwsheng009)
|
||||
- **DuckDB支持**: 支持DuckDB数据库 ([a058dc8b6](https://github.com/tencentmusic/supersonic/commit/a058dc8b6) by jerryjzhang)
|
||||
- **Kyuubi支持**: 支持Kyuubi Presto Trino ([5e3bafb95](https://github.com/tencentmusic/supersonic/commit/5e3bafb95) by zyclove)
|
||||
- **OpenSearch支持**: 新增OpenSearch支持 ([d942d35c9](https://github.com/tencentmusic/supersonic/commit/d942d35c9) by zyclove)
|
||||
|
||||
#### 智能问答增强 / AI-Enhanced Query Processing
|
||||
- **LLM纠错器**: 新增LLM物理SQL纠错器 ([f899d23b6](https://github.com/tencentmusic/supersonic/commit/f899d23b6) by 柯慕灵)
|
||||
- **记忆管理**: Agent记忆管理启用few-shot优先机制 ([fae9118c2](https://github.com/tencentmusic/supersonic/commit/fae9118c2) by feelshana)
|
||||
- **结构化查询**: 支持struct查询中的offset子句 ([d2a43a99c](https://github.com/tencentmusic/supersonic/commit/d2a43a99c) by jerryjzhang)
|
||||
- **向量召回优化**: 优化嵌入向量召回机制 ([8c6ae6252](https://github.com/tencentmusic/supersonic/commit/8c6ae6252) by lexluo09)
|
||||
|
||||
#### 权限管理系统 / Permission Management
|
||||
- **Agent权限**: 支持agent级别的权限管理 ([b5aa6e046](https://github.com/tencentmusic/supersonic/commit/b5aa6e046) by jerryjzhang)
|
||||
- **用户管理**: 支持用户删除功能 ([1c9cf788c](https://github.com/tencentmusic/supersonic/commit/1c9cf788c) by supersonicbi)
|
||||
- **鉴权优化**: 全面优化鉴权与召回机制 ([1faf84e37](https://github.com/tencentmusic/supersonic/commit/1faf84e37), [7e6639df8](https://github.com/tencentmusic/supersonic/commit/7e6639df8) by guilinlewis)
|
||||
|
||||
### 架构升级 / Architecture Upgrades
|
||||
|
||||
#### 核心框架升级 / Core Framework Upgrades
|
||||
- **SpringBoot 3升级**: 完成SpringBoot 3.x升级 ([07f6be51c](https://github.com/tencentmusic/supersonic/commit/07f6be51c) by mislayming)
|
||||
- **依赖升级**: 升级依赖包并修复安全漏洞 ([232a20227](https://github.com/tencentmusic/supersonic/commit/232a20227) by beat4ocean)
|
||||
- **LangChain4j更新**: 替换已废弃的LangChain4j APIs ([acffc03c7](https://github.com/tencentmusic/supersonic/commit/acffc03c7) by beat4ocean)
|
||||
- **Swagger升级**: 使用SpringDoc支持Swagger在Spring 3.x ([758d170bb](https://github.com/tencentmusic/supersonic/commit/758d170bb) by jerryjzhang)
|
||||
|
||||
#### 许可证变更 / License Changes
|
||||
- **Apache 2.0**: 从MIT更改为Apache 2.0许可证 ([0aa002882](https://github.com/tencentmusic/supersonic/commit/0aa002882) by jerryjzhang)
|
||||
|
||||
### 性能优化 / Performance Improvements
|
||||
|
||||
#### 系统性能 / System Performance
|
||||
- **GC优化**: 实现Generational ZGC ([3fc1ec42b](https://github.com/tencentmusic/supersonic/commit/3fc1ec42b) by beat4ocean)
|
||||
- **Docker优化**: 减少Docker镜像体积 ([614917ba7](https://github.com/tencentmusic/supersonic/commit/614917ba7) by kino)
|
||||
- **并行处理**: 嵌入向量并行执行优化 ([8c6ae6252](https://github.com/tencentmusic/supersonic/commit/8c6ae6252) by lexluo09)
|
||||
- **记忆评估**: 记忆评估性能优化 ([524ec38ed](https://github.com/tencentmusic/supersonic/commit/524ec38ed) by yudong)
|
||||
- **多平台构建**: 支持Docker多平台构建 ([da6d28c18](https://github.com/tencentmusic/supersonic/commit/da6d28c18) by jerryjzhang)
|
||||
|
||||
#### 数据处理优化 / Data Processing Optimization
|
||||
- **日期格式**: 支持更多日期字符串格式 ([2b13866c0](https://github.com/tencentmusic/supersonic/commit/2b13866c0) by supersonicbi)
|
||||
- **SQL优化**: 优化SQL生成和执行性能 ([0ab764329](https://github.com/tencentmusic/supersonic/commit/0ab764329) by jerryjzhang)
|
||||
- **模型关联**: 优化模型关联查询性能 ([47c2595fb](https://github.com/tencentmusic/supersonic/commit/47c2595fb) by Willy-J)
|
||||
|
||||
### 功能增强 / Feature Enhancements
|
||||
|
||||
#### 前端界面优化 / Frontend Improvements
|
||||
- **图表导出**: 消息支持导出图表图片 ([ce9ae1c0c](https://github.com/tencentmusic/supersonic/commit/ce9ae1c0c) by pisces)
|
||||
- **路由重构**: 重构语义建模路由交互 ([82c63a7f2](https://github.com/tencentmusic/supersonic/commit/82c63a7f2) by tristanliu)
|
||||
- **权限界面**: 统一助理权限设置交互界面 ([46d64d78f](https://github.com/tencentmusic/supersonic/commit/46d64d78f) by tristanliu)
|
||||
- **图表优化**: 优化ChatMsg图表条件 ([06fb6ba74](https://github.com/tencentmusic/supersonic/commit/06fb6ba74) by FredTsang)
|
||||
- **数据格式**: 提取formatByDataFormatType()方法 ([9ffdba956](https://github.com/tencentmusic/supersonic/commit/9ffdba956) by FredTsang)
|
||||
|
||||
#### 开发体验 / Developer Experience
|
||||
- **构建脚本**: 优化Web应用构建脚本 ([baae7f74b](https://github.com/tencentmusic/supersonic/commit/baae7f74b) by zyclove)
|
||||
- **GitHub Actions**: 优化GitHub Actions镜像推送 ([6a4458a57](https://github.com/tencentmusic/supersonic/commit/6a4458a57) by lexluo09)
|
||||
- **基准测试**: 改进基准测试,增加解析结果分析 ([97710a90c](https://github.com/tencentmusic/supersonic/commit/97710a90c) by Antgeek)
|
||||
|
||||
### Bug修复 / Bug Fixes
|
||||
|
||||
#### 核心功能修复 / Core Function Fixes
|
||||
- **插件功能**: 修复插件功能无法调用/结果被NL2SQL覆盖问题 ([c75233e37](https://github.com/tencentmusic/supersonic/commit/c75233e37) by QJ_wonder)
|
||||
- **维度别名**: 修复映射阶段维度值别名不生效问题 ([785bda6cd](https://github.com/tencentmusic/supersonic/commit/785bda6cd) by feelshana)
|
||||
- **模型字段**: 修复模型字段更新问题 ([6bd897084](https://github.com/tencentmusic/supersonic/commit/6bd897084) by WDEP)
|
||||
- **多轮对话**: 修复headless中字段查询及多轮对话使用问题 ([be0447ae1](https://github.com/tencentmusic/supersonic/commit/be0447ae1) by QJ_wonder)
|
||||
|
||||
#### NPE异常修复 / NPE Exception Fixes
|
||||
- **聊天查询**: 修复EmbeddingMatchStrategy.detectByBatch() NPE异常 ([6d907b6ad](https://github.com/tencentmusic/supersonic/commit/6d907b6ad) by wangyong)
|
||||
- **文件处理**: 修复FileHandlerImpl.convert2Resp() 维度值数据行首字符为空格异常 ([da172a030](https://github.com/tencentmusic/supersonic/commit/da172a030) by wangyong)
|
||||
- **头部服务**: 修复多处headless NPE问题 ([79a44b27e](https://github.com/tencentmusic/supersonic/commit/79a44b27e) by jerryjzhang)
|
||||
- **解析信息**: 修复getParseInfo中的NPE ([dce9a8a58](https://github.com/tencentmusic/supersonic/commit/dce9a8a58) by supersonicbi)
|
||||
|
||||
#### SQL兼容性修复 / SQL Compatibility Fixes
|
||||
- **SQL处理**: 修复SQL前后换行符导致的语句结尾";"删除问题 ([55ac3d1aa](https://github.com/tencentmusic/supersonic/commit/55ac3d1aa) by wangyong)
|
||||
- **查询别名**: DictUtils.constructQuerySqlReq针对sql query增加别名 ([042791762](https://github.com/tencentmusic/supersonic/commit/042791762) by andybj0228)
|
||||
- **SQL变量**: 支持SQL脚本变量替换 ([0709575cd](https://github.com/tencentmusic/supersonic/commit/0709575cd) by wanglongqiang)
|
||||
|
||||
#### 前端Bug修复 / Frontend Bug Fixes
|
||||
- **UI样式**: 修复问答对话右侧历史对话模块样式异常 ([c33a85b58](https://github.com/tencentmusic/supersonic/commit/c33a85b58) by wangyong)
|
||||
- **推荐维度**: 修复页面不显示推荐下钻维度问题 ([62b9db679](https://github.com/tencentmusic/supersonic/commit/62b9db679) by WDEP)
|
||||
- **图表显示**: 修复饼图显示条件问题 ([1b8cd7f0d](https://github.com/tencentmusic/supersonic/commit/1b8cd7f0d) by WDEP)
|
||||
- **负数支持**: 支持负数显示 ([2552e2ae4](https://github.com/tencentmusic/supersonic/commit/2552e2ae4) by FredTsang)
|
||||
- **百分比显示**: 支持bar图needMultiply100显示正确百分比值 ([8abfc923a](https://github.com/tencentmusic/supersonic/commit/8abfc923a) by coosir)
|
||||
- **TypeScript错误**: 修复前端TypeScript错误 ([5585b9e22](https://github.com/tencentmusic/supersonic/commit/5585b9e22) by poncheen)
|
||||
|
||||
#### 系统兼容性修复 / System Compatibility Fixes
|
||||
- **Windows脚本**: 修复Windows daemon.bat路径配置问题 ([e5a41765b](https://github.com/tencentmusic/supersonic/commit/e5a41765b) by 柯慕灵)
|
||||
- **字符编码**: 将utf8编码修改为utf8mb4,解决字符问题 ([2e81b190a](https://github.com/tencentmusic/supersonic/commit/2e81b190a) by Kun Gu)
|
||||
- **记忆缓存**: 修复记忆管理中因缓存无法存储的问题 ([81cd60d2d](https://github.com/tencentmusic/supersonic/commit/81cd60d2d) by guilinlewis)
|
||||
- **Mac兼容**: 降级djl库以支持Mac Intel机器 ([bf3213e8f](https://github.com/tencentmusic/supersonic/commit/bf3213e8f) by jerryjzhang)
|
||||
|
||||
### 数据管理优化 / Data Management Improvements
|
||||
|
||||
#### 维度指标管理 / Dimension & Metric Management
|
||||
- **维度检索**: 修复维度和指标检索及百分比显示问题 ([d8fe2ed2b](https://github.com/tencentmusic/supersonic/commit/d8fe2ed2b) by 木鱼和尚)
|
||||
- **查询导出**: 基于queryColumns导出数据 ([11d1264d3](https://github.com/tencentmusic/supersonic/commit/11d1264d3) by FredTsang)
|
||||
- **表格排序**: 移除表格defaultSortOrder ([32675387d](https://github.com/tencentmusic/supersonic/commit/32675387d) by FredTsang)
|
||||
- **维度搜索**: 修复维度搜索带key查询范围超出问题 ([269f146c1](https://github.com/tencentmusic/supersonic/commit/269f146c1) by wangyong)
|
||||
|
||||
### 测试和质量保证 / Testing & Quality Assurance
|
||||
|
||||
#### 单元测试 / Unit Testing
|
||||
- **测试修复**: 修复单元测试用例 ([91e4b51ef](https://github.com/tencentmusic/supersonic/commit/91e4b51ef) by jerryjzhang)
|
||||
- **模型测试**: 修复ModelCreateForm.tsx错误 ([d2aa73b85](https://github.com/tencentmusic/supersonic/commit/d2aa73b85) by Antgeek)
|
||||
|
||||
### 重要变更说明 / Breaking Changes
|
||||
|
||||
#### 升级注意事项 / Upgrade Notes
|
||||
1. **SpringBoot 3升级**: 可能需要更新依赖配置和代码适配
|
||||
2. **许可证变更**: 从MIT变更为Apache 2.0,请注意法律合规
|
||||
3. **API接口调整**: 部分API接口为支持新功能进行了调整
|
||||
4. **数据库兼容**: 新增多种数据库支持,配置方式有所变化
|
||||
|
||||
### 完整提交统计 / Commit Statistics
|
||||
- **总提交数**: 419个提交
|
||||
- **主要贡献者**:
|
||||
- jerryjzhang: 158次提交
|
||||
- supersonicbi: 22次提交
|
||||
- zyclove: 20次提交
|
||||
- beat4ocean: 15次提交
|
||||
- guilinlewis: 11次提交
|
||||
- wangyong: 11次提交
|
||||
- 其他贡献者: 182次提交
|
||||
- **涉及模块**: headless, chat, auth, common, webapp, launcher, docker
|
||||
- **时间跨度**: 2024年11月1日 - 2025年8月5日
|
||||
|
||||
### 致谢 / Acknowledgments
|
||||
|
||||
感谢所有为SuperSonic 1.0.0版本贡献代码、文档、测试和建议的开发者们!🎉
|
||||
|
||||
#### 核心贡献者 / Core Contributors
|
||||
- **jerryjzhang** - 项目维护者,核心架构设计与实现
|
||||
- **supersonicbi** - 核心功能开发,多数据库支持
|
||||
- **beat4ocean** - 架构升级,依赖管理,安全优化
|
||||
- **zyclove** - 数据库适配,构建优化
|
||||
- **guilinlewis** - 鉴权系统,召回优化
|
||||
- **wangyong** - Bug修复,NPE异常处理
|
||||
|
||||
#### 活跃贡献者 / Active Contributors
|
||||
- **WDEP** - 前端优化,图表功能
|
||||
- **FredTsang** - Chat SDK优化,数据导出
|
||||
- **feelshana** - 记忆管理,向量召回
|
||||
- **QJ_wonder** - 插件功能,多轮对话
|
||||
- **Willy-J** - 模型关联,数据库兼容
|
||||
- **iridescentpeo** - 查询优化,模型管理
|
||||
- **tristanliu** - 前端路由,权限界面
|
||||
- **mislayming** - SpringBoot 3升级
|
||||
- **Antgeek** - 基准测试,模型修复
|
||||
- **柯慕灵** - LLM纠错器,Windows脚本
|
||||
- **superhero** - 项目管理,代码审查
|
||||
|
||||
#### 其他重要贡献者 / Other Important Contributors
|
||||
- **木鱼和尚** - 维度指标检索优化
|
||||
- **pisces** - 图表导出功能
|
||||
- **lexluo09** - 并行处理,GitHub Actions
|
||||
- **andybj0228** - SQL查询优化
|
||||
- **wanglongqiang** - SQL变量支持
|
||||
- **Hyman_bz** - StarRocks支持
|
||||
- **wwsheng009** - SAP HANA适配
|
||||
- **poncheen** - TypeScript错误修复
|
||||
- **kino** - Docker镜像优化
|
||||
- **coosir** - 前端百分比显示
|
||||
- **Kun Gu** - 字符编码优化
|
||||
- **chixiaopao** - NPE异常修复
|
||||
- **naimehao** - 核心功能修复
|
||||
- **yudong** - 记忆评估优化
|
||||
- **mroldx** - 数据库脚本更新
|
||||
- **ChPi** - 解析器性能优化
|
||||
- **Hwting** - Docker配置优化
|
||||
|
||||
#### 特别感谢 / Special Thanks
|
||||
感谢所有提交Issue、参与讨论、提供反馈的社区用户,你们的每一个建议都让SuperSonic变得更好!
|
||||
|
||||
#### 社区支持 / Community Support
|
||||
SuperSonic是一个开源项目,我们欢迎更多开发者加入:
|
||||
- 🔗 **GitHub**: https://github.com/tencentmusic/supersonic
|
||||
- 📖 **文档**: 详见项目README和Wiki
|
||||
- 🐛 **Issue报告**: 欢迎提交Bug和功能请求
|
||||
- 🚀 **贡献代码**: 欢迎提交Pull Request
|
||||
- 💬 **社区讨论**: 加入我们的技术交流群
|
||||
|
||||
#### 未来展望 / Future Vision
|
||||
SuperSonic 1.0.0是一个重要的里程碑,但这只是开始。我们将继续:
|
||||
- 🌟 **持续优化性能和稳定性**
|
||||
- 🔧 **扩展更多数据库和AI模型支持**
|
||||
- 🎨 **改善用户体验和界面设计**
|
||||
- 📚 **完善文档和最佳实践**
|
||||
- 🤝 **建设更活跃的开源社区**
|
||||
|
||||
**让我们一起把SuperSonic做得更好!** ✨
|
||||
|
||||
---
|
||||
|
||||
*如果您在使用过程中遇到问题或有改进建议,欢迎随时与我们交流。每一份贡献都让SuperSonic更加强大!*
|
||||
|
||||
|
||||
## SuperSonic [0.9.8] - 2024-11-01
|
||||
- Add LLM management module to reuse connection across agents.
|
||||
|
||||
113
CLAUDE.md
113
CLAUDE.md
@@ -1,113 +0,0 @@
|
||||
# CLAUDE.md
|
||||
|
||||
This file provides guidance to Claude Code (claude.ai/code) when working with code in this repository.
|
||||
|
||||
## Build Commands
|
||||
|
||||
### Backend (Java/Maven)
|
||||
|
||||
```bash
|
||||
# Clean build (skip tests)
|
||||
mvn clean package -DskipTests -Dspotless.skip=true
|
||||
|
||||
# Run all tests
|
||||
mvn test
|
||||
|
||||
# Run single test class
|
||||
mvn test -Dtest=ClassName
|
||||
|
||||
# Full CI build
|
||||
mvn -B package --file pom.xml
|
||||
```
|
||||
|
||||
**Requirements:** Java 21, Maven
|
||||
|
||||
### Frontend (pnpm/React)
|
||||
|
||||
```bash
|
||||
cd webapp
|
||||
|
||||
# Install dependencies
|
||||
pnpm install
|
||||
|
||||
# Start dev server (port 9000)
|
||||
pnpm dev
|
||||
|
||||
# Production build
|
||||
pnpm build
|
||||
|
||||
# Run tests
|
||||
pnpm test
|
||||
```
|
||||
|
||||
**Requirements:** Node.js >=16, pnpm 9.12.3+
|
||||
|
||||
### Quick Start
|
||||
|
||||
```bash
|
||||
# Build full release
|
||||
./assembly/bin/supersonic-build.sh standalone
|
||||
|
||||
# Start service
|
||||
./assembly/bin/supersonic-daemon.sh start
|
||||
|
||||
# Stop service
|
||||
./assembly/bin/supersonic-daemon.sh stop
|
||||
```
|
||||
|
||||
Visit http://localhost:9080 after startup.
|
||||
|
||||
## Architecture Overview
|
||||
|
||||
SuperSonic unifies **Chat BI** (LLM-powered) and **Headless BI** (semantic layer) paradigms.
|
||||
|
||||
### Core Modules
|
||||
|
||||
```
|
||||
supersonic/
|
||||
├── auth/ # Authentication & authorization (SPI-based)
|
||||
├── chat/ # Chat BI module - LLM-powered Q&A interface
|
||||
├── common/ # Shared utilities
|
||||
├── headless/ # Headless BI - semantic layer with open API
|
||||
├── launchers/ # Application entry points
|
||||
│ ├── standalone/ # Combined Chat + Headless (default)
|
||||
│ ├── chat/ # Chat-only service
|
||||
│ └── headless/ # Headless-only service
|
||||
└── webapp/ # Frontend React app (UmiJS 4 + Ant Design)
|
||||
```
|
||||
|
||||
### Data Flow
|
||||
|
||||
1. **Knowledge Base**: Extracts schema from semantic models, builds dictionary/index for schema mapping
|
||||
2. **Schema Mapper**: Identifies metrics/dimensions/entities/values in user queries
|
||||
3. **Semantic Parser**: Generates S2SQL (semantic SQL) using rule-based and LLM-based parsers
|
||||
4. **Semantic Corrector**: Validates and corrects semantic queries
|
||||
5. **Semantic Translator**: Converts S2SQL to executable SQL
|
||||
|
||||
### Key Entry Points
|
||||
|
||||
- `StandaloneLauncher.java` - Combined service with `scanBasePackages: ["com.tencent.supersonic", "dev.langchain4j"]`
|
||||
- `ChatLauncher.java` - Chat BI only
|
||||
- `HeadlessLauncher.java` - Headless BI only
|
||||
|
||||
## Key Technologies
|
||||
|
||||
**Backend:** Spring Boot 3.3.9, MyBatis-Plus 3.5.10.1, LangChain4j 0.36.2, JSqlParser 4.9, Calcite 1.38.0
|
||||
|
||||
**Frontend:** React 18, UmiJS 4, Ant Design 5.17.4, ECharts 5.0.2, AntV G6/X6
|
||||
|
||||
**Databases:** MySQL, PostgreSQL (with pgvector), H2, ClickHouse, StarRocks, Presto, Trino, DuckDB
|
||||
|
||||
## Testing
|
||||
|
||||
**Java tests:** JUnit 5, Mockito. Located in `src/test/java/` of each module.
|
||||
|
||||
**Frontend tests:** Jest with Puppeteer environment in `webapp/packages/supersonic-fe/`
|
||||
|
||||
**Evaluation scripts:** Python scripts in `evaluation/` directory for Text2SQL accuracy testing.
|
||||
|
||||
## Related Documentation
|
||||
|
||||
- [README.md](README.md) - English documentation
|
||||
- [README_CN.md](README_CN.md) - Chinese documentation
|
||||
- [Evaluation Guide](evaluation/README.md) - Text2SQL evaluation process
|
||||
@@ -43,26 +43,10 @@ if "%service%"=="webapp" (
|
||||
call mvn -f %projectDir% clean package -DskipTests -Dspotless.skip=true
|
||||
IF ERRORLEVEL 1 (
|
||||
ECHO Failed to build backend Java modules.
|
||||
ECHO Please check Maven and Java versions are compatible.
|
||||
ECHO Current Java: %JAVA_HOME%
|
||||
ECHO Current Maven: %MAVEN_HOME%
|
||||
EXIT /B 1
|
||||
)
|
||||
|
||||
REM extract and copy files to deployment directory
|
||||
cd %projectDir%\launchers\%model_name%\target
|
||||
if exist "launchers-%model_name%-%MVN_VERSION%-bin.tar.gz" (
|
||||
echo "Extracting launchers-%model_name%-%MVN_VERSION%-bin.tar.gz..."
|
||||
tar -xf "launchers-%model_name%-%MVN_VERSION%-bin.tar.gz"
|
||||
if exist "launchers-%model_name%-%MVN_VERSION%" (
|
||||
echo "Copying files to deployment directory..."
|
||||
xcopy /E /Y "launchers-%model_name%-%MVN_VERSION%\*" "%buildDir%\supersonic-%model_name%-%MVN_VERSION%\"
|
||||
)
|
||||
)
|
||||
|
||||
copy /y %projectDir%\launchers\%model_name%\target\*.tar.gz %buildDir%\
|
||||
echo "finished building supersonic-%model_name% service"
|
||||
cd %baseDir%
|
||||
goto :EOF
|
||||
|
||||
|
||||
@@ -88,55 +72,22 @@ if "%service%"=="webapp" (
|
||||
cd %buildDir%
|
||||
if exist %release_dir% rmdir /s /q %release_dir%
|
||||
if exist %release_dir%.zip del %release_dir%.zip
|
||||
|
||||
rem check if release directory already exists from buildJavaService
|
||||
if exist %release_dir% (
|
||||
echo "Release directory already prepared by buildJavaService"
|
||||
) else (
|
||||
mkdir %release_dir%
|
||||
|
||||
rem package java service
|
||||
tar xvf %service_name%-bin.tar.gz 2>nul
|
||||
if errorlevel 1 (
|
||||
echo "Warning: tar command failed, trying PowerShell extraction..."
|
||||
powershell -Command "Expand-Archive -Path '%service_name%-bin.tar.gz' -DestinationPath '.' -Force"
|
||||
)
|
||||
for /d %%D in ("%service_name%\*") do (
|
||||
move "%%D" "%release_dir%"
|
||||
)
|
||||
rmdir /s /q %service_name% 2>nul
|
||||
)
|
||||
|
||||
mkdir %release_dir%
|
||||
rem package webapp
|
||||
if exist supersonic-webapp.tar.gz (
|
||||
tar xvf supersonic-webapp.tar.gz 2>nul
|
||||
if errorlevel 1 (
|
||||
echo "Warning: tar command failed, trying PowerShell extraction..."
|
||||
powershell -Command "Expand-Archive -Path 'supersonic-webapp.tar.gz' -DestinationPath '.' -Force"
|
||||
)
|
||||
move /y supersonic-webapp webapp
|
||||
echo {"env": ""} > webapp\supersonic.config.json
|
||||
move /y webapp %release_dir%
|
||||
del supersonic-webapp.tar.gz 2>nul
|
||||
tar xvf supersonic-webapp.tar.gz
|
||||
move /y supersonic-webapp webapp
|
||||
echo {"env": ""} > webapp\supersonic.config.json
|
||||
move /y webapp %release_dir%
|
||||
rem package java service
|
||||
tar xvf %service_name%-bin.tar.gz
|
||||
for /d %%D in ("%service_name%\*") do (
|
||||
move "%%D" "%release_dir%"
|
||||
)
|
||||
|
||||
rem verify deployment structure
|
||||
if exist "%release_dir%\lib\launchers-%model_name%-%MVN_VERSION%.jar" (
|
||||
echo "Deployment structure verified successfully"
|
||||
) else (
|
||||
echo "Warning: Main jar file not found in deployment structure"
|
||||
echo "Expected: %release_dir%\lib\launchers-%model_name%-%MVN_VERSION%.jar"
|
||||
)
|
||||
|
||||
rem generate zip file
|
||||
powershell -Command "Compress-Archive -Path '%release_dir%' -DestinationPath '%release_dir%.zip' -Force"
|
||||
if errorlevel 1 (
|
||||
echo "Warning: PowerShell compression failed, release directory still available: %release_dir%"
|
||||
) else (
|
||||
echo "Successfully created release package: %release_dir%.zip"
|
||||
)
|
||||
|
||||
del %service_name%-bin.tar.gz 2>nul
|
||||
powershell Compress-Archive -Path %release_dir% -DestinationPath %release_dir%.zip
|
||||
del %service_name%-bin.tar.gz
|
||||
del supersonic-webapp.tar.gz
|
||||
rmdir /s /q %service_name%
|
||||
echo "finished packaging supersonic release"
|
||||
goto :EOF
|
||||
|
||||
|
||||
@@ -20,9 +20,7 @@ if "%profile%"=="" (
|
||||
|
||||
set "model_name=%service%"
|
||||
|
||||
REM fix path configuration - point to the correct release package directory
|
||||
set "releaseDir=%buildDir%\supersonic-%service%-1.0.0-SNAPSHOT"
|
||||
cd %releaseDir%
|
||||
cd %baseDir%
|
||||
|
||||
if "%command%"=="restart" (
|
||||
call :stop
|
||||
@@ -52,58 +50,20 @@ if "%command%"=="restart" (
|
||||
|
||||
:runJavaService
|
||||
echo 'java service starting, see logs in logs/'
|
||||
echo 'Using release directory: %releaseDir%'
|
||||
|
||||
REM use release package directory as base path
|
||||
set "libDir=%releaseDir%\lib"
|
||||
set "confDir=%releaseDir%\conf"
|
||||
set "webDir=%releaseDir%\webapp"
|
||||
set "logDir=%releaseDir%\logs"
|
||||
|
||||
REM fix variable name matching problem
|
||||
set "CLASSPATH=%releaseDir%;%webDir%;%libDir%\*;%confDir%"
|
||||
set "MAIN_CLASS=%main_class%"
|
||||
|
||||
REM add port configuration
|
||||
set "property=-Dfile.encoding=UTF-8 -Duser.language=Zh -Duser.region=CN -Duser.timezone=GMT+08 -Dspring.profiles.active=%profile% -Dserver.port=9080"
|
||||
set "java_command=%property% -Xms1024m -Xmx2048m -cp "%CLASSPATH%" %MAIN_CLASS%"
|
||||
|
||||
set "libDir=%baseDir%\lib"
|
||||
set "confDir=%baseDir%\conf"
|
||||
set "webDir=%baseDir%\webapp"
|
||||
set "logDir=%baseDir%\logs"
|
||||
set "classpath=%baseDir%;%webDir%;%libDir%\*;%confDir%"
|
||||
set "property=-Dfile.encoding=UTF-8 -Duser.language=Zh -Duser.region=CN -Duser.timezone=GMT+08 -Dspring.profiles.active=%profile%"
|
||||
set "java-command=%property% -Xms1024m -Xmx1024m -cp %CLASSPATH% %MAIN_CLASS%"
|
||||
if not exist %logDir% mkdir %logDir%
|
||||
|
||||
REM check if the main jar file exists
|
||||
if not exist "%libDir%\launchers-standalone-1.0.0-SNAPSHOT.jar" (
|
||||
echo "Error: Main jar file not found in %libDir%"
|
||||
echo "Please make sure the application has been built and packaged correctly."
|
||||
goto :EOF
|
||||
)
|
||||
|
||||
echo 'Main Class: %MAIN_CLASS%'
|
||||
echo 'Profile: %profile%'
|
||||
echo 'Starting Java service...'
|
||||
|
||||
REM start service and save logs
|
||||
start /B java %java_command% > "%logDir%\supersonic.log" 2>&1
|
||||
timeout /t 15 >nul
|
||||
|
||||
REM check service status
|
||||
netstat -an | findstr ":9080" >nul
|
||||
if errorlevel 1 (
|
||||
echo "Warning: Port 9080 is not listening"
|
||||
echo "Please check the log file: %logDir%\supersonic.log"
|
||||
if exist "%logDir%\supersonic.log" (
|
||||
echo "Recent log entries:"
|
||||
powershell -Command "Get-Content '%logDir%\supersonic.log' | Select-Object -Last 10"
|
||||
)
|
||||
) else (
|
||||
echo "Service started successfully on port 9080"
|
||||
echo "You can access the application at: http://localhost:9080"
|
||||
)
|
||||
|
||||
start /B java %java-command% >nul 2>&1
|
||||
timeout /t 10 >nul
|
||||
echo 'java service started'
|
||||
goto :EOF
|
||||
|
||||
:stopJavaService
|
||||
echo 'Stopping Java service...'
|
||||
for /f "tokens=2" %%i in ('tasklist ^| findstr /i "java"') do (
|
||||
taskkill /PID %%i /F
|
||||
echo "java service (PID = %%i) is killed."
|
||||
|
||||
@@ -60,8 +60,7 @@ function runJavaService {
|
||||
JAVA_HOME=$(ls /usr/jdk64/jdk* -d 2>/dev/null | xargs | awk '{print "'$local_app_name'"}')
|
||||
fi
|
||||
export PATH=$JAVA_HOME/bin:$PATH
|
||||
command="-Dfile.encoding=UTF-8 -Duser.language=Zh -Duser.region=CN -Duser.timezone=GMT+08
|
||||
-Dapp_name=${local_app_name} -Xms1024m -Xmx2048m -XX:+UseZGC -XX:+ZGenerational $main_class"
|
||||
command="-Dfile.encoding=UTF-8 -Duser.language=Zh -Duser.region=CN -Duser.timezone=GMT+08 -Dapp_name=${local_app_name} -Xms1024m -Xmx1024m $main_class"
|
||||
|
||||
mkdir -p $javaRunDir/logs
|
||||
java -Dspring.profiles.active="$profile" $command >/dev/null 2>$javaRunDir/logs/error.log &
|
||||
|
||||
@@ -34,8 +34,8 @@
|
||||
</dependencies>
|
||||
|
||||
<properties>
|
||||
<maven.compiler.source>21</maven.compiler.source>
|
||||
<maven.compiler.target>21</maven.compiler.target>
|
||||
<maven.compiler.source>8</maven.compiler.source>
|
||||
<maven.compiler.target>8</maven.compiler.target>
|
||||
</properties>
|
||||
|
||||
</project>
|
||||
@@ -21,8 +21,6 @@ public interface UserAdaptor {
|
||||
|
||||
void register(UserReq userReq);
|
||||
|
||||
void deleteUser(long userId);
|
||||
|
||||
String login(UserReq userReq, HttpServletRequest request);
|
||||
|
||||
String login(UserReq userReq, String appKey);
|
||||
|
||||
@@ -24,7 +24,7 @@ public class UserWithPassword extends User {
|
||||
|
||||
public UserWithPassword(Long id, String name, String displayName, String email, String password,
|
||||
Integer isAdmin) {
|
||||
super(id, name, displayName, email, isAdmin, null);
|
||||
super(id, name, displayName, email, isAdmin);
|
||||
this.password = password;
|
||||
}
|
||||
|
||||
|
||||
@@ -15,6 +15,4 @@ public class UserReq {
|
||||
|
||||
@NotBlank(message = "password can not be null")
|
||||
private String newPassword;
|
||||
|
||||
private String role;
|
||||
}
|
||||
|
||||
@@ -23,8 +23,6 @@ public interface UserService {
|
||||
|
||||
void register(UserReq userCmd);
|
||||
|
||||
void deleteUser(long userId);
|
||||
|
||||
String login(UserReq userCmd, HttpServletRequest request);
|
||||
|
||||
String login(UserReq userCmd, String appKey);
|
||||
|
||||
@@ -18,8 +18,6 @@ import jakarta.servlet.http.HttpServletRequest;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.springframework.beans.BeanUtils;
|
||||
|
||||
import java.sql.Timestamp;
|
||||
import java.util.Date;
|
||||
import java.util.List;
|
||||
import java.util.Optional;
|
||||
import java.util.Set;
|
||||
@@ -92,12 +90,6 @@ public class DefaultUserAdaptor implements UserAdaptor {
|
||||
userRepository.addUser(userDO);
|
||||
}
|
||||
|
||||
@Override
|
||||
public void deleteUser(long userId) {
|
||||
UserRepository userRepository = ContextUtils.getBean(UserRepository.class);
|
||||
userRepository.deleteUser(userId);
|
||||
}
|
||||
|
||||
@Override
|
||||
public String login(UserReq userReq, HttpServletRequest request) {
|
||||
TokenService tokenService = ContextUtils.getBean(TokenService.class);
|
||||
@@ -110,9 +102,7 @@ public class DefaultUserAdaptor implements UserAdaptor {
|
||||
TokenService tokenService = ContextUtils.getBean(TokenService.class);
|
||||
try {
|
||||
UserWithPassword user = getUserWithPassword(userReq);
|
||||
String token = tokenService.generateToken(UserWithPassword.convert(user), appKey);
|
||||
updateLastLogin(userReq.getName());
|
||||
return token;
|
||||
return tokenService.generateToken(UserWithPassword.convert(user), appKey);
|
||||
} catch (Exception e) {
|
||||
log.error("", e);
|
||||
throw new RuntimeException("password encrypt error, please try again");
|
||||
@@ -223,9 +213,8 @@ public class DefaultUserAdaptor implements UserAdaptor {
|
||||
new UserWithPassword(userDO.getId(), userDO.getName(), userDO.getDisplayName(),
|
||||
userDO.getEmail(), userDO.getPassword(), userDO.getIsAdmin());
|
||||
|
||||
// 使用令牌名称作为生成key ,这样可以区分正常请求和api 请求,api 的令牌失效时间很长,需考虑令牌泄露的情况
|
||||
String token = tokenService.generateToken(UserWithPassword.convert(userWithPassword),
|
||||
"SysDbToken:" + name, (new Date().getTime() + expireTime));
|
||||
String token =
|
||||
tokenService.generateToken(UserWithPassword.convert(userWithPassword), expireTime);
|
||||
UserTokenDO userTokenDO = saveUserToken(name, userName, token, expireTime);
|
||||
return convertUserToken(userTokenDO);
|
||||
}
|
||||
@@ -278,11 +267,4 @@ public class DefaultUserAdaptor implements UserAdaptor {
|
||||
userToken.setExpireDate(userTokenDO.getExpireDateTime());
|
||||
return userToken;
|
||||
}
|
||||
|
||||
private void updateLastLogin(String userName) {
|
||||
UserRepository userRepository = ContextUtils.getBean(UserRepository.class);
|
||||
UserDO userDO = userRepository.getUser(userName);
|
||||
userDO.setLastLogin(new Timestamp(System.currentTimeMillis()));
|
||||
userRepository.updateUser(userDO);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -3,11 +3,7 @@ package com.tencent.supersonic.auth.authentication.persistence.dataobject;
|
||||
import com.baomidou.mybatisplus.annotation.IdType;
|
||||
import com.baomidou.mybatisplus.annotation.TableId;
|
||||
import com.baomidou.mybatisplus.annotation.TableName;
|
||||
import lombok.Data;
|
||||
|
||||
import java.sql.Timestamp;
|
||||
|
||||
@Data
|
||||
@TableName("s2_user")
|
||||
public class UserDO {
|
||||
|
||||
@@ -31,25 +27,71 @@ public class UserDO {
|
||||
/** */
|
||||
private Integer isAdmin;
|
||||
|
||||
private Timestamp lastLogin;
|
||||
/** @return id */
|
||||
public Long getId() {
|
||||
return id;
|
||||
}
|
||||
|
||||
/** @param id */
|
||||
public void setId(Long id) {
|
||||
this.id = id;
|
||||
}
|
||||
|
||||
/** @return name */
|
||||
public String getName() {
|
||||
return name;
|
||||
}
|
||||
|
||||
/** @param name */
|
||||
public void setName(String name) {
|
||||
this.name = name == null ? null : name.trim();
|
||||
}
|
||||
|
||||
/** @return password */
|
||||
public String getPassword() {
|
||||
return password;
|
||||
}
|
||||
|
||||
/** @param password */
|
||||
public void setPassword(String password) {
|
||||
this.password = password == null ? null : password.trim();
|
||||
}
|
||||
|
||||
public String getSalt() {
|
||||
return salt;
|
||||
}
|
||||
|
||||
public void setSalt(String salt) {
|
||||
this.salt = salt == null ? null : salt.trim();
|
||||
}
|
||||
|
||||
/** @return display_name */
|
||||
public String getDisplayName() {
|
||||
return displayName;
|
||||
}
|
||||
|
||||
/** @param displayName */
|
||||
public void setDisplayName(String displayName) {
|
||||
this.displayName = displayName == null ? null : displayName.trim();
|
||||
}
|
||||
|
||||
/** @return email */
|
||||
public String getEmail() {
|
||||
return email;
|
||||
}
|
||||
|
||||
/** @param email */
|
||||
public void setEmail(String email) {
|
||||
this.email = email == null ? null : email.trim();
|
||||
}
|
||||
|
||||
/** @return is_admin */
|
||||
public Integer getIsAdmin() {
|
||||
return isAdmin;
|
||||
}
|
||||
|
||||
/** @param isAdmin */
|
||||
public void setIsAdmin(Integer isAdmin) {
|
||||
this.isAdmin = isAdmin;
|
||||
}
|
||||
}
|
||||
|
||||
@@ -21,11 +21,7 @@ public interface UserRepository {
|
||||
|
||||
UserTokenDO getUserToken(Long tokenId);
|
||||
|
||||
UserTokenDO getUserTokenByName(String tokenName);
|
||||
|
||||
void deleteUserTokenByName(String userName);
|
||||
|
||||
void deleteUserToken(Long tokenId);
|
||||
|
||||
void deleteUser(long userId);
|
||||
}
|
||||
|
||||
@@ -65,13 +65,6 @@ public class UserRepositoryImpl implements UserRepository {
|
||||
return userTokenDOMapper.selectById(tokenId);
|
||||
}
|
||||
|
||||
@Override
|
||||
public UserTokenDO getUserTokenByName(String tokenName) {
|
||||
QueryWrapper<UserTokenDO> queryWrapper = new QueryWrapper<>();
|
||||
queryWrapper.lambda().eq(UserTokenDO::getName, tokenName);
|
||||
return userTokenDOMapper.selectOne(queryWrapper);
|
||||
}
|
||||
|
||||
@Override
|
||||
public void deleteUserTokenByName(String userName) {
|
||||
QueryWrapper<UserTokenDO> queryWrapper = new QueryWrapper<>();
|
||||
@@ -83,9 +76,4 @@ public class UserRepositoryImpl implements UserRepository {
|
||||
public void deleteUserToken(Long tokenId) {
|
||||
userTokenDOMapper.deleteById(tokenId);
|
||||
}
|
||||
|
||||
@Override
|
||||
public void deleteUser(long userId) {
|
||||
userDOMapper.deleteById(userId);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -9,7 +9,13 @@ import com.tencent.supersonic.common.pojo.User;
|
||||
import jakarta.servlet.http.HttpServletRequest;
|
||||
import jakarta.servlet.http.HttpServletResponse;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.springframework.web.bind.annotation.*;
|
||||
import org.springframework.web.bind.annotation.GetMapping;
|
||||
import org.springframework.web.bind.annotation.PathVariable;
|
||||
import org.springframework.web.bind.annotation.PostMapping;
|
||||
import org.springframework.web.bind.annotation.RequestBody;
|
||||
import org.springframework.web.bind.annotation.RequestMapping;
|
||||
import org.springframework.web.bind.annotation.RequestParam;
|
||||
import org.springframework.web.bind.annotation.RestController;
|
||||
|
||||
import java.util.List;
|
||||
import java.util.Set;
|
||||
@@ -61,16 +67,6 @@ public class UserController {
|
||||
userService.register(userCmd);
|
||||
}
|
||||
|
||||
@DeleteMapping("/delete/{userId}")
|
||||
public void delete(@PathVariable("userId") long userId, HttpServletRequest httpServletRequest,
|
||||
HttpServletResponse httpServletResponse) throws IllegalAccessException {
|
||||
User user = userService.getCurrentUser(httpServletRequest, httpServletResponse);
|
||||
if (user.getIsAdmin() != 1) {
|
||||
throw new IllegalAccessException("only admin can delete user");
|
||||
}
|
||||
userService.deleteUser(userId);
|
||||
}
|
||||
|
||||
@PostMapping("/login")
|
||||
public String login(@RequestBody UserReq userCmd, HttpServletRequest request) {
|
||||
return userService.login(userCmd, request);
|
||||
|
||||
@@ -70,11 +70,6 @@ public class UserServiceImpl implements UserService {
|
||||
ComponentFactory.getUserAdaptor().register(userReq);
|
||||
}
|
||||
|
||||
@Override
|
||||
public void deleteUser(long userId) {
|
||||
ComponentFactory.getUserAdaptor().deleteUser(userId);
|
||||
}
|
||||
|
||||
@Override
|
||||
public String login(UserReq userReq, HttpServletRequest request) {
|
||||
return ComponentFactory.getUserAdaptor().login(userReq, request);
|
||||
|
||||
@@ -6,10 +6,7 @@ import javax.crypto.spec.SecretKeySpec;
|
||||
|
||||
import com.tencent.supersonic.auth.api.authentication.config.AuthenticationConfig;
|
||||
import com.tencent.supersonic.auth.api.authentication.pojo.UserWithPassword;
|
||||
import com.tencent.supersonic.auth.authentication.persistence.dataobject.UserTokenDO;
|
||||
import com.tencent.supersonic.auth.authentication.persistence.repository.UserRepository;
|
||||
import com.tencent.supersonic.common.pojo.exception.AccessException;
|
||||
import com.tencent.supersonic.common.util.ContextUtils;
|
||||
import io.jsonwebtoken.Claims;
|
||||
import io.jsonwebtoken.Jwts;
|
||||
import io.jsonwebtoken.SignatureAlgorithm;
|
||||
@@ -74,7 +71,6 @@ public class TokenService {
|
||||
return generateToken(UserWithPassword.convert(appUser), request);
|
||||
}
|
||||
|
||||
|
||||
public Optional<Claims> getClaims(HttpServletRequest request) {
|
||||
String token = request.getHeader(authenticationConfig.getTokenHttpHeaderKey());
|
||||
String appKey = getAppKey(request);
|
||||
@@ -94,14 +90,6 @@ public class TokenService {
|
||||
|
||||
public Optional<Claims> getClaims(String token, String appKey) {
|
||||
try {
|
||||
if (StringUtils.isNotBlank(appKey) && appKey.startsWith("SysDbToken:")) {// 如果是配置的长期令牌,需校验数据库是否存在该配置
|
||||
UserRepository userRepository = ContextUtils.getBean(UserRepository.class);
|
||||
UserTokenDO dbToken =
|
||||
userRepository.getUserTokenByName(appKey.substring("SysDbToken:".length()));
|
||||
if (dbToken == null || !dbToken.getToken().equals(token.replace("Bearer ", ""))) {
|
||||
throw new AccessException("Token does not exist :" + appKey);
|
||||
}
|
||||
}
|
||||
String tokenSecret = getTokenSecret(appKey);
|
||||
Claims claims =
|
||||
Jwts.parser().setSigningKey(tokenSecret.getBytes(StandardCharsets.UTF_8))
|
||||
@@ -134,16 +122,6 @@ public class TokenService {
|
||||
Map<String, String> appKeyToSecretMap = authenticationConfig.getAppKeyToSecretMap();
|
||||
String secret = appKeyToSecretMap.get(appKey);
|
||||
if (StringUtils.isBlank(secret)) {
|
||||
if (StringUtils.isNotBlank(appKey) && appKey.startsWith("SysDbToken:")) { // 是配置的长期令牌
|
||||
String realAppKey = appKey.substring("SysDbToken:".length());
|
||||
String tmp =
|
||||
"WIaO9YRRVt+7QtpPvyWsARFngnEcbaKBk783uGFwMrbJBaochsqCH62L4Kijcb0sZCYoSsiKGV/zPml5MnZ3uQ==";
|
||||
if (tmp.length() <= realAppKey.length()) {
|
||||
return realAppKey;
|
||||
} else {
|
||||
return realAppKey + tmp.substring(realAppKey.length());
|
||||
}
|
||||
}
|
||||
throw new AccessException("get secret from appKey failed :" + appKey);
|
||||
}
|
||||
return secret;
|
||||
|
||||
@@ -9,7 +9,6 @@
|
||||
<result column="display_name" jdbcType="VARCHAR" property="displayName" />
|
||||
<result column="email" jdbcType="VARCHAR" property="email" />
|
||||
<result column="is_admin" jdbcType="INTEGER" property="isAdmin" />
|
||||
<result column="last_login" jdbcType="TIMESTAMP" property="lastLogin" />
|
||||
</resultMap>
|
||||
<sql id="Example_Where_Clause">
|
||||
<where>
|
||||
@@ -41,7 +40,7 @@
|
||||
</where>
|
||||
</sql>
|
||||
<sql id="Base_Column_List">
|
||||
id, name, password, salt, display_name, email, is_admin, last_login
|
||||
id, name, password, salt, display_name, email, is_admin
|
||||
</sql>
|
||||
<select id="selectByExample" parameterType="com.tencent.supersonic.auth.authentication.persistence.dataobject.UserDOExample" resultMap="BaseResultMap">
|
||||
select
|
||||
@@ -137,9 +136,6 @@
|
||||
<if test="isAdmin != null">
|
||||
is_admin = #{isAdmin,jdbcType=INTEGER},
|
||||
</if>
|
||||
<if test="lastLogin != null">
|
||||
last_login = #{lastLogin,jdbcType=TIMESTAMP},
|
||||
</if>
|
||||
</set>
|
||||
where id = #{id,jdbcType=BIGINT}
|
||||
</update>
|
||||
|
||||
@@ -15,68 +15,6 @@ import requests
|
||||
import time
|
||||
import jwt
|
||||
import traceback
|
||||
import os
|
||||
from datetime import datetime
|
||||
|
||||
|
||||
class DataFrameAppender:
|
||||
def __init__(self,file_name = "output"):
|
||||
# 定义表头
|
||||
columns = ['问题', '解析状态', '解析耗时', '执行状态', '执行耗时', '总耗时']
|
||||
# 创建只有表头的 DataFrame
|
||||
self.df = pd.DataFrame(columns=columns)
|
||||
self.file_name = file_name
|
||||
|
||||
def append_data(self, new_data):
|
||||
# 假设 new_data 是一维数组,将其转换为字典
|
||||
columns = ['问题', '解析状态', '解析耗时', '执行状态', '执行耗时', '总耗时']
|
||||
new_dict = dict(zip(columns, new_data))
|
||||
# 使用 loc 方法追加数据
|
||||
self.df.loc[len(self.df)] = new_dict
|
||||
def print_analysis_result(self):
|
||||
# 测试样例总数
|
||||
total_samples = len(self.df)
|
||||
|
||||
# 解析成功数量
|
||||
parse_success_count = (self.df['解析状态'] == '解析成功').sum()
|
||||
|
||||
# 执行成功数量
|
||||
execute_success_count = (self.df['执行状态'] == '执行成功').sum()
|
||||
|
||||
# 解析平均耗时,保留两位小数
|
||||
avg_parse_time = round(self.df['解析耗时'].mean(), 2)
|
||||
|
||||
# 执行平均耗时,保留两位小数
|
||||
avg_execute_time = round(self.df['执行耗时'].mean(), 2)
|
||||
|
||||
# 总平均耗时,保留两位小数
|
||||
avg_total_time = round(self.df['总耗时'].mean(), 2)
|
||||
|
||||
# 最长耗时,保留两位小数
|
||||
max_time = round(self.df['总耗时'].max(), 2)
|
||||
|
||||
# 最短耗时,保留两位小数
|
||||
min_time = round(self.df['总耗时'].min(), 2)
|
||||
|
||||
print(f"测试样例总数 : {total_samples}")
|
||||
print(f"解析成功数量 : {parse_success_count}")
|
||||
print(f"执行成功数量 : {execute_success_count}")
|
||||
print(f"解析平均耗时 : {avg_parse_time} 秒")
|
||||
print(f"执行平均耗时 : {avg_execute_time} 秒")
|
||||
print(f"总平均耗时 : {avg_total_time} 秒")
|
||||
print(f"最长耗时 : {max_time} 秒")
|
||||
print(f"最短耗时 : {min_time} 秒")
|
||||
|
||||
def write_to_csv(self):
|
||||
# 检查 data 文件夹是否存在,如果不存在则创建
|
||||
if not os.path.exists('res'):
|
||||
os.makedirs('res')
|
||||
# 获取当前时间戳
|
||||
timestamp = datetime.now().strftime("%Y%m%d%H%M%S")
|
||||
# 生成带时间戳的文件名
|
||||
file_path = os.path.join('res', f'{self.file_name}_{timestamp}.csv')
|
||||
self.df.to_csv(file_path, index=False)
|
||||
print(f"测试结果已保存到 {file_path}")
|
||||
|
||||
class BatchTest:
|
||||
def __init__(self, url, agentId, chatId, userName):
|
||||
@@ -132,35 +70,18 @@ class BatchTest:
|
||||
def benchmark(url:str, agentId:str, chatId:str, filePath:str, userName:str):
|
||||
batch_test = BatchTest(url, agentId, chatId, userName)
|
||||
df = batch_test.read_question_from_csv(filePath)
|
||||
appender = DataFrameAppender(os.path.basename(filePath))
|
||||
for index, row in df.iterrows():
|
||||
question = row['question']
|
||||
print('start to ask question:', question)
|
||||
# 捕获异常,防止程序中断
|
||||
try:
|
||||
parse_resp = batch_test.parse(question)
|
||||
parse_status = '解析失败'
|
||||
if parse_resp.get('data').get('errorMsg') is None:
|
||||
parse_status = '解析成功'
|
||||
parse_cost = parse_resp.get('data').get('parseTimeCost').get('parseTime')
|
||||
execute_resp = batch_test.execute(agentId, question, parse_resp['data']['queryId'])
|
||||
execute_status = '执行失败'
|
||||
execute_cost = 0
|
||||
if parse_status == '解析成功' and execute_resp.get('data').get('errorMsg') is None:
|
||||
execute_status = '执行成功'
|
||||
execute_cost = execute_resp.get('data').get('queryTimeCost')
|
||||
res = [question.replace(',', '#'),parse_status,parse_cost/1000,execute_status,execute_cost/1000,(parse_cost+execute_cost)/1000]
|
||||
appender.append_data(res)
|
||||
|
||||
batch_test.execute(agentId, question, parse_resp['data']['queryId'])
|
||||
except Exception as e:
|
||||
print('error:', e)
|
||||
traceback.print_exc()
|
||||
continue
|
||||
time.sleep(1)
|
||||
# 打印分析结果
|
||||
appender.print_analysis_result()
|
||||
# 分析明细输出
|
||||
appender.write_to_csv()
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
||||
|
||||
@@ -18,5 +18,4 @@ public class ChatExecuteReq {
|
||||
private int parseId;
|
||||
private String queryText;
|
||||
private boolean saveAnswer;
|
||||
private boolean streamingResult;
|
||||
}
|
||||
|
||||
@@ -1,19 +0,0 @@
|
||||
package com.tencent.supersonic.chat.api.pojo.request;
|
||||
|
||||
import lombok.AllArgsConstructor;
|
||||
import lombok.Builder;
|
||||
import lombok.Data;
|
||||
import lombok.NoArgsConstructor;
|
||||
|
||||
import java.util.List;
|
||||
|
||||
@Data
|
||||
@Builder
|
||||
@NoArgsConstructor
|
||||
@AllArgsConstructor
|
||||
public class ChatMemoryDeleteReq {
|
||||
|
||||
private List<Long> ids;
|
||||
|
||||
private Integer agentId;
|
||||
}
|
||||
@@ -13,7 +13,6 @@ public class QueryResp {
|
||||
private Long questionId;
|
||||
private Date createTime;
|
||||
private Long chatId;
|
||||
private Integer agentId;
|
||||
private Integer score;
|
||||
private String feedback;
|
||||
private String queryText;
|
||||
|
||||
@@ -75,12 +75,8 @@ public class SqlExecutor implements ChatQueryExecutor {
|
||||
return null;
|
||||
}
|
||||
|
||||
// 使用querySQL,它已经包含了所有修正(包括物理SQL修正)
|
||||
String finalSql = StringUtils.isNotBlank(parseInfo.getSqlInfo().getQuerySQL())
|
||||
? parseInfo.getSqlInfo().getQuerySQL()
|
||||
: parseInfo.getSqlInfo().getCorrectedS2SQL();
|
||||
|
||||
QuerySqlReq sqlReq = QuerySqlReq.builder().sql(finalSql).build();
|
||||
QuerySqlReq sqlReq =
|
||||
QuerySqlReq.builder().sql(parseInfo.getSqlInfo().getCorrectedS2SQL()).build();
|
||||
sqlReq.setSqlInfo(parseInfo.getSqlInfo());
|
||||
sqlReq.setDataSetId(parseInfo.getDataSetId());
|
||||
|
||||
@@ -94,7 +90,7 @@ public class SqlExecutor implements ChatQueryExecutor {
|
||||
queryResult.setQueryTimeCost(System.currentTimeMillis() - startTime);
|
||||
if (queryResp != null) {
|
||||
queryResult.setQueryAuthorization(queryResp.getQueryAuthorization());
|
||||
queryResult.setQuerySql(finalSql);
|
||||
queryResult.setQuerySql(queryResp.getSql());
|
||||
queryResult.setQueryResults(queryResp.getResultList());
|
||||
queryResult.setQueryColumns(queryResp.getColumns());
|
||||
queryResult.setQueryState(QueryState.SUCCESS);
|
||||
|
||||
@@ -32,7 +32,6 @@ import dev.langchain4j.model.input.PromptTemplate;
|
||||
import dev.langchain4j.model.output.Response;
|
||||
import dev.langchain4j.provider.ModelProvider;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.apache.commons.lang3.StringUtils;
|
||||
import org.slf4j.Logger;
|
||||
import org.slf4j.LoggerFactory;
|
||||
|
||||
@@ -172,6 +171,10 @@ public class NL2SQLParser implements ChatQueryParser {
|
||||
return;
|
||||
}
|
||||
|
||||
// derive mapping result of current question and parsing result of last question.
|
||||
ChatLayerService chatLayerService = ContextUtils.getBean(ChatLayerService.class);
|
||||
MapResp currentMapResult = chatLayerService.map(queryNLReq);
|
||||
|
||||
List<QueryResp> historyQueries =
|
||||
getHistoryQueries(parseContext.getRequest().getChatId(), 1);
|
||||
if (historyQueries.isEmpty()) {
|
||||
@@ -179,18 +182,12 @@ public class NL2SQLParser implements ChatQueryParser {
|
||||
}
|
||||
QueryResp lastQuery = historyQueries.get(0);
|
||||
SemanticParseInfo lastParseInfo = lastQuery.getParseInfos().get(0);
|
||||
String histSQL = lastParseInfo.getSqlInfo().getCorrectedS2SQL();
|
||||
if (StringUtils.isBlank(histSQL)) // 优化性能,如果问答不是chat bi 则无需重写,因为数据都不全
|
||||
return;
|
||||
|
||||
// derive mapping result of current question and parsing result of last question.
|
||||
ChatLayerService chatLayerService = ContextUtils.getBean(ChatLayerService.class);
|
||||
MapResp currentMapResult = chatLayerService.map(queryNLReq); // 优化性能 ,只有满足条件才mapping
|
||||
|
||||
Long dataId = lastParseInfo.getDataSetId();
|
||||
|
||||
String curtMapStr =
|
||||
generateSchemaPrompt(currentMapResult.getMapInfo().getMatchedElements(dataId));
|
||||
String histMapStr = generateSchemaPrompt(lastParseInfo.getElementMatches());
|
||||
String histSQL = lastParseInfo.getSqlInfo().getCorrectedS2SQL();
|
||||
|
||||
Map<String, Object> variables = new HashMap<>();
|
||||
variables.put("current_question", currentMapResult.getQueryText());
|
||||
|
||||
@@ -1,16 +1,11 @@
|
||||
package com.tencent.supersonic.chat.server.persistence.dataobject;
|
||||
|
||||
import com.baomidou.mybatisplus.annotation.IdType;
|
||||
import com.baomidou.mybatisplus.annotation.TableId;
|
||||
import com.baomidou.mybatisplus.annotation.TableName;
|
||||
import lombok.Data;
|
||||
|
||||
@Data
|
||||
@TableName("s2_chat")
|
||||
public class ChatDO {
|
||||
|
||||
@TableId(type = IdType.AUTO)
|
||||
private Long chatId;
|
||||
private long chatId;
|
||||
private Integer agentId;
|
||||
private String chatName;
|
||||
private String createTime;
|
||||
|
||||
@@ -0,0 +1,20 @@
|
||||
package com.tencent.supersonic.chat.server.persistence.dataobject;
|
||||
|
||||
import lombok.Data;
|
||||
|
||||
import java.util.Date;
|
||||
|
||||
@Data
|
||||
public class DictConfDO {
|
||||
|
||||
private Long id;
|
||||
|
||||
private Long modelId;
|
||||
|
||||
private String dimValueInfos;
|
||||
|
||||
private String createdBy;
|
||||
private String updatedBy;
|
||||
private Date createdAt;
|
||||
private Date updatedAt;
|
||||
}
|
||||
@@ -0,0 +1,38 @@
|
||||
package com.tencent.supersonic.chat.server.persistence.dataobject;
|
||||
|
||||
import lombok.Data;
|
||||
import lombok.ToString;
|
||||
import org.apache.commons.codec.digest.DigestUtils;
|
||||
|
||||
import java.util.Date;
|
||||
|
||||
@Data
|
||||
@ToString
|
||||
public class DictTaskDO {
|
||||
|
||||
private Long id;
|
||||
|
||||
private String name;
|
||||
|
||||
private String description;
|
||||
|
||||
private String command;
|
||||
|
||||
private String commandMd5;
|
||||
|
||||
private String dimIds;
|
||||
|
||||
private Integer status;
|
||||
|
||||
private String createdBy;
|
||||
|
||||
private Date createdAt;
|
||||
|
||||
private Double progress;
|
||||
|
||||
private Long elapsedMs;
|
||||
|
||||
public String getCommandMd5() {
|
||||
return DigestUtils.md5Hex(command);
|
||||
}
|
||||
}
|
||||
@@ -35,7 +35,9 @@ public class ChatMemoryRepositoryImpl implements ChatMemoryRepository {
|
||||
if (CollectionUtils.isEmpty(ids)) {
|
||||
return;
|
||||
}
|
||||
chatMemoryMapper.deleteByIds(ids);
|
||||
for (Long id : ids) {
|
||||
chatMemoryMapper.deleteById(id);
|
||||
}
|
||||
}
|
||||
|
||||
@Override
|
||||
|
||||
@@ -148,7 +148,7 @@ public class ChatQueryRepositoryImpl implements ChatQueryRepository {
|
||||
chatQueryDO.setUserName(chatParseReq.getUser().getName());
|
||||
chatQueryDO.setQueryText(chatParseReq.getQueryText());
|
||||
chatQueryDO.setAgentId(chatParseReq.getAgentId());
|
||||
chatQueryDO.setQueryResult("{}");
|
||||
chatQueryDO.setQueryResult("");
|
||||
chatQueryDO.setQueryState(1);
|
||||
try {
|
||||
chatQueryDOMapper.insert(chatQueryDO);
|
||||
|
||||
@@ -88,10 +88,10 @@ public class WebServiceQuery extends PluginSemanticQuery {
|
||||
restTemplate = ContextUtils.getBean(RestTemplate.class);
|
||||
try {
|
||||
responseEntity =
|
||||
restTemplate.exchange(requestUrl, HttpMethod.POST, entity, String.class);
|
||||
restTemplate.exchange(requestUrl, HttpMethod.POST, entity, Object.class);
|
||||
objectResponse = responseEntity.getBody();
|
||||
log.info("objectResponse:{}", objectResponse);
|
||||
Map<String, Object> response = JSON.parseObject(objectResponse.toString());
|
||||
Map<String, Object> response = JsonUtil.objectToMap(objectResponse);
|
||||
webServiceResponse.setResult(response);
|
||||
} catch (Exception e) {
|
||||
log.info("Exception:{}", e.getMessage());
|
||||
|
||||
@@ -19,8 +19,7 @@ public class ParseContext {
|
||||
}
|
||||
|
||||
public boolean enableNL2SQL() {
|
||||
return Objects.nonNull(agent) && agent.containsDatasetTool()
|
||||
&& response.getSelectedParses().size() == 0;
|
||||
return Objects.nonNull(agent) && agent.containsDatasetTool();
|
||||
}
|
||||
|
||||
public boolean enableLLM() {
|
||||
|
||||
@@ -1,20 +1,13 @@
|
||||
package com.tencent.supersonic.chat.server.processor.execute;
|
||||
|
||||
import com.alibaba.fastjson.JSON;
|
||||
import com.alibaba.fastjson.JSONObject;
|
||||
import com.tencent.supersonic.chat.api.pojo.response.QueryResult;
|
||||
import com.tencent.supersonic.chat.server.agent.Agent;
|
||||
import com.tencent.supersonic.chat.server.persistence.dataobject.ChatQueryDO;
|
||||
import com.tencent.supersonic.chat.server.persistence.repository.ChatQueryRepository;
|
||||
import com.tencent.supersonic.chat.server.pojo.ExecuteContext;
|
||||
import com.tencent.supersonic.common.pojo.ChatApp;
|
||||
import com.tencent.supersonic.common.pojo.enums.AppModule;
|
||||
import com.tencent.supersonic.common.util.ChatAppManager;
|
||||
import com.tencent.supersonic.common.util.ContextUtils;
|
||||
import dev.langchain4j.data.message.AiMessage;
|
||||
import dev.langchain4j.model.StreamingResponseHandler;
|
||||
import dev.langchain4j.model.chat.ChatLanguageModel;
|
||||
import dev.langchain4j.model.chat.StreamingChatLanguageModel;
|
||||
import dev.langchain4j.model.input.Prompt;
|
||||
import dev.langchain4j.model.input.PromptTemplate;
|
||||
import dev.langchain4j.model.output.Response;
|
||||
@@ -31,10 +24,8 @@ import java.util.Objects;
|
||||
* DataInterpretProcessor interprets query result to make it more readable to the users.
|
||||
*/
|
||||
public class DataInterpretProcessor implements ExecuteResultProcessor {
|
||||
public static String tip = "AI 回答中...\r\n";
|
||||
private static final Logger keyPipelineLog = LoggerFactory.getLogger("keyPipeline");
|
||||
|
||||
private static Map<Long, StringBuffer> resultCache = new HashMap<>();
|
||||
private static final Logger keyPipelineLog = LoggerFactory.getLogger("keyPipeline");
|
||||
|
||||
public static final String APP_KEY = "DATA_INTERPRETER";
|
||||
private static final String INSTRUCTION = ""
|
||||
@@ -50,24 +41,12 @@ public class DataInterpretProcessor implements ExecuteResultProcessor {
|
||||
.appModule(AppModule.CHAT).description("通过大模型对结果数据做提炼总结").enable(false).build());
|
||||
}
|
||||
|
||||
public static String getTextSummary(Long queryId) {
|
||||
if (resultCache.get(queryId) != null) {
|
||||
return resultCache.get(queryId).toString();
|
||||
}
|
||||
return "";
|
||||
}
|
||||
|
||||
public static Map<Long, StringBuffer> getResultCache() {
|
||||
return resultCache;
|
||||
}
|
||||
|
||||
@Override
|
||||
public boolean accept(ExecuteContext executeContext) {
|
||||
Agent agent = executeContext.getAgent();
|
||||
ChatApp chatApp = agent.getChatAppConfig().get(APP_KEY);
|
||||
return Objects.nonNull(chatApp) && chatApp.isEnable()
|
||||
&& StringUtils.isNotBlank(executeContext.getResponse().getTextResult()) // 如果都没结果,则无法处理
|
||||
&& StringUtils.isBlank(executeContext.getResponse().getTextSummary()); // 如果已经有汇总的结果了,无法再次处理
|
||||
return Objects.nonNull(chatApp) && chatApp.isEnable();
|
||||
}
|
||||
|
||||
@Override
|
||||
@@ -77,62 +56,18 @@ public class DataInterpretProcessor implements ExecuteResultProcessor {
|
||||
ChatApp chatApp = agent.getChatAppConfig().get(APP_KEY);
|
||||
|
||||
Map<String, Object> variable = new HashMap<>();
|
||||
String question = executeContext.getResponse().getTextResult();// 结果解析应该用改写的问题,因为改写的内容信息量更大
|
||||
if (executeContext.getParseInfo().getProperties() != null
|
||||
&& executeContext.getParseInfo().getProperties().containsKey("CONTEXT")) {
|
||||
Map<String, Object> context = (Map<String, Object>) executeContext.getParseInfo()
|
||||
.getProperties().get("CONTEXT");
|
||||
if (context.get("queryText") != null && "".equals(context.get("queryText"))) {
|
||||
question = context.get("queryText").toString();
|
||||
}
|
||||
}
|
||||
variable.put("question", question);
|
||||
variable.put("question", executeContext.getRequest().getQueryText());
|
||||
variable.put("data", queryResult.getTextResult());
|
||||
|
||||
Prompt prompt = PromptTemplate.from(chatApp.getPrompt()).apply(variable);
|
||||
if (executeContext.getRequest().isStreamingResult()) {
|
||||
StreamingChatLanguageModel chatLanguageModel =
|
||||
ModelProvider.getChatStreamingModel(chatApp.getChatModelConfig());
|
||||
final Long queryId = executeContext.getRequest().getQueryId();
|
||||
resultCache.put(queryId, new StringBuffer(tip));
|
||||
chatLanguageModel.generate(prompt.toUserMessage(),
|
||||
new StreamingResponseHandler<AiMessage>() {
|
||||
@Override
|
||||
public void onNext(String token) {
|
||||
resultCache.get(queryId).append(token);
|
||||
}
|
||||
|
||||
@Override
|
||||
public void onComplete(Response<AiMessage> response) {
|
||||
ChatQueryRepository chatQueryRepository =
|
||||
ContextUtils.getBean(ChatQueryRepository.class);
|
||||
ChatQueryDO chatQueryDO = chatQueryRepository.getChatQueryDO(queryId);
|
||||
JSONObject queryResult = JSON.parseObject(chatQueryDO.getQueryResult());
|
||||
queryResult.put("textSummary",
|
||||
resultCache.get(queryId).toString().substring(tip.length()));
|
||||
chatQueryDO.setQueryResult(queryResult.toJSONString());
|
||||
chatQueryRepository.updateChatQuery(chatQueryDO);
|
||||
resultCache.remove(queryId);
|
||||
}
|
||||
|
||||
@Override
|
||||
public void onError(Throwable error) {
|
||||
error.printStackTrace();
|
||||
resultCache.remove(queryId);
|
||||
}
|
||||
});
|
||||
} else {
|
||||
ChatLanguageModel chatLanguageModel =
|
||||
ModelProvider.getChatModel(chatApp.getChatModelConfig());
|
||||
Response<AiMessage> response = chatLanguageModel.generate(prompt.toUserMessage());
|
||||
String anwser = response.content().text();
|
||||
keyPipelineLog.info("DataInterpretProcessor modelReq:\n{} \nmodelResp:\n{}",
|
||||
prompt.text(), anwser);
|
||||
if (StringUtils.isNotBlank(anwser)) {
|
||||
queryResult.setTextSummary(anwser);
|
||||
}
|
||||
ChatLanguageModel chatLanguageModel =
|
||||
ModelProvider.getChatModel(chatApp.getChatModelConfig());
|
||||
Response<AiMessage> response = chatLanguageModel.generate(prompt.toUserMessage());
|
||||
String anwser = response.content().text();
|
||||
keyPipelineLog.info("DataInterpretProcessor modelReq:\n{} \nmodelResp:\n{}", prompt.text(),
|
||||
anwser);
|
||||
if (StringUtils.isNotBlank(anwser)) {
|
||||
queryResult.setTextSummary(anwser);
|
||||
}
|
||||
|
||||
|
||||
}
|
||||
}
|
||||
|
||||
@@ -22,10 +22,11 @@ public class ChatController {
|
||||
private ChatManageService chatService;
|
||||
|
||||
@PostMapping("/save")
|
||||
public Long save(@RequestParam(value = "chatName") String chatName,
|
||||
public Boolean save(@RequestParam(value = "chatName") String chatName,
|
||||
@RequestParam(value = "agentId", required = false) Integer agentId,
|
||||
HttpServletRequest request, HttpServletResponse response) {
|
||||
return chatService.addChat(UserHolder.findUser(request, response), chatName, agentId);
|
||||
chatService.addChat(UserHolder.findUser(request, response), chatName, agentId);
|
||||
return true;
|
||||
}
|
||||
|
||||
@GetMapping("/getAll")
|
||||
|
||||
@@ -5,7 +5,6 @@ import com.tencent.supersonic.chat.api.pojo.request.ChatExecuteReq;
|
||||
import com.tencent.supersonic.chat.api.pojo.request.ChatParseReq;
|
||||
import com.tencent.supersonic.chat.api.pojo.request.ChatQueryDataReq;
|
||||
import com.tencent.supersonic.chat.api.pojo.response.ChatParseResp;
|
||||
import com.tencent.supersonic.chat.api.pojo.response.QueryResult;
|
||||
import com.tencent.supersonic.chat.server.service.ChatQueryService;
|
||||
import com.tencent.supersonic.common.pojo.User;
|
||||
import com.tencent.supersonic.common.pojo.exception.InvalidArgumentException;
|
||||
@@ -51,14 +50,6 @@ public class ChatQueryController {
|
||||
return chatQueryService.execute(chatExecuteReq);
|
||||
}
|
||||
|
||||
@PostMapping("getExecuteSummary")
|
||||
public Object getExecuteSummary(@RequestBody ChatExecuteReq chatExecuteReq,
|
||||
HttpServletRequest request, HttpServletResponse response) {
|
||||
chatExecuteReq.setUser(UserHolder.findUser(request, response));
|
||||
QueryResult res = chatQueryService.getTextSummary(chatExecuteReq);
|
||||
return res;
|
||||
}
|
||||
|
||||
@PostMapping("/")
|
||||
public Object query(@RequestBody ChatParseReq chatParseReq, HttpServletRequest request,
|
||||
HttpServletResponse response) throws Exception {
|
||||
|
||||
@@ -4,12 +4,12 @@ import com.github.pagehelper.PageInfo;
|
||||
import com.tencent.supersonic.auth.api.authentication.utils.UserHolder;
|
||||
import com.tencent.supersonic.chat.api.pojo.enums.MemoryReviewResult;
|
||||
import com.tencent.supersonic.chat.api.pojo.request.ChatMemoryCreateReq;
|
||||
import com.tencent.supersonic.chat.api.pojo.request.ChatMemoryDeleteReq;
|
||||
import com.tencent.supersonic.chat.api.pojo.request.ChatMemoryUpdateReq;
|
||||
import com.tencent.supersonic.chat.api.pojo.request.PageMemoryReq;
|
||||
import com.tencent.supersonic.chat.server.pojo.ChatMemory;
|
||||
import com.tencent.supersonic.chat.server.service.MemoryService;
|
||||
import com.tencent.supersonic.common.pojo.User;
|
||||
import com.tencent.supersonic.headless.api.pojo.request.MetaBatchReq;
|
||||
import jakarta.servlet.http.HttpServletRequest;
|
||||
import jakarta.servlet.http.HttpServletResponse;
|
||||
import org.springframework.beans.factory.annotation.Autowired;
|
||||
@@ -53,10 +53,8 @@ public class MemoryController {
|
||||
}
|
||||
|
||||
@PostMapping("batchDelete")
|
||||
public Boolean deleteMemory(@RequestBody ChatMemoryDeleteReq chatMemoryDeleteReq,
|
||||
HttpServletRequest request, HttpServletResponse response) {
|
||||
User user = UserHolder.findUser(request, response);
|
||||
memoryService.batchDelete(chatMemoryDeleteReq, user);
|
||||
public Boolean batchDelete(@RequestBody MetaBatchReq metaBatchReq) {
|
||||
memoryService.batchDelete(metaBatchReq.getIds());
|
||||
return true;
|
||||
}
|
||||
}
|
||||
|
||||
@@ -35,8 +35,6 @@ public interface ChatManageService {
|
||||
|
||||
QueryResp getChatQuery(Long queryId);
|
||||
|
||||
ChatQueryDO getChatQueryDO(Long queryId);
|
||||
|
||||
List<QueryResp> getChatQueries(Integer chatId);
|
||||
|
||||
ShowCaseResp queryShowCase(PageQueryInfoReq pageQueryInfoReq, int agentId);
|
||||
|
||||
@@ -19,8 +19,6 @@ public interface ChatQueryService {
|
||||
|
||||
QueryResult execute(ChatExecuteReq chatExecuteReq) throws Exception;
|
||||
|
||||
QueryResult getTextSummary(ChatExecuteReq chatExecuteReq);
|
||||
|
||||
QueryResult parseAndExecute(ChatParseReq chatParseReq);
|
||||
|
||||
Object queryData(ChatQueryDataReq chatQueryDataReq, User user) throws Exception;
|
||||
|
||||
@@ -1,7 +1,6 @@
|
||||
package com.tencent.supersonic.chat.server.service;
|
||||
|
||||
import com.github.pagehelper.PageInfo;
|
||||
import com.tencent.supersonic.chat.api.pojo.request.ChatMemoryDeleteReq;
|
||||
import com.tencent.supersonic.chat.api.pojo.request.ChatMemoryFilter;
|
||||
import com.tencent.supersonic.chat.api.pojo.request.ChatMemoryUpdateReq;
|
||||
import com.tencent.supersonic.chat.api.pojo.request.PageMemoryReq;
|
||||
@@ -15,7 +14,7 @@ public interface MemoryService {
|
||||
|
||||
void updateMemory(ChatMemoryUpdateReq chatMemoryUpdateReq, User user);
|
||||
|
||||
void batchDelete(ChatMemoryDeleteReq chatMemoryDeleteReq, User user);
|
||||
void batchDelete(List<Long> ids);
|
||||
|
||||
PageInfo<ChatMemory> pageMemories(PageMemoryReq pageMemoryReq);
|
||||
|
||||
|
||||
@@ -22,7 +22,6 @@ import lombok.extern.slf4j.Slf4j;
|
||||
import org.springframework.beans.BeanUtils;
|
||||
import org.springframework.beans.factory.annotation.Autowired;
|
||||
import org.springframework.beans.factory.annotation.Qualifier;
|
||||
import org.springframework.context.annotation.Lazy;
|
||||
import org.springframework.stereotype.Service;
|
||||
import org.springframework.util.CollectionUtils;
|
||||
|
||||
@@ -40,7 +39,6 @@ public class AgentServiceImpl extends ServiceImpl<AgentDOMapper, AgentDO> implem
|
||||
private MemoryService memoryService;
|
||||
|
||||
@Autowired
|
||||
@Lazy
|
||||
private ChatQueryService chatQueryService;
|
||||
|
||||
@Autowired
|
||||
|
||||
@@ -123,11 +123,6 @@ public class ChatManageServiceImpl implements ChatManageService {
|
||||
return chatQueryRepository.getChatQuery(queryId);
|
||||
}
|
||||
|
||||
@Override
|
||||
public ChatQueryDO getChatQueryDO(Long queryId) {
|
||||
return chatQueryRepository.getChatQueryDO(queryId);
|
||||
}
|
||||
|
||||
@Override
|
||||
public List<QueryResp> getChatQueries(Integer chatId) {
|
||||
List<QueryResp> queries = chatQueryRepository.getChatQueries(chatId);
|
||||
@@ -238,10 +233,6 @@ public class ChatManageServiceImpl implements ChatManageService {
|
||||
@Override
|
||||
public SemanticParseInfo getParseInfo(Long questionId, int parseId) {
|
||||
ChatParseDO chatParseDO = chatQueryRepository.getParseInfo(questionId, parseId);
|
||||
if (chatParseDO == null) {
|
||||
return null;
|
||||
} else {
|
||||
return JSONObject.parseObject(chatParseDO.getParseInfo(), SemanticParseInfo.class);
|
||||
}
|
||||
return JSONObject.parseObject(chatParseDO.getParseInfo(), SemanticParseInfo.class);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,6 +1,5 @@
|
||||
package com.tencent.supersonic.chat.server.service.impl;
|
||||
|
||||
import com.alibaba.fastjson2.JSON;
|
||||
import com.google.common.collect.Lists;
|
||||
import com.tencent.supersonic.chat.api.pojo.request.ChatExecuteReq;
|
||||
import com.tencent.supersonic.chat.api.pojo.request.ChatParseReq;
|
||||
@@ -10,10 +9,8 @@ import com.tencent.supersonic.chat.api.pojo.response.QueryResult;
|
||||
import com.tencent.supersonic.chat.server.agent.Agent;
|
||||
import com.tencent.supersonic.chat.server.executor.ChatQueryExecutor;
|
||||
import com.tencent.supersonic.chat.server.parser.ChatQueryParser;
|
||||
import com.tencent.supersonic.chat.server.persistence.dataobject.ChatQueryDO;
|
||||
import com.tencent.supersonic.chat.server.pojo.ExecuteContext;
|
||||
import com.tencent.supersonic.chat.server.pojo.ParseContext;
|
||||
import com.tencent.supersonic.chat.server.processor.execute.DataInterpretProcessor;
|
||||
import com.tencent.supersonic.chat.server.processor.execute.ExecuteResultProcessor;
|
||||
import com.tencent.supersonic.chat.server.processor.parse.ParseResultProcessor;
|
||||
import com.tencent.supersonic.chat.server.service.AgentService;
|
||||
@@ -21,11 +18,7 @@ import com.tencent.supersonic.chat.server.service.ChatManageService;
|
||||
import com.tencent.supersonic.chat.server.service.ChatQueryService;
|
||||
import com.tencent.supersonic.chat.server.util.ComponentFactory;
|
||||
import com.tencent.supersonic.chat.server.util.QueryReqConverter;
|
||||
import com.tencent.supersonic.common.jsqlparser.FieldExpression;
|
||||
import com.tencent.supersonic.common.jsqlparser.SqlAddHelper;
|
||||
import com.tencent.supersonic.common.jsqlparser.SqlRemoveHelper;
|
||||
import com.tencent.supersonic.common.jsqlparser.SqlReplaceHelper;
|
||||
import com.tencent.supersonic.common.jsqlparser.SqlSelectHelper;
|
||||
import com.tencent.supersonic.common.jsqlparser.*;
|
||||
import com.tencent.supersonic.common.pojo.User;
|
||||
import com.tencent.supersonic.common.pojo.enums.FilterOperatorEnum;
|
||||
import com.tencent.supersonic.common.util.DateUtils;
|
||||
@@ -51,27 +44,15 @@ import lombok.extern.slf4j.Slf4j;
|
||||
import net.sf.jsqlparser.expression.Expression;
|
||||
import net.sf.jsqlparser.expression.LongValue;
|
||||
import net.sf.jsqlparser.expression.StringValue;
|
||||
import net.sf.jsqlparser.expression.operators.relational.ComparisonOperator;
|
||||
import net.sf.jsqlparser.expression.operators.relational.GreaterThanEquals;
|
||||
import net.sf.jsqlparser.expression.operators.relational.InExpression;
|
||||
import net.sf.jsqlparser.expression.operators.relational.MinorThanEquals;
|
||||
import net.sf.jsqlparser.expression.operators.relational.ParenthesedExpressionList;
|
||||
import net.sf.jsqlparser.expression.operators.relational.*;
|
||||
import net.sf.jsqlparser.schema.Column;
|
||||
import org.apache.commons.lang3.StringUtils;
|
||||
import org.apache.commons.lang3.tuple.Pair;
|
||||
import org.springframework.beans.factory.annotation.Autowired;
|
||||
import org.springframework.context.annotation.Lazy;
|
||||
import org.springframework.stereotype.Service;
|
||||
import org.springframework.util.CollectionUtils;
|
||||
|
||||
import java.util.ArrayList;
|
||||
import java.util.HashMap;
|
||||
import java.util.HashSet;
|
||||
import java.util.Iterator;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
import java.util.Objects;
|
||||
import java.util.Set;
|
||||
import java.util.*;
|
||||
import java.util.stream.Collectors;
|
||||
|
||||
@Slf4j
|
||||
@@ -85,7 +66,6 @@ public class ChatQueryServiceImpl implements ChatQueryService {
|
||||
@Autowired
|
||||
private SemanticLayerService semanticLayerService;
|
||||
@Autowired
|
||||
@Lazy
|
||||
private AgentService agentService;
|
||||
|
||||
private final List<ChatQueryParser> chatQueryParsers = ComponentFactory.getChatParsers();
|
||||
@@ -128,8 +108,6 @@ public class ChatQueryServiceImpl implements ChatQueryService {
|
||||
}
|
||||
|
||||
if (!parseContext.needFeedback()) {
|
||||
parseContext.getResponse().getParseTimeCost().setParseTime(System.currentTimeMillis()
|
||||
- parseContext.getResponse().getParseTimeCost().getParseStartTime());
|
||||
chatManageService.batchAddParse(chatParseReq, parseContext.getResponse());
|
||||
chatManageService.updateParseCostTime(parseContext.getResponse());
|
||||
}
|
||||
@@ -163,21 +141,6 @@ public class ChatQueryServiceImpl implements ChatQueryService {
|
||||
return queryResult;
|
||||
}
|
||||
|
||||
@Override
|
||||
public QueryResult getTextSummary(ChatExecuteReq chatExecuteReq) {
|
||||
String text = DataInterpretProcessor.getTextSummary(chatExecuteReq.getQueryId());
|
||||
if (StringUtils.isNotBlank(text)) {
|
||||
QueryResult res = new QueryResult();
|
||||
res.setTextSummary(text);
|
||||
res.setQueryId(chatExecuteReq.getQueryId());
|
||||
return res;
|
||||
} else {
|
||||
ChatQueryDO chatQueryDo = chatManageService.getChatQueryDO(chatExecuteReq.getQueryId());
|
||||
QueryResult res = JSON.parseObject(chatQueryDo.getQueryResult(), QueryResult.class);
|
||||
return res;
|
||||
}
|
||||
}
|
||||
|
||||
@Override
|
||||
public QueryResult parseAndExecute(ChatParseReq chatParseReq) {
|
||||
ChatParseResp parseResp = parse(chatParseReq);
|
||||
|
||||
@@ -6,7 +6,6 @@ import com.github.pagehelper.PageHelper;
|
||||
import com.github.pagehelper.PageInfo;
|
||||
import com.tencent.supersonic.chat.api.pojo.enums.MemoryReviewResult;
|
||||
import com.tencent.supersonic.chat.api.pojo.enums.MemoryStatus;
|
||||
import com.tencent.supersonic.chat.api.pojo.request.ChatMemoryDeleteReq;
|
||||
import com.tencent.supersonic.chat.api.pojo.request.ChatMemoryFilter;
|
||||
import com.tencent.supersonic.chat.api.pojo.request.ChatMemoryUpdateReq;
|
||||
import com.tencent.supersonic.chat.api.pojo.request.PageMemoryReq;
|
||||
@@ -27,7 +26,7 @@ import org.springframework.boot.CommandLineRunner;
|
||||
import org.springframework.stereotype.Service;
|
||||
import org.springframework.util.CollectionUtils;
|
||||
|
||||
import java.util.ArrayList;
|
||||
import java.util.Arrays;
|
||||
import java.util.Date;
|
||||
import java.util.List;
|
||||
import java.util.Objects;
|
||||
@@ -35,7 +34,7 @@ import java.util.stream.Collectors;
|
||||
|
||||
@Service
|
||||
@Slf4j
|
||||
public class MemoryServiceImpl implements MemoryService, CommandLineRunner {
|
||||
public class MemoryServiceImpl implements MemoryService , CommandLineRunner {
|
||||
|
||||
@Autowired
|
||||
private ChatMemoryRepository chatMemoryRepository;
|
||||
@@ -66,17 +65,12 @@ public class MemoryServiceImpl implements MemoryService, CommandLineRunner {
|
||||
ChatMemoryDO chatMemoryDO = chatMemoryRepository.getMemory(chatMemoryUpdateReq.getId());
|
||||
boolean hadEnabled =
|
||||
MemoryStatus.ENABLED.toString().equals(chatMemoryDO.getStatus().trim());
|
||||
|
||||
if (MemoryStatus.ENABLED.equals(chatMemoryUpdateReq.getStatus())) {
|
||||
// Update the latest SQL/Schema to vector DB once memory is enabled
|
||||
chatMemoryDO.setS2sql(chatMemoryUpdateReq.getS2sql());
|
||||
chatMemoryDO.setDbSchema(chatMemoryUpdateReq.getDbSchema());
|
||||
if (MemoryStatus.ENABLED.equals(chatMemoryUpdateReq.getStatus()) && !hadEnabled) {
|
||||
enableMemory(chatMemoryDO);
|
||||
} else if ((MemoryStatus.DISABLED.equals(chatMemoryUpdateReq.getStatus())
|
||||
|| MemoryStatus.PENDING.equals(chatMemoryUpdateReq.getStatus())) && hadEnabled) {
|
||||
// Remove from vector DB when transitioning: launched→disabled OR enabled→pending
|
||||
} else if (MemoryStatus.DISABLED.equals(chatMemoryUpdateReq.getStatus()) && hadEnabled) {
|
||||
disableMemory(chatMemoryDO);
|
||||
}
|
||||
|
||||
LambdaUpdateWrapper<ChatMemoryDO> updateWrapper = new LambdaUpdateWrapper<>();
|
||||
updateWrapper.eq(ChatMemoryDO::getId, chatMemoryDO.getId());
|
||||
if (Objects.nonNull(chatMemoryUpdateReq.getStatus())) {
|
||||
@@ -97,12 +91,6 @@ public class MemoryServiceImpl implements MemoryService, CommandLineRunner {
|
||||
updateWrapper.set(ChatMemoryDO::getHumanReviewCmt,
|
||||
chatMemoryUpdateReq.getHumanReviewCmt());
|
||||
}
|
||||
if (Objects.nonNull(chatMemoryUpdateReq.getDbSchema())) {
|
||||
updateWrapper.set(ChatMemoryDO::getDbSchema, chatMemoryUpdateReq.getDbSchema());
|
||||
}
|
||||
if (Objects.nonNull(chatMemoryUpdateReq.getS2sql())) {
|
||||
updateWrapper.set(ChatMemoryDO::getS2sql, chatMemoryUpdateReq.getS2sql());
|
||||
}
|
||||
updateWrapper.set(ChatMemoryDO::getUpdatedAt, new Date());
|
||||
updateWrapper.set(ChatMemoryDO::getUpdatedBy, user.getName());
|
||||
|
||||
@@ -110,22 +98,7 @@ public class MemoryServiceImpl implements MemoryService, CommandLineRunner {
|
||||
}
|
||||
|
||||
@Override
|
||||
public void batchDelete(ChatMemoryDeleteReq chatMemoryDeleteReq, User user) {
|
||||
QueryWrapper<ChatMemoryDO> queryWrapper = new QueryWrapper<>();
|
||||
if (!CollectionUtils.isEmpty(chatMemoryDeleteReq.getIds())) {
|
||||
queryWrapper.lambda().in(ChatMemoryDO::getId, chatMemoryDeleteReq.getIds());
|
||||
}
|
||||
if (chatMemoryDeleteReq.getAgentId() != null) {
|
||||
queryWrapper.lambda().eq(ChatMemoryDO::getAgentId, chatMemoryDeleteReq.getAgentId());
|
||||
}
|
||||
List<ChatMemoryDO> chatMemoryDOS = chatMemoryRepository.getMemories(queryWrapper);
|
||||
List<Long> ids = new ArrayList<>();
|
||||
chatMemoryDOS.forEach(chatMemoryDO -> {
|
||||
if (MemoryStatus.ENABLED.toString().equals(chatMemoryDO.getStatus().trim())) {
|
||||
disableMemory(chatMemoryDO);
|
||||
}
|
||||
ids.add(chatMemoryDO.getId());
|
||||
});
|
||||
public void batchDelete(List<Long> ids) {
|
||||
chatMemoryRepository.batchDelete(ids);
|
||||
}
|
||||
|
||||
@@ -222,14 +195,12 @@ public class MemoryServiceImpl implements MemoryService, CommandLineRunner {
|
||||
public void run(String... args) { // 优化,启动时检查,向量数据,将记忆放到向量数据库
|
||||
loadSysExemplars();
|
||||
}
|
||||
|
||||
public void loadSysExemplars() {
|
||||
try {
|
||||
List<ChatMemory> memories = this
|
||||
.getMemories(ChatMemoryFilter.builder().status(MemoryStatus.ENABLED).build());
|
||||
for (ChatMemory memory : memories) {
|
||||
exemplarService.storeExemplar(
|
||||
embeddingConfig.getMemoryCollectionName(memory.getAgentId()),
|
||||
List<ChatMemory> memories =
|
||||
this.getMemories(ChatMemoryFilter.builder().status(MemoryStatus.ENABLED).build());
|
||||
for(ChatMemory memory:memories){
|
||||
exemplarService.storeExemplar(embeddingConfig.getMemoryCollectionName(memory.getAgentId()),
|
||||
Text2SQLExemplar.builder().question(memory.getQuestion())
|
||||
.sideInfo(memory.getSideInfo()).dbSchema(memory.getDbSchema())
|
||||
.sql(memory.getS2sql()).build());
|
||||
|
||||
@@ -108,7 +108,6 @@ public class PluginServiceImpl implements PluginService {
|
||||
if (StringUtils.isNotBlank(pluginQueryReq.getCreatedBy())) {
|
||||
queryWrapper.lambda().eq(PluginDO::getCreatedBy, pluginQueryReq.getCreatedBy());
|
||||
}
|
||||
queryWrapper.orderByAsc("name");
|
||||
List<PluginDO> pluginDOS = pluginRepository.query(queryWrapper);
|
||||
if (StringUtils.isNotBlank(pluginQueryReq.getPattern())) {
|
||||
pluginDOS = pluginDOS.stream()
|
||||
|
||||
@@ -21,10 +21,7 @@
|
||||
<groupId>org.springframework.boot</groupId>
|
||||
<artifactId>spring-boot-starter-validation</artifactId>
|
||||
</dependency>
|
||||
<dependency>
|
||||
<groupId>org.springframework.boot</groupId>
|
||||
<artifactId>spring-boot-autoconfigure-processor</artifactId>
|
||||
</dependency>
|
||||
|
||||
|
||||
<dependency>
|
||||
<groupId>org.slf4j</groupId>
|
||||
@@ -36,7 +33,7 @@
|
||||
<dependency>
|
||||
<groupId>org.apache.httpcomponents.client5</groupId>
|
||||
<artifactId>httpclient5</artifactId>
|
||||
<version>${httpclient5.version}</version>
|
||||
<version>${httpclient5.version}</version> <!-- 请确认使用最新稳定版本 -->
|
||||
</dependency>
|
||||
<!-- <dependency>-->
|
||||
<!-- <groupId>org.apache.httpcomponents</groupId>-->
|
||||
@@ -185,6 +182,10 @@
|
||||
<groupId>dev.langchain4j</groupId>
|
||||
<artifactId>langchain4j-pgvector</artifactId>
|
||||
</dependency>
|
||||
<dependency>
|
||||
<groupId>dev.langchain4j</groupId>
|
||||
<artifactId>langchain4j-azure-open-ai</artifactId>
|
||||
</dependency>
|
||||
<dependency>
|
||||
<groupId>dev.langchain4j</groupId>
|
||||
<artifactId>langchain4j-embeddings-bge-small-zh</artifactId>
|
||||
@@ -197,6 +198,34 @@
|
||||
<groupId>dev.langchain4j</groupId>
|
||||
<artifactId>langchain4j-embeddings-all-minilm-l6-v2-q</artifactId>
|
||||
</dependency>
|
||||
<dependency>
|
||||
<groupId>dev.langchain4j</groupId>
|
||||
<artifactId>langchain4j-qianfan</artifactId>
|
||||
</dependency>
|
||||
<dependency>
|
||||
<groupId>dev.langchain4j</groupId>
|
||||
<artifactId>langchain4j-zhipu-ai</artifactId>
|
||||
</dependency>
|
||||
<dependency>
|
||||
<groupId>dev.langchain4j</groupId>
|
||||
<artifactId>langchain4j-dashscope</artifactId>
|
||||
<exclusions>
|
||||
<exclusion>
|
||||
<groupId>org.slf4j</groupId>
|
||||
<artifactId>slf4j-simple</artifactId>
|
||||
</exclusion>
|
||||
</exclusions>
|
||||
</dependency>
|
||||
<dependency>
|
||||
<groupId>dev.langchain4j</groupId>
|
||||
<artifactId>langchain4j-chatglm</artifactId>
|
||||
<exclusions>
|
||||
<exclusion>
|
||||
<groupId>org.slf4j</groupId>
|
||||
<artifactId>slf4j-simple</artifactId>
|
||||
</exclusion>
|
||||
</exclusions>
|
||||
</dependency>
|
||||
<dependency>
|
||||
<groupId>dev.langchain4j</groupId>
|
||||
<artifactId>langchain4j-ollama</artifactId>
|
||||
@@ -208,6 +237,11 @@
|
||||
<version>${hanlp.version}</version>
|
||||
</dependency>
|
||||
|
||||
<dependency>
|
||||
<groupId>org.springframework.boot</groupId>
|
||||
<artifactId>spring-boot-autoconfigure-processor</artifactId>
|
||||
</dependency>
|
||||
|
||||
<dependency>
|
||||
<groupId>com.google.code.gson</groupId>
|
||||
<artifactId>gson</artifactId>
|
||||
|
||||
@@ -21,8 +21,7 @@ public class LoadRemoveService {
|
||||
List<String> resultList = new ArrayList<>(value);
|
||||
if (!CollectionUtils.isEmpty(modelIdOrDataSetIds)) {
|
||||
resultList.removeIf(nature -> {
|
||||
if (Objects.isNull(nature) || !nature.startsWith("_")) { // 系统的字典是以 _ 开头的,
|
||||
// 过滤因引用外部字典导致的异常
|
||||
if (Objects.isNull(nature)) {
|
||||
return false;
|
||||
}
|
||||
Long id = getId(nature);
|
||||
|
||||
@@ -77,6 +77,11 @@ public class SemanticSqlConformance implements SqlConformance {
|
||||
return SqlConformanceEnum.BIG_QUERY.isMinusAllowed();
|
||||
}
|
||||
|
||||
@Override
|
||||
public boolean isRegexReplaceCaptureGroupDollarIndexed() {
|
||||
return SqlConformanceEnum.BIG_QUERY.isRegexReplaceCaptureGroupDollarIndexed();
|
||||
}
|
||||
|
||||
@Override
|
||||
public boolean isApplyAllowed() {
|
||||
return SqlConformanceEnum.BIG_QUERY.isApplyAllowed();
|
||||
|
||||
@@ -4,10 +4,14 @@ import com.google.common.collect.ImmutableMap;
|
||||
import com.google.common.collect.Lists;
|
||||
import com.tencent.supersonic.common.pojo.EmbeddingModelConfig;
|
||||
import com.tencent.supersonic.common.pojo.Parameter;
|
||||
import dev.langchain4j.provider.AzureModelFactory;
|
||||
import dev.langchain4j.provider.DashscopeModelFactory;
|
||||
import dev.langchain4j.provider.EmbeddingModelConstant;
|
||||
import dev.langchain4j.provider.InMemoryModelFactory;
|
||||
import dev.langchain4j.provider.OllamaModelFactory;
|
||||
import dev.langchain4j.provider.OpenAiModelFactory;
|
||||
import dev.langchain4j.provider.QianfanModelFactory;
|
||||
import dev.langchain4j.provider.ZhipuModelFactory;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.springframework.stereotype.Service;
|
||||
|
||||
@@ -66,31 +70,52 @@ public class EmbeddingModelParameterConfig extends ParameterConfig {
|
||||
|
||||
private static ArrayList<String> getCandidateValues() {
|
||||
return Lists.newArrayList(InMemoryModelFactory.PROVIDER, OpenAiModelFactory.PROVIDER,
|
||||
OllamaModelFactory.PROVIDER);
|
||||
OllamaModelFactory.PROVIDER, DashscopeModelFactory.PROVIDER,
|
||||
QianfanModelFactory.PROVIDER, ZhipuModelFactory.PROVIDER,
|
||||
AzureModelFactory.PROVIDER);
|
||||
}
|
||||
|
||||
private static List<Parameter.Dependency> getBaseUrlDependency() {
|
||||
return getDependency(EMBEDDING_MODEL_PROVIDER.getName(),
|
||||
Lists.newArrayList(OpenAiModelFactory.PROVIDER, OllamaModelFactory.PROVIDER),
|
||||
Lists.newArrayList(OpenAiModelFactory.PROVIDER, OllamaModelFactory.PROVIDER,
|
||||
AzureModelFactory.PROVIDER, DashscopeModelFactory.PROVIDER,
|
||||
QianfanModelFactory.PROVIDER, ZhipuModelFactory.PROVIDER),
|
||||
ImmutableMap.of(OpenAiModelFactory.PROVIDER, OpenAiModelFactory.DEFAULT_BASE_URL,
|
||||
OllamaModelFactory.PROVIDER, OllamaModelFactory.DEFAULT_BASE_URL));
|
||||
OllamaModelFactory.PROVIDER, OllamaModelFactory.DEFAULT_BASE_URL,
|
||||
AzureModelFactory.PROVIDER, AzureModelFactory.DEFAULT_BASE_URL,
|
||||
DashscopeModelFactory.PROVIDER, DashscopeModelFactory.DEFAULT_BASE_URL,
|
||||
QianfanModelFactory.PROVIDER, QianfanModelFactory.DEFAULT_BASE_URL,
|
||||
ZhipuModelFactory.PROVIDER, ZhipuModelFactory.DEFAULT_BASE_URL));
|
||||
}
|
||||
|
||||
private static List<Parameter.Dependency> getApiKeyDependency() {
|
||||
return getDependency(EMBEDDING_MODEL_PROVIDER.getName(),
|
||||
Lists.newArrayList(OpenAiModelFactory.PROVIDER),
|
||||
ImmutableMap.of(OpenAiModelFactory.PROVIDER, DEMO));
|
||||
Lists.newArrayList(OpenAiModelFactory.PROVIDER, AzureModelFactory.PROVIDER,
|
||||
DashscopeModelFactory.PROVIDER, QianfanModelFactory.PROVIDER,
|
||||
ZhipuModelFactory.PROVIDER),
|
||||
ImmutableMap.of(OpenAiModelFactory.PROVIDER, DEMO, AzureModelFactory.PROVIDER, DEMO,
|
||||
DashscopeModelFactory.PROVIDER, DEMO, QianfanModelFactory.PROVIDER, DEMO,
|
||||
ZhipuModelFactory.PROVIDER, DEMO));
|
||||
}
|
||||
|
||||
private static List<Parameter.Dependency> getModelNameDependency() {
|
||||
return getDependency(EMBEDDING_MODEL_PROVIDER.getName(),
|
||||
Lists.newArrayList(InMemoryModelFactory.PROVIDER, OpenAiModelFactory.PROVIDER,
|
||||
OllamaModelFactory.PROVIDER),
|
||||
OllamaModelFactory.PROVIDER, AzureModelFactory.PROVIDER,
|
||||
DashscopeModelFactory.PROVIDER, QianfanModelFactory.PROVIDER,
|
||||
ZhipuModelFactory.PROVIDER),
|
||||
ImmutableMap.of(InMemoryModelFactory.PROVIDER, EmbeddingModelConstant.BGE_SMALL_ZH,
|
||||
OpenAiModelFactory.PROVIDER,
|
||||
OpenAiModelFactory.DEFAULT_EMBEDDING_MODEL_NAME,
|
||||
OllamaModelFactory.PROVIDER,
|
||||
OllamaModelFactory.DEFAULT_EMBEDDING_MODEL_NAME));
|
||||
OllamaModelFactory.DEFAULT_EMBEDDING_MODEL_NAME, AzureModelFactory.PROVIDER,
|
||||
AzureModelFactory.DEFAULT_EMBEDDING_MODEL_NAME,
|
||||
DashscopeModelFactory.PROVIDER,
|
||||
DashscopeModelFactory.DEFAULT_EMBEDDING_MODEL_NAME,
|
||||
QianfanModelFactory.PROVIDER,
|
||||
QianfanModelFactory.DEFAULT_EMBEDDING_MODEL_NAME,
|
||||
ZhipuModelFactory.PROVIDER,
|
||||
ZhipuModelFactory.DEFAULT_EMBEDDING_MODEL_NAME));
|
||||
}
|
||||
|
||||
private static List<Parameter.Dependency> getModelPathDependency() {
|
||||
@@ -101,7 +126,7 @@ public class EmbeddingModelParameterConfig extends ParameterConfig {
|
||||
|
||||
private static List<Parameter.Dependency> getSecretKeyDependency() {
|
||||
return getDependency(EMBEDDING_MODEL_PROVIDER.getName(),
|
||||
Lists.newArrayList(OpenAiModelFactory.PROVIDER),
|
||||
ImmutableMap.of(OpenAiModelFactory.PROVIDER, DEMO));
|
||||
Lists.newArrayList(QianfanModelFactory.PROVIDER),
|
||||
ImmutableMap.of(QianfanModelFactory.PROVIDER, DEMO));
|
||||
}
|
||||
}
|
||||
|
||||
@@ -46,62 +46,6 @@ public class FieldValueReplaceVisitor extends ExpressionVisitorAdapter {
|
||||
replaceComparisonExpression(expr);
|
||||
}
|
||||
|
||||
public void visit(LikeExpression expr) {
|
||||
Expression leftExpression = expr.getLeftExpression();
|
||||
Expression rightExpression = expr.getRightExpression();
|
||||
|
||||
if (!(leftExpression instanceof Column)) {
|
||||
return;
|
||||
}
|
||||
if (CollectionUtils.isEmpty(filedNameToValueMap)) {
|
||||
return;
|
||||
}
|
||||
if (Objects.isNull(rightExpression) || Objects.isNull(leftExpression)) {
|
||||
return;
|
||||
}
|
||||
Column column = (Column) leftExpression;
|
||||
String columnName = column.getColumnName();
|
||||
if (StringUtils.isEmpty(columnName)) {
|
||||
return;
|
||||
}
|
||||
Map<String, String> valueMap = filedNameToValueMap.get(columnName);
|
||||
if (Objects.isNull(valueMap) || valueMap.isEmpty()) {
|
||||
return;
|
||||
}
|
||||
if (rightExpression instanceof StringValue) {
|
||||
StringValue rightStringValue = (StringValue) rightExpression;
|
||||
String value = rightStringValue.getValue();
|
||||
|
||||
// 使用split处理方式,按通配符分割字符串,对每个片段进行转换
|
||||
String[] parts = value.split("%", -1);
|
||||
boolean changed = false;
|
||||
|
||||
// 处理每个部分
|
||||
for (int i = 0; i < parts.length; i++) {
|
||||
if (!parts[i].isEmpty()) {
|
||||
String replaceValue = getReplaceValue(valueMap, parts[i]);
|
||||
if (StringUtils.isNotEmpty(replaceValue) && !parts[i].equals(replaceValue)) {
|
||||
parts[i] = replaceValue;
|
||||
changed = true;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 如果有任何部分发生变化,则重新构建字符串
|
||||
if (changed) {
|
||||
StringBuilder sb = new StringBuilder();
|
||||
for (int i = 0; i < parts.length; i++) {
|
||||
sb.append(parts[i]);
|
||||
// 除了最后一个部分,其他部分后面都需要加上"%"
|
||||
if (i < parts.length - 1) {
|
||||
sb.append("%");
|
||||
}
|
||||
}
|
||||
rightStringValue.setValue(sb.toString());
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
public void visit(InExpression inExpression) {
|
||||
if (!(inExpression.getLeftExpression() instanceof Column)) {
|
||||
return;
|
||||
|
||||
@@ -16,21 +16,16 @@ import net.sf.jsqlparser.parser.CCJSqlParserUtil;
|
||||
import net.sf.jsqlparser.schema.Column;
|
||||
import org.apache.commons.collections.CollectionUtils;
|
||||
|
||||
import java.util.*;
|
||||
import java.util.ArrayList;
|
||||
import java.util.List;
|
||||
import java.util.Objects;
|
||||
import java.util.Set;
|
||||
|
||||
@Slf4j
|
||||
public class FiledFilterReplaceVisitor extends ExpressionVisitorAdapter {
|
||||
|
||||
private List<Expression> waitingForAdds = new ArrayList<>();
|
||||
private Set<String> fieldNames;
|
||||
private Map<String, String> fieldNameMap = new HashMap<>();
|
||||
|
||||
private static Set<String> HAVING_AGG_TYPES = Set.of("SUM", "AVG", "MAX", "MIN", "COUNT");
|
||||
|
||||
public FiledFilterReplaceVisitor(Map<String, String> fieldNameMap) {
|
||||
this.fieldNameMap = fieldNameMap;
|
||||
this.fieldNames = fieldNameMap.keySet();
|
||||
}
|
||||
|
||||
public FiledFilterReplaceVisitor(Set<String> fieldNames) {
|
||||
this.fieldNames = fieldNames;
|
||||
@@ -87,22 +82,7 @@ public class FiledFilterReplaceVisitor extends ExpressionVisitorAdapter {
|
||||
Expression leftExpression = comparisonOperator.getLeftExpression();
|
||||
|
||||
if (!(leftExpression instanceof Function)) {
|
||||
if (leftExpression instanceof Column) {
|
||||
Column leftColumn = (Column) leftExpression;
|
||||
String agg = fieldNameMap.get(leftColumn.getColumnName());
|
||||
if (agg != null && HAVING_AGG_TYPES.contains(agg.toUpperCase())) {
|
||||
Expression expression = parseCondExpression(comparisonOperator, condExpr);
|
||||
if (Objects.nonNull(expression)) {
|
||||
result.add(expression);
|
||||
return result;
|
||||
} else {
|
||||
return null;
|
||||
}
|
||||
}
|
||||
return result;
|
||||
} else {
|
||||
return result;
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
Function leftFunction = (Function) leftExpression;
|
||||
@@ -122,24 +102,14 @@ public class FiledFilterReplaceVisitor extends ExpressionVisitorAdapter {
|
||||
return null;
|
||||
}
|
||||
|
||||
Expression expression = parseCondExpression(comparisonOperator, condExpr);
|
||||
if (Objects.nonNull(expression)) {
|
||||
result.add(expression);
|
||||
return result;
|
||||
} else {
|
||||
return null;
|
||||
}
|
||||
}
|
||||
|
||||
private Expression parseCondExpression(ComparisonOperator comparisonOperator, String condExpr) {
|
||||
try {
|
||||
String comparisonOperatorStr = comparisonOperator.toString();
|
||||
ComparisonOperator parsedExpression =
|
||||
(ComparisonOperator) CCJSqlParserUtil.parseCondExpression(condExpr);
|
||||
comparisonOperator.setLeftExpression(parsedExpression.getLeftExpression());
|
||||
comparisonOperator.setRightExpression(parsedExpression.getRightExpression());
|
||||
comparisonOperator.setASTNode(parsedExpression.getASTNode());
|
||||
return CCJSqlParserUtil.parseCondExpression(comparisonOperatorStr);
|
||||
result.add(CCJSqlParserUtil.parseCondExpression(comparisonOperatorStr));
|
||||
return result;
|
||||
} catch (JSQLParserException e) {
|
||||
log.error("JSQLParserException", e);
|
||||
}
|
||||
|
||||
@@ -309,7 +309,7 @@ public class SqlAddHelper {
|
||||
}
|
||||
}
|
||||
|
||||
public static String addHaving(String sql, Map<String, String> fieldNames) {
|
||||
public static String addHaving(String sql, Set<String> fieldNames) {
|
||||
Select selectStatement = SqlSelectHelper.getSelect(sql);
|
||||
|
||||
if (!(selectStatement instanceof PlainSelect)) {
|
||||
|
||||
@@ -118,26 +118,22 @@ public class SqlReplaceHelper {
|
||||
}
|
||||
|
||||
public static void getFromSelect(FromItem fromItem, List<PlainSelect> plainSelectList) {
|
||||
if (!(fromItem instanceof ParenthesedSelect parenthesedSelect)) {
|
||||
if (!(fromItem instanceof ParenthesedSelect)) {
|
||||
return;
|
||||
}
|
||||
ParenthesedSelect parenthesedSelect = (ParenthesedSelect) fromItem;
|
||||
Select select = parenthesedSelect.getSelect();
|
||||
if (select instanceof PlainSelect plainSelect) {
|
||||
if (select instanceof PlainSelect) {
|
||||
PlainSelect plainSelect = (PlainSelect) select;
|
||||
plainSelectList.add(plainSelect);
|
||||
getFromSelect(plainSelect.getFromItem(), plainSelectList);
|
||||
} else if (select instanceof SetOperationList setOperationList) {
|
||||
} else if (select instanceof SetOperationList) {
|
||||
SetOperationList setOperationList = (SetOperationList) select;
|
||||
if (!CollectionUtils.isEmpty(setOperationList.getSelects())) {
|
||||
setOperationList.getSelects().forEach(subSelectBody -> {
|
||||
if (subSelectBody instanceof PlainSelect subPlainSelect) {
|
||||
plainSelectList.add(subPlainSelect);
|
||||
getFromSelect(subPlainSelect.getFromItem(), plainSelectList);
|
||||
} else if (subSelectBody instanceof ParenthesedSelect subParenthesedSelect) {
|
||||
Select innerSelect = subParenthesedSelect.getSelect();
|
||||
if (innerSelect instanceof PlainSelect innerPlainSelect) {
|
||||
plainSelectList.add(innerPlainSelect);
|
||||
getFromSelect(innerPlainSelect.getFromItem(), plainSelectList);
|
||||
}
|
||||
}
|
||||
PlainSelect subPlainSelect = (PlainSelect) subSelectBody;
|
||||
plainSelectList.add(subPlainSelect);
|
||||
getFromSelect(subPlainSelect.getFromItem(), plainSelectList);
|
||||
});
|
||||
}
|
||||
}
|
||||
@@ -192,13 +188,8 @@ public class SqlReplaceHelper {
|
||||
SetOperationList setOperationList = (SetOperationList) select;
|
||||
if (!CollectionUtils.isEmpty(setOperationList.getSelects())) {
|
||||
setOperationList.getSelects().forEach(subSelectBody -> {
|
||||
if (subSelectBody instanceof PlainSelect) {
|
||||
PlainSelect subPlainSelect = (PlainSelect) subSelectBody;
|
||||
replaceFieldsInPlainOneSelect(fieldNameMap, exactReplace, subPlainSelect);
|
||||
} else if (subSelectBody instanceof ParenthesedSelect) {
|
||||
replaceFieldsInPlainOneSelect(fieldNameMap, exactReplace,
|
||||
((ParenthesedSelect) subSelectBody).getPlainSelect());
|
||||
}
|
||||
PlainSelect subPlainSelect = (PlainSelect) subSelectBody;
|
||||
replaceFieldsInPlainOneSelect(fieldNameMap, exactReplace, subPlainSelect);
|
||||
});
|
||||
}
|
||||
}
|
||||
@@ -736,7 +727,7 @@ public class SqlReplaceHelper {
|
||||
List<PlainSelect> plainSelects = SqlSelectHelper.getPlainSelects(plainSelectList);
|
||||
for (PlainSelect plainSelect : plainSelects) {
|
||||
if (Objects.nonNull(plainSelect.getFromItem())) {
|
||||
Table table = SqlSelectHelper.getTable(plainSelect.getFromItem());
|
||||
Table table = (Table) plainSelect.getFromItem();
|
||||
if (table.getName().equals(tableName)) {
|
||||
replacePlainSelectByExpr(plainSelect, replace);
|
||||
if (SqlSelectHelper.hasAggregateFunction(plainSelect)) {
|
||||
|
||||
@@ -723,44 +723,6 @@ public class SqlSelectHelper {
|
||||
return null;
|
||||
}
|
||||
|
||||
public static Table getTable(FromItem fromItem) {
|
||||
Table table = null;
|
||||
if (fromItem instanceof Table) {
|
||||
table = (Table) fromItem;
|
||||
} else if (fromItem instanceof ParenthesedSelect) {
|
||||
ParenthesedSelect parenthesedSelect = (ParenthesedSelect) fromItem;
|
||||
if (parenthesedSelect.getSelect() instanceof PlainSelect) {
|
||||
PlainSelect subSelect = (PlainSelect) parenthesedSelect.getSelect();
|
||||
table = getTable(subSelect.getSelectBody());
|
||||
} else if (parenthesedSelect.getSelect() instanceof SetOperationList) {
|
||||
table = getTable(parenthesedSelect.getSelect());
|
||||
}
|
||||
}
|
||||
return table;
|
||||
}
|
||||
|
||||
public static Table getTable(Select select) {
|
||||
if (select == null) {
|
||||
return null;
|
||||
}
|
||||
List<PlainSelect> plainSelectList = getWithItem(select);
|
||||
if (!CollectionUtils.isEmpty(plainSelectList)) {
|
||||
List<PlainSelect> selectList = new ArrayList<>(plainSelectList);
|
||||
Table table = getTable(selectList.get(0));
|
||||
return table;
|
||||
}
|
||||
if (select instanceof PlainSelect) {
|
||||
PlainSelect plainSelect = (PlainSelect) select;
|
||||
return getTable(plainSelect.getFromItem());
|
||||
} else if (select instanceof SetOperationList) {
|
||||
SetOperationList setOperationList = (SetOperationList) select;
|
||||
if (!CollectionUtils.isEmpty(setOperationList.getSelects())) {
|
||||
return getTable(setOperationList.getSelects().get(0));
|
||||
}
|
||||
}
|
||||
return null;
|
||||
}
|
||||
|
||||
public static String getDbTableName(String sql) {
|
||||
Table table = getTable(sql);
|
||||
return table.getFullyQualifiedName();
|
||||
|
||||
@@ -28,8 +28,6 @@ public class ChatModelConfig implements Serializable {
|
||||
private Boolean logRequests = false;
|
||||
private Boolean logResponses = false;
|
||||
private Boolean enableSearch = false;
|
||||
private Boolean jsonFormat = false;
|
||||
private String jsonFormatType = "json_schema";
|
||||
|
||||
public String keyDecrypt() {
|
||||
return AESEncryptionUtil.aesDecryptECB(getApiKey());
|
||||
|
||||
@@ -2,7 +2,15 @@ package com.tencent.supersonic.common.pojo;
|
||||
|
||||
import com.google.common.collect.ImmutableMap;
|
||||
import com.google.common.collect.Lists;
|
||||
import dev.langchain4j.provider.*;
|
||||
import dev.langchain4j.provider.AzureModelFactory;
|
||||
import dev.langchain4j.provider.DashscopeModelFactory;
|
||||
import dev.langchain4j.provider.DifyModelFactory;
|
||||
import dev.langchain4j.provider.LocalAiModelFactory;
|
||||
import dev.langchain4j.provider.ModelProvider;
|
||||
import dev.langchain4j.provider.OllamaModelFactory;
|
||||
import dev.langchain4j.provider.OpenAiModelFactory;
|
||||
import dev.langchain4j.provider.QianfanModelFactory;
|
||||
import dev.langchain4j.provider.ZhipuModelFactory;
|
||||
|
||||
import java.util.ArrayList;
|
||||
import java.util.List;
|
||||
@@ -13,7 +21,7 @@ public class ChatModelParameters {
|
||||
|
||||
public static final Parameter CHAT_MODEL_PROVIDER =
|
||||
new Parameter("provider", ModelProvider.DEMO_CHAT_MODEL.getProvider(), "接口协议", "",
|
||||
"list", MODULE_NAME, getCandidateProviders());
|
||||
"list", MODULE_NAME, getCandidateValues());
|
||||
|
||||
public static final Parameter CHAT_MODEL_BASE_URL =
|
||||
new Parameter("baseUrl", ModelProvider.DEMO_CHAT_MODEL.getBaseUrl(), "BaseUrl", "",
|
||||
@@ -29,6 +37,15 @@ public class ChatModelParameters {
|
||||
public static final Parameter CHAT_MODEL_API_VERSION = new Parameter("apiVersion", "2024-02-01",
|
||||
"ApiVersion", "", "string", MODULE_NAME, null, getApiVersionDependency());
|
||||
|
||||
public static final Parameter CHAT_MODEL_ENDPOINT = new Parameter("endpoint", "llama_2_70b",
|
||||
"Endpoint", "", "string", MODULE_NAME, null, getEndpointDependency());
|
||||
|
||||
public static final Parameter CHAT_MODEL_SECRET_KEY = new Parameter("secretKey", "demo",
|
||||
"SecretKey", "", "password", MODULE_NAME, null, getSecretKeyDependency());
|
||||
|
||||
public static final Parameter CHAT_MODEL_ENABLE_SEARCH = new Parameter("enableSearch", "false",
|
||||
"是否启用搜索增强功能,设为false表示不启用", "", "bool", MODULE_NAME, null, getEnableSearchDependency());
|
||||
|
||||
public static final Parameter CHAT_MODEL_TEMPERATURE =
|
||||
new Parameter("temperature", "0.0", "Temperature", "", "slider", MODULE_NAME);
|
||||
|
||||
@@ -36,27 +53,42 @@ public class ChatModelParameters {
|
||||
new Parameter("timeOut", "60", "超时时间(秒)", "", "number", MODULE_NAME);
|
||||
|
||||
public static List<Parameter> getParameters() {
|
||||
return Lists.newArrayList(CHAT_MODEL_PROVIDER, CHAT_MODEL_BASE_URL, CHAT_MODEL_API_KEY,
|
||||
CHAT_MODEL_NAME, CHAT_MODEL_API_VERSION, CHAT_MODEL_TEMPERATURE,
|
||||
CHAT_MODEL_TIMEOUT);
|
||||
return Lists.newArrayList(CHAT_MODEL_PROVIDER, CHAT_MODEL_BASE_URL, CHAT_MODEL_ENDPOINT,
|
||||
CHAT_MODEL_API_KEY, CHAT_MODEL_SECRET_KEY, CHAT_MODEL_NAME, CHAT_MODEL_API_VERSION,
|
||||
CHAT_MODEL_ENABLE_SEARCH, CHAT_MODEL_TEMPERATURE, CHAT_MODEL_TIMEOUT);
|
||||
}
|
||||
|
||||
private static List<String> getCandidateProviders() {
|
||||
private static List<String> getCandidateValues() {
|
||||
return Lists.newArrayList(OpenAiModelFactory.PROVIDER, OllamaModelFactory.PROVIDER,
|
||||
DifyModelFactory.PROVIDER);
|
||||
QianfanModelFactory.PROVIDER, ZhipuModelFactory.PROVIDER,
|
||||
LocalAiModelFactory.PROVIDER, DashscopeModelFactory.PROVIDER,
|
||||
AzureModelFactory.PROVIDER, DifyModelFactory.PROVIDER);
|
||||
}
|
||||
|
||||
private static List<Parameter.Dependency> getBaseUrlDependency() {
|
||||
return getDependency(CHAT_MODEL_PROVIDER.getName(), getCandidateProviders(),
|
||||
return getDependency(CHAT_MODEL_PROVIDER.getName(), getCandidateValues(),
|
||||
ImmutableMap.of(OpenAiModelFactory.PROVIDER, OpenAiModelFactory.DEFAULT_BASE_URL,
|
||||
AzureModelFactory.PROVIDER, AzureModelFactory.DEFAULT_BASE_URL,
|
||||
OllamaModelFactory.PROVIDER, OllamaModelFactory.DEFAULT_BASE_URL,
|
||||
QianfanModelFactory.PROVIDER, QianfanModelFactory.DEFAULT_BASE_URL,
|
||||
ZhipuModelFactory.PROVIDER, ZhipuModelFactory.DEFAULT_BASE_URL,
|
||||
LocalAiModelFactory.PROVIDER, LocalAiModelFactory.DEFAULT_BASE_URL,
|
||||
DashscopeModelFactory.PROVIDER, DashscopeModelFactory.DEFAULT_BASE_URL,
|
||||
DifyModelFactory.PROVIDER, DifyModelFactory.DEFAULT_BASE_URL));
|
||||
}
|
||||
|
||||
private static List<Parameter.Dependency> getApiKeyDependency() {
|
||||
return getDependency(CHAT_MODEL_PROVIDER.getName(),
|
||||
Lists.newArrayList(OpenAiModelFactory.PROVIDER, DifyModelFactory.PROVIDER),
|
||||
Lists.newArrayList(OpenAiModelFactory.PROVIDER, QianfanModelFactory.PROVIDER,
|
||||
ZhipuModelFactory.PROVIDER, LocalAiModelFactory.PROVIDER,
|
||||
AzureModelFactory.PROVIDER, DashscopeModelFactory.PROVIDER,
|
||||
DifyModelFactory.PROVIDER),
|
||||
ImmutableMap.of(OpenAiModelFactory.PROVIDER,
|
||||
ModelProvider.DEMO_CHAT_MODEL.getApiKey(), QianfanModelFactory.PROVIDER,
|
||||
ModelProvider.DEMO_CHAT_MODEL.getApiKey(), ZhipuModelFactory.PROVIDER,
|
||||
ModelProvider.DEMO_CHAT_MODEL.getApiKey(), LocalAiModelFactory.PROVIDER,
|
||||
ModelProvider.DEMO_CHAT_MODEL.getApiKey(), AzureModelFactory.PROVIDER,
|
||||
ModelProvider.DEMO_CHAT_MODEL.getApiKey(), DashscopeModelFactory.PROVIDER,
|
||||
ModelProvider.DEMO_CHAT_MODEL.getApiKey(), DifyModelFactory.PROVIDER,
|
||||
ModelProvider.DEMO_CHAT_MODEL.getApiKey()));
|
||||
}
|
||||
@@ -68,28 +100,33 @@ public class ChatModelParameters {
|
||||
}
|
||||
|
||||
private static List<Parameter.Dependency> getModelNameDependency() {
|
||||
return getDependency(CHAT_MODEL_PROVIDER.getName(), getCandidateProviders(),
|
||||
return getDependency(CHAT_MODEL_PROVIDER.getName(), getCandidateValues(),
|
||||
ImmutableMap.of(OpenAiModelFactory.PROVIDER, OpenAiModelFactory.DEFAULT_MODEL_NAME,
|
||||
OllamaModelFactory.PROVIDER, OllamaModelFactory.DEFAULT_MODEL_NAME,
|
||||
QianfanModelFactory.PROVIDER, QianfanModelFactory.DEFAULT_MODEL_NAME,
|
||||
ZhipuModelFactory.PROVIDER, ZhipuModelFactory.DEFAULT_MODEL_NAME,
|
||||
LocalAiModelFactory.PROVIDER, LocalAiModelFactory.DEFAULT_MODEL_NAME,
|
||||
AzureModelFactory.PROVIDER, AzureModelFactory.DEFAULT_MODEL_NAME,
|
||||
DashscopeModelFactory.PROVIDER, DashscopeModelFactory.DEFAULT_MODEL_NAME,
|
||||
DifyModelFactory.PROVIDER, DifyModelFactory.DEFAULT_MODEL_NAME));
|
||||
}
|
||||
|
||||
private static List<Parameter.Dependency> getEndpointDependency() {
|
||||
return getDependency(CHAT_MODEL_PROVIDER.getName(),
|
||||
Lists.newArrayList(OpenAiModelFactory.PROVIDER), ImmutableMap
|
||||
.of(OpenAiModelFactory.PROVIDER, OpenAiModelFactory.DEFAULT_MODEL_NAME));
|
||||
Lists.newArrayList(QianfanModelFactory.PROVIDER), ImmutableMap
|
||||
.of(QianfanModelFactory.PROVIDER, QianfanModelFactory.DEFAULT_ENDPOINT));
|
||||
}
|
||||
|
||||
private static List<Parameter.Dependency> getEnableSearchDependency() {
|
||||
return getDependency(CHAT_MODEL_PROVIDER.getName(),
|
||||
Lists.newArrayList(OpenAiModelFactory.PROVIDER),
|
||||
ImmutableMap.of(OpenAiModelFactory.PROVIDER, "false"));
|
||||
Lists.newArrayList(DashscopeModelFactory.PROVIDER),
|
||||
ImmutableMap.of(DashscopeModelFactory.PROVIDER, "false"));
|
||||
}
|
||||
|
||||
private static List<Parameter.Dependency> getSecretKeyDependency() {
|
||||
return getDependency(CHAT_MODEL_PROVIDER.getName(),
|
||||
Lists.newArrayList(OpenAiModelFactory.PROVIDER), ImmutableMap.of(
|
||||
OpenAiModelFactory.PROVIDER, ModelProvider.DEMO_CHAT_MODEL.getApiKey()));
|
||||
Lists.newArrayList(QianfanModelFactory.PROVIDER), ImmutableMap.of(
|
||||
QianfanModelFactory.PROVIDER, ModelProvider.DEMO_CHAT_MODEL.getApiKey()));
|
||||
}
|
||||
|
||||
private static List<Parameter.Dependency> getDependency(String dependencyParameterName,
|
||||
|
||||
@@ -1,26 +1,27 @@
|
||||
package com.tencent.supersonic.common.pojo;
|
||||
|
||||
import com.tencent.supersonic.common.pojo.enums.EventType;
|
||||
import lombok.Getter;
|
||||
import org.springframework.context.ApplicationEvent;
|
||||
|
||||
import java.util.List;
|
||||
|
||||
@Getter
|
||||
public class DataEvent extends ApplicationEvent {
|
||||
|
||||
private final List<DataItem> dataItems;
|
||||
private List<DataItem> dataItems;
|
||||
|
||||
private final EventType eventType;
|
||||
private EventType eventType;
|
||||
|
||||
private final String userName;
|
||||
|
||||
public DataEvent(Object source, List<DataItem> dataItems, EventType eventType,
|
||||
String userName) {
|
||||
public DataEvent(Object source, List<DataItem> dataItems, EventType eventType) {
|
||||
super(source);
|
||||
this.dataItems = dataItems;
|
||||
this.eventType = eventType;
|
||||
this.userName = userName;
|
||||
}
|
||||
|
||||
public List<DataItem> getDataItems() {
|
||||
return dataItems;
|
||||
}
|
||||
|
||||
public EventType getEventType() {
|
||||
return eventType;
|
||||
}
|
||||
}
|
||||
|
||||
@@ -5,6 +5,4 @@ public class DimensionConstants {
|
||||
public static final String DIMENSION_TIME_FORMAT = "time_format";
|
||||
|
||||
public static final String DIMENSION_TYPE = "dimension_type";
|
||||
|
||||
public static final String DIMENSION_DATA_TYPE = "dimension_data_type";
|
||||
}
|
||||
|
||||
@@ -22,6 +22,4 @@ public class Text2SQLExemplar implements Serializable {
|
||||
private String dbSchema;
|
||||
|
||||
private String sql;
|
||||
|
||||
protected double similarity; // 传递相似度,可以作为样本筛选的依据
|
||||
}
|
||||
|
||||
@@ -6,7 +6,6 @@ import lombok.NoArgsConstructor;
|
||||
import org.apache.commons.lang3.StringUtils;
|
||||
|
||||
import java.io.Serializable;
|
||||
import java.sql.Timestamp;
|
||||
|
||||
@Data
|
||||
@NoArgsConstructor
|
||||
@@ -23,28 +22,26 @@ public class User implements Serializable {
|
||||
|
||||
private Integer isAdmin;
|
||||
|
||||
private Timestamp lastLogin;
|
||||
|
||||
public static User get(Long id, String name, String displayName, String email,
|
||||
Integer isAdmin) {
|
||||
return new User(id, name, displayName, email, isAdmin, null);
|
||||
return new User(id, name, displayName, email, isAdmin);
|
||||
}
|
||||
|
||||
public static User get(Long id, String name) {
|
||||
return new User(id, name, name, name, 0, null);
|
||||
return new User(id, name, name, name, 0);
|
||||
}
|
||||
|
||||
public static User getDefaultUser() {
|
||||
return new User(1L, "admin", "admin", "admin@email", 1, null);
|
||||
return new User(1L, "admin", "admin", "admin@email", 1);
|
||||
}
|
||||
|
||||
public static User getVisitUser() {
|
||||
return new User(1L, "visit", "visit", "visit@email", 0, null);
|
||||
return new User(1L, "visit", "visit", "visit@email", 0);
|
||||
}
|
||||
|
||||
public static User getAppUser(int appId) {
|
||||
String name = String.format("app_%s", appId);
|
||||
return new User(1L, name, name, "", 1, null);
|
||||
return new User(1L, name, name, "", 1);
|
||||
}
|
||||
|
||||
public String getDisplayName() {
|
||||
|
||||
@@ -13,8 +13,7 @@ public enum EngineType {
|
||||
STARROCKS(10, "STARROCKS"),
|
||||
KYUUBI(11, "KYUUBI"),
|
||||
PRESTO(12, "PRESTO"),
|
||||
TRINO(13, "TRINO"),
|
||||
ORACLE(14, "ORACLE");
|
||||
TRINO(13, "TRINO"),;
|
||||
|
||||
private Integer code;
|
||||
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
package com.tencent.supersonic.common.pojo.enums;
|
||||
|
||||
public enum TypeEnums {
|
||||
METRIC, DIMENSION, VALUE, TAG, DOMAIN, DATASET, MODEL, UNKNOWN
|
||||
METRIC, DIMENSION, TAG, DOMAIN, DATASET, MODEL, UNKNOWN
|
||||
}
|
||||
|
||||
@@ -15,5 +15,5 @@ public interface ChatModelService {
|
||||
|
||||
ChatModel updateChatModel(ChatModel chatModel, User user);
|
||||
|
||||
void deleteChatModel(Integer id, User user);
|
||||
void deleteChatModel(Integer id);
|
||||
}
|
||||
|
||||
@@ -14,7 +14,6 @@ import org.apache.commons.lang3.StringUtils;
|
||||
import org.springframework.beans.BeanUtils;
|
||||
import org.springframework.stereotype.Service;
|
||||
|
||||
import java.util.Comparator;
|
||||
import java.util.Date;
|
||||
import java.util.List;
|
||||
import java.util.stream.Collectors;
|
||||
@@ -32,7 +31,7 @@ public class ChatModelServiceImpl extends ServiceImpl<ChatModelMapper, ChatModel
|
||||
return true;
|
||||
}
|
||||
return false;
|
||||
}).sorted(Comparator.comparingLong(ChatModel::getId)).collect(Collectors.toList());
|
||||
}).collect(Collectors.toList());
|
||||
}
|
||||
|
||||
@Override
|
||||
@@ -79,12 +78,7 @@ public class ChatModelServiceImpl extends ServiceImpl<ChatModelMapper, ChatModel
|
||||
}
|
||||
|
||||
@Override
|
||||
public void deleteChatModel(Integer id, User user) {
|
||||
ChatModel chatModel = getChatModel(id);
|
||||
if (!checkAdminPermission(user, chatModel)) {
|
||||
throw new RuntimeException("没有权限删除该大模型");
|
||||
}
|
||||
|
||||
public void deleteChatModel(Integer id) {
|
||||
removeById(id);
|
||||
}
|
||||
|
||||
@@ -108,13 +102,4 @@ public class ChatModelServiceImpl extends ServiceImpl<ChatModelMapper, ChatModel
|
||||
chatModelDO.setConfig(JsonUtil.toString(chatModel.getConfig()));
|
||||
return chatModelDO;
|
||||
}
|
||||
|
||||
private boolean checkAdminPermission(User user, ChatModel chatModel) {
|
||||
String admin = chatModel.getAdmin();
|
||||
if (user.isSuperAdmin()) {
|
||||
return true;
|
||||
}
|
||||
return admin != null && admin.equals(user.getName())
|
||||
|| chatModel.getCreatedBy().equals(user.getName());
|
||||
}
|
||||
}
|
||||
|
||||
@@ -49,10 +49,10 @@ public class EmbeddingServiceImpl implements EmbeddingService {
|
||||
try {
|
||||
EmbeddingModel embeddingModel = ModelProvider.getEmbeddingModel();
|
||||
Embedding embedding = embeddingModel.embed(question).content();
|
||||
MetadataFilterBuilder filterBuilder =
|
||||
new MetadataFilterBuilder(TextSegmentConvert.QUERY_ID);
|
||||
Filter filter = filterBuilder.isEqualTo(TextSegmentConvert.getQueryId(query));
|
||||
embeddingStore.removeAll(filter);
|
||||
boolean existSegment = existSegment(embeddingStore, query, embedding);
|
||||
if (existSegment) {
|
||||
continue;
|
||||
}
|
||||
embeddingStore.add(embedding, query);
|
||||
cache.put(TextSegmentConvert.getQueryId(query), true);
|
||||
} catch (Exception e) {
|
||||
@@ -62,14 +62,14 @@ public class EmbeddingServiceImpl implements EmbeddingService {
|
||||
}
|
||||
}
|
||||
|
||||
private boolean existSegment(String collectionName, EmbeddingStore embeddingStore,
|
||||
TextSegment query, Embedding embedding) {
|
||||
private boolean existSegment(EmbeddingStore embeddingStore, TextSegment query,
|
||||
Embedding embedding) {
|
||||
String queryId = TextSegmentConvert.getQueryId(query);
|
||||
if (queryId == null) {
|
||||
return false;
|
||||
}
|
||||
// Check cache first
|
||||
Boolean cachedResult = cache.getIfPresent(collectionName + queryId);
|
||||
Boolean cachedResult = cache.getIfPresent(queryId);
|
||||
if (cachedResult != null) {
|
||||
return cachedResult;
|
||||
}
|
||||
@@ -82,7 +82,7 @@ public class EmbeddingServiceImpl implements EmbeddingService {
|
||||
EmbeddingSearchResult result = embeddingStore.search(request);
|
||||
List<EmbeddingMatch<TextSegment>> relevant = result.matches();
|
||||
boolean exists = CollectionUtils.isNotEmpty(relevant);
|
||||
cache.put(collectionName + queryId, exists);
|
||||
cache.put(queryId, exists);
|
||||
return exists;
|
||||
}
|
||||
|
||||
|
||||
@@ -72,10 +72,7 @@ public class ExemplarServiceImpl implements ExemplarService, CommandLineRunner {
|
||||
embeddingService.retrieveQuery(collection, retrieveQuery, num);
|
||||
results.forEach(ret -> {
|
||||
ret.getRetrieval().forEach(r -> {
|
||||
Text2SQLExemplar tmp = // 传递相似度,可以作为样本筛选的依据
|
||||
JsonUtil.mapToObject(r.getMetadata(), Text2SQLExemplar.class);
|
||||
tmp.setSimilarity(r.getSimilarity());
|
||||
exemplars.add(tmp);
|
||||
exemplars.add(JsonUtil.mapToObject(r.getMetadata(), Text2SQLExemplar.class));
|
||||
});
|
||||
});
|
||||
|
||||
|
||||
@@ -16,7 +16,6 @@ import org.springframework.util.CollectionUtils;
|
||||
|
||||
import java.util.List;
|
||||
import java.util.concurrent.atomic.AtomicReference;
|
||||
import java.util.stream.Collectors;
|
||||
|
||||
@Service
|
||||
public class SystemConfigServiceImpl extends ServiceImpl<SystemConfigMapper, SystemConfigDO>
|
||||
@@ -39,8 +38,8 @@ public class SystemConfigServiceImpl extends ServiceImpl<SystemConfigMapper, Sys
|
||||
return systemConfigDb;
|
||||
}
|
||||
|
||||
private SystemConfig getSystemConfigFromDB() { // 加上id ,如果有多条记录,会出错
|
||||
List<SystemConfigDO> list = this.lambdaQuery().eq(SystemConfigDO::getId, 1).list();
|
||||
private SystemConfig getSystemConfigFromDB() {
|
||||
List<SystemConfigDO> list = list();
|
||||
if (CollectionUtils.isEmpty(list)) {
|
||||
SystemConfig systemConfig = new SystemConfig();
|
||||
systemConfig.setId(1);
|
||||
|
||||
@@ -242,8 +242,10 @@ public class DateModeUtils {
|
||||
return String.format("%s >= '%s' and %s <= '%s'", dateField,
|
||||
dateInfo.getStartDate(), dateField, dateInfo.getEndDate());
|
||||
}
|
||||
LocalDate endData = DateUtils.parseDate(dateInfo.getEndDate());
|
||||
LocalDate startData = DateUtils.parseDate(dateInfo.getStartDate());
|
||||
LocalDate endData =
|
||||
LocalDate.parse(dateInfo.getEndDate(), DateTimeFormatter.ofPattern(DAY_FORMAT));
|
||||
LocalDate startData = LocalDate.parse(dateInfo.getStartDate(),
|
||||
DateTimeFormatter.ofPattern(DAY_FORMAT));
|
||||
DateTimeFormatter formatter = DateTimeFormatter.ofPattern(MONTH_FORMAT);
|
||||
return String.format("%s >= '%s' and %s <= '%s'", dateField,
|
||||
startData.format(formatter), dateField, endData.format(formatter));
|
||||
@@ -318,7 +320,7 @@ public class DateModeUtils {
|
||||
}
|
||||
|
||||
public String getDateWhereStr(DateConf dateInfo, ItemDateResp dateDate) {
|
||||
if (Objects.isNull(dateInfo) || Objects.isNull(dateInfo.getDateField())) {
|
||||
if (Objects.isNull(dateInfo)) {
|
||||
return "";
|
||||
}
|
||||
String dateStr = "";
|
||||
|
||||
@@ -75,7 +75,7 @@ public class DateUtils {
|
||||
}
|
||||
|
||||
public static String getBeforeDate(String currentDate, DatePeriodEnum datePeriodEnum) {
|
||||
LocalDate specifiedDate = parseDate(currentDate);
|
||||
LocalDate specifiedDate = LocalDate.parse(currentDate, DEFAULT_DATE_FORMATTER2);
|
||||
LocalDate startDate;
|
||||
switch (datePeriodEnum) {
|
||||
case MONTH:
|
||||
@@ -93,7 +93,7 @@ public class DateUtils {
|
||||
|
||||
public static String getBeforeDate(String currentDate, int intervalDay,
|
||||
DatePeriodEnum datePeriodEnum) {
|
||||
LocalDate specifiedDate = parseDate(currentDate);
|
||||
LocalDate specifiedDate = LocalDate.parse(currentDate, DEFAULT_DATE_FORMATTER2);
|
||||
LocalDate result = null;
|
||||
switch (datePeriodEnum) {
|
||||
case DAY:
|
||||
@@ -161,25 +161,11 @@ public class DateUtils {
|
||||
return !timeString.equals("00:00:00");
|
||||
}
|
||||
|
||||
public static LocalDate parseDate(String timeString) {
|
||||
DateTimeFormatter[] dateFormatters =
|
||||
{DateTimeFormatter.ofPattern("yyyyMMdd"), DateTimeFormatter.ofPattern("yyyy-MM-dd"),
|
||||
DateTimeFormatter.ofPattern("yyyy/MM/dd"),
|
||||
DateTimeFormatter.ofPattern("yyyy-MM")};
|
||||
for (DateTimeFormatter formatter : dateFormatters) {
|
||||
try {
|
||||
return LocalDate.parse(timeString, formatter);
|
||||
} catch (DateTimeParseException ignored) {
|
||||
}
|
||||
}
|
||||
return null;
|
||||
}
|
||||
|
||||
public static List<String> getDateList(String startDateStr, String endDateStr,
|
||||
DatePeriodEnum period) {
|
||||
try {
|
||||
LocalDate startDate = parseDate(startDateStr);
|
||||
LocalDate endDate = parseDate(endDateStr);
|
||||
LocalDate startDate = LocalDate.parse(startDateStr);
|
||||
LocalDate endDate = LocalDate.parse(endDateStr);
|
||||
List<String> datesInRange = new ArrayList<>();
|
||||
LocalDate currentDate = startDate;
|
||||
DateTimeFormatter formatter = DateTimeFormatter.ofPattern("yyyy-MM");
|
||||
@@ -203,7 +189,7 @@ public class DateUtils {
|
||||
}
|
||||
|
||||
public static boolean isAnyDateString(String value) {
|
||||
List<String> formats = Arrays.asList("yyyy-MM-dd", "yyyy-MM", "yyyy/MM/dd", "yyyyMMdd");
|
||||
List<String> formats = Arrays.asList("yyyy-MM-dd", "yyyy-MM", "yyyy/MM/dd");
|
||||
return isAnyDateString(value, formats);
|
||||
}
|
||||
|
||||
|
||||
@@ -0,0 +1,23 @@
|
||||
package dev.langchain4j.dashscope.spring;
|
||||
|
||||
import lombok.Getter;
|
||||
import lombok.Setter;
|
||||
|
||||
import java.util.List;
|
||||
|
||||
@Getter
|
||||
@Setter
|
||||
class ChatModelProperties {
|
||||
|
||||
String baseUrl;
|
||||
String apiKey;
|
||||
String modelName;
|
||||
Double topP;
|
||||
Integer topK;
|
||||
Boolean enableSearch;
|
||||
Integer seed;
|
||||
Float repetitionPenalty;
|
||||
Float temperature;
|
||||
List<String> stops;
|
||||
Integer maxTokens;
|
||||
}
|
||||
@@ -0,0 +1,84 @@
|
||||
package dev.langchain4j.dashscope.spring;
|
||||
|
||||
import dev.langchain4j.model.dashscope.QwenChatModel;
|
||||
import dev.langchain4j.model.dashscope.QwenEmbeddingModel;
|
||||
import dev.langchain4j.model.dashscope.QwenLanguageModel;
|
||||
import dev.langchain4j.model.dashscope.QwenStreamingChatModel;
|
||||
import dev.langchain4j.model.dashscope.QwenStreamingLanguageModel;
|
||||
import org.springframework.boot.autoconfigure.condition.ConditionalOnProperty;
|
||||
import org.springframework.boot.context.properties.EnableConfigurationProperties;
|
||||
import org.springframework.context.annotation.Bean;
|
||||
import org.springframework.context.annotation.Configuration;
|
||||
|
||||
import static dev.langchain4j.dashscope.spring.Properties.PREFIX;
|
||||
|
||||
@Configuration
|
||||
@EnableConfigurationProperties(Properties.class)
|
||||
public class DashscopeAutoConfig {
|
||||
|
||||
@Bean
|
||||
@ConditionalOnProperty(PREFIX + ".chat-model.api-key")
|
||||
QwenChatModel qwenChatModel(Properties properties) {
|
||||
ChatModelProperties chatModelProperties = properties.getChatModel();
|
||||
return QwenChatModel.builder().baseUrl(chatModelProperties.getBaseUrl())
|
||||
.apiKey(chatModelProperties.getApiKey())
|
||||
.modelName(chatModelProperties.getModelName()).topP(chatModelProperties.getTopP())
|
||||
.topK(chatModelProperties.getTopK())
|
||||
.enableSearch(chatModelProperties.getEnableSearch())
|
||||
.seed(chatModelProperties.getSeed())
|
||||
.repetitionPenalty(chatModelProperties.getRepetitionPenalty())
|
||||
.temperature(chatModelProperties.getTemperature())
|
||||
.stops(chatModelProperties.getStops()).maxTokens(chatModelProperties.getMaxTokens())
|
||||
.build();
|
||||
}
|
||||
|
||||
@Bean
|
||||
@ConditionalOnProperty(PREFIX + ".streaming-chat-model.api-key")
|
||||
QwenStreamingChatModel qwenStreamingChatModel(Properties properties) {
|
||||
ChatModelProperties chatModelProperties = properties.getStreamingChatModel();
|
||||
return QwenStreamingChatModel.builder().baseUrl(chatModelProperties.getBaseUrl())
|
||||
.apiKey(chatModelProperties.getApiKey())
|
||||
.modelName(chatModelProperties.getModelName()).topP(chatModelProperties.getTopP())
|
||||
.topK(chatModelProperties.getTopK())
|
||||
.enableSearch(chatModelProperties.getEnableSearch())
|
||||
.seed(chatModelProperties.getSeed())
|
||||
.repetitionPenalty(chatModelProperties.getRepetitionPenalty())
|
||||
.temperature(chatModelProperties.getTemperature())
|
||||
.stops(chatModelProperties.getStops()).maxTokens(chatModelProperties.getMaxTokens())
|
||||
.build();
|
||||
}
|
||||
|
||||
@Bean
|
||||
@ConditionalOnProperty(PREFIX + ".language-model.api-key")
|
||||
QwenLanguageModel qwenLanguageModel(Properties properties) {
|
||||
ChatModelProperties languageModel = properties.getLanguageModel();
|
||||
return QwenLanguageModel.builder().baseUrl(languageModel.getBaseUrl())
|
||||
.apiKey(languageModel.getApiKey()).modelName(languageModel.getModelName())
|
||||
.topP(languageModel.getTopP()).topK(languageModel.getTopK())
|
||||
.enableSearch(languageModel.getEnableSearch()).seed(languageModel.getSeed())
|
||||
.repetitionPenalty(languageModel.getRepetitionPenalty())
|
||||
.temperature(languageModel.getTemperature()).stops(languageModel.getStops())
|
||||
.maxTokens(languageModel.getMaxTokens()).build();
|
||||
}
|
||||
|
||||
@Bean
|
||||
@ConditionalOnProperty(PREFIX + ".streaming-language-model.api-key")
|
||||
QwenStreamingLanguageModel qwenStreamingLanguageModel(Properties properties) {
|
||||
ChatModelProperties languageModel = properties.getStreamingLanguageModel();
|
||||
return QwenStreamingLanguageModel.builder().baseUrl(languageModel.getBaseUrl())
|
||||
.apiKey(languageModel.getApiKey()).modelName(languageModel.getModelName())
|
||||
.topP(languageModel.getTopP()).topK(languageModel.getTopK())
|
||||
.enableSearch(languageModel.getEnableSearch()).seed(languageModel.getSeed())
|
||||
.repetitionPenalty(languageModel.getRepetitionPenalty())
|
||||
.temperature(languageModel.getTemperature()).stops(languageModel.getStops())
|
||||
.maxTokens(languageModel.getMaxTokens()).build();
|
||||
}
|
||||
|
||||
@Bean
|
||||
@ConditionalOnProperty(PREFIX + ".embedding-model.api-key")
|
||||
QwenEmbeddingModel qwenEmbeddingModel(Properties properties) {
|
||||
EmbeddingModelProperties embeddingModelProperties = properties.getEmbeddingModel();
|
||||
return QwenEmbeddingModel.builder().apiKey(embeddingModelProperties.getApiKey())
|
||||
.modelName(embeddingModelProperties.getModelName()).build();
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,12 @@
|
||||
package dev.langchain4j.dashscope.spring;
|
||||
|
||||
import lombok.Getter;
|
||||
import lombok.Setter;
|
||||
|
||||
@Getter
|
||||
@Setter
|
||||
class EmbeddingModelProperties {
|
||||
|
||||
private String apiKey;
|
||||
private String modelName;
|
||||
}
|
||||
@@ -0,0 +1,29 @@
|
||||
package dev.langchain4j.dashscope.spring;
|
||||
|
||||
import lombok.Getter;
|
||||
import lombok.Setter;
|
||||
import org.springframework.boot.context.properties.ConfigurationProperties;
|
||||
import org.springframework.boot.context.properties.NestedConfigurationProperty;
|
||||
|
||||
@Getter
|
||||
@Setter
|
||||
@ConfigurationProperties(prefix = Properties.PREFIX)
|
||||
public class Properties {
|
||||
|
||||
static final String PREFIX = "langchain4j.dashscope";
|
||||
|
||||
@NestedConfigurationProperty
|
||||
ChatModelProperties chatModel;
|
||||
|
||||
@NestedConfigurationProperty
|
||||
ChatModelProperties streamingChatModel;
|
||||
|
||||
@NestedConfigurationProperty
|
||||
ChatModelProperties languageModel;
|
||||
|
||||
@NestedConfigurationProperty
|
||||
ChatModelProperties streamingLanguageModel;
|
||||
|
||||
@NestedConfigurationProperty
|
||||
EmbeddingModelProperties embeddingModel;
|
||||
}
|
||||
@@ -1,9 +1,9 @@
|
||||
package dev.langchain4j.inmemory.spring;
|
||||
|
||||
import dev.langchain4j.model.embedding.AllMiniLmL6V2QuantizedEmbeddingModel;
|
||||
import dev.langchain4j.model.embedding.BgeSmallZhEmbeddingModel;
|
||||
import dev.langchain4j.model.embedding.EmbeddingModel;
|
||||
import dev.langchain4j.model.embedding.S2OnnxEmbeddingModel;
|
||||
import dev.langchain4j.model.embedding.onnx.allminilml6v2q.AllMiniLmL6V2QuantizedEmbeddingModel;
|
||||
import dev.langchain4j.model.embedding.onnx.bgesmallzh.BgeSmallZhEmbeddingModel;
|
||||
import dev.langchain4j.provider.EmbeddingModelConstant;
|
||||
import dev.langchain4j.store.embedding.EmbeddingStoreFactory;
|
||||
import org.apache.commons.lang3.StringUtils;
|
||||
|
||||
@@ -9,7 +9,6 @@ import dev.langchain4j.data.message.ChatMessage;
|
||||
import dev.langchain4j.model.chat.ChatLanguageModel;
|
||||
import dev.langchain4j.model.output.Response;
|
||||
import lombok.Builder;
|
||||
import lombok.Setter;
|
||||
|
||||
import java.util.List;
|
||||
|
||||
@@ -33,7 +32,6 @@ public class DifyAiChatModel implements ChatLanguageModel {
|
||||
private final Double temperature;
|
||||
private final Long timeOut;
|
||||
|
||||
@Setter
|
||||
private String userName;
|
||||
|
||||
@Builder
|
||||
@@ -56,7 +54,7 @@ public class DifyAiChatModel implements ChatLanguageModel {
|
||||
@Override
|
||||
public String generate(String message) {
|
||||
DifyResult difyResult = this.difyClient.generate(message, this.getUserName());
|
||||
return difyResult.getAnswer();
|
||||
return difyResult.getAnswer().toString();
|
||||
}
|
||||
|
||||
@Override
|
||||
@@ -69,7 +67,7 @@ public class DifyAiChatModel implements ChatLanguageModel {
|
||||
List<ToolSpecification> toolSpecifications) {
|
||||
ensureNotEmpty(messages, "messages");
|
||||
DifyResult difyResult =
|
||||
this.difyClient.generate(messages.get(0).toString(), this.getUserName());
|
||||
this.difyClient.generate(messages.get(0).text(), this.getUserName());
|
||||
System.out.println(difyResult.toString());
|
||||
|
||||
if (!isNullOrEmpty(toolSpecifications)) {
|
||||
@@ -86,8 +84,12 @@ public class DifyAiChatModel implements ChatLanguageModel {
|
||||
toolSpecification != null ? singletonList(toolSpecification) : null);
|
||||
}
|
||||
|
||||
public void setUserName(String userName) {
|
||||
this.userName = userName;
|
||||
}
|
||||
|
||||
public String getUserName() {
|
||||
return null == userName ? "admin" : userName;
|
||||
return null == userName ? "zhaodongsheng" : userName;
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
@@ -1,8 +1,5 @@
|
||||
package dev.langchain4j.model.embedding;
|
||||
|
||||
import dev.langchain4j.model.embedding.onnx.AbstractInProcessEmbeddingModel;
|
||||
import dev.langchain4j.model.embedding.onnx.OnnxBertBiEncoder;
|
||||
import dev.langchain4j.model.embedding.onnx.PoolingMode;
|
||||
import org.apache.commons.lang3.StringUtils;
|
||||
|
||||
import java.io.IOException;
|
||||
@@ -12,7 +9,6 @@ import java.nio.file.Files;
|
||||
import java.nio.file.Path;
|
||||
import java.nio.file.Paths;
|
||||
import java.util.Objects;
|
||||
import java.util.concurrent.Executors;
|
||||
|
||||
/**
|
||||
* An embedding model that runs within your Java application's process. Any BERT-based model (e.g.,
|
||||
@@ -29,7 +25,6 @@ public class S2OnnxEmbeddingModel extends AbstractInProcessEmbeddingModel {
|
||||
private static volatile String cachedVocabularyPath;
|
||||
|
||||
public S2OnnxEmbeddingModel(String pathToModel, String vocabularyPath) {
|
||||
super(Executors.newSingleThreadExecutor());
|
||||
if (shouldReloadModel(pathToModel, vocabularyPath)) {
|
||||
synchronized (S2OnnxEmbeddingModel.class) {
|
||||
if (shouldReloadModel(pathToModel, vocabularyPath)) {
|
||||
@@ -66,8 +61,8 @@ public class S2OnnxEmbeddingModel extends AbstractInProcessEmbeddingModel {
|
||||
|
||||
static OnnxBertBiEncoder loadFromFileSystem(Path pathToModel, URL vocabularyFile) {
|
||||
try {
|
||||
return new OnnxBertBiEncoder(Files.newInputStream(pathToModel),
|
||||
vocabularyFile.openStream(), PoolingMode.MEAN);
|
||||
return new OnnxBertBiEncoder(Files.newInputStream(pathToModel), vocabularyFile,
|
||||
PoolingMode.MEAN);
|
||||
} catch (IOException e) {
|
||||
throw new RuntimeException(e);
|
||||
}
|
||||
|
||||
@@ -52,7 +52,7 @@ import static dev.langchain4j.model.openai.InternalOpenAiHelper.toOpenAiMessages
|
||||
import static dev.langchain4j.model.openai.InternalOpenAiHelper.toOpenAiResponseFormat;
|
||||
import static dev.langchain4j.model.openai.InternalOpenAiHelper.toTools;
|
||||
import static dev.langchain4j.model.openai.InternalOpenAiHelper.tokenUsageFrom;
|
||||
import static dev.langchain4j.model.openai.OpenAiChatModelName.GPT_3_5_TURBO;
|
||||
import static dev.langchain4j.model.openai.OpenAiModelName.GPT_3_5_TURBO;
|
||||
import static dev.langchain4j.spi.ServiceHelper.loadFactories;
|
||||
import static java.time.Duration.ofSeconds;
|
||||
import static java.util.Collections.emptyList;
|
||||
@@ -66,6 +66,7 @@ import static java.util.Collections.singletonList;
|
||||
@Slf4j
|
||||
public class OpenAiChatModel implements ChatLanguageModel, TokenCountEstimator {
|
||||
|
||||
public static final String ZHIPU = "bigmodel";
|
||||
private final OpenAiClient client;
|
||||
private final String baseUrl;
|
||||
private final String modelName;
|
||||
@@ -110,7 +111,7 @@ public class OpenAiChatModel implements ChatLanguageModel, TokenCountEstimator {
|
||||
.connectTimeout(timeout).readTimeout(timeout).writeTimeout(timeout).proxy(proxy)
|
||||
.logRequests(logRequests).logResponses(logResponses).userAgent(DEFAULT_USER_AGENT)
|
||||
.customHeaders(customHeaders).build();
|
||||
this.modelName = getOrDefault(modelName, GPT_3_5_TURBO.name());
|
||||
this.modelName = getOrDefault(modelName, GPT_3_5_TURBO);
|
||||
this.apiVersion = apiVersion;
|
||||
this.temperature = getOrDefault(temperature, 0.7);
|
||||
this.topP = topP;
|
||||
@@ -129,7 +130,7 @@ public class OpenAiChatModel implements ChatLanguageModel, TokenCountEstimator {
|
||||
this.strictTools = getOrDefault(strictTools, false);
|
||||
this.parallelToolCalls = parallelToolCalls;
|
||||
this.maxRetries = getOrDefault(maxRetries, 3);
|
||||
this.tokenizer = getOrDefault(tokenizer, () -> new OpenAiTokenizer(this.modelName));
|
||||
this.tokenizer = getOrDefault(tokenizer, OpenAiTokenizer::new);
|
||||
this.listeners = listeners == null ? emptyList() : new ArrayList<>(listeners);
|
||||
}
|
||||
|
||||
@@ -191,7 +192,9 @@ public class OpenAiChatModel implements ChatLanguageModel, TokenCountEstimator {
|
||||
.responseFormat(responseFormat).seed(seed).user(user)
|
||||
.parallelToolCalls(parallelToolCalls);
|
||||
|
||||
requestBuilder.temperature(temperature);
|
||||
if (!(baseUrl.contains(ZHIPU))) {
|
||||
requestBuilder.temperature(temperature);
|
||||
}
|
||||
|
||||
if (toolSpecifications != null && !toolSpecifications.isEmpty()) {
|
||||
requestBuilder.tools(toTools(toolSpecifications, strictTools));
|
||||
|
||||
@@ -0,0 +1,16 @@
|
||||
package dev.langchain4j.model.zhipu;
|
||||
|
||||
public enum ChatCompletionModel {
|
||||
GLM_4("glm-4"), GLM_3_TURBO("glm-3-turbo"), CHATGLM_TURBO("chatglm_turbo");
|
||||
|
||||
private final String value;
|
||||
|
||||
ChatCompletionModel(String value) {
|
||||
this.value = value;
|
||||
}
|
||||
|
||||
@Override
|
||||
public String toString() {
|
||||
return this.value;
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,100 @@
|
||||
package dev.langchain4j.model.zhipu;
|
||||
|
||||
import dev.langchain4j.agent.tool.ToolSpecification;
|
||||
import dev.langchain4j.data.message.AiMessage;
|
||||
import dev.langchain4j.data.message.ChatMessage;
|
||||
import dev.langchain4j.model.chat.ChatLanguageModel;
|
||||
import dev.langchain4j.model.output.Response;
|
||||
import dev.langchain4j.model.zhipu.chat.ChatCompletionRequest;
|
||||
import dev.langchain4j.model.zhipu.chat.ChatCompletionResponse;
|
||||
import dev.langchain4j.model.zhipu.spi.ZhipuAiChatModelBuilderFactory;
|
||||
import lombok.Builder;
|
||||
|
||||
import java.util.List;
|
||||
|
||||
import static dev.langchain4j.internal.RetryUtils.withRetry;
|
||||
import static dev.langchain4j.internal.Utils.getOrDefault;
|
||||
import static dev.langchain4j.internal.Utils.isNullOrEmpty;
|
||||
import static dev.langchain4j.internal.ValidationUtils.ensureNotEmpty;
|
||||
import static dev.langchain4j.model.zhipu.DefaultZhipuAiHelper.aiMessageFrom;
|
||||
import static dev.langchain4j.model.zhipu.DefaultZhipuAiHelper.finishReasonFrom;
|
||||
import static dev.langchain4j.model.zhipu.DefaultZhipuAiHelper.toTools;
|
||||
import static dev.langchain4j.model.zhipu.DefaultZhipuAiHelper.toZhipuAiMessages;
|
||||
import static dev.langchain4j.model.zhipu.DefaultZhipuAiHelper.tokenUsageFrom;
|
||||
import static dev.langchain4j.model.zhipu.chat.ToolChoiceMode.AUTO;
|
||||
import static dev.langchain4j.spi.ServiceHelper.loadFactories;
|
||||
import static java.util.Collections.singletonList;
|
||||
|
||||
/**
|
||||
* Represents an ZhipuAi language model with a chat completion interface, such as glm-3-turbo and
|
||||
* glm-4. You can find description of parameters
|
||||
* <a href="https://open.bigmodel.cn/dev/api">here</a>.
|
||||
*/
|
||||
public class ZhipuAiChatModel implements ChatLanguageModel {
|
||||
|
||||
private final String baseUrl;
|
||||
private final Double temperature;
|
||||
private final Double topP;
|
||||
private final String model;
|
||||
private final Integer maxRetries;
|
||||
private final Integer maxToken;
|
||||
private final ZhipuAiClient client;
|
||||
|
||||
@Builder
|
||||
public ZhipuAiChatModel(String baseUrl, String apiKey, Double temperature, Double topP,
|
||||
String model, Integer maxRetries, Integer maxToken, Boolean logRequests,
|
||||
Boolean logResponses) {
|
||||
this.baseUrl = getOrDefault(baseUrl, "https://open.bigmodel.cn/");
|
||||
this.temperature = getOrDefault(temperature, 0.7);
|
||||
this.topP = topP;
|
||||
this.model = getOrDefault(model, ChatCompletionModel.GLM_4.toString());
|
||||
this.maxRetries = getOrDefault(maxRetries, 3);
|
||||
this.maxToken = getOrDefault(maxToken, 512);
|
||||
this.client = ZhipuAiClient.builder().baseUrl(this.baseUrl).apiKey(apiKey)
|
||||
.logRequests(getOrDefault(logRequests, false))
|
||||
.logResponses(getOrDefault(logResponses, false)).build();
|
||||
}
|
||||
|
||||
public static ZhipuAiChatModelBuilder builder() {
|
||||
for (ZhipuAiChatModelBuilderFactory factories : loadFactories(
|
||||
ZhipuAiChatModelBuilderFactory.class)) {
|
||||
return factories.get();
|
||||
}
|
||||
return new ZhipuAiChatModelBuilder();
|
||||
}
|
||||
|
||||
@Override
|
||||
public Response<AiMessage> generate(List<ChatMessage> messages) {
|
||||
return generate(messages, (ToolSpecification) null);
|
||||
}
|
||||
|
||||
@Override
|
||||
public Response<AiMessage> generate(List<ChatMessage> messages,
|
||||
List<ToolSpecification> toolSpecifications) {
|
||||
ensureNotEmpty(messages, "messages");
|
||||
|
||||
ChatCompletionRequest.Builder requestBuilder =
|
||||
ChatCompletionRequest.builder().model(this.model).maxTokens(maxToken).stream(false)
|
||||
.topP(topP).toolChoice(AUTO).messages(toZhipuAiMessages(messages));
|
||||
|
||||
if (!isNullOrEmpty(toolSpecifications)) {
|
||||
requestBuilder.tools(toTools(toolSpecifications));
|
||||
}
|
||||
|
||||
ChatCompletionResponse response =
|
||||
withRetry(() -> client.chatCompletion(requestBuilder.build()), maxRetries);
|
||||
return Response.from(aiMessageFrom(response), tokenUsageFrom(response.getUsage()),
|
||||
finishReasonFrom(response.getChoices().get(0).getFinishReason()));
|
||||
}
|
||||
|
||||
@Override
|
||||
public Response<AiMessage> generate(List<ChatMessage> messages,
|
||||
ToolSpecification toolSpecification) {
|
||||
return generate(messages,
|
||||
toolSpecification != null ? singletonList(toolSpecification) : null);
|
||||
}
|
||||
|
||||
public static class ZhipuAiChatModelBuilder {
|
||||
public ZhipuAiChatModelBuilder() {}
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,51 @@
|
||||
package dev.langchain4j.provider;
|
||||
|
||||
import com.tencent.supersonic.common.pojo.ChatModelConfig;
|
||||
import com.tencent.supersonic.common.pojo.EmbeddingModelConfig;
|
||||
import dev.langchain4j.model.azure.AzureOpenAiChatModel;
|
||||
import dev.langchain4j.model.azure.AzureOpenAiEmbeddingModel;
|
||||
import dev.langchain4j.model.chat.ChatLanguageModel;
|
||||
import dev.langchain4j.model.embedding.EmbeddingModel;
|
||||
import org.springframework.beans.factory.InitializingBean;
|
||||
import org.springframework.stereotype.Service;
|
||||
|
||||
import java.time.Duration;
|
||||
|
||||
@Service
|
||||
public class AzureModelFactory implements ModelFactory, InitializingBean {
|
||||
public static final String PROVIDER = "AZURE";
|
||||
public static final String DEFAULT_BASE_URL = "https://your-resource-name.openai.azure.com/";
|
||||
public static final String DEFAULT_MODEL_NAME = "gpt-35-turbo";
|
||||
public static final String DEFAULT_EMBEDDING_MODEL_NAME = "text-embedding-ada-002";
|
||||
|
||||
@Override
|
||||
public ChatLanguageModel createChatModel(ChatModelConfig modelConfig) {
|
||||
AzureOpenAiChatModel.Builder builder = AzureOpenAiChatModel.builder()
|
||||
.endpoint(modelConfig.getBaseUrl()).apiKey(modelConfig.getApiKey())
|
||||
.deploymentName(modelConfig.getModelName())
|
||||
.temperature(modelConfig.getTemperature()).maxRetries(modelConfig.getMaxRetries())
|
||||
.topP(modelConfig.getTopP())
|
||||
.timeout(Duration.ofSeconds(
|
||||
modelConfig.getTimeOut() == null ? 0L : modelConfig.getTimeOut()))
|
||||
.logRequestsAndResponses(
|
||||
modelConfig.getLogRequests() != null && modelConfig.getLogResponses());
|
||||
return builder.build();
|
||||
}
|
||||
|
||||
@Override
|
||||
public EmbeddingModel createEmbeddingModel(EmbeddingModelConfig embeddingModelConfig) {
|
||||
AzureOpenAiEmbeddingModel.Builder builder =
|
||||
AzureOpenAiEmbeddingModel.builder().endpoint(embeddingModelConfig.getBaseUrl())
|
||||
.apiKey(embeddingModelConfig.getApiKey())
|
||||
.deploymentName(embeddingModelConfig.getModelName())
|
||||
.maxRetries(embeddingModelConfig.getMaxRetries())
|
||||
.logRequestsAndResponses(embeddingModelConfig.getLogRequests() != null
|
||||
&& embeddingModelConfig.getLogResponses());
|
||||
return builder.build();
|
||||
}
|
||||
|
||||
@Override
|
||||
public void afterPropertiesSet() {
|
||||
ModelProvider.add(PROVIDER, this);
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,39 @@
|
||||
package dev.langchain4j.provider;
|
||||
|
||||
import com.tencent.supersonic.common.pojo.ChatModelConfig;
|
||||
import com.tencent.supersonic.common.pojo.EmbeddingModelConfig;
|
||||
import dev.langchain4j.model.chat.ChatLanguageModel;
|
||||
import dev.langchain4j.model.dashscope.QwenChatModel;
|
||||
import dev.langchain4j.model.dashscope.QwenEmbeddingModel;
|
||||
import dev.langchain4j.model.dashscope.QwenModelName;
|
||||
import dev.langchain4j.model.embedding.EmbeddingModel;
|
||||
import org.springframework.beans.factory.InitializingBean;
|
||||
import org.springframework.stereotype.Service;
|
||||
|
||||
@Service
|
||||
public class DashscopeModelFactory implements ModelFactory, InitializingBean {
|
||||
public static final String PROVIDER = "DASHSCOPE";
|
||||
public static final String DEFAULT_BASE_URL = "https://dashscope.aliyuncs.com/api/v1";
|
||||
public static final String DEFAULT_MODEL_NAME = QwenModelName.QWEN_PLUS;
|
||||
public static final String DEFAULT_EMBEDDING_MODEL_NAME = "text-embedding-v2";
|
||||
|
||||
@Override
|
||||
public ChatLanguageModel createChatModel(ChatModelConfig modelConfig) {
|
||||
return QwenChatModel.builder().baseUrl(modelConfig.getBaseUrl())
|
||||
.apiKey(modelConfig.getApiKey()).modelName(modelConfig.getModelName())
|
||||
.temperature(modelConfig.getTemperature() == null ? 0L
|
||||
: modelConfig.getTemperature().floatValue())
|
||||
.topP(modelConfig.getTopP()).enableSearch(modelConfig.getEnableSearch()).build();
|
||||
}
|
||||
|
||||
@Override
|
||||
public EmbeddingModel createEmbeddingModel(EmbeddingModelConfig embeddingModelConfig) {
|
||||
return QwenEmbeddingModel.builder().apiKey(embeddingModelConfig.getApiKey())
|
||||
.modelName(embeddingModelConfig.getModelName()).build();
|
||||
}
|
||||
|
||||
@Override
|
||||
public void afterPropertiesSet() {
|
||||
ModelProvider.add(PROVIDER, this);
|
||||
}
|
||||
}
|
||||
@@ -6,8 +6,7 @@ import com.tencent.supersonic.common.util.AESEncryptionUtil;
|
||||
import dev.langchain4j.model.chat.ChatLanguageModel;
|
||||
import dev.langchain4j.model.dify.DifyAiChatModel;
|
||||
import dev.langchain4j.model.embedding.EmbeddingModel;
|
||||
import dev.langchain4j.model.openai.OpenAiEmbeddingModel;
|
||||
import dev.langchain4j.model.openai.OpenAiStreamingChatModel;
|
||||
import dev.langchain4j.model.zhipu.ZhipuAiEmbeddingModel;
|
||||
import org.springframework.beans.factory.InitializingBean;
|
||||
import org.springframework.stereotype.Service;
|
||||
|
||||
@@ -26,16 +25,10 @@ public class DifyModelFactory implements ModelFactory, InitializingBean {
|
||||
.modelName(modelConfig.getModelName()).timeOut(modelConfig.getTimeOut()).build();
|
||||
}
|
||||
|
||||
@Override
|
||||
public OpenAiStreamingChatModel createChatStreamingModel(ChatModelConfig modelConfig) {
|
||||
throw new RuntimeException("待开发");
|
||||
}
|
||||
|
||||
@Override
|
||||
public EmbeddingModel createEmbeddingModel(EmbeddingModelConfig embeddingModelConfig) {
|
||||
return OpenAiEmbeddingModel.builder().baseUrl(embeddingModelConfig.getBaseUrl())
|
||||
.apiKey(embeddingModelConfig.getApiKey())
|
||||
.modelName(embeddingModelConfig.getModelName())
|
||||
return ZhipuAiEmbeddingModel.builder().baseUrl(embeddingModelConfig.getBaseUrl())
|
||||
.apiKey(embeddingModelConfig.getApiKey()).model(embeddingModelConfig.getModelName())
|
||||
.maxRetries(embeddingModelConfig.getMaxRetries())
|
||||
.logRequests(embeddingModelConfig.getLogRequests())
|
||||
.logResponses(embeddingModelConfig.getLogResponses()).build();
|
||||
|
||||
@@ -1,8 +1,8 @@
|
||||
package dev.langchain4j.provider;
|
||||
|
||||
import dev.langchain4j.model.embedding.AllMiniLmL6V2QuantizedEmbeddingModel;
|
||||
import dev.langchain4j.model.embedding.BgeSmallZhEmbeddingModel;
|
||||
import dev.langchain4j.model.embedding.EmbeddingModel;
|
||||
import dev.langchain4j.model.embedding.onnx.allminilml6v2q.AllMiniLmL6V2QuantizedEmbeddingModel;
|
||||
import dev.langchain4j.model.embedding.onnx.bgesmallzh.BgeSmallZhEmbeddingModel;
|
||||
import org.springframework.stereotype.Service;
|
||||
|
||||
@Service
|
||||
|
||||
@@ -5,7 +5,6 @@ import com.tencent.supersonic.common.pojo.EmbeddingModelConfig;
|
||||
import dev.langchain4j.model.chat.ChatLanguageModel;
|
||||
import dev.langchain4j.model.embedding.EmbeddingModel;
|
||||
import dev.langchain4j.model.embedding.S2OnnxEmbeddingModel;
|
||||
import dev.langchain4j.model.openai.OpenAiStreamingChatModel;
|
||||
import org.apache.commons.lang3.StringUtils;
|
||||
import org.springframework.beans.factory.InitializingBean;
|
||||
import org.springframework.stereotype.Service;
|
||||
@@ -36,11 +35,6 @@ public class InMemoryModelFactory implements ModelFactory, InitializingBean {
|
||||
return EmbeddingModelConstant.BGE_SMALL_ZH_MODEL;
|
||||
}
|
||||
|
||||
@Override
|
||||
public OpenAiStreamingChatModel createChatStreamingModel(ChatModelConfig modelConfig) {
|
||||
throw new RuntimeException("待开发");
|
||||
}
|
||||
|
||||
@Override
|
||||
public void afterPropertiesSet() {
|
||||
ModelProvider.add(PROVIDER, this);
|
||||
|
||||
@@ -6,7 +6,6 @@ import dev.langchain4j.model.chat.ChatLanguageModel;
|
||||
import dev.langchain4j.model.embedding.EmbeddingModel;
|
||||
import dev.langchain4j.model.localai.LocalAiChatModel;
|
||||
import dev.langchain4j.model.localai.LocalAiEmbeddingModel;
|
||||
import dev.langchain4j.model.openai.OpenAiStreamingChatModel;
|
||||
import org.springframework.beans.factory.InitializingBean;
|
||||
import org.springframework.stereotype.Service;
|
||||
|
||||
@@ -28,11 +27,6 @@ public class LocalAiModelFactory implements ModelFactory, InitializingBean {
|
||||
.build();
|
||||
}
|
||||
|
||||
@Override
|
||||
public OpenAiStreamingChatModel createChatStreamingModel(ChatModelConfig modelConfig) {
|
||||
throw new RuntimeException("待开发");
|
||||
}
|
||||
|
||||
@Override
|
||||
public EmbeddingModel createEmbeddingModel(EmbeddingModelConfig embeddingModel) {
|
||||
return LocalAiEmbeddingModel.builder().baseUrl(embeddingModel.getBaseUrl())
|
||||
|
||||
@@ -4,12 +4,9 @@ import com.tencent.supersonic.common.pojo.ChatModelConfig;
|
||||
import com.tencent.supersonic.common.pojo.EmbeddingModelConfig;
|
||||
import dev.langchain4j.model.chat.ChatLanguageModel;
|
||||
import dev.langchain4j.model.embedding.EmbeddingModel;
|
||||
import dev.langchain4j.model.openai.OpenAiStreamingChatModel;
|
||||
|
||||
public interface ModelFactory {
|
||||
ChatLanguageModel createChatModel(ChatModelConfig modelConfig);
|
||||
|
||||
OpenAiStreamingChatModel createChatStreamingModel(ChatModelConfig modelConfig);
|
||||
|
||||
EmbeddingModel createEmbeddingModel(EmbeddingModelConfig embeddingModel);
|
||||
}
|
||||
|
||||
@@ -5,9 +5,7 @@ import com.tencent.supersonic.common.pojo.ChatModelConfig;
|
||||
import com.tencent.supersonic.common.pojo.EmbeddingModelConfig;
|
||||
import com.tencent.supersonic.common.util.ContextUtils;
|
||||
import dev.langchain4j.model.chat.ChatLanguageModel;
|
||||
import dev.langchain4j.model.chat.StreamingChatLanguageModel;
|
||||
import dev.langchain4j.model.embedding.EmbeddingModel;
|
||||
import dev.langchain4j.model.openai.OpenAiStreamingChatModel;
|
||||
import org.apache.commons.lang3.StringUtils;
|
||||
|
||||
import java.util.HashMap;
|
||||
@@ -43,20 +41,6 @@ public class ModelProvider {
|
||||
"Unsupported ChatLanguageModel provider: " + modelConfig.getProvider());
|
||||
}
|
||||
|
||||
public static StreamingChatLanguageModel getChatStreamingModel(ChatModelConfig modelConfig) {
|
||||
if (modelConfig == null || StringUtils.isBlank(modelConfig.getProvider())
|
||||
|| StringUtils.isBlank(modelConfig.getBaseUrl())) {
|
||||
modelConfig = DEMO_CHAT_MODEL;
|
||||
}
|
||||
ModelFactory modelFactory = factories.get(modelConfig.getProvider().toUpperCase());
|
||||
if (modelFactory != null) {
|
||||
return modelFactory.createChatStreamingModel(modelConfig);
|
||||
}
|
||||
|
||||
throw new RuntimeException(
|
||||
"Unsupported ChatLanguageModel provider: " + modelConfig.getProvider());
|
||||
}
|
||||
|
||||
public static EmbeddingModel getEmbeddingModel() {
|
||||
return getEmbeddingModel(null);
|
||||
}
|
||||
|
||||
@@ -6,7 +6,6 @@ import dev.langchain4j.model.chat.ChatLanguageModel;
|
||||
import dev.langchain4j.model.embedding.EmbeddingModel;
|
||||
import dev.langchain4j.model.ollama.OllamaChatModel;
|
||||
import dev.langchain4j.model.ollama.OllamaEmbeddingModel;
|
||||
import dev.langchain4j.model.openai.OpenAiStreamingChatModel;
|
||||
import org.springframework.beans.factory.InitializingBean;
|
||||
import org.springframework.stereotype.Service;
|
||||
|
||||
@@ -29,11 +28,6 @@ public class OllamaModelFactory implements ModelFactory, InitializingBean {
|
||||
.logResponses(modelConfig.getLogResponses()).build();
|
||||
}
|
||||
|
||||
@Override
|
||||
public OpenAiStreamingChatModel createChatStreamingModel(ChatModelConfig modelConfig) {
|
||||
throw new RuntimeException("待开发");
|
||||
}
|
||||
|
||||
@Override
|
||||
public EmbeddingModel createEmbeddingModel(EmbeddingModelConfig embeddingModelConfig) {
|
||||
return OllamaEmbeddingModel.builder().baseUrl(embeddingModelConfig.getBaseUrl())
|
||||
|
||||
@@ -6,7 +6,6 @@ import dev.langchain4j.model.chat.ChatLanguageModel;
|
||||
import dev.langchain4j.model.embedding.EmbeddingModel;
|
||||
import dev.langchain4j.model.openai.OpenAiChatModel;
|
||||
import dev.langchain4j.model.openai.OpenAiEmbeddingModel;
|
||||
import dev.langchain4j.model.openai.OpenAiStreamingChatModel;
|
||||
import org.springframework.beans.factory.InitializingBean;
|
||||
import org.springframework.stereotype.Service;
|
||||
|
||||
@@ -23,26 +22,10 @@ public class OpenAiModelFactory implements ModelFactory, InitializingBean {
|
||||
|
||||
@Override
|
||||
public ChatLanguageModel createChatModel(ChatModelConfig modelConfig) {
|
||||
OpenAiChatModel.OpenAiChatModelBuilder openAiChatModelBuilder = OpenAiChatModel.builder()
|
||||
.baseUrl(modelConfig.getBaseUrl()).modelName(modelConfig.getModelName())
|
||||
.apiKey(modelConfig.keyDecrypt()).apiVersion(modelConfig.getApiVersion())
|
||||
.temperature(modelConfig.getTemperature()).topP(modelConfig.getTopP())
|
||||
.maxRetries(modelConfig.getMaxRetries())
|
||||
.timeout(Duration.ofSeconds(modelConfig.getTimeOut()))
|
||||
.logRequests(modelConfig.getLogRequests())
|
||||
.logResponses(modelConfig.getLogResponses());
|
||||
if (modelConfig.getJsonFormat() != null && modelConfig.getJsonFormat()) {
|
||||
openAiChatModelBuilder.strictJsonSchema(true)
|
||||
.responseFormat(modelConfig.getJsonFormatType());
|
||||
}
|
||||
return openAiChatModelBuilder.build();
|
||||
}
|
||||
|
||||
@Override
|
||||
public OpenAiStreamingChatModel createChatStreamingModel(ChatModelConfig modelConfig) {
|
||||
return OpenAiStreamingChatModel.builder().baseUrl(modelConfig.getBaseUrl())
|
||||
return OpenAiChatModel.builder().baseUrl(modelConfig.getBaseUrl())
|
||||
.modelName(modelConfig.getModelName()).apiKey(modelConfig.keyDecrypt())
|
||||
.temperature(modelConfig.getTemperature()).topP(modelConfig.getTopP())
|
||||
.apiVersion(modelConfig.getApiVersion()).temperature(modelConfig.getTemperature())
|
||||
.topP(modelConfig.getTopP()).maxRetries(modelConfig.getMaxRetries())
|
||||
.timeout(Duration.ofSeconds(modelConfig.getTimeOut()))
|
||||
.logRequests(modelConfig.getLogRequests())
|
||||
.logResponses(modelConfig.getLogResponses()).build();
|
||||
|
||||
@@ -0,0 +1,47 @@
|
||||
package dev.langchain4j.provider;
|
||||
|
||||
import com.tencent.supersonic.common.pojo.ChatModelConfig;
|
||||
import com.tencent.supersonic.common.pojo.EmbeddingModelConfig;
|
||||
import dev.langchain4j.model.chat.ChatLanguageModel;
|
||||
import dev.langchain4j.model.embedding.EmbeddingModel;
|
||||
import dev.langchain4j.model.qianfan.QianfanChatModel;
|
||||
import dev.langchain4j.model.qianfan.QianfanEmbeddingModel;
|
||||
import org.springframework.beans.factory.InitializingBean;
|
||||
import org.springframework.stereotype.Service;
|
||||
|
||||
@Service
|
||||
public class QianfanModelFactory implements ModelFactory, InitializingBean {
|
||||
|
||||
public static final String PROVIDER = "QIANFAN";
|
||||
public static final String DEFAULT_BASE_URL = "https://aip.baidubce.com";
|
||||
public static final String DEFAULT_MODEL_NAME = "Llama-2-70b-chat";
|
||||
|
||||
public static final String DEFAULT_EMBEDDING_MODEL_NAME = "Embedding-V1";
|
||||
public static final String DEFAULT_ENDPOINT = "llama_2_70b";
|
||||
|
||||
@Override
|
||||
public ChatLanguageModel createChatModel(ChatModelConfig modelConfig) {
|
||||
return QianfanChatModel.builder().baseUrl(modelConfig.getBaseUrl())
|
||||
.apiKey(modelConfig.getApiKey()).secretKey(modelConfig.getSecretKey())
|
||||
.endpoint(modelConfig.getEndpoint()).modelName(modelConfig.getModelName())
|
||||
.temperature(modelConfig.getTemperature()).topP(modelConfig.getTopP())
|
||||
.maxRetries(modelConfig.getMaxRetries()).logRequests(modelConfig.getLogRequests())
|
||||
.logResponses(modelConfig.getLogResponses()).build();
|
||||
}
|
||||
|
||||
@Override
|
||||
public EmbeddingModel createEmbeddingModel(EmbeddingModelConfig embeddingModelConfig) {
|
||||
return QianfanEmbeddingModel.builder().baseUrl(embeddingModelConfig.getBaseUrl())
|
||||
.apiKey(embeddingModelConfig.getApiKey())
|
||||
.secretKey(embeddingModelConfig.getSecretKey())
|
||||
.modelName(embeddingModelConfig.getModelName())
|
||||
.maxRetries(embeddingModelConfig.getMaxRetries())
|
||||
.logRequests(embeddingModelConfig.getLogRequests())
|
||||
.logResponses(embeddingModelConfig.getLogResponses()).build();
|
||||
}
|
||||
|
||||
@Override
|
||||
public void afterPropertiesSet() {
|
||||
ModelProvider.add(PROVIDER, this);
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,45 @@
|
||||
package dev.langchain4j.provider;
|
||||
|
||||
import com.tencent.supersonic.common.pojo.ChatModelConfig;
|
||||
import com.tencent.supersonic.common.pojo.EmbeddingModelConfig;
|
||||
import dev.langchain4j.model.chat.ChatLanguageModel;
|
||||
import dev.langchain4j.model.embedding.EmbeddingModel;
|
||||
import dev.langchain4j.model.zhipu.ChatCompletionModel;
|
||||
import dev.langchain4j.model.zhipu.ZhipuAiChatModel;
|
||||
import dev.langchain4j.model.zhipu.ZhipuAiEmbeddingModel;
|
||||
import org.springframework.beans.factory.InitializingBean;
|
||||
import org.springframework.stereotype.Service;
|
||||
|
||||
import static java.time.Duration.ofSeconds;
|
||||
|
||||
@Service
|
||||
public class ZhipuModelFactory implements ModelFactory, InitializingBean {
|
||||
public static final String PROVIDER = "ZHIPU";
|
||||
public static final String DEFAULT_BASE_URL = "https://open.bigmodel.cn/";
|
||||
public static final String DEFAULT_MODEL_NAME = ChatCompletionModel.GLM_4.toString();
|
||||
public static final String DEFAULT_EMBEDDING_MODEL_NAME = "embedding-2";
|
||||
|
||||
@Override
|
||||
public ChatLanguageModel createChatModel(ChatModelConfig modelConfig) {
|
||||
return ZhipuAiChatModel.builder().baseUrl(modelConfig.getBaseUrl())
|
||||
.apiKey(modelConfig.getApiKey()).model(modelConfig.getModelName())
|
||||
.temperature(modelConfig.getTemperature()).topP(modelConfig.getTopP())
|
||||
.maxRetries(modelConfig.getMaxRetries()).logRequests(modelConfig.getLogRequests())
|
||||
.logResponses(modelConfig.getLogResponses()).build();
|
||||
}
|
||||
|
||||
@Override
|
||||
public EmbeddingModel createEmbeddingModel(EmbeddingModelConfig embeddingModelConfig) {
|
||||
return ZhipuAiEmbeddingModel.builder().baseUrl(embeddingModelConfig.getBaseUrl())
|
||||
.apiKey(embeddingModelConfig.getApiKey()).model(embeddingModelConfig.getModelName())
|
||||
.maxRetries(embeddingModelConfig.getMaxRetries()).callTimeout(ofSeconds(60))
|
||||
.connectTimeout(ofSeconds(60)).writeTimeout(ofSeconds(60))
|
||||
.readTimeout(ofSeconds(60)).logRequests(embeddingModelConfig.getLogRequests())
|
||||
.logResponses(embeddingModelConfig.getLogResponses()).build();
|
||||
}
|
||||
|
||||
@Override
|
||||
public void afterPropertiesSet() {
|
||||
ModelProvider.add(PROVIDER, this);
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,21 @@
|
||||
package dev.langchain4j.qianfan.spring;
|
||||
|
||||
import lombok.Getter;
|
||||
import lombok.Setter;
|
||||
|
||||
@Getter
|
||||
@Setter
|
||||
class ChatModelProperties {
|
||||
private String baseUrl;
|
||||
private String apiKey;
|
||||
private String secretKey;
|
||||
private Double temperature;
|
||||
private Integer maxRetries;
|
||||
private Double topP;
|
||||
private String modelName;
|
||||
private String endpoint;
|
||||
private String responseFormat;
|
||||
private Double penaltyScore;
|
||||
private Boolean logRequests;
|
||||
private Boolean logResponses;
|
||||
}
|
||||
@@ -0,0 +1,18 @@
|
||||
package dev.langchain4j.qianfan.spring;
|
||||
|
||||
import lombok.Getter;
|
||||
import lombok.Setter;
|
||||
|
||||
@Getter
|
||||
@Setter
|
||||
class EmbeddingModelProperties {
|
||||
private String baseUrl;
|
||||
private String apiKey;
|
||||
private String secretKey;
|
||||
private Integer maxRetries;
|
||||
private String modelName;
|
||||
private String endpoint;
|
||||
private String user;
|
||||
private Boolean logRequests;
|
||||
private Boolean logResponses;
|
||||
}
|
||||
@@ -0,0 +1,21 @@
|
||||
package dev.langchain4j.qianfan.spring;
|
||||
|
||||
import lombok.Getter;
|
||||
import lombok.Setter;
|
||||
|
||||
@Getter
|
||||
@Setter
|
||||
class LanguageModelProperties {
|
||||
private String baseUrl;
|
||||
private String apiKey;
|
||||
private String secretKey;
|
||||
private Double temperature;
|
||||
private Integer maxRetries;
|
||||
private Integer topK;
|
||||
private Double topP;
|
||||
private String modelName;
|
||||
private String endpoint;
|
||||
private Double penaltyScore;
|
||||
private Boolean logRequests;
|
||||
private Boolean logResponses;
|
||||
}
|
||||
@@ -0,0 +1,29 @@
|
||||
package dev.langchain4j.qianfan.spring;
|
||||
|
||||
import lombok.Getter;
|
||||
import lombok.Setter;
|
||||
import org.springframework.boot.context.properties.ConfigurationProperties;
|
||||
import org.springframework.boot.context.properties.NestedConfigurationProperty;
|
||||
|
||||
@Getter
|
||||
@Setter
|
||||
@ConfigurationProperties(prefix = Properties.PREFIX)
|
||||
public class Properties {
|
||||
|
||||
static final String PREFIX = "langchain4j.qianfan";
|
||||
|
||||
@NestedConfigurationProperty
|
||||
ChatModelProperties chatModel;
|
||||
|
||||
@NestedConfigurationProperty
|
||||
ChatModelProperties streamingChatModel;
|
||||
|
||||
@NestedConfigurationProperty
|
||||
LanguageModelProperties languageModel;
|
||||
|
||||
@NestedConfigurationProperty
|
||||
LanguageModelProperties streamingLanguageModel;
|
||||
|
||||
@NestedConfigurationProperty
|
||||
EmbeddingModelProperties embeddingModel;
|
||||
}
|
||||
@@ -0,0 +1,102 @@
|
||||
package dev.langchain4j.qianfan.spring;
|
||||
|
||||
import dev.langchain4j.model.qianfan.QianfanChatModel;
|
||||
import dev.langchain4j.model.qianfan.QianfanEmbeddingModel;
|
||||
import dev.langchain4j.model.qianfan.QianfanLanguageModel;
|
||||
import dev.langchain4j.model.qianfan.QianfanStreamingChatModel;
|
||||
import dev.langchain4j.model.qianfan.QianfanStreamingLanguageModel;
|
||||
import org.springframework.boot.autoconfigure.condition.ConditionalOnProperty;
|
||||
import org.springframework.boot.context.properties.EnableConfigurationProperties;
|
||||
import org.springframework.context.annotation.Bean;
|
||||
import org.springframework.context.annotation.Configuration;
|
||||
|
||||
import static dev.langchain4j.qianfan.spring.Properties.PREFIX;
|
||||
|
||||
@Configuration
|
||||
@EnableConfigurationProperties(Properties.class)
|
||||
public class QianfanAutoConfig {
|
||||
|
||||
@Bean
|
||||
@ConditionalOnProperty(PREFIX + ".chat-model.api-key")
|
||||
QianfanChatModel qianfanChatModel(Properties properties) {
|
||||
ChatModelProperties chatModelProperties = properties.getChatModel();
|
||||
return QianfanChatModel.builder().baseUrl(chatModelProperties.getBaseUrl())
|
||||
.apiKey(chatModelProperties.getApiKey())
|
||||
.secretKey(chatModelProperties.getSecretKey())
|
||||
.endpoint(chatModelProperties.getEndpoint())
|
||||
.penaltyScore(chatModelProperties.getPenaltyScore())
|
||||
.modelName(chatModelProperties.getModelName())
|
||||
.temperature(chatModelProperties.getTemperature())
|
||||
.topP(chatModelProperties.getTopP())
|
||||
.responseFormat(chatModelProperties.getResponseFormat())
|
||||
.maxRetries(chatModelProperties.getMaxRetries())
|
||||
.logRequests(chatModelProperties.getLogRequests())
|
||||
.logResponses(chatModelProperties.getLogResponses()).build();
|
||||
}
|
||||
|
||||
@Bean
|
||||
@ConditionalOnProperty(PREFIX + ".streaming-chat-model.api-key")
|
||||
QianfanStreamingChatModel qianfanStreamingChatModel(Properties properties) {
|
||||
ChatModelProperties chatModelProperties = properties.getStreamingChatModel();
|
||||
return QianfanStreamingChatModel.builder().endpoint(chatModelProperties.getEndpoint())
|
||||
.penaltyScore(chatModelProperties.getPenaltyScore())
|
||||
.temperature(chatModelProperties.getTemperature())
|
||||
.topP(chatModelProperties.getTopP()).baseUrl(chatModelProperties.getBaseUrl())
|
||||
.apiKey(chatModelProperties.getApiKey())
|
||||
.secretKey(chatModelProperties.getSecretKey())
|
||||
.modelName(chatModelProperties.getModelName())
|
||||
.responseFormat(chatModelProperties.getResponseFormat())
|
||||
.logRequests(chatModelProperties.getLogRequests())
|
||||
.logResponses(chatModelProperties.getLogResponses()).build();
|
||||
}
|
||||
|
||||
@Bean
|
||||
@ConditionalOnProperty(PREFIX + ".language-model.api-key")
|
||||
QianfanLanguageModel qianfanLanguageModel(Properties properties) {
|
||||
LanguageModelProperties languageModelProperties = properties.getLanguageModel();
|
||||
return QianfanLanguageModel.builder().endpoint(languageModelProperties.getEndpoint())
|
||||
.penaltyScore(languageModelProperties.getPenaltyScore())
|
||||
.topK(languageModelProperties.getTopK()).topP(languageModelProperties.getTopP())
|
||||
.baseUrl(languageModelProperties.getBaseUrl())
|
||||
.apiKey(languageModelProperties.getApiKey())
|
||||
.secretKey(languageModelProperties.getSecretKey())
|
||||
.modelName(languageModelProperties.getModelName())
|
||||
.temperature(languageModelProperties.getTemperature())
|
||||
.maxRetries(languageModelProperties.getMaxRetries())
|
||||
.logRequests(languageModelProperties.getLogRequests())
|
||||
.logResponses(languageModelProperties.getLogResponses()).build();
|
||||
}
|
||||
|
||||
@Bean
|
||||
@ConditionalOnProperty(PREFIX + ".streaming-language-model.api-key")
|
||||
QianfanStreamingLanguageModel qianfanStreamingLanguageModel(Properties properties) {
|
||||
LanguageModelProperties languageModelProperties = properties.getStreamingLanguageModel();
|
||||
return QianfanStreamingLanguageModel.builder()
|
||||
.endpoint(languageModelProperties.getEndpoint())
|
||||
.penaltyScore(languageModelProperties.getPenaltyScore())
|
||||
.topK(languageModelProperties.getTopK()).topP(languageModelProperties.getTopP())
|
||||
.baseUrl(languageModelProperties.getBaseUrl())
|
||||
.apiKey(languageModelProperties.getApiKey())
|
||||
.secretKey(languageModelProperties.getSecretKey())
|
||||
.modelName(languageModelProperties.getModelName())
|
||||
.temperature(languageModelProperties.getTemperature())
|
||||
.maxRetries(languageModelProperties.getMaxRetries())
|
||||
.logRequests(languageModelProperties.getLogRequests())
|
||||
.logResponses(languageModelProperties.getLogResponses()).build();
|
||||
}
|
||||
|
||||
@Bean
|
||||
@ConditionalOnProperty(PREFIX + ".embedding-model.api-key")
|
||||
QianfanEmbeddingModel qianfanEmbeddingModel(Properties properties) {
|
||||
EmbeddingModelProperties embeddingModelProperties = properties.getEmbeddingModel();
|
||||
return QianfanEmbeddingModel.builder().baseUrl(embeddingModelProperties.getBaseUrl())
|
||||
.endpoint(embeddingModelProperties.getEndpoint())
|
||||
.apiKey(embeddingModelProperties.getApiKey())
|
||||
.secretKey(embeddingModelProperties.getSecretKey())
|
||||
.modelName(embeddingModelProperties.getModelName())
|
||||
.user(embeddingModelProperties.getUser())
|
||||
.maxRetries(embeddingModelProperties.getMaxRetries())
|
||||
.logRequests(embeddingModelProperties.getLogRequests())
|
||||
.logResponses(embeddingModelProperties.getLogResponses()).build();
|
||||
}
|
||||
}
|
||||
@@ -42,6 +42,6 @@ public class TextSegmentConvert {
|
||||
if (Objects.isNull(textSegment) || Objects.isNull(textSegment.metadata())) {
|
||||
return null;
|
||||
}
|
||||
return textSegment.metadata().getString(QUERY_ID);
|
||||
return textSegment.metadata().get(QUERY_ID);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -57,7 +57,6 @@ public class MilvusEmbeddingStore implements EmbeddingStore<TextSegment> {
|
||||
private final ConsistencyLevelEnum consistencyLevel;
|
||||
private final boolean retrieveEmbeddingsOnSearch;
|
||||
private final boolean autoFlushOnInsert;
|
||||
private final FieldDefinition fieldDefinition;
|
||||
|
||||
public MilvusEmbeddingStore(String host, Integer port, String collectionName, Integer dimension,
|
||||
IndexType indexType, MetricType metricType, String uri, String token, String username,
|
||||
@@ -79,15 +78,11 @@ public class MilvusEmbeddingStore implements EmbeddingStore<TextSegment> {
|
||||
this.retrieveEmbeddingsOnSearch = getOrDefault(retrieveEmbeddingsOnSearch, false);
|
||||
this.autoFlushOnInsert = getOrDefault(autoFlushOnInsert, false);
|
||||
|
||||
// Define the field structure for the collection
|
||||
this.fieldDefinition = new FieldDefinition(ID_FIELD_NAME, TEXT_FIELD_NAME,
|
||||
METADATA_FIELD_NAME, VECTOR_FIELD_NAME);
|
||||
|
||||
if (!hasCollection(this.milvusClient, this.collectionName)) {
|
||||
createCollection(this.milvusClient, this.collectionName, fieldDefinition,
|
||||
createCollection(this.milvusClient, this.collectionName,
|
||||
ensureNotNull(dimension, "dimension"));
|
||||
createIndex(this.milvusClient, this.collectionName, VECTOR_FIELD_NAME,
|
||||
getOrDefault(indexType, FLAT), this.metricType);
|
||||
createIndex(this.milvusClient, this.collectionName, getOrDefault(indexType, FLAT),
|
||||
this.metricType);
|
||||
}
|
||||
|
||||
loadCollectionInMemory(this.milvusClient, collectionName);
|
||||
@@ -133,7 +128,7 @@ public class MilvusEmbeddingStore implements EmbeddingStore<TextSegment> {
|
||||
public EmbeddingSearchResult<TextSegment> search(
|
||||
EmbeddingSearchRequest embeddingSearchRequest) {
|
||||
|
||||
SearchParam searchParam = buildSearchRequest(collectionName, fieldDefinition,
|
||||
SearchParam searchParam = buildSearchRequest(collectionName,
|
||||
embeddingSearchRequest.queryEmbedding().vectorAsList(),
|
||||
embeddingSearchRequest.filter(), embeddingSearchRequest.maxResults(), metricType,
|
||||
consistencyLevel);
|
||||
@@ -142,7 +137,7 @@ public class MilvusEmbeddingStore implements EmbeddingStore<TextSegment> {
|
||||
CollectionOperationsExecutor.search(milvusClient, searchParam);
|
||||
|
||||
List<EmbeddingMatch<TextSegment>> matches = toEmbeddingMatches(milvusClient, resultsWrapper,
|
||||
collectionName, fieldDefinition, consistencyLevel, retrieveEmbeddingsOnSearch);
|
||||
collectionName, consistencyLevel, retrieveEmbeddingsOnSearch);
|
||||
|
||||
List<EmbeddingMatch<TextSegment>> result =
|
||||
matches.stream().filter(match -> match.score() >= embeddingSearchRequest.minScore())
|
||||
@@ -231,7 +226,7 @@ public class MilvusEmbeddingStore implements EmbeddingStore<TextSegment> {
|
||||
@Override
|
||||
public void removeAll(Filter filter) {
|
||||
ensureNotNull(filter, "filter");
|
||||
removeForVector(this.milvusClient, this.collectionName, map(filter, METADATA_FIELD_NAME));
|
||||
removeForVector(this.milvusClient, this.collectionName, map(filter));
|
||||
}
|
||||
|
||||
/**
|
||||
|
||||
@@ -0,0 +1,19 @@
|
||||
package dev.langchain4j.zhipu.spring;
|
||||
|
||||
import lombok.Getter;
|
||||
import lombok.Setter;
|
||||
|
||||
@Getter
|
||||
@Setter
|
||||
class ChatModelProperties {
|
||||
|
||||
String baseUrl;
|
||||
String apiKey;
|
||||
Double temperature;
|
||||
Double topP;
|
||||
String modelName;
|
||||
Integer maxRetries;
|
||||
Integer maxToken;
|
||||
Boolean logRequests;
|
||||
Boolean logResponses;
|
||||
}
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user