mirror of
https://github.com/tencentmusic/supersonic.git
synced 2026-03-22 01:03:42 +08:00
Compare commits
77 Commits
8d34dcd5dd
...
master
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
c1d50f978d | ||
|
|
18ce934bba | ||
|
|
6fe0ebcb9d | ||
|
|
77d8d63df7 | ||
|
|
0876f5eae8 | ||
|
|
ddbaf53ad4 | ||
|
|
4c97d01eab | ||
|
|
008f1443cb | ||
|
|
29c1119ee2 | ||
|
|
d658e437fb | ||
|
|
b6f561f18c | ||
|
|
593d26a072 | ||
|
|
9162b922c4 | ||
|
|
1d9324f689 | ||
|
|
6c5f8fce40 | ||
|
|
04b1edb2e2 | ||
|
|
9857256488 | ||
|
|
d695bed75d | ||
|
|
7490dabdc3 | ||
|
|
fad28ee5ac | ||
|
|
353c8d8b16 | ||
|
|
3dd53bad89 | ||
|
|
2d39ebf38b | ||
|
|
6c472e1c76 | ||
|
|
431aa60e4d | ||
|
|
25df22758a | ||
|
|
9af6499491 | ||
|
|
c992e57b13 | ||
|
|
80aaabe58b | ||
|
|
5a4fd2b888 | ||
|
|
5df0b87da9 | ||
|
|
ab24b1777a | ||
|
|
ff76f8edbd | ||
|
|
76745f38a4 | ||
|
|
ce4cdb62ab | ||
|
|
c2ce3a75b7 | ||
|
|
1f6d217b26 | ||
|
|
af28bc7c2a | ||
|
|
42bf355839 | ||
|
|
91e4b51ef8 | ||
|
|
bf3213e8fb | ||
|
|
c75233e37f | ||
|
|
785bda6cd9 | ||
|
|
6bd8970849 | ||
|
|
c33a85b583 | ||
|
|
62b9db6791 | ||
|
|
6d907b6adf | ||
|
|
da172a030e | ||
|
|
47c2595fb8 | ||
|
|
9bddd4457e | ||
|
|
55ac3d1aa5 | ||
|
|
0427917624 | ||
|
|
d8fe2ed2b3 | ||
|
|
11d1264d38 | ||
|
|
32675387d7 | ||
|
|
e408204690 | ||
|
|
269f146c11 | ||
|
|
6f497b142e | ||
|
|
79a44b27ee | ||
|
|
76cc5ee111 | ||
|
|
320fcf04bd | ||
|
|
75fc83010c | ||
|
|
37673c82da | ||
|
|
3ae0d645a7 | ||
|
|
256a6bcb3f | ||
|
|
1faf84e372 | ||
|
|
7e6639df83 | ||
|
|
075ae4c0af | ||
|
|
08133ccbfb | ||
|
|
164d2a9e23 | ||
|
|
f899d23b63 | ||
|
|
944beddafc | ||
|
|
019d737f07 | ||
|
|
0721df2e66 | ||
|
|
303392f492 | ||
|
|
e5a41765b4 | ||
|
|
87355533b4 |
199
CHANGELOG.md
199
CHANGELOG.md
@@ -3,6 +3,205 @@
|
||||
- All notable changes to this project will be documented in this file.
|
||||
- "Breaking Changes" describes any changes that may break existing functionality or cause
|
||||
compatibility issues with previous versions.
|
||||
## SuperSonic [1.0.0] - 2025-08-05
|
||||
|
||||
### 重大特性变更 / Major Features
|
||||
|
||||
#### 多数据库支持扩展 / Multi-Database Support
|
||||
- **Oracle数据库支持**: 新增Oracle数据库引擎类型及适配器 ([8eeed87ba](https://github.com/tencentmusic/supersonic/commit/8eeed87ba) by supersonicbi)
|
||||
- **StarRocks支持**: 支持StarRocks和多catalog功能 ([33268bf3d](https://github.com/tencentmusic/supersonic/commit/33268bf3d) by zyclove)
|
||||
- **SAP HANA支持**: 新增SAP HANA数据库适配支持 ([2e28a4c7a](https://github.com/tencentmusic/supersonic/commit/2e28a4c7a) by wwsheng009)
|
||||
- **DuckDB支持**: 支持DuckDB数据库 ([a058dc8b6](https://github.com/tencentmusic/supersonic/commit/a058dc8b6) by jerryjzhang)
|
||||
- **Kyuubi支持**: 支持Kyuubi Presto Trino ([5e3bafb95](https://github.com/tencentmusic/supersonic/commit/5e3bafb95) by zyclove)
|
||||
- **OpenSearch支持**: 新增OpenSearch支持 ([d942d35c9](https://github.com/tencentmusic/supersonic/commit/d942d35c9) by zyclove)
|
||||
|
||||
#### 智能问答增强 / AI-Enhanced Query Processing
|
||||
- **LLM纠错器**: 新增LLM物理SQL纠错器 ([f899d23b6](https://github.com/tencentmusic/supersonic/commit/f899d23b6) by 柯慕灵)
|
||||
- **记忆管理**: Agent记忆管理启用few-shot优先机制 ([fae9118c2](https://github.com/tencentmusic/supersonic/commit/fae9118c2) by feelshana)
|
||||
- **结构化查询**: 支持struct查询中的offset子句 ([d2a43a99c](https://github.com/tencentmusic/supersonic/commit/d2a43a99c) by jerryjzhang)
|
||||
- **向量召回优化**: 优化嵌入向量召回机制 ([8c6ae6252](https://github.com/tencentmusic/supersonic/commit/8c6ae6252) by lexluo09)
|
||||
|
||||
#### 权限管理系统 / Permission Management
|
||||
- **Agent权限**: 支持agent级别的权限管理 ([b5aa6e046](https://github.com/tencentmusic/supersonic/commit/b5aa6e046) by jerryjzhang)
|
||||
- **用户管理**: 支持用户删除功能 ([1c9cf788c](https://github.com/tencentmusic/supersonic/commit/1c9cf788c) by supersonicbi)
|
||||
- **鉴权优化**: 全面优化鉴权与召回机制 ([1faf84e37](https://github.com/tencentmusic/supersonic/commit/1faf84e37), [7e6639df8](https://github.com/tencentmusic/supersonic/commit/7e6639df8) by guilinlewis)
|
||||
|
||||
### 架构升级 / Architecture Upgrades
|
||||
|
||||
#### 核心框架升级 / Core Framework Upgrades
|
||||
- **SpringBoot 3升级**: 完成SpringBoot 3.x升级 ([07f6be51c](https://github.com/tencentmusic/supersonic/commit/07f6be51c) by mislayming)
|
||||
- **依赖升级**: 升级依赖包并修复安全漏洞 ([232a20227](https://github.com/tencentmusic/supersonic/commit/232a20227) by beat4ocean)
|
||||
- **LangChain4j更新**: 替换已废弃的LangChain4j APIs ([acffc03c7](https://github.com/tencentmusic/supersonic/commit/acffc03c7) by beat4ocean)
|
||||
- **Swagger升级**: 使用SpringDoc支持Swagger在Spring 3.x ([758d170bb](https://github.com/tencentmusic/supersonic/commit/758d170bb) by jerryjzhang)
|
||||
|
||||
#### 许可证变更 / License Changes
|
||||
- **Apache 2.0**: 从MIT更改为Apache 2.0许可证 ([0aa002882](https://github.com/tencentmusic/supersonic/commit/0aa002882) by jerryjzhang)
|
||||
|
||||
### 性能优化 / Performance Improvements
|
||||
|
||||
#### 系统性能 / System Performance
|
||||
- **GC优化**: 实现Generational ZGC ([3fc1ec42b](https://github.com/tencentmusic/supersonic/commit/3fc1ec42b) by beat4ocean)
|
||||
- **Docker优化**: 减少Docker镜像体积 ([614917ba7](https://github.com/tencentmusic/supersonic/commit/614917ba7) by kino)
|
||||
- **并行处理**: 嵌入向量并行执行优化 ([8c6ae6252](https://github.com/tencentmusic/supersonic/commit/8c6ae6252) by lexluo09)
|
||||
- **记忆评估**: 记忆评估性能优化 ([524ec38ed](https://github.com/tencentmusic/supersonic/commit/524ec38ed) by yudong)
|
||||
- **多平台构建**: 支持Docker多平台构建 ([da6d28c18](https://github.com/tencentmusic/supersonic/commit/da6d28c18) by jerryjzhang)
|
||||
|
||||
#### 数据处理优化 / Data Processing Optimization
|
||||
- **日期格式**: 支持更多日期字符串格式 ([2b13866c0](https://github.com/tencentmusic/supersonic/commit/2b13866c0) by supersonicbi)
|
||||
- **SQL优化**: 优化SQL生成和执行性能 ([0ab764329](https://github.com/tencentmusic/supersonic/commit/0ab764329) by jerryjzhang)
|
||||
- **模型关联**: 优化模型关联查询性能 ([47c2595fb](https://github.com/tencentmusic/supersonic/commit/47c2595fb) by Willy-J)
|
||||
|
||||
### 功能增强 / Feature Enhancements
|
||||
|
||||
#### 前端界面优化 / Frontend Improvements
|
||||
- **图表导出**: 消息支持导出图表图片 ([ce9ae1c0c](https://github.com/tencentmusic/supersonic/commit/ce9ae1c0c) by pisces)
|
||||
- **路由重构**: 重构语义建模路由交互 ([82c63a7f2](https://github.com/tencentmusic/supersonic/commit/82c63a7f2) by tristanliu)
|
||||
- **权限界面**: 统一助理权限设置交互界面 ([46d64d78f](https://github.com/tencentmusic/supersonic/commit/46d64d78f) by tristanliu)
|
||||
- **图表优化**: 优化ChatMsg图表条件 ([06fb6ba74](https://github.com/tencentmusic/supersonic/commit/06fb6ba74) by FredTsang)
|
||||
- **数据格式**: 提取formatByDataFormatType()方法 ([9ffdba956](https://github.com/tencentmusic/supersonic/commit/9ffdba956) by FredTsang)
|
||||
|
||||
#### 开发体验 / Developer Experience
|
||||
- **构建脚本**: 优化Web应用构建脚本 ([baae7f74b](https://github.com/tencentmusic/supersonic/commit/baae7f74b) by zyclove)
|
||||
- **GitHub Actions**: 优化GitHub Actions镜像推送 ([6a4458a57](https://github.com/tencentmusic/supersonic/commit/6a4458a57) by lexluo09)
|
||||
- **基准测试**: 改进基准测试,增加解析结果分析 ([97710a90c](https://github.com/tencentmusic/supersonic/commit/97710a90c) by Antgeek)
|
||||
|
||||
### Bug修复 / Bug Fixes
|
||||
|
||||
#### 核心功能修复 / Core Function Fixes
|
||||
- **插件功能**: 修复插件功能无法调用/结果被NL2SQL覆盖问题 ([c75233e37](https://github.com/tencentmusic/supersonic/commit/c75233e37) by QJ_wonder)
|
||||
- **维度别名**: 修复映射阶段维度值别名不生效问题 ([785bda6cd](https://github.com/tencentmusic/supersonic/commit/785bda6cd) by feelshana)
|
||||
- **模型字段**: 修复模型字段更新问题 ([6bd897084](https://github.com/tencentmusic/supersonic/commit/6bd897084) by WDEP)
|
||||
- **多轮对话**: 修复headless中字段查询及多轮对话使用问题 ([be0447ae1](https://github.com/tencentmusic/supersonic/commit/be0447ae1) by QJ_wonder)
|
||||
|
||||
#### NPE异常修复 / NPE Exception Fixes
|
||||
- **聊天查询**: 修复EmbeddingMatchStrategy.detectByBatch() NPE异常 ([6d907b6ad](https://github.com/tencentmusic/supersonic/commit/6d907b6ad) by wangyong)
|
||||
- **文件处理**: 修复FileHandlerImpl.convert2Resp() 维度值数据行首字符为空格异常 ([da172a030](https://github.com/tencentmusic/supersonic/commit/da172a030) by wangyong)
|
||||
- **头部服务**: 修复多处headless NPE问题 ([79a44b27e](https://github.com/tencentmusic/supersonic/commit/79a44b27e) by jerryjzhang)
|
||||
- **解析信息**: 修复getParseInfo中的NPE ([dce9a8a58](https://github.com/tencentmusic/supersonic/commit/dce9a8a58) by supersonicbi)
|
||||
|
||||
#### SQL兼容性修复 / SQL Compatibility Fixes
|
||||
- **SQL处理**: 修复SQL前后换行符导致的语句结尾";"删除问题 ([55ac3d1aa](https://github.com/tencentmusic/supersonic/commit/55ac3d1aa) by wangyong)
|
||||
- **查询别名**: DictUtils.constructQuerySqlReq针对sql query增加别名 ([042791762](https://github.com/tencentmusic/supersonic/commit/042791762) by andybj0228)
|
||||
- **SQL变量**: 支持SQL脚本变量替换 ([0709575cd](https://github.com/tencentmusic/supersonic/commit/0709575cd) by wanglongqiang)
|
||||
|
||||
#### 前端Bug修复 / Frontend Bug Fixes
|
||||
- **UI样式**: 修复问答对话右侧历史对话模块样式异常 ([c33a85b58](https://github.com/tencentmusic/supersonic/commit/c33a85b58) by wangyong)
|
||||
- **推荐维度**: 修复页面不显示推荐下钻维度问题 ([62b9db679](https://github.com/tencentmusic/supersonic/commit/62b9db679) by WDEP)
|
||||
- **图表显示**: 修复饼图显示条件问题 ([1b8cd7f0d](https://github.com/tencentmusic/supersonic/commit/1b8cd7f0d) by WDEP)
|
||||
- **负数支持**: 支持负数显示 ([2552e2ae4](https://github.com/tencentmusic/supersonic/commit/2552e2ae4) by FredTsang)
|
||||
- **百分比显示**: 支持bar图needMultiply100显示正确百分比值 ([8abfc923a](https://github.com/tencentmusic/supersonic/commit/8abfc923a) by coosir)
|
||||
- **TypeScript错误**: 修复前端TypeScript错误 ([5585b9e22](https://github.com/tencentmusic/supersonic/commit/5585b9e22) by poncheen)
|
||||
|
||||
#### 系统兼容性修复 / System Compatibility Fixes
|
||||
- **Windows脚本**: 修复Windows daemon.bat路径配置问题 ([e5a41765b](https://github.com/tencentmusic/supersonic/commit/e5a41765b) by 柯慕灵)
|
||||
- **字符编码**: 将utf8编码修改为utf8mb4,解决字符问题 ([2e81b190a](https://github.com/tencentmusic/supersonic/commit/2e81b190a) by Kun Gu)
|
||||
- **记忆缓存**: 修复记忆管理中因缓存无法存储的问题 ([81cd60d2d](https://github.com/tencentmusic/supersonic/commit/81cd60d2d) by guilinlewis)
|
||||
- **Mac兼容**: 降级djl库以支持Mac Intel机器 ([bf3213e8f](https://github.com/tencentmusic/supersonic/commit/bf3213e8f) by jerryjzhang)
|
||||
|
||||
### 数据管理优化 / Data Management Improvements
|
||||
|
||||
#### 维度指标管理 / Dimension & Metric Management
|
||||
- **维度检索**: 修复维度和指标检索及百分比显示问题 ([d8fe2ed2b](https://github.com/tencentmusic/supersonic/commit/d8fe2ed2b) by 木鱼和尚)
|
||||
- **查询导出**: 基于queryColumns导出数据 ([11d1264d3](https://github.com/tencentmusic/supersonic/commit/11d1264d3) by FredTsang)
|
||||
- **表格排序**: 移除表格defaultSortOrder ([32675387d](https://github.com/tencentmusic/supersonic/commit/32675387d) by FredTsang)
|
||||
- **维度搜索**: 修复维度搜索带key查询范围超出问题 ([269f146c1](https://github.com/tencentmusic/supersonic/commit/269f146c1) by wangyong)
|
||||
|
||||
### 测试和质量保证 / Testing & Quality Assurance
|
||||
|
||||
#### 单元测试 / Unit Testing
|
||||
- **测试修复**: 修复单元测试用例 ([91e4b51ef](https://github.com/tencentmusic/supersonic/commit/91e4b51ef) by jerryjzhang)
|
||||
- **模型测试**: 修复ModelCreateForm.tsx错误 ([d2aa73b85](https://github.com/tencentmusic/supersonic/commit/d2aa73b85) by Antgeek)
|
||||
|
||||
### 重要变更说明 / Breaking Changes
|
||||
|
||||
#### 升级注意事项 / Upgrade Notes
|
||||
1. **SpringBoot 3升级**: 可能需要更新依赖配置和代码适配
|
||||
2. **许可证变更**: 从MIT变更为Apache 2.0,请注意法律合规
|
||||
3. **API接口调整**: 部分API接口为支持新功能进行了调整
|
||||
4. **数据库兼容**: 新增多种数据库支持,配置方式有所变化
|
||||
|
||||
### 完整提交统计 / Commit Statistics
|
||||
- **总提交数**: 419个提交
|
||||
- **主要贡献者**:
|
||||
- jerryjzhang: 158次提交
|
||||
- supersonicbi: 22次提交
|
||||
- zyclove: 20次提交
|
||||
- beat4ocean: 15次提交
|
||||
- guilinlewis: 11次提交
|
||||
- wangyong: 11次提交
|
||||
- 其他贡献者: 182次提交
|
||||
- **涉及模块**: headless, chat, auth, common, webapp, launcher, docker
|
||||
- **时间跨度**: 2024年11月1日 - 2025年8月5日
|
||||
|
||||
### 致谢 / Acknowledgments
|
||||
|
||||
感谢所有为SuperSonic 1.0.0版本贡献代码、文档、测试和建议的开发者们!🎉
|
||||
|
||||
#### 核心贡献者 / Core Contributors
|
||||
- **jerryjzhang** - 项目维护者,核心架构设计与实现
|
||||
- **supersonicbi** - 核心功能开发,多数据库支持
|
||||
- **beat4ocean** - 架构升级,依赖管理,安全优化
|
||||
- **zyclove** - 数据库适配,构建优化
|
||||
- **guilinlewis** - 鉴权系统,召回优化
|
||||
- **wangyong** - Bug修复,NPE异常处理
|
||||
|
||||
#### 活跃贡献者 / Active Contributors
|
||||
- **WDEP** - 前端优化,图表功能
|
||||
- **FredTsang** - Chat SDK优化,数据导出
|
||||
- **feelshana** - 记忆管理,向量召回
|
||||
- **QJ_wonder** - 插件功能,多轮对话
|
||||
- **Willy-J** - 模型关联,数据库兼容
|
||||
- **iridescentpeo** - 查询优化,模型管理
|
||||
- **tristanliu** - 前端路由,权限界面
|
||||
- **mislayming** - SpringBoot 3升级
|
||||
- **Antgeek** - 基准测试,模型修复
|
||||
- **柯慕灵** - LLM纠错器,Windows脚本
|
||||
- **superhero** - 项目管理,代码审查
|
||||
|
||||
#### 其他重要贡献者 / Other Important Contributors
|
||||
- **木鱼和尚** - 维度指标检索优化
|
||||
- **pisces** - 图表导出功能
|
||||
- **lexluo09** - 并行处理,GitHub Actions
|
||||
- **andybj0228** - SQL查询优化
|
||||
- **wanglongqiang** - SQL变量支持
|
||||
- **Hyman_bz** - StarRocks支持
|
||||
- **wwsheng009** - SAP HANA适配
|
||||
- **poncheen** - TypeScript错误修复
|
||||
- **kino** - Docker镜像优化
|
||||
- **coosir** - 前端百分比显示
|
||||
- **Kun Gu** - 字符编码优化
|
||||
- **chixiaopao** - NPE异常修复
|
||||
- **naimehao** - 核心功能修复
|
||||
- **yudong** - 记忆评估优化
|
||||
- **mroldx** - 数据库脚本更新
|
||||
- **ChPi** - 解析器性能优化
|
||||
- **Hwting** - Docker配置优化
|
||||
|
||||
#### 特别感谢 / Special Thanks
|
||||
感谢所有提交Issue、参与讨论、提供反馈的社区用户,你们的每一个建议都让SuperSonic变得更好!
|
||||
|
||||
#### 社区支持 / Community Support
|
||||
SuperSonic是一个开源项目,我们欢迎更多开发者加入:
|
||||
- 🔗 **GitHub**: https://github.com/tencentmusic/supersonic
|
||||
- 📖 **文档**: 详见项目README和Wiki
|
||||
- 🐛 **Issue报告**: 欢迎提交Bug和功能请求
|
||||
- 🚀 **贡献代码**: 欢迎提交Pull Request
|
||||
- 💬 **社区讨论**: 加入我们的技术交流群
|
||||
|
||||
#### 未来展望 / Future Vision
|
||||
SuperSonic 1.0.0是一个重要的里程碑,但这只是开始。我们将继续:
|
||||
- 🌟 **持续优化性能和稳定性**
|
||||
- 🔧 **扩展更多数据库和AI模型支持**
|
||||
- 🎨 **改善用户体验和界面设计**
|
||||
- 📚 **完善文档和最佳实践**
|
||||
- 🤝 **建设更活跃的开源社区**
|
||||
|
||||
**让我们一起把SuperSonic做得更好!** ✨
|
||||
|
||||
---
|
||||
|
||||
*如果您在使用过程中遇到问题或有改进建议,欢迎随时与我们交流。每一份贡献都让SuperSonic更加强大!*
|
||||
|
||||
|
||||
## SuperSonic [0.9.8] - 2024-11-01
|
||||
- Add LLM management module to reuse connection across agents.
|
||||
|
||||
113
CLAUDE.md
Normal file
113
CLAUDE.md
Normal file
@@ -0,0 +1,113 @@
|
||||
# CLAUDE.md
|
||||
|
||||
This file provides guidance to Claude Code (claude.ai/code) when working with code in this repository.
|
||||
|
||||
## Build Commands
|
||||
|
||||
### Backend (Java/Maven)
|
||||
|
||||
```bash
|
||||
# Clean build (skip tests)
|
||||
mvn clean package -DskipTests -Dspotless.skip=true
|
||||
|
||||
# Run all tests
|
||||
mvn test
|
||||
|
||||
# Run single test class
|
||||
mvn test -Dtest=ClassName
|
||||
|
||||
# Full CI build
|
||||
mvn -B package --file pom.xml
|
||||
```
|
||||
|
||||
**Requirements:** Java 21, Maven
|
||||
|
||||
### Frontend (pnpm/React)
|
||||
|
||||
```bash
|
||||
cd webapp
|
||||
|
||||
# Install dependencies
|
||||
pnpm install
|
||||
|
||||
# Start dev server (port 9000)
|
||||
pnpm dev
|
||||
|
||||
# Production build
|
||||
pnpm build
|
||||
|
||||
# Run tests
|
||||
pnpm test
|
||||
```
|
||||
|
||||
**Requirements:** Node.js >=16, pnpm 9.12.3+
|
||||
|
||||
### Quick Start
|
||||
|
||||
```bash
|
||||
# Build full release
|
||||
./assembly/bin/supersonic-build.sh standalone
|
||||
|
||||
# Start service
|
||||
./assembly/bin/supersonic-daemon.sh start
|
||||
|
||||
# Stop service
|
||||
./assembly/bin/supersonic-daemon.sh stop
|
||||
```
|
||||
|
||||
Visit http://localhost:9080 after startup.
|
||||
|
||||
## Architecture Overview
|
||||
|
||||
SuperSonic unifies **Chat BI** (LLM-powered) and **Headless BI** (semantic layer) paradigms.
|
||||
|
||||
### Core Modules
|
||||
|
||||
```
|
||||
supersonic/
|
||||
├── auth/ # Authentication & authorization (SPI-based)
|
||||
├── chat/ # Chat BI module - LLM-powered Q&A interface
|
||||
├── common/ # Shared utilities
|
||||
├── headless/ # Headless BI - semantic layer with open API
|
||||
├── launchers/ # Application entry points
|
||||
│ ├── standalone/ # Combined Chat + Headless (default)
|
||||
│ ├── chat/ # Chat-only service
|
||||
│ └── headless/ # Headless-only service
|
||||
└── webapp/ # Frontend React app (UmiJS 4 + Ant Design)
|
||||
```
|
||||
|
||||
### Data Flow
|
||||
|
||||
1. **Knowledge Base**: Extracts schema from semantic models, builds dictionary/index for schema mapping
|
||||
2. **Schema Mapper**: Identifies metrics/dimensions/entities/values in user queries
|
||||
3. **Semantic Parser**: Generates S2SQL (semantic SQL) using rule-based and LLM-based parsers
|
||||
4. **Semantic Corrector**: Validates and corrects semantic queries
|
||||
5. **Semantic Translator**: Converts S2SQL to executable SQL
|
||||
|
||||
### Key Entry Points
|
||||
|
||||
- `StandaloneLauncher.java` - Combined service with `scanBasePackages: ["com.tencent.supersonic", "dev.langchain4j"]`
|
||||
- `ChatLauncher.java` - Chat BI only
|
||||
- `HeadlessLauncher.java` - Headless BI only
|
||||
|
||||
## Key Technologies
|
||||
|
||||
**Backend:** Spring Boot 3.3.9, MyBatis-Plus 3.5.10.1, LangChain4j 0.36.2, JSqlParser 4.9, Calcite 1.38.0
|
||||
|
||||
**Frontend:** React 18, UmiJS 4, Ant Design 5.17.4, ECharts 5.0.2, AntV G6/X6
|
||||
|
||||
**Databases:** MySQL, PostgreSQL (with pgvector), H2, ClickHouse, StarRocks, Presto, Trino, DuckDB
|
||||
|
||||
## Testing
|
||||
|
||||
**Java tests:** JUnit 5, Mockito. Located in `src/test/java/` of each module.
|
||||
|
||||
**Frontend tests:** Jest with Puppeteer environment in `webapp/packages/supersonic-fe/`
|
||||
|
||||
**Evaluation scripts:** Python scripts in `evaluation/` directory for Text2SQL accuracy testing.
|
||||
|
||||
## Related Documentation
|
||||
|
||||
- [README.md](README.md) - English documentation
|
||||
- [README_CN.md](README_CN.md) - Chinese documentation
|
||||
- [Evaluation Guide](evaluation/README.md) - Text2SQL evaluation process
|
||||
@@ -43,10 +43,26 @@ if "%service%"=="webapp" (
|
||||
call mvn -f %projectDir% clean package -DskipTests -Dspotless.skip=true
|
||||
IF ERRORLEVEL 1 (
|
||||
ECHO Failed to build backend Java modules.
|
||||
ECHO Please check Maven and Java versions are compatible.
|
||||
ECHO Current Java: %JAVA_HOME%
|
||||
ECHO Current Maven: %MAVEN_HOME%
|
||||
EXIT /B 1
|
||||
)
|
||||
|
||||
REM extract and copy files to deployment directory
|
||||
cd %projectDir%\launchers\%model_name%\target
|
||||
if exist "launchers-%model_name%-%MVN_VERSION%-bin.tar.gz" (
|
||||
echo "Extracting launchers-%model_name%-%MVN_VERSION%-bin.tar.gz..."
|
||||
tar -xf "launchers-%model_name%-%MVN_VERSION%-bin.tar.gz"
|
||||
if exist "launchers-%model_name%-%MVN_VERSION%" (
|
||||
echo "Copying files to deployment directory..."
|
||||
xcopy /E /Y "launchers-%model_name%-%MVN_VERSION%\*" "%buildDir%\supersonic-%model_name%-%MVN_VERSION%\"
|
||||
)
|
||||
)
|
||||
|
||||
copy /y %projectDir%\launchers\%model_name%\target\*.tar.gz %buildDir%\
|
||||
echo "finished building supersonic-%model_name% service"
|
||||
cd %baseDir%
|
||||
goto :EOF
|
||||
|
||||
|
||||
@@ -72,22 +88,55 @@ if "%service%"=="webapp" (
|
||||
cd %buildDir%
|
||||
if exist %release_dir% rmdir /s /q %release_dir%
|
||||
if exist %release_dir%.zip del %release_dir%.zip
|
||||
mkdir %release_dir%
|
||||
rem package webapp
|
||||
tar xvf supersonic-webapp.tar.gz
|
||||
move /y supersonic-webapp webapp
|
||||
echo {"env": ""} > webapp\supersonic.config.json
|
||||
move /y webapp %release_dir%
|
||||
rem package java service
|
||||
tar xvf %service_name%-bin.tar.gz
|
||||
for /d %%D in ("%service_name%\*") do (
|
||||
move "%%D" "%release_dir%"
|
||||
|
||||
rem check if release directory already exists from buildJavaService
|
||||
if exist %release_dir% (
|
||||
echo "Release directory already prepared by buildJavaService"
|
||||
) else (
|
||||
mkdir %release_dir%
|
||||
|
||||
rem package java service
|
||||
tar xvf %service_name%-bin.tar.gz 2>nul
|
||||
if errorlevel 1 (
|
||||
echo "Warning: tar command failed, trying PowerShell extraction..."
|
||||
powershell -Command "Expand-Archive -Path '%service_name%-bin.tar.gz' -DestinationPath '.' -Force"
|
||||
)
|
||||
for /d %%D in ("%service_name%\*") do (
|
||||
move "%%D" "%release_dir%"
|
||||
)
|
||||
rmdir /s /q %service_name% 2>nul
|
||||
)
|
||||
|
||||
rem package webapp
|
||||
if exist supersonic-webapp.tar.gz (
|
||||
tar xvf supersonic-webapp.tar.gz 2>nul
|
||||
if errorlevel 1 (
|
||||
echo "Warning: tar command failed, trying PowerShell extraction..."
|
||||
powershell -Command "Expand-Archive -Path 'supersonic-webapp.tar.gz' -DestinationPath '.' -Force"
|
||||
)
|
||||
move /y supersonic-webapp webapp
|
||||
echo {"env": ""} > webapp\supersonic.config.json
|
||||
move /y webapp %release_dir%
|
||||
del supersonic-webapp.tar.gz 2>nul
|
||||
)
|
||||
|
||||
rem verify deployment structure
|
||||
if exist "%release_dir%\lib\launchers-%model_name%-%MVN_VERSION%.jar" (
|
||||
echo "Deployment structure verified successfully"
|
||||
) else (
|
||||
echo "Warning: Main jar file not found in deployment structure"
|
||||
echo "Expected: %release_dir%\lib\launchers-%model_name%-%MVN_VERSION%.jar"
|
||||
)
|
||||
|
||||
rem generate zip file
|
||||
powershell Compress-Archive -Path %release_dir% -DestinationPath %release_dir%.zip
|
||||
del %service_name%-bin.tar.gz
|
||||
del supersonic-webapp.tar.gz
|
||||
rmdir /s /q %service_name%
|
||||
powershell -Command "Compress-Archive -Path '%release_dir%' -DestinationPath '%release_dir%.zip' -Force"
|
||||
if errorlevel 1 (
|
||||
echo "Warning: PowerShell compression failed, release directory still available: %release_dir%"
|
||||
) else (
|
||||
echo "Successfully created release package: %release_dir%.zip"
|
||||
)
|
||||
|
||||
del %service_name%-bin.tar.gz 2>nul
|
||||
echo "finished packaging supersonic release"
|
||||
goto :EOF
|
||||
|
||||
|
||||
@@ -20,7 +20,9 @@ if "%profile%"=="" (
|
||||
|
||||
set "model_name=%service%"
|
||||
|
||||
cd %baseDir%
|
||||
REM fix path configuration - point to the correct release package directory
|
||||
set "releaseDir=%buildDir%\supersonic-%service%-1.0.0-SNAPSHOT"
|
||||
cd %releaseDir%
|
||||
|
||||
if "%command%"=="restart" (
|
||||
call :stop
|
||||
@@ -50,20 +52,58 @@ if "%command%"=="restart" (
|
||||
|
||||
:runJavaService
|
||||
echo 'java service starting, see logs in logs/'
|
||||
set "libDir=%baseDir%\lib"
|
||||
set "confDir=%baseDir%\conf"
|
||||
set "webDir=%baseDir%\webapp"
|
||||
set "logDir=%baseDir%\logs"
|
||||
set "classpath=%baseDir%;%webDir%;%libDir%\*;%confDir%"
|
||||
set "property=-Dfile.encoding=UTF-8 -Duser.language=Zh -Duser.region=CN -Duser.timezone=GMT+08 -Dspring.profiles.active=%profile%"
|
||||
set "java-command=%property% -Xms1024m -Xmx2048m -cp %CLASSPATH% %MAIN_CLASS%"
|
||||
echo 'Using release directory: %releaseDir%'
|
||||
|
||||
REM use release package directory as base path
|
||||
set "libDir=%releaseDir%\lib"
|
||||
set "confDir=%releaseDir%\conf"
|
||||
set "webDir=%releaseDir%\webapp"
|
||||
set "logDir=%releaseDir%\logs"
|
||||
|
||||
REM fix variable name matching problem
|
||||
set "CLASSPATH=%releaseDir%;%webDir%;%libDir%\*;%confDir%"
|
||||
set "MAIN_CLASS=%main_class%"
|
||||
|
||||
REM add port configuration
|
||||
set "property=-Dfile.encoding=UTF-8 -Duser.language=Zh -Duser.region=CN -Duser.timezone=GMT+08 -Dspring.profiles.active=%profile% -Dserver.port=9080"
|
||||
set "java_command=%property% -Xms1024m -Xmx2048m -cp "%CLASSPATH%" %MAIN_CLASS%"
|
||||
|
||||
if not exist %logDir% mkdir %logDir%
|
||||
start /B java %java-command% >nul 2>&1
|
||||
timeout /t 10 >nul
|
||||
|
||||
REM check if the main jar file exists
|
||||
if not exist "%libDir%\launchers-standalone-1.0.0-SNAPSHOT.jar" (
|
||||
echo "Error: Main jar file not found in %libDir%"
|
||||
echo "Please make sure the application has been built and packaged correctly."
|
||||
goto :EOF
|
||||
)
|
||||
|
||||
echo 'Main Class: %MAIN_CLASS%'
|
||||
echo 'Profile: %profile%'
|
||||
echo 'Starting Java service...'
|
||||
|
||||
REM start service and save logs
|
||||
start /B java %java_command% > "%logDir%\supersonic.log" 2>&1
|
||||
timeout /t 15 >nul
|
||||
|
||||
REM check service status
|
||||
netstat -an | findstr ":9080" >nul
|
||||
if errorlevel 1 (
|
||||
echo "Warning: Port 9080 is not listening"
|
||||
echo "Please check the log file: %logDir%\supersonic.log"
|
||||
if exist "%logDir%\supersonic.log" (
|
||||
echo "Recent log entries:"
|
||||
powershell -Command "Get-Content '%logDir%\supersonic.log' | Select-Object -Last 10"
|
||||
)
|
||||
) else (
|
||||
echo "Service started successfully on port 9080"
|
||||
echo "You can access the application at: http://localhost:9080"
|
||||
)
|
||||
|
||||
echo 'java service started'
|
||||
goto :EOF
|
||||
|
||||
:stopJavaService
|
||||
echo 'Stopping Java service...'
|
||||
for /f "tokens=2" %%i in ('tasklist ^| findstr /i "java"') do (
|
||||
taskkill /PID %%i /F
|
||||
echo "java service (PID = %%i) is killed."
|
||||
|
||||
@@ -15,4 +15,6 @@ public class UserReq {
|
||||
|
||||
@NotBlank(message = "password can not be null")
|
||||
private String newPassword;
|
||||
|
||||
private String role;
|
||||
}
|
||||
|
||||
@@ -19,6 +19,7 @@ import lombok.extern.slf4j.Slf4j;
|
||||
import org.springframework.beans.BeanUtils;
|
||||
|
||||
import java.sql.Timestamp;
|
||||
import java.util.Date;
|
||||
import java.util.List;
|
||||
import java.util.Optional;
|
||||
import java.util.Set;
|
||||
@@ -222,8 +223,9 @@ public class DefaultUserAdaptor implements UserAdaptor {
|
||||
new UserWithPassword(userDO.getId(), userDO.getName(), userDO.getDisplayName(),
|
||||
userDO.getEmail(), userDO.getPassword(), userDO.getIsAdmin());
|
||||
|
||||
String token =
|
||||
tokenService.generateToken(UserWithPassword.convert(userWithPassword), expireTime);
|
||||
// 使用令牌名称作为生成key ,这样可以区分正常请求和api 请求,api 的令牌失效时间很长,需考虑令牌泄露的情况
|
||||
String token = tokenService.generateToken(UserWithPassword.convert(userWithPassword),
|
||||
"SysDbToken:" + name, (new Date().getTime() + expireTime));
|
||||
UserTokenDO userTokenDO = saveUserToken(name, userName, token, expireTime);
|
||||
return convertUserToken(userTokenDO);
|
||||
}
|
||||
|
||||
@@ -21,6 +21,8 @@ public interface UserRepository {
|
||||
|
||||
UserTokenDO getUserToken(Long tokenId);
|
||||
|
||||
UserTokenDO getUserTokenByName(String tokenName);
|
||||
|
||||
void deleteUserTokenByName(String userName);
|
||||
|
||||
void deleteUserToken(Long tokenId);
|
||||
|
||||
@@ -65,6 +65,13 @@ public class UserRepositoryImpl implements UserRepository {
|
||||
return userTokenDOMapper.selectById(tokenId);
|
||||
}
|
||||
|
||||
@Override
|
||||
public UserTokenDO getUserTokenByName(String tokenName) {
|
||||
QueryWrapper<UserTokenDO> queryWrapper = new QueryWrapper<>();
|
||||
queryWrapper.lambda().eq(UserTokenDO::getName, tokenName);
|
||||
return userTokenDOMapper.selectOne(queryWrapper);
|
||||
}
|
||||
|
||||
@Override
|
||||
public void deleteUserTokenByName(String userName) {
|
||||
QueryWrapper<UserTokenDO> queryWrapper = new QueryWrapper<>();
|
||||
|
||||
@@ -6,7 +6,10 @@ import javax.crypto.spec.SecretKeySpec;
|
||||
|
||||
import com.tencent.supersonic.auth.api.authentication.config.AuthenticationConfig;
|
||||
import com.tencent.supersonic.auth.api.authentication.pojo.UserWithPassword;
|
||||
import com.tencent.supersonic.auth.authentication.persistence.dataobject.UserTokenDO;
|
||||
import com.tencent.supersonic.auth.authentication.persistence.repository.UserRepository;
|
||||
import com.tencent.supersonic.common.pojo.exception.AccessException;
|
||||
import com.tencent.supersonic.common.util.ContextUtils;
|
||||
import io.jsonwebtoken.Claims;
|
||||
import io.jsonwebtoken.Jwts;
|
||||
import io.jsonwebtoken.SignatureAlgorithm;
|
||||
@@ -71,6 +74,7 @@ public class TokenService {
|
||||
return generateToken(UserWithPassword.convert(appUser), request);
|
||||
}
|
||||
|
||||
|
||||
public Optional<Claims> getClaims(HttpServletRequest request) {
|
||||
String token = request.getHeader(authenticationConfig.getTokenHttpHeaderKey());
|
||||
String appKey = getAppKey(request);
|
||||
@@ -90,6 +94,14 @@ public class TokenService {
|
||||
|
||||
public Optional<Claims> getClaims(String token, String appKey) {
|
||||
try {
|
||||
if (StringUtils.isNotBlank(appKey) && appKey.startsWith("SysDbToken:")) {// 如果是配置的长期令牌,需校验数据库是否存在该配置
|
||||
UserRepository userRepository = ContextUtils.getBean(UserRepository.class);
|
||||
UserTokenDO dbToken =
|
||||
userRepository.getUserTokenByName(appKey.substring("SysDbToken:".length()));
|
||||
if (dbToken == null || !dbToken.getToken().equals(token.replace("Bearer ", ""))) {
|
||||
throw new AccessException("Token does not exist :" + appKey);
|
||||
}
|
||||
}
|
||||
String tokenSecret = getTokenSecret(appKey);
|
||||
Claims claims =
|
||||
Jwts.parser().setSigningKey(tokenSecret.getBytes(StandardCharsets.UTF_8))
|
||||
@@ -122,6 +134,16 @@ public class TokenService {
|
||||
Map<String, String> appKeyToSecretMap = authenticationConfig.getAppKeyToSecretMap();
|
||||
String secret = appKeyToSecretMap.get(appKey);
|
||||
if (StringUtils.isBlank(secret)) {
|
||||
if (StringUtils.isNotBlank(appKey) && appKey.startsWith("SysDbToken:")) { // 是配置的长期令牌
|
||||
String realAppKey = appKey.substring("SysDbToken:".length());
|
||||
String tmp =
|
||||
"WIaO9YRRVt+7QtpPvyWsARFngnEcbaKBk783uGFwMrbJBaochsqCH62L4Kijcb0sZCYoSsiKGV/zPml5MnZ3uQ==";
|
||||
if (tmp.length() <= realAppKey.length()) {
|
||||
return realAppKey;
|
||||
} else {
|
||||
return realAppKey + tmp.substring(realAppKey.length());
|
||||
}
|
||||
}
|
||||
throw new AccessException("get secret from appKey failed :" + appKey);
|
||||
}
|
||||
return secret;
|
||||
|
||||
@@ -18,4 +18,5 @@ public class ChatExecuteReq {
|
||||
private int parseId;
|
||||
private String queryText;
|
||||
private boolean saveAnswer;
|
||||
private boolean streamingResult;
|
||||
}
|
||||
|
||||
@@ -0,0 +1,19 @@
|
||||
package com.tencent.supersonic.chat.api.pojo.request;
|
||||
|
||||
import lombok.AllArgsConstructor;
|
||||
import lombok.Builder;
|
||||
import lombok.Data;
|
||||
import lombok.NoArgsConstructor;
|
||||
|
||||
import java.util.List;
|
||||
|
||||
@Data
|
||||
@Builder
|
||||
@NoArgsConstructor
|
||||
@AllArgsConstructor
|
||||
public class ChatMemoryDeleteReq {
|
||||
|
||||
private List<Long> ids;
|
||||
|
||||
private Integer agentId;
|
||||
}
|
||||
@@ -75,8 +75,12 @@ public class SqlExecutor implements ChatQueryExecutor {
|
||||
return null;
|
||||
}
|
||||
|
||||
QuerySqlReq sqlReq =
|
||||
QuerySqlReq.builder().sql(parseInfo.getSqlInfo().getCorrectedS2SQL()).build();
|
||||
// 使用querySQL,它已经包含了所有修正(包括物理SQL修正)
|
||||
String finalSql = StringUtils.isNotBlank(parseInfo.getSqlInfo().getQuerySQL())
|
||||
? parseInfo.getSqlInfo().getQuerySQL()
|
||||
: parseInfo.getSqlInfo().getCorrectedS2SQL();
|
||||
|
||||
QuerySqlReq sqlReq = QuerySqlReq.builder().sql(finalSql).build();
|
||||
sqlReq.setSqlInfo(parseInfo.getSqlInfo());
|
||||
sqlReq.setDataSetId(parseInfo.getDataSetId());
|
||||
|
||||
@@ -90,7 +94,7 @@ public class SqlExecutor implements ChatQueryExecutor {
|
||||
queryResult.setQueryTimeCost(System.currentTimeMillis() - startTime);
|
||||
if (queryResp != null) {
|
||||
queryResult.setQueryAuthorization(queryResp.getQueryAuthorization());
|
||||
queryResult.setQuerySql(queryResp.getSql());
|
||||
queryResult.setQuerySql(finalSql);
|
||||
queryResult.setQueryResults(queryResp.getResultList());
|
||||
queryResult.setQueryColumns(queryResp.getColumns());
|
||||
queryResult.setQueryState(QueryState.SUCCESS);
|
||||
|
||||
@@ -32,6 +32,7 @@ import dev.langchain4j.model.input.PromptTemplate;
|
||||
import dev.langchain4j.model.output.Response;
|
||||
import dev.langchain4j.provider.ModelProvider;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.apache.commons.lang3.StringUtils;
|
||||
import org.slf4j.Logger;
|
||||
import org.slf4j.LoggerFactory;
|
||||
|
||||
@@ -171,10 +172,6 @@ public class NL2SQLParser implements ChatQueryParser {
|
||||
return;
|
||||
}
|
||||
|
||||
// derive mapping result of current question and parsing result of last question.
|
||||
ChatLayerService chatLayerService = ContextUtils.getBean(ChatLayerService.class);
|
||||
MapResp currentMapResult = chatLayerService.map(queryNLReq);
|
||||
|
||||
List<QueryResp> historyQueries =
|
||||
getHistoryQueries(parseContext.getRequest().getChatId(), 1);
|
||||
if (historyQueries.isEmpty()) {
|
||||
@@ -182,12 +179,18 @@ public class NL2SQLParser implements ChatQueryParser {
|
||||
}
|
||||
QueryResp lastQuery = historyQueries.get(0);
|
||||
SemanticParseInfo lastParseInfo = lastQuery.getParseInfos().get(0);
|
||||
Long dataId = lastParseInfo.getDataSetId();
|
||||
String histSQL = lastParseInfo.getSqlInfo().getCorrectedS2SQL();
|
||||
if (StringUtils.isBlank(histSQL)) // 优化性能,如果问答不是chat bi 则无需重写,因为数据都不全
|
||||
return;
|
||||
|
||||
// derive mapping result of current question and parsing result of last question.
|
||||
ChatLayerService chatLayerService = ContextUtils.getBean(ChatLayerService.class);
|
||||
MapResp currentMapResult = chatLayerService.map(queryNLReq); // 优化性能 ,只有满足条件才mapping
|
||||
|
||||
Long dataId = lastParseInfo.getDataSetId();
|
||||
String curtMapStr =
|
||||
generateSchemaPrompt(currentMapResult.getMapInfo().getMatchedElements(dataId));
|
||||
String histMapStr = generateSchemaPrompt(lastParseInfo.getElementMatches());
|
||||
String histSQL = lastParseInfo.getSqlInfo().getCorrectedS2SQL();
|
||||
|
||||
Map<String, Object> variables = new HashMap<>();
|
||||
variables.put("current_question", currentMapResult.getQueryText());
|
||||
|
||||
@@ -35,9 +35,7 @@ public class ChatMemoryRepositoryImpl implements ChatMemoryRepository {
|
||||
if (CollectionUtils.isEmpty(ids)) {
|
||||
return;
|
||||
}
|
||||
for (Long id : ids) {
|
||||
chatMemoryMapper.deleteById(id);
|
||||
}
|
||||
chatMemoryMapper.deleteByIds(ids);
|
||||
}
|
||||
|
||||
@Override
|
||||
|
||||
@@ -19,7 +19,8 @@ public class ParseContext {
|
||||
}
|
||||
|
||||
public boolean enableNL2SQL() {
|
||||
return Objects.nonNull(agent) && agent.containsDatasetTool()&&response.getSelectedParses().size() == 0;
|
||||
return Objects.nonNull(agent) && agent.containsDatasetTool()
|
||||
&& response.getSelectedParses().size() == 0;
|
||||
}
|
||||
|
||||
public boolean enableLLM() {
|
||||
|
||||
@@ -1,13 +1,20 @@
|
||||
package com.tencent.supersonic.chat.server.processor.execute;
|
||||
|
||||
import com.alibaba.fastjson.JSON;
|
||||
import com.alibaba.fastjson.JSONObject;
|
||||
import com.tencent.supersonic.chat.api.pojo.response.QueryResult;
|
||||
import com.tencent.supersonic.chat.server.agent.Agent;
|
||||
import com.tencent.supersonic.chat.server.persistence.dataobject.ChatQueryDO;
|
||||
import com.tencent.supersonic.chat.server.persistence.repository.ChatQueryRepository;
|
||||
import com.tencent.supersonic.chat.server.pojo.ExecuteContext;
|
||||
import com.tencent.supersonic.common.pojo.ChatApp;
|
||||
import com.tencent.supersonic.common.pojo.enums.AppModule;
|
||||
import com.tencent.supersonic.common.util.ChatAppManager;
|
||||
import com.tencent.supersonic.common.util.ContextUtils;
|
||||
import dev.langchain4j.data.message.AiMessage;
|
||||
import dev.langchain4j.model.StreamingResponseHandler;
|
||||
import dev.langchain4j.model.chat.ChatLanguageModel;
|
||||
import dev.langchain4j.model.chat.StreamingChatLanguageModel;
|
||||
import dev.langchain4j.model.input.Prompt;
|
||||
import dev.langchain4j.model.input.PromptTemplate;
|
||||
import dev.langchain4j.model.output.Response;
|
||||
@@ -24,9 +31,11 @@ import java.util.Objects;
|
||||
* DataInterpretProcessor interprets query result to make it more readable to the users.
|
||||
*/
|
||||
public class DataInterpretProcessor implements ExecuteResultProcessor {
|
||||
|
||||
public static String tip = "AI 回答中...\r\n";
|
||||
private static final Logger keyPipelineLog = LoggerFactory.getLogger("keyPipeline");
|
||||
|
||||
private static Map<Long, StringBuffer> resultCache = new HashMap<>();
|
||||
|
||||
public static final String APP_KEY = "DATA_INTERPRETER";
|
||||
private static final String INSTRUCTION = ""
|
||||
+ "#Role: You are a data expert who communicates with business users everyday."
|
||||
@@ -41,13 +50,24 @@ public class DataInterpretProcessor implements ExecuteResultProcessor {
|
||||
.appModule(AppModule.CHAT).description("通过大模型对结果数据做提炼总结").enable(false).build());
|
||||
}
|
||||
|
||||
public static String getTextSummary(Long queryId) {
|
||||
if (resultCache.get(queryId) != null) {
|
||||
return resultCache.get(queryId).toString();
|
||||
}
|
||||
return "";
|
||||
}
|
||||
|
||||
public static Map<Long, StringBuffer> getResultCache() {
|
||||
return resultCache;
|
||||
}
|
||||
|
||||
@Override
|
||||
public boolean accept(ExecuteContext executeContext) {
|
||||
Agent agent = executeContext.getAgent();
|
||||
ChatApp chatApp = agent.getChatAppConfig().get(APP_KEY);
|
||||
return Objects.nonNull(chatApp) && chatApp.isEnable()
|
||||
&& StringUtils.isNotBlank(executeContext.getResponse().getTextResult()); // 如果都没结果,则无法处理,直接跳过
|
||||
&& StringUtils.isNotBlank(executeContext.getResponse().getTextResult()) // 如果都没结果,则无法处理
|
||||
&& StringUtils.isBlank(executeContext.getResponse().getTextSummary()); // 如果已经有汇总的结果了,无法再次处理
|
||||
}
|
||||
|
||||
@Override
|
||||
@@ -57,18 +77,62 @@ public class DataInterpretProcessor implements ExecuteResultProcessor {
|
||||
ChatApp chatApp = agent.getChatAppConfig().get(APP_KEY);
|
||||
|
||||
Map<String, Object> variable = new HashMap<>();
|
||||
variable.put("question", executeContext.getRequest().getQueryText());
|
||||
String question = executeContext.getResponse().getTextResult();// 结果解析应该用改写的问题,因为改写的内容信息量更大
|
||||
if (executeContext.getParseInfo().getProperties() != null
|
||||
&& executeContext.getParseInfo().getProperties().containsKey("CONTEXT")) {
|
||||
Map<String, Object> context = (Map<String, Object>) executeContext.getParseInfo()
|
||||
.getProperties().get("CONTEXT");
|
||||
if (context.get("queryText") != null && "".equals(context.get("queryText"))) {
|
||||
question = context.get("queryText").toString();
|
||||
}
|
||||
}
|
||||
variable.put("question", question);
|
||||
variable.put("data", queryResult.getTextResult());
|
||||
|
||||
Prompt prompt = PromptTemplate.from(chatApp.getPrompt()).apply(variable);
|
||||
ChatLanguageModel chatLanguageModel =
|
||||
ModelProvider.getChatModel(chatApp.getChatModelConfig());
|
||||
Response<AiMessage> response = chatLanguageModel.generate(prompt.toUserMessage());
|
||||
String anwser = response.content().text();
|
||||
keyPipelineLog.info("DataInterpretProcessor modelReq:\n{} \nmodelResp:\n{}", prompt.text(),
|
||||
anwser);
|
||||
if (StringUtils.isNotBlank(anwser)) {
|
||||
queryResult.setTextSummary(anwser);
|
||||
if (executeContext.getRequest().isStreamingResult()) {
|
||||
StreamingChatLanguageModel chatLanguageModel =
|
||||
ModelProvider.getChatStreamingModel(chatApp.getChatModelConfig());
|
||||
final Long queryId = executeContext.getRequest().getQueryId();
|
||||
resultCache.put(queryId, new StringBuffer(tip));
|
||||
chatLanguageModel.generate(prompt.toUserMessage(),
|
||||
new StreamingResponseHandler<AiMessage>() {
|
||||
@Override
|
||||
public void onNext(String token) {
|
||||
resultCache.get(queryId).append(token);
|
||||
}
|
||||
|
||||
@Override
|
||||
public void onComplete(Response<AiMessage> response) {
|
||||
ChatQueryRepository chatQueryRepository =
|
||||
ContextUtils.getBean(ChatQueryRepository.class);
|
||||
ChatQueryDO chatQueryDO = chatQueryRepository.getChatQueryDO(queryId);
|
||||
JSONObject queryResult = JSON.parseObject(chatQueryDO.getQueryResult());
|
||||
queryResult.put("textSummary",
|
||||
resultCache.get(queryId).toString().substring(tip.length()));
|
||||
chatQueryDO.setQueryResult(queryResult.toJSONString());
|
||||
chatQueryRepository.updateChatQuery(chatQueryDO);
|
||||
resultCache.remove(queryId);
|
||||
}
|
||||
|
||||
@Override
|
||||
public void onError(Throwable error) {
|
||||
error.printStackTrace();
|
||||
resultCache.remove(queryId);
|
||||
}
|
||||
});
|
||||
} else {
|
||||
ChatLanguageModel chatLanguageModel =
|
||||
ModelProvider.getChatModel(chatApp.getChatModelConfig());
|
||||
Response<AiMessage> response = chatLanguageModel.generate(prompt.toUserMessage());
|
||||
String anwser = response.content().text();
|
||||
keyPipelineLog.info("DataInterpretProcessor modelReq:\n{} \nmodelResp:\n{}",
|
||||
prompt.text(), anwser);
|
||||
if (StringUtils.isNotBlank(anwser)) {
|
||||
queryResult.setTextSummary(anwser);
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
}
|
||||
}
|
||||
|
||||
@@ -5,6 +5,7 @@ import com.tencent.supersonic.chat.api.pojo.request.ChatExecuteReq;
|
||||
import com.tencent.supersonic.chat.api.pojo.request.ChatParseReq;
|
||||
import com.tencent.supersonic.chat.api.pojo.request.ChatQueryDataReq;
|
||||
import com.tencent.supersonic.chat.api.pojo.response.ChatParseResp;
|
||||
import com.tencent.supersonic.chat.api.pojo.response.QueryResult;
|
||||
import com.tencent.supersonic.chat.server.service.ChatQueryService;
|
||||
import com.tencent.supersonic.common.pojo.User;
|
||||
import com.tencent.supersonic.common.pojo.exception.InvalidArgumentException;
|
||||
@@ -50,6 +51,14 @@ public class ChatQueryController {
|
||||
return chatQueryService.execute(chatExecuteReq);
|
||||
}
|
||||
|
||||
@PostMapping("getExecuteSummary")
|
||||
public Object getExecuteSummary(@RequestBody ChatExecuteReq chatExecuteReq,
|
||||
HttpServletRequest request, HttpServletResponse response) {
|
||||
chatExecuteReq.setUser(UserHolder.findUser(request, response));
|
||||
QueryResult res = chatQueryService.getTextSummary(chatExecuteReq);
|
||||
return res;
|
||||
}
|
||||
|
||||
@PostMapping("/")
|
||||
public Object query(@RequestBody ChatParseReq chatParseReq, HttpServletRequest request,
|
||||
HttpServletResponse response) throws Exception {
|
||||
|
||||
@@ -4,12 +4,12 @@ import com.github.pagehelper.PageInfo;
|
||||
import com.tencent.supersonic.auth.api.authentication.utils.UserHolder;
|
||||
import com.tencent.supersonic.chat.api.pojo.enums.MemoryReviewResult;
|
||||
import com.tencent.supersonic.chat.api.pojo.request.ChatMemoryCreateReq;
|
||||
import com.tencent.supersonic.chat.api.pojo.request.ChatMemoryDeleteReq;
|
||||
import com.tencent.supersonic.chat.api.pojo.request.ChatMemoryUpdateReq;
|
||||
import com.tencent.supersonic.chat.api.pojo.request.PageMemoryReq;
|
||||
import com.tencent.supersonic.chat.server.pojo.ChatMemory;
|
||||
import com.tencent.supersonic.chat.server.service.MemoryService;
|
||||
import com.tencent.supersonic.common.pojo.User;
|
||||
import com.tencent.supersonic.headless.api.pojo.request.MetaBatchReq;
|
||||
import jakarta.servlet.http.HttpServletRequest;
|
||||
import jakarta.servlet.http.HttpServletResponse;
|
||||
import org.springframework.beans.factory.annotation.Autowired;
|
||||
@@ -53,8 +53,10 @@ public class MemoryController {
|
||||
}
|
||||
|
||||
@PostMapping("batchDelete")
|
||||
public Boolean batchDelete(@RequestBody MetaBatchReq metaBatchReq) {
|
||||
memoryService.batchDelete(metaBatchReq.getIds());
|
||||
public Boolean deleteMemory(@RequestBody ChatMemoryDeleteReq chatMemoryDeleteReq,
|
||||
HttpServletRequest request, HttpServletResponse response) {
|
||||
User user = UserHolder.findUser(request, response);
|
||||
memoryService.batchDelete(chatMemoryDeleteReq, user);
|
||||
return true;
|
||||
}
|
||||
}
|
||||
|
||||
@@ -35,6 +35,8 @@ public interface ChatManageService {
|
||||
|
||||
QueryResp getChatQuery(Long queryId);
|
||||
|
||||
ChatQueryDO getChatQueryDO(Long queryId);
|
||||
|
||||
List<QueryResp> getChatQueries(Integer chatId);
|
||||
|
||||
ShowCaseResp queryShowCase(PageQueryInfoReq pageQueryInfoReq, int agentId);
|
||||
|
||||
@@ -19,6 +19,8 @@ public interface ChatQueryService {
|
||||
|
||||
QueryResult execute(ChatExecuteReq chatExecuteReq) throws Exception;
|
||||
|
||||
QueryResult getTextSummary(ChatExecuteReq chatExecuteReq);
|
||||
|
||||
QueryResult parseAndExecute(ChatParseReq chatParseReq);
|
||||
|
||||
Object queryData(ChatQueryDataReq chatQueryDataReq, User user) throws Exception;
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
package com.tencent.supersonic.chat.server.service;
|
||||
|
||||
import com.github.pagehelper.PageInfo;
|
||||
import com.tencent.supersonic.chat.api.pojo.request.ChatMemoryDeleteReq;
|
||||
import com.tencent.supersonic.chat.api.pojo.request.ChatMemoryFilter;
|
||||
import com.tencent.supersonic.chat.api.pojo.request.ChatMemoryUpdateReq;
|
||||
import com.tencent.supersonic.chat.api.pojo.request.PageMemoryReq;
|
||||
@@ -14,7 +15,7 @@ public interface MemoryService {
|
||||
|
||||
void updateMemory(ChatMemoryUpdateReq chatMemoryUpdateReq, User user);
|
||||
|
||||
void batchDelete(List<Long> ids);
|
||||
void batchDelete(ChatMemoryDeleteReq chatMemoryDeleteReq, User user);
|
||||
|
||||
PageInfo<ChatMemory> pageMemories(PageMemoryReq pageMemoryReq);
|
||||
|
||||
|
||||
@@ -22,6 +22,7 @@ import lombok.extern.slf4j.Slf4j;
|
||||
import org.springframework.beans.BeanUtils;
|
||||
import org.springframework.beans.factory.annotation.Autowired;
|
||||
import org.springframework.beans.factory.annotation.Qualifier;
|
||||
import org.springframework.context.annotation.Lazy;
|
||||
import org.springframework.stereotype.Service;
|
||||
import org.springframework.util.CollectionUtils;
|
||||
|
||||
@@ -39,6 +40,7 @@ public class AgentServiceImpl extends ServiceImpl<AgentDOMapper, AgentDO> implem
|
||||
private MemoryService memoryService;
|
||||
|
||||
@Autowired
|
||||
@Lazy
|
||||
private ChatQueryService chatQueryService;
|
||||
|
||||
@Autowired
|
||||
|
||||
@@ -123,6 +123,11 @@ public class ChatManageServiceImpl implements ChatManageService {
|
||||
return chatQueryRepository.getChatQuery(queryId);
|
||||
}
|
||||
|
||||
@Override
|
||||
public ChatQueryDO getChatQueryDO(Long queryId) {
|
||||
return chatQueryRepository.getChatQueryDO(queryId);
|
||||
}
|
||||
|
||||
@Override
|
||||
public List<QueryResp> getChatQueries(Integer chatId) {
|
||||
List<QueryResp> queries = chatQueryRepository.getChatQueries(chatId);
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
package com.tencent.supersonic.chat.server.service.impl;
|
||||
|
||||
import com.alibaba.fastjson2.JSON;
|
||||
import com.google.common.collect.Lists;
|
||||
import com.tencent.supersonic.chat.api.pojo.request.ChatExecuteReq;
|
||||
import com.tencent.supersonic.chat.api.pojo.request.ChatParseReq;
|
||||
@@ -9,8 +10,10 @@ import com.tencent.supersonic.chat.api.pojo.response.QueryResult;
|
||||
import com.tencent.supersonic.chat.server.agent.Agent;
|
||||
import com.tencent.supersonic.chat.server.executor.ChatQueryExecutor;
|
||||
import com.tencent.supersonic.chat.server.parser.ChatQueryParser;
|
||||
import com.tencent.supersonic.chat.server.persistence.dataobject.ChatQueryDO;
|
||||
import com.tencent.supersonic.chat.server.pojo.ExecuteContext;
|
||||
import com.tencent.supersonic.chat.server.pojo.ParseContext;
|
||||
import com.tencent.supersonic.chat.server.processor.execute.DataInterpretProcessor;
|
||||
import com.tencent.supersonic.chat.server.processor.execute.ExecuteResultProcessor;
|
||||
import com.tencent.supersonic.chat.server.processor.parse.ParseResultProcessor;
|
||||
import com.tencent.supersonic.chat.server.service.AgentService;
|
||||
@@ -18,7 +21,11 @@ import com.tencent.supersonic.chat.server.service.ChatManageService;
|
||||
import com.tencent.supersonic.chat.server.service.ChatQueryService;
|
||||
import com.tencent.supersonic.chat.server.util.ComponentFactory;
|
||||
import com.tencent.supersonic.chat.server.util.QueryReqConverter;
|
||||
import com.tencent.supersonic.common.jsqlparser.*;
|
||||
import com.tencent.supersonic.common.jsqlparser.FieldExpression;
|
||||
import com.tencent.supersonic.common.jsqlparser.SqlAddHelper;
|
||||
import com.tencent.supersonic.common.jsqlparser.SqlRemoveHelper;
|
||||
import com.tencent.supersonic.common.jsqlparser.SqlReplaceHelper;
|
||||
import com.tencent.supersonic.common.jsqlparser.SqlSelectHelper;
|
||||
import com.tencent.supersonic.common.pojo.User;
|
||||
import com.tencent.supersonic.common.pojo.enums.FilterOperatorEnum;
|
||||
import com.tencent.supersonic.common.util.DateUtils;
|
||||
@@ -44,15 +51,27 @@ import lombok.extern.slf4j.Slf4j;
|
||||
import net.sf.jsqlparser.expression.Expression;
|
||||
import net.sf.jsqlparser.expression.LongValue;
|
||||
import net.sf.jsqlparser.expression.StringValue;
|
||||
import net.sf.jsqlparser.expression.operators.relational.*;
|
||||
import net.sf.jsqlparser.expression.operators.relational.ComparisonOperator;
|
||||
import net.sf.jsqlparser.expression.operators.relational.GreaterThanEquals;
|
||||
import net.sf.jsqlparser.expression.operators.relational.InExpression;
|
||||
import net.sf.jsqlparser.expression.operators.relational.MinorThanEquals;
|
||||
import net.sf.jsqlparser.expression.operators.relational.ParenthesedExpressionList;
|
||||
import net.sf.jsqlparser.schema.Column;
|
||||
import org.apache.commons.lang3.StringUtils;
|
||||
import org.apache.commons.lang3.tuple.Pair;
|
||||
import org.springframework.beans.factory.annotation.Autowired;
|
||||
import org.springframework.context.annotation.Lazy;
|
||||
import org.springframework.stereotype.Service;
|
||||
import org.springframework.util.CollectionUtils;
|
||||
|
||||
import java.util.*;
|
||||
import java.util.ArrayList;
|
||||
import java.util.HashMap;
|
||||
import java.util.HashSet;
|
||||
import java.util.Iterator;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
import java.util.Objects;
|
||||
import java.util.Set;
|
||||
import java.util.stream.Collectors;
|
||||
|
||||
@Slf4j
|
||||
@@ -66,6 +85,7 @@ public class ChatQueryServiceImpl implements ChatQueryService {
|
||||
@Autowired
|
||||
private SemanticLayerService semanticLayerService;
|
||||
@Autowired
|
||||
@Lazy
|
||||
private AgentService agentService;
|
||||
|
||||
private final List<ChatQueryParser> chatQueryParsers = ComponentFactory.getChatParsers();
|
||||
@@ -108,6 +128,8 @@ public class ChatQueryServiceImpl implements ChatQueryService {
|
||||
}
|
||||
|
||||
if (!parseContext.needFeedback()) {
|
||||
parseContext.getResponse().getParseTimeCost().setParseTime(System.currentTimeMillis()
|
||||
- parseContext.getResponse().getParseTimeCost().getParseStartTime());
|
||||
chatManageService.batchAddParse(chatParseReq, parseContext.getResponse());
|
||||
chatManageService.updateParseCostTime(parseContext.getResponse());
|
||||
}
|
||||
@@ -141,6 +163,21 @@ public class ChatQueryServiceImpl implements ChatQueryService {
|
||||
return queryResult;
|
||||
}
|
||||
|
||||
@Override
|
||||
public QueryResult getTextSummary(ChatExecuteReq chatExecuteReq) {
|
||||
String text = DataInterpretProcessor.getTextSummary(chatExecuteReq.getQueryId());
|
||||
if (StringUtils.isNotBlank(text)) {
|
||||
QueryResult res = new QueryResult();
|
||||
res.setTextSummary(text);
|
||||
res.setQueryId(chatExecuteReq.getQueryId());
|
||||
return res;
|
||||
} else {
|
||||
ChatQueryDO chatQueryDo = chatManageService.getChatQueryDO(chatExecuteReq.getQueryId());
|
||||
QueryResult res = JSON.parseObject(chatQueryDo.getQueryResult(), QueryResult.class);
|
||||
return res;
|
||||
}
|
||||
}
|
||||
|
||||
@Override
|
||||
public QueryResult parseAndExecute(ChatParseReq chatParseReq) {
|
||||
ChatParseResp parseResp = parse(chatParseReq);
|
||||
|
||||
@@ -6,6 +6,7 @@ import com.github.pagehelper.PageHelper;
|
||||
import com.github.pagehelper.PageInfo;
|
||||
import com.tencent.supersonic.chat.api.pojo.enums.MemoryReviewResult;
|
||||
import com.tencent.supersonic.chat.api.pojo.enums.MemoryStatus;
|
||||
import com.tencent.supersonic.chat.api.pojo.request.ChatMemoryDeleteReq;
|
||||
import com.tencent.supersonic.chat.api.pojo.request.ChatMemoryFilter;
|
||||
import com.tencent.supersonic.chat.api.pojo.request.ChatMemoryUpdateReq;
|
||||
import com.tencent.supersonic.chat.api.pojo.request.PageMemoryReq;
|
||||
@@ -26,7 +27,7 @@ import org.springframework.boot.CommandLineRunner;
|
||||
import org.springframework.stereotype.Service;
|
||||
import org.springframework.util.CollectionUtils;
|
||||
|
||||
import java.util.Arrays;
|
||||
import java.util.ArrayList;
|
||||
import java.util.Date;
|
||||
import java.util.List;
|
||||
import java.util.Objects;
|
||||
@@ -71,8 +72,9 @@ public class MemoryServiceImpl implements MemoryService, CommandLineRunner {
|
||||
chatMemoryDO.setS2sql(chatMemoryUpdateReq.getS2sql());
|
||||
chatMemoryDO.setDbSchema(chatMemoryUpdateReq.getDbSchema());
|
||||
enableMemory(chatMemoryDO);
|
||||
} else if ((MemoryStatus.DISABLED.equals(chatMemoryUpdateReq.getStatus())||MemoryStatus.PENDING.equals(chatMemoryUpdateReq.getStatus())) && hadEnabled) {
|
||||
// Remove from vector DB when transitioning: launched→disabled OR enabled→pending
|
||||
} else if ((MemoryStatus.DISABLED.equals(chatMemoryUpdateReq.getStatus())
|
||||
|| MemoryStatus.PENDING.equals(chatMemoryUpdateReq.getStatus())) && hadEnabled) {
|
||||
// Remove from vector DB when transitioning: launched→disabled OR enabled→pending
|
||||
disableMemory(chatMemoryDO);
|
||||
}
|
||||
LambdaUpdateWrapper<ChatMemoryDO> updateWrapper = new LambdaUpdateWrapper<>();
|
||||
@@ -108,7 +110,22 @@ public class MemoryServiceImpl implements MemoryService, CommandLineRunner {
|
||||
}
|
||||
|
||||
@Override
|
||||
public void batchDelete(List<Long> ids) {
|
||||
public void batchDelete(ChatMemoryDeleteReq chatMemoryDeleteReq, User user) {
|
||||
QueryWrapper<ChatMemoryDO> queryWrapper = new QueryWrapper<>();
|
||||
if (!CollectionUtils.isEmpty(chatMemoryDeleteReq.getIds())) {
|
||||
queryWrapper.lambda().in(ChatMemoryDO::getId, chatMemoryDeleteReq.getIds());
|
||||
}
|
||||
if (chatMemoryDeleteReq.getAgentId() != null) {
|
||||
queryWrapper.lambda().eq(ChatMemoryDO::getAgentId, chatMemoryDeleteReq.getAgentId());
|
||||
}
|
||||
List<ChatMemoryDO> chatMemoryDOS = chatMemoryRepository.getMemories(queryWrapper);
|
||||
List<Long> ids = new ArrayList<>();
|
||||
chatMemoryDOS.forEach(chatMemoryDO -> {
|
||||
if (MemoryStatus.ENABLED.toString().equals(chatMemoryDO.getStatus().trim())) {
|
||||
disableMemory(chatMemoryDO);
|
||||
}
|
||||
ids.add(chatMemoryDO.getId());
|
||||
});
|
||||
chatMemoryRepository.batchDelete(ids);
|
||||
}
|
||||
|
||||
|
||||
@@ -108,6 +108,7 @@ public class PluginServiceImpl implements PluginService {
|
||||
if (StringUtils.isNotBlank(pluginQueryReq.getCreatedBy())) {
|
||||
queryWrapper.lambda().eq(PluginDO::getCreatedBy, pluginQueryReq.getCreatedBy());
|
||||
}
|
||||
queryWrapper.orderByAsc("name");
|
||||
List<PluginDO> pluginDOS = pluginRepository.query(queryWrapper);
|
||||
if (StringUtils.isNotBlank(pluginQueryReq.getPattern())) {
|
||||
pluginDOS = pluginDOS.stream()
|
||||
|
||||
@@ -21,7 +21,8 @@ public class LoadRemoveService {
|
||||
List<String> resultList = new ArrayList<>(value);
|
||||
if (!CollectionUtils.isEmpty(modelIdOrDataSetIds)) {
|
||||
resultList.removeIf(nature -> {
|
||||
if (Objects.isNull(nature)) {
|
||||
if (Objects.isNull(nature) || !nature.startsWith("_")) { // 系统的字典是以 _ 开头的,
|
||||
// 过滤因引用外部字典导致的异常
|
||||
return false;
|
||||
}
|
||||
Long id = getId(nature);
|
||||
|
||||
@@ -46,6 +46,62 @@ public class FieldValueReplaceVisitor extends ExpressionVisitorAdapter {
|
||||
replaceComparisonExpression(expr);
|
||||
}
|
||||
|
||||
public void visit(LikeExpression expr) {
|
||||
Expression leftExpression = expr.getLeftExpression();
|
||||
Expression rightExpression = expr.getRightExpression();
|
||||
|
||||
if (!(leftExpression instanceof Column)) {
|
||||
return;
|
||||
}
|
||||
if (CollectionUtils.isEmpty(filedNameToValueMap)) {
|
||||
return;
|
||||
}
|
||||
if (Objects.isNull(rightExpression) || Objects.isNull(leftExpression)) {
|
||||
return;
|
||||
}
|
||||
Column column = (Column) leftExpression;
|
||||
String columnName = column.getColumnName();
|
||||
if (StringUtils.isEmpty(columnName)) {
|
||||
return;
|
||||
}
|
||||
Map<String, String> valueMap = filedNameToValueMap.get(columnName);
|
||||
if (Objects.isNull(valueMap) || valueMap.isEmpty()) {
|
||||
return;
|
||||
}
|
||||
if (rightExpression instanceof StringValue) {
|
||||
StringValue rightStringValue = (StringValue) rightExpression;
|
||||
String value = rightStringValue.getValue();
|
||||
|
||||
// 使用split处理方式,按通配符分割字符串,对每个片段进行转换
|
||||
String[] parts = value.split("%", -1);
|
||||
boolean changed = false;
|
||||
|
||||
// 处理每个部分
|
||||
for (int i = 0; i < parts.length; i++) {
|
||||
if (!parts[i].isEmpty()) {
|
||||
String replaceValue = getReplaceValue(valueMap, parts[i]);
|
||||
if (StringUtils.isNotEmpty(replaceValue) && !parts[i].equals(replaceValue)) {
|
||||
parts[i] = replaceValue;
|
||||
changed = true;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 如果有任何部分发生变化,则重新构建字符串
|
||||
if (changed) {
|
||||
StringBuilder sb = new StringBuilder();
|
||||
for (int i = 0; i < parts.length; i++) {
|
||||
sb.append(parts[i]);
|
||||
// 除了最后一个部分,其他部分后面都需要加上"%"
|
||||
if (i < parts.length - 1) {
|
||||
sb.append("%");
|
||||
}
|
||||
}
|
||||
rightStringValue.setValue(sb.toString());
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
public void visit(InExpression inExpression) {
|
||||
if (!(inExpression.getLeftExpression() instanceof Column)) {
|
||||
return;
|
||||
|
||||
@@ -16,16 +16,21 @@ import net.sf.jsqlparser.parser.CCJSqlParserUtil;
|
||||
import net.sf.jsqlparser.schema.Column;
|
||||
import org.apache.commons.collections.CollectionUtils;
|
||||
|
||||
import java.util.ArrayList;
|
||||
import java.util.List;
|
||||
import java.util.Objects;
|
||||
import java.util.Set;
|
||||
import java.util.*;
|
||||
|
||||
@Slf4j
|
||||
public class FiledFilterReplaceVisitor extends ExpressionVisitorAdapter {
|
||||
|
||||
private List<Expression> waitingForAdds = new ArrayList<>();
|
||||
private Set<String> fieldNames;
|
||||
private Map<String, String> fieldNameMap = new HashMap<>();
|
||||
|
||||
private static Set<String> HAVING_AGG_TYPES = Set.of("SUM", "AVG", "MAX", "MIN", "COUNT");
|
||||
|
||||
public FiledFilterReplaceVisitor(Map<String, String> fieldNameMap) {
|
||||
this.fieldNameMap = fieldNameMap;
|
||||
this.fieldNames = fieldNameMap.keySet();
|
||||
}
|
||||
|
||||
public FiledFilterReplaceVisitor(Set<String> fieldNames) {
|
||||
this.fieldNames = fieldNames;
|
||||
@@ -82,7 +87,22 @@ public class FiledFilterReplaceVisitor extends ExpressionVisitorAdapter {
|
||||
Expression leftExpression = comparisonOperator.getLeftExpression();
|
||||
|
||||
if (!(leftExpression instanceof Function)) {
|
||||
return result;
|
||||
if (leftExpression instanceof Column) {
|
||||
Column leftColumn = (Column) leftExpression;
|
||||
String agg = fieldNameMap.get(leftColumn.getColumnName());
|
||||
if (agg != null && HAVING_AGG_TYPES.contains(agg.toUpperCase())) {
|
||||
Expression expression = parseCondExpression(comparisonOperator, condExpr);
|
||||
if (Objects.nonNull(expression)) {
|
||||
result.add(expression);
|
||||
return result;
|
||||
} else {
|
||||
return null;
|
||||
}
|
||||
}
|
||||
return result;
|
||||
} else {
|
||||
return result;
|
||||
}
|
||||
}
|
||||
|
||||
Function leftFunction = (Function) leftExpression;
|
||||
@@ -102,14 +122,24 @@ public class FiledFilterReplaceVisitor extends ExpressionVisitorAdapter {
|
||||
return null;
|
||||
}
|
||||
|
||||
Expression expression = parseCondExpression(comparisonOperator, condExpr);
|
||||
if (Objects.nonNull(expression)) {
|
||||
result.add(expression);
|
||||
return result;
|
||||
} else {
|
||||
return null;
|
||||
}
|
||||
}
|
||||
|
||||
private Expression parseCondExpression(ComparisonOperator comparisonOperator, String condExpr) {
|
||||
try {
|
||||
String comparisonOperatorStr = comparisonOperator.toString();
|
||||
ComparisonOperator parsedExpression =
|
||||
(ComparisonOperator) CCJSqlParserUtil.parseCondExpression(condExpr);
|
||||
comparisonOperator.setLeftExpression(parsedExpression.getLeftExpression());
|
||||
comparisonOperator.setRightExpression(parsedExpression.getRightExpression());
|
||||
comparisonOperator.setASTNode(parsedExpression.getASTNode());
|
||||
result.add(CCJSqlParserUtil.parseCondExpression(comparisonOperatorStr));
|
||||
return result;
|
||||
return CCJSqlParserUtil.parseCondExpression(comparisonOperatorStr);
|
||||
} catch (JSQLParserException e) {
|
||||
log.error("JSQLParserException", e);
|
||||
}
|
||||
|
||||
@@ -309,7 +309,7 @@ public class SqlAddHelper {
|
||||
}
|
||||
}
|
||||
|
||||
public static String addHaving(String sql, Set<String> fieldNames) {
|
||||
public static String addHaving(String sql, Map<String, String> fieldNames) {
|
||||
Select selectStatement = SqlSelectHelper.getSelect(sql);
|
||||
|
||||
if (!(selectStatement instanceof PlainSelect)) {
|
||||
|
||||
@@ -727,7 +727,7 @@ public class SqlReplaceHelper {
|
||||
List<PlainSelect> plainSelects = SqlSelectHelper.getPlainSelects(plainSelectList);
|
||||
for (PlainSelect plainSelect : plainSelects) {
|
||||
if (Objects.nonNull(plainSelect.getFromItem())) {
|
||||
Table table = (Table) plainSelect.getFromItem();
|
||||
Table table = SqlSelectHelper.getTable(plainSelect.getFromItem());
|
||||
if (table.getName().equals(tableName)) {
|
||||
replacePlainSelectByExpr(plainSelect, replace);
|
||||
if (SqlSelectHelper.hasAggregateFunction(plainSelect)) {
|
||||
|
||||
@@ -723,6 +723,44 @@ public class SqlSelectHelper {
|
||||
return null;
|
||||
}
|
||||
|
||||
public static Table getTable(FromItem fromItem) {
|
||||
Table table = null;
|
||||
if (fromItem instanceof Table) {
|
||||
table = (Table) fromItem;
|
||||
} else if (fromItem instanceof ParenthesedSelect) {
|
||||
ParenthesedSelect parenthesedSelect = (ParenthesedSelect) fromItem;
|
||||
if (parenthesedSelect.getSelect() instanceof PlainSelect) {
|
||||
PlainSelect subSelect = (PlainSelect) parenthesedSelect.getSelect();
|
||||
table = getTable(subSelect.getSelectBody());
|
||||
} else if (parenthesedSelect.getSelect() instanceof SetOperationList) {
|
||||
table = getTable(parenthesedSelect.getSelect());
|
||||
}
|
||||
}
|
||||
return table;
|
||||
}
|
||||
|
||||
public static Table getTable(Select select) {
|
||||
if (select == null) {
|
||||
return null;
|
||||
}
|
||||
List<PlainSelect> plainSelectList = getWithItem(select);
|
||||
if (!CollectionUtils.isEmpty(plainSelectList)) {
|
||||
List<PlainSelect> selectList = new ArrayList<>(plainSelectList);
|
||||
Table table = getTable(selectList.get(0));
|
||||
return table;
|
||||
}
|
||||
if (select instanceof PlainSelect) {
|
||||
PlainSelect plainSelect = (PlainSelect) select;
|
||||
return getTable(plainSelect.getFromItem());
|
||||
} else if (select instanceof SetOperationList) {
|
||||
SetOperationList setOperationList = (SetOperationList) select;
|
||||
if (!CollectionUtils.isEmpty(setOperationList.getSelects())) {
|
||||
return getTable(setOperationList.getSelects().get(0));
|
||||
}
|
||||
}
|
||||
return null;
|
||||
}
|
||||
|
||||
public static String getDbTableName(String sql) {
|
||||
Table table = getTable(sql);
|
||||
return table.getFullyQualifiedName();
|
||||
|
||||
@@ -28,6 +28,8 @@ public class ChatModelConfig implements Serializable {
|
||||
private Boolean logRequests = false;
|
||||
private Boolean logResponses = false;
|
||||
private Boolean enableSearch = false;
|
||||
private Boolean jsonFormat = false;
|
||||
private String jsonFormatType = "json_schema";
|
||||
|
||||
public String keyDecrypt() {
|
||||
return AESEncryptionUtil.aesDecryptECB(getApiKey());
|
||||
|
||||
@@ -1,27 +1,26 @@
|
||||
package com.tencent.supersonic.common.pojo;
|
||||
|
||||
import com.tencent.supersonic.common.pojo.enums.EventType;
|
||||
import lombok.Getter;
|
||||
import org.springframework.context.ApplicationEvent;
|
||||
|
||||
import java.util.List;
|
||||
|
||||
@Getter
|
||||
public class DataEvent extends ApplicationEvent {
|
||||
|
||||
private List<DataItem> dataItems;
|
||||
private final List<DataItem> dataItems;
|
||||
|
||||
private EventType eventType;
|
||||
private final EventType eventType;
|
||||
|
||||
public DataEvent(Object source, List<DataItem> dataItems, EventType eventType) {
|
||||
private final String userName;
|
||||
|
||||
public DataEvent(Object source, List<DataItem> dataItems, EventType eventType,
|
||||
String userName) {
|
||||
super(source);
|
||||
this.dataItems = dataItems;
|
||||
this.eventType = eventType;
|
||||
this.userName = userName;
|
||||
}
|
||||
|
||||
public List<DataItem> getDataItems() {
|
||||
return dataItems;
|
||||
}
|
||||
|
||||
public EventType getEventType() {
|
||||
return eventType;
|
||||
}
|
||||
}
|
||||
|
||||
@@ -5,4 +5,6 @@ public class DimensionConstants {
|
||||
public static final String DIMENSION_TIME_FORMAT = "time_format";
|
||||
|
||||
public static final String DIMENSION_TYPE = "dimension_type";
|
||||
|
||||
public static final String DIMENSION_DATA_TYPE = "dimension_data_type";
|
||||
}
|
||||
|
||||
@@ -22,4 +22,6 @@ public class Text2SQLExemplar implements Serializable {
|
||||
private String dbSchema;
|
||||
|
||||
private String sql;
|
||||
|
||||
protected double similarity; // 传递相似度,可以作为样本筛选的依据
|
||||
}
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
package com.tencent.supersonic.common.pojo.enums;
|
||||
|
||||
public enum TypeEnums {
|
||||
METRIC, DIMENSION, TAG, DOMAIN, DATASET, MODEL, UNKNOWN
|
||||
METRIC, DIMENSION, VALUE, TAG, DOMAIN, DATASET, MODEL, UNKNOWN
|
||||
}
|
||||
|
||||
@@ -15,5 +15,5 @@ public interface ChatModelService {
|
||||
|
||||
ChatModel updateChatModel(ChatModel chatModel, User user);
|
||||
|
||||
void deleteChatModel(Integer id);
|
||||
void deleteChatModel(Integer id, User user);
|
||||
}
|
||||
|
||||
@@ -79,7 +79,12 @@ public class ChatModelServiceImpl extends ServiceImpl<ChatModelMapper, ChatModel
|
||||
}
|
||||
|
||||
@Override
|
||||
public void deleteChatModel(Integer id) {
|
||||
public void deleteChatModel(Integer id, User user) {
|
||||
ChatModel chatModel = getChatModel(id);
|
||||
if (!checkAdminPermission(user, chatModel)) {
|
||||
throw new RuntimeException("没有权限删除该大模型");
|
||||
}
|
||||
|
||||
removeById(id);
|
||||
}
|
||||
|
||||
@@ -103,4 +108,13 @@ public class ChatModelServiceImpl extends ServiceImpl<ChatModelMapper, ChatModel
|
||||
chatModelDO.setConfig(JsonUtil.toString(chatModel.getConfig()));
|
||||
return chatModelDO;
|
||||
}
|
||||
|
||||
private boolean checkAdminPermission(User user, ChatModel chatModel) {
|
||||
String admin = chatModel.getAdmin();
|
||||
if (user.isSuperAdmin()) {
|
||||
return true;
|
||||
}
|
||||
return admin != null && admin.equals(user.getName())
|
||||
|| chatModel.getCreatedBy().equals(user.getName());
|
||||
}
|
||||
}
|
||||
|
||||
@@ -72,7 +72,10 @@ public class ExemplarServiceImpl implements ExemplarService, CommandLineRunner {
|
||||
embeddingService.retrieveQuery(collection, retrieveQuery, num);
|
||||
results.forEach(ret -> {
|
||||
ret.getRetrieval().forEach(r -> {
|
||||
exemplars.add(JsonUtil.mapToObject(r.getMetadata(), Text2SQLExemplar.class));
|
||||
Text2SQLExemplar tmp = // 传递相似度,可以作为样本筛选的依据
|
||||
JsonUtil.mapToObject(r.getMetadata(), Text2SQLExemplar.class);
|
||||
tmp.setSimilarity(r.getSimilarity());
|
||||
exemplars.add(tmp);
|
||||
});
|
||||
});
|
||||
|
||||
|
||||
@@ -16,6 +16,7 @@ import org.springframework.util.CollectionUtils;
|
||||
|
||||
import java.util.List;
|
||||
import java.util.concurrent.atomic.AtomicReference;
|
||||
import java.util.stream.Collectors;
|
||||
|
||||
@Service
|
||||
public class SystemConfigServiceImpl extends ServiceImpl<SystemConfigMapper, SystemConfigDO>
|
||||
@@ -38,8 +39,8 @@ public class SystemConfigServiceImpl extends ServiceImpl<SystemConfigMapper, Sys
|
||||
return systemConfigDb;
|
||||
}
|
||||
|
||||
private SystemConfig getSystemConfigFromDB() {
|
||||
List<SystemConfigDO> list = list();
|
||||
private SystemConfig getSystemConfigFromDB() { // 加上id ,如果有多条记录,会出错
|
||||
List<SystemConfigDO> list = this.lambdaQuery().eq(SystemConfigDO::getId, 1).list();
|
||||
if (CollectionUtils.isEmpty(list)) {
|
||||
SystemConfig systemConfig = new SystemConfig();
|
||||
systemConfig.setId(1);
|
||||
|
||||
@@ -7,6 +7,7 @@ import dev.langchain4j.model.chat.ChatLanguageModel;
|
||||
import dev.langchain4j.model.dify.DifyAiChatModel;
|
||||
import dev.langchain4j.model.embedding.EmbeddingModel;
|
||||
import dev.langchain4j.model.openai.OpenAiEmbeddingModel;
|
||||
import dev.langchain4j.model.openai.OpenAiStreamingChatModel;
|
||||
import org.springframework.beans.factory.InitializingBean;
|
||||
import org.springframework.stereotype.Service;
|
||||
|
||||
@@ -25,6 +26,11 @@ public class DifyModelFactory implements ModelFactory, InitializingBean {
|
||||
.modelName(modelConfig.getModelName()).timeOut(modelConfig.getTimeOut()).build();
|
||||
}
|
||||
|
||||
@Override
|
||||
public OpenAiStreamingChatModel createChatStreamingModel(ChatModelConfig modelConfig) {
|
||||
throw new RuntimeException("待开发");
|
||||
}
|
||||
|
||||
@Override
|
||||
public EmbeddingModel createEmbeddingModel(EmbeddingModelConfig embeddingModelConfig) {
|
||||
return OpenAiEmbeddingModel.builder().baseUrl(embeddingModelConfig.getBaseUrl())
|
||||
|
||||
@@ -5,6 +5,7 @@ import com.tencent.supersonic.common.pojo.EmbeddingModelConfig;
|
||||
import dev.langchain4j.model.chat.ChatLanguageModel;
|
||||
import dev.langchain4j.model.embedding.EmbeddingModel;
|
||||
import dev.langchain4j.model.embedding.S2OnnxEmbeddingModel;
|
||||
import dev.langchain4j.model.openai.OpenAiStreamingChatModel;
|
||||
import org.apache.commons.lang3.StringUtils;
|
||||
import org.springframework.beans.factory.InitializingBean;
|
||||
import org.springframework.stereotype.Service;
|
||||
@@ -35,6 +36,11 @@ public class InMemoryModelFactory implements ModelFactory, InitializingBean {
|
||||
return EmbeddingModelConstant.BGE_SMALL_ZH_MODEL;
|
||||
}
|
||||
|
||||
@Override
|
||||
public OpenAiStreamingChatModel createChatStreamingModel(ChatModelConfig modelConfig) {
|
||||
throw new RuntimeException("待开发");
|
||||
}
|
||||
|
||||
@Override
|
||||
public void afterPropertiesSet() {
|
||||
ModelProvider.add(PROVIDER, this);
|
||||
|
||||
@@ -6,6 +6,7 @@ import dev.langchain4j.model.chat.ChatLanguageModel;
|
||||
import dev.langchain4j.model.embedding.EmbeddingModel;
|
||||
import dev.langchain4j.model.localai.LocalAiChatModel;
|
||||
import dev.langchain4j.model.localai.LocalAiEmbeddingModel;
|
||||
import dev.langchain4j.model.openai.OpenAiStreamingChatModel;
|
||||
import org.springframework.beans.factory.InitializingBean;
|
||||
import org.springframework.stereotype.Service;
|
||||
|
||||
@@ -27,6 +28,11 @@ public class LocalAiModelFactory implements ModelFactory, InitializingBean {
|
||||
.build();
|
||||
}
|
||||
|
||||
@Override
|
||||
public OpenAiStreamingChatModel createChatStreamingModel(ChatModelConfig modelConfig) {
|
||||
throw new RuntimeException("待开发");
|
||||
}
|
||||
|
||||
@Override
|
||||
public EmbeddingModel createEmbeddingModel(EmbeddingModelConfig embeddingModel) {
|
||||
return LocalAiEmbeddingModel.builder().baseUrl(embeddingModel.getBaseUrl())
|
||||
|
||||
@@ -4,9 +4,12 @@ import com.tencent.supersonic.common.pojo.ChatModelConfig;
|
||||
import com.tencent.supersonic.common.pojo.EmbeddingModelConfig;
|
||||
import dev.langchain4j.model.chat.ChatLanguageModel;
|
||||
import dev.langchain4j.model.embedding.EmbeddingModel;
|
||||
import dev.langchain4j.model.openai.OpenAiStreamingChatModel;
|
||||
|
||||
public interface ModelFactory {
|
||||
ChatLanguageModel createChatModel(ChatModelConfig modelConfig);
|
||||
|
||||
OpenAiStreamingChatModel createChatStreamingModel(ChatModelConfig modelConfig);
|
||||
|
||||
EmbeddingModel createEmbeddingModel(EmbeddingModelConfig embeddingModel);
|
||||
}
|
||||
|
||||
@@ -5,7 +5,9 @@ import com.tencent.supersonic.common.pojo.ChatModelConfig;
|
||||
import com.tencent.supersonic.common.pojo.EmbeddingModelConfig;
|
||||
import com.tencent.supersonic.common.util.ContextUtils;
|
||||
import dev.langchain4j.model.chat.ChatLanguageModel;
|
||||
import dev.langchain4j.model.chat.StreamingChatLanguageModel;
|
||||
import dev.langchain4j.model.embedding.EmbeddingModel;
|
||||
import dev.langchain4j.model.openai.OpenAiStreamingChatModel;
|
||||
import org.apache.commons.lang3.StringUtils;
|
||||
|
||||
import java.util.HashMap;
|
||||
@@ -41,6 +43,20 @@ public class ModelProvider {
|
||||
"Unsupported ChatLanguageModel provider: " + modelConfig.getProvider());
|
||||
}
|
||||
|
||||
public static StreamingChatLanguageModel getChatStreamingModel(ChatModelConfig modelConfig) {
|
||||
if (modelConfig == null || StringUtils.isBlank(modelConfig.getProvider())
|
||||
|| StringUtils.isBlank(modelConfig.getBaseUrl())) {
|
||||
modelConfig = DEMO_CHAT_MODEL;
|
||||
}
|
||||
ModelFactory modelFactory = factories.get(modelConfig.getProvider().toUpperCase());
|
||||
if (modelFactory != null) {
|
||||
return modelFactory.createChatStreamingModel(modelConfig);
|
||||
}
|
||||
|
||||
throw new RuntimeException(
|
||||
"Unsupported ChatLanguageModel provider: " + modelConfig.getProvider());
|
||||
}
|
||||
|
||||
public static EmbeddingModel getEmbeddingModel() {
|
||||
return getEmbeddingModel(null);
|
||||
}
|
||||
|
||||
@@ -6,6 +6,7 @@ import dev.langchain4j.model.chat.ChatLanguageModel;
|
||||
import dev.langchain4j.model.embedding.EmbeddingModel;
|
||||
import dev.langchain4j.model.ollama.OllamaChatModel;
|
||||
import dev.langchain4j.model.ollama.OllamaEmbeddingModel;
|
||||
import dev.langchain4j.model.openai.OpenAiStreamingChatModel;
|
||||
import org.springframework.beans.factory.InitializingBean;
|
||||
import org.springframework.stereotype.Service;
|
||||
|
||||
@@ -28,6 +29,11 @@ public class OllamaModelFactory implements ModelFactory, InitializingBean {
|
||||
.logResponses(modelConfig.getLogResponses()).build();
|
||||
}
|
||||
|
||||
@Override
|
||||
public OpenAiStreamingChatModel createChatStreamingModel(ChatModelConfig modelConfig) {
|
||||
throw new RuntimeException("待开发");
|
||||
}
|
||||
|
||||
@Override
|
||||
public EmbeddingModel createEmbeddingModel(EmbeddingModelConfig embeddingModelConfig) {
|
||||
return OllamaEmbeddingModel.builder().baseUrl(embeddingModelConfig.getBaseUrl())
|
||||
|
||||
@@ -6,6 +6,7 @@ import dev.langchain4j.model.chat.ChatLanguageModel;
|
||||
import dev.langchain4j.model.embedding.EmbeddingModel;
|
||||
import dev.langchain4j.model.openai.OpenAiChatModel;
|
||||
import dev.langchain4j.model.openai.OpenAiEmbeddingModel;
|
||||
import dev.langchain4j.model.openai.OpenAiStreamingChatModel;
|
||||
import org.springframework.beans.factory.InitializingBean;
|
||||
import org.springframework.stereotype.Service;
|
||||
|
||||
@@ -22,10 +23,26 @@ public class OpenAiModelFactory implements ModelFactory, InitializingBean {
|
||||
|
||||
@Override
|
||||
public ChatLanguageModel createChatModel(ChatModelConfig modelConfig) {
|
||||
return OpenAiChatModel.builder().baseUrl(modelConfig.getBaseUrl())
|
||||
OpenAiChatModel.OpenAiChatModelBuilder openAiChatModelBuilder = OpenAiChatModel.builder()
|
||||
.baseUrl(modelConfig.getBaseUrl()).modelName(modelConfig.getModelName())
|
||||
.apiKey(modelConfig.keyDecrypt()).apiVersion(modelConfig.getApiVersion())
|
||||
.temperature(modelConfig.getTemperature()).topP(modelConfig.getTopP())
|
||||
.maxRetries(modelConfig.getMaxRetries())
|
||||
.timeout(Duration.ofSeconds(modelConfig.getTimeOut()))
|
||||
.logRequests(modelConfig.getLogRequests())
|
||||
.logResponses(modelConfig.getLogResponses());
|
||||
if (modelConfig.getJsonFormat() != null && modelConfig.getJsonFormat()) {
|
||||
openAiChatModelBuilder.strictJsonSchema(true)
|
||||
.responseFormat(modelConfig.getJsonFormatType());
|
||||
}
|
||||
return openAiChatModelBuilder.build();
|
||||
}
|
||||
|
||||
@Override
|
||||
public OpenAiStreamingChatModel createChatStreamingModel(ChatModelConfig modelConfig) {
|
||||
return OpenAiStreamingChatModel.builder().baseUrl(modelConfig.getBaseUrl())
|
||||
.modelName(modelConfig.getModelName()).apiKey(modelConfig.keyDecrypt())
|
||||
.apiVersion(modelConfig.getApiVersion()).temperature(modelConfig.getTemperature())
|
||||
.topP(modelConfig.getTopP()).maxRetries(modelConfig.getMaxRetries())
|
||||
.temperature(modelConfig.getTemperature()).topP(modelConfig.getTopP())
|
||||
.timeout(Duration.ofSeconds(modelConfig.getTimeOut()))
|
||||
.logRequests(modelConfig.getLogRequests())
|
||||
.logResponses(modelConfig.getLogResponses()).build();
|
||||
|
||||
@@ -42,6 +42,6 @@ public class TextSegmentConvert {
|
||||
if (Objects.isNull(textSegment) || Objects.isNull(textSegment.metadata())) {
|
||||
return null;
|
||||
}
|
||||
return textSegment.metadata().get(QUERY_ID);
|
||||
return textSegment.metadata().getString(QUERY_ID);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -338,8 +338,8 @@ class SqlAddHelperTest {
|
||||
List<String> groupByFields = new ArrayList<>();
|
||||
groupByFields.add("department");
|
||||
|
||||
Set<String> fieldNames = new HashSet<>();
|
||||
fieldNames.add("pv");
|
||||
Map<String, String> fieldNames = new HashMap<>();
|
||||
fieldNames.put("pv", "sum");
|
||||
|
||||
String replaceSql = SqlAddHelper.addHaving(sql, fieldNames);
|
||||
|
||||
@@ -355,6 +355,14 @@ class SqlAddHelperTest {
|
||||
Assert.assertEquals("SELECT department, sum(pv) FROM t_1 WHERE sys_imp_date = '2023-09-11' "
|
||||
+ "GROUP BY department HAVING sum(pv) > 2000 ORDER BY sum(pv) DESC LIMIT 10",
|
||||
replaceSql);
|
||||
|
||||
sql = "SELECT 数据日期,访问用户数 FROM 超音数数据集 WHERE 访问次数 > 10 GROUP BY 数据日期";
|
||||
|
||||
fieldNames.put("访问次数", "sum");
|
||||
replaceSql = SqlAddHelper.addHaving(sql, fieldNames);
|
||||
|
||||
Assert.assertEquals("SELECT 数据日期, 访问用户数 FROM 超音数数据集 GROUP BY 数据日期 HAVING 访问次数 > 10",
|
||||
replaceSql);
|
||||
}
|
||||
|
||||
@Test
|
||||
|
||||
@@ -18,6 +18,7 @@ import java.util.List;
|
||||
import java.util.Map;
|
||||
import java.util.Objects;
|
||||
import java.util.Set;
|
||||
import java.util.stream.Collectors;
|
||||
|
||||
import static com.tencent.supersonic.common.pojo.Constants.DEFAULT_DETAIL_LIMIT;
|
||||
import static com.tencent.supersonic.common.pojo.Constants.DEFAULT_METRIC_LIMIT;
|
||||
@@ -65,12 +66,23 @@ public class SemanticParseInfo implements Serializable {
|
||||
DataSetMatchResult mr2 = getDataSetMatchResult(o2.getElementMatches());
|
||||
|
||||
double difference = mr1.getMaxDatesetSimilarity() - mr2.getMaxDatesetSimilarity();
|
||||
if (difference == 0) {
|
||||
if (Math.abs(difference) < 0.0005) { // 看完全匹配的个数,实践证明,可以用户输入规范后,该逻辑具有优势
|
||||
if (!o1.getDataSetId().equals(o2.getDataSetId())) {
|
||||
List<SchemaElementMatch> elementMatches1 = o1.getElementMatches().stream()
|
||||
.filter(e -> e.getSimilarity() == 1).collect(Collectors.toList());
|
||||
List<SchemaElementMatch> elementMatches2 = o2.getElementMatches().stream()
|
||||
.filter(e -> e.getSimilarity() == 1).collect(Collectors.toList());
|
||||
if (elementMatches1.size() > elementMatches2.size()) {
|
||||
return -1;
|
||||
} else if (elementMatches1.size() < elementMatches2.size()) {
|
||||
return 1;
|
||||
}
|
||||
}
|
||||
difference = mr1.getMaxMetricSimilarity() - mr2.getMaxMetricSimilarity();
|
||||
if (difference == 0) {
|
||||
if (Math.abs(difference) < 0.0005) {
|
||||
difference = mr1.getTotalSimilarity() - mr2.getTotalSimilarity();
|
||||
}
|
||||
if (difference == 0) {
|
||||
if (Math.abs(difference) < 0.0005) {
|
||||
difference = mr1.getMaxMetricUseCnt() - mr2.getMaxMetricUseCnt();
|
||||
}
|
||||
}
|
||||
|
||||
@@ -16,4 +16,7 @@ public class SqlInfo implements Serializable {
|
||||
|
||||
// SQL to be executed finally
|
||||
private String querySQL;
|
||||
|
||||
// Physical SQL corrected by LLM for performance optimization
|
||||
private String correctedQuerySQL;
|
||||
}
|
||||
|
||||
@@ -8,5 +8,6 @@ public enum ChatWorkflowState {
|
||||
VALIDATING,
|
||||
SQL_CORRECTING,
|
||||
PROCESSING,
|
||||
PHYSICAL_SQL_CORRECTING,
|
||||
FINISHED
|
||||
}
|
||||
|
||||
@@ -12,4 +12,6 @@ public class DictSingleTaskReq {
|
||||
private TypeEnums type;
|
||||
@NotNull
|
||||
private Long itemId;
|
||||
private String startDate;
|
||||
private String endDate;
|
||||
}
|
||||
|
||||
@@ -15,6 +15,7 @@ public class PageSchemaItemReq extends PageBaseReq {
|
||||
private String createdBy;
|
||||
private List<Long> domainIds = Lists.newArrayList();
|
||||
private List<Long> modelIds = Lists.newArrayList();
|
||||
private Long dataSetId;
|
||||
private Integer sensitiveLevel;
|
||||
private Integer status;
|
||||
private String key;
|
||||
|
||||
@@ -3,7 +3,11 @@ package com.tencent.supersonic.headless.api.pojo.request;
|
||||
import com.google.common.collect.Lists;
|
||||
import com.tencent.supersonic.common.jsqlparser.SqlAddHelper;
|
||||
import com.tencent.supersonic.common.jsqlparser.SqlReplaceHelper;
|
||||
import com.tencent.supersonic.common.pojo.*;
|
||||
import com.tencent.supersonic.common.pojo.Aggregator;
|
||||
import com.tencent.supersonic.common.pojo.Constants;
|
||||
import com.tencent.supersonic.common.pojo.DateConf;
|
||||
import com.tencent.supersonic.common.pojo.Filter;
|
||||
import com.tencent.supersonic.common.pojo.Order;
|
||||
import com.tencent.supersonic.common.pojo.enums.AggOperatorEnum;
|
||||
import com.tencent.supersonic.common.pojo.enums.QueryType;
|
||||
import com.tencent.supersonic.common.util.ContextUtils;
|
||||
@@ -21,14 +25,22 @@ import net.sf.jsqlparser.expression.operators.relational.ExpressionList;
|
||||
import net.sf.jsqlparser.parser.CCJSqlParserUtil;
|
||||
import net.sf.jsqlparser.schema.Column;
|
||||
import net.sf.jsqlparser.schema.Table;
|
||||
import net.sf.jsqlparser.statement.select.*;
|
||||
import net.sf.jsqlparser.statement.select.GroupByElement;
|
||||
import net.sf.jsqlparser.statement.select.Limit;
|
||||
import net.sf.jsqlparser.statement.select.Offset;
|
||||
import net.sf.jsqlparser.statement.select.OrderByElement;
|
||||
import net.sf.jsqlparser.statement.select.ParenthesedSelect;
|
||||
import net.sf.jsqlparser.statement.select.PlainSelect;
|
||||
import net.sf.jsqlparser.statement.select.SelectItem;
|
||||
import org.apache.commons.codec.digest.DigestUtils;
|
||||
import org.apache.commons.lang3.StringUtils;
|
||||
import org.springframework.util.CollectionUtils;
|
||||
|
||||
import java.util.ArrayList;
|
||||
import java.util.HashSet;
|
||||
import java.util.List;
|
||||
import java.util.Objects;
|
||||
import java.util.Set;
|
||||
import java.util.stream.Collectors;
|
||||
|
||||
@Data
|
||||
@@ -176,7 +188,7 @@ public class QueryStructReq extends SemanticQueryReq {
|
||||
|
||||
private List<SelectItem<?>> buildSelectItems(QueryStructReq queryStructReq) {
|
||||
List<SelectItem<?>> selectItems = new ArrayList<>();
|
||||
List<String> groups = queryStructReq.getGroups();
|
||||
Set<String> groups = new HashSet<>(queryStructReq.getGroups());
|
||||
|
||||
if (!CollectionUtils.isEmpty(groups)) {
|
||||
for (String group : groups) {
|
||||
@@ -236,7 +248,7 @@ public class QueryStructReq extends SemanticQueryReq {
|
||||
}
|
||||
|
||||
private GroupByElement buildGroupByElement(QueryStructReq queryStructReq) {
|
||||
List<String> groups = queryStructReq.getGroups();
|
||||
Set<String> groups = new HashSet<>(queryStructReq.getGroups());
|
||||
if ((!CollectionUtils.isEmpty(groups) && !queryStructReq.getAggregators().isEmpty())
|
||||
|| !queryStructReq.getMetricFilters().isEmpty()) {
|
||||
GroupByElement groupByElement = new GroupByElement();
|
||||
|
||||
@@ -23,9 +23,11 @@ public class SqlExecuteReq {
|
||||
private Integer limit = 1000;
|
||||
|
||||
public String getSql() {
|
||||
if (StringUtils.isNotBlank(sql) && sql.endsWith(";")) {
|
||||
sql = sql.substring(0, sql.length() - 1);
|
||||
if (StringUtils.isNotBlank(sql)) {
|
||||
sql = sql.replaceAll("^[\\n]+|[\\n]+$", "");
|
||||
sql = StringUtils.removeEnd(sql, ";");
|
||||
}
|
||||
|
||||
return String.format(LIMIT_WRAPPER, sql, limit);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -3,6 +3,7 @@ package com.tencent.supersonic.headless.chat.corrector;
|
||||
import com.tencent.supersonic.common.jsqlparser.SqlAddHelper;
|
||||
import com.tencent.supersonic.common.jsqlparser.SqlSelectFunctionHelper;
|
||||
import com.tencent.supersonic.common.jsqlparser.SqlSelectHelper;
|
||||
import com.tencent.supersonic.headless.api.pojo.SchemaElement;
|
||||
import com.tencent.supersonic.headless.api.pojo.SemanticParseInfo;
|
||||
import com.tencent.supersonic.headless.api.pojo.SemanticSchema;
|
||||
import com.tencent.supersonic.headless.chat.ChatQueryContext;
|
||||
@@ -11,7 +12,8 @@ import net.sf.jsqlparser.expression.Expression;
|
||||
import org.springframework.util.CollectionUtils;
|
||||
|
||||
import java.util.List;
|
||||
import java.util.Set;
|
||||
import java.util.Map;
|
||||
import java.util.Optional;
|
||||
import java.util.stream.Collectors;
|
||||
|
||||
/** Perform SQL corrections on the "Having" section in S2SQL. */
|
||||
@@ -29,8 +31,9 @@ public class HavingCorrector extends BaseSemanticCorrector {
|
||||
|
||||
SemanticSchema semanticSchema = chatQueryContext.getSemanticSchema();
|
||||
|
||||
Set<String> metrics = semanticSchema.getMetrics(dataSet).stream()
|
||||
.map(schemaElement -> schemaElement.getName()).collect(Collectors.toSet());
|
||||
Map<String, String> metrics = semanticSchema.getMetrics(dataSet).stream()
|
||||
.collect(Collectors.toMap(SchemaElement::getName,
|
||||
e -> Optional.ofNullable(e.getDefaultAgg()).orElse("")));
|
||||
|
||||
if (CollectionUtils.isEmpty(metrics)) {
|
||||
return;
|
||||
|
||||
@@ -0,0 +1,98 @@
|
||||
package com.tencent.supersonic.headless.chat.corrector;
|
||||
|
||||
import com.tencent.supersonic.common.pojo.ChatApp;
|
||||
import com.tencent.supersonic.common.pojo.enums.AppModule;
|
||||
import com.tencent.supersonic.common.util.ChatAppManager;
|
||||
import com.tencent.supersonic.headless.api.pojo.SemanticParseInfo;
|
||||
import com.tencent.supersonic.headless.chat.ChatQueryContext;
|
||||
import dev.langchain4j.model.chat.ChatLanguageModel;
|
||||
import dev.langchain4j.model.input.Prompt;
|
||||
import dev.langchain4j.model.input.PromptTemplate;
|
||||
import dev.langchain4j.model.output.structured.Description;
|
||||
import dev.langchain4j.provider.ModelProvider;
|
||||
import dev.langchain4j.service.AiServices;
|
||||
import lombok.Data;
|
||||
import lombok.ToString;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.apache.commons.lang3.StringUtils;
|
||||
import org.slf4j.Logger;
|
||||
import org.slf4j.LoggerFactory;
|
||||
|
||||
import java.util.HashMap;
|
||||
import java.util.Map;
|
||||
import java.util.Objects;
|
||||
|
||||
/**
|
||||
* 物理SQL修正器 - 使用LLM优化物理SQL性能
|
||||
*/
|
||||
@Slf4j
|
||||
public class LLMPhysicalSqlCorrector extends BaseSemanticCorrector {
|
||||
|
||||
private static final Logger keyPipelineLog = LoggerFactory.getLogger("keyPipeline");
|
||||
|
||||
public static final String APP_KEY = "PHYSICAL_SQL_CORRECTOR";
|
||||
private static final String INSTRUCTION = ""
|
||||
+ "#Role: You are a senior database performance optimization expert experienced in SQL tuning."
|
||||
+ "\n\n#Task: You will be provided with a user question and the corresponding physical SQL query,"
|
||||
+ " please analyze and optimize this SQL to improve query performance." + "\n\n#Rules:"
|
||||
+ "\n1. DO NOT add or introduce any new fields, columns, or aliases that are not in the original SQL."
|
||||
+ "\n2. Push WHERE conditions into JOIN ON clauses when possible to reduce intermediate result sets."
|
||||
+ "\n3. Optimize JOIN order by placing smaller tables or tables with selective conditions first."
|
||||
+ "\n4. For date range conditions, ensure they are applied as early as possible in the query execution."
|
||||
+ "\n5. Remove or comment out database-specific index hints (like USE INDEX) that may cause syntax errors."
|
||||
+ "\n6. ONLY modify the structure and order of existing elements, do not change field names or add new ones."
|
||||
+ "\n7. Ensure the optimized SQL is syntactically correct and logically equivalent to the original."
|
||||
+ "\n\n#Question: {{question}}" + "\n\n#OriginalSQL: {{sql}}";
|
||||
|
||||
public LLMPhysicalSqlCorrector() {
|
||||
ChatAppManager.register(APP_KEY, ChatApp.builder().prompt(INSTRUCTION).name("物理SQL修正")
|
||||
.appModule(AppModule.CHAT).description("通过大模型对物理SQL做性能优化").enable(false).build());
|
||||
}
|
||||
|
||||
@Data
|
||||
@ToString
|
||||
static class PhysicalSql {
|
||||
@Description("either positive or negative")
|
||||
private String opinion;
|
||||
|
||||
@Description("optimized sql if negative")
|
||||
private String sql;
|
||||
}
|
||||
|
||||
interface PhysicalSqlExtractor {
|
||||
PhysicalSql generatePhysicalSql(String text);
|
||||
}
|
||||
|
||||
@Override
|
||||
public void doCorrect(ChatQueryContext chatQueryContext, SemanticParseInfo semanticParseInfo) {
|
||||
ChatApp chatApp = chatQueryContext.getRequest().getChatAppConfig().get(APP_KEY);
|
||||
if (!chatQueryContext.getRequest().getText2SQLType().enableLLM() || Objects.isNull(chatApp)
|
||||
|| !chatApp.isEnable()) {
|
||||
return;
|
||||
}
|
||||
|
||||
ChatLanguageModel chatLanguageModel =
|
||||
ModelProvider.getChatModel(chatApp.getChatModelConfig());
|
||||
PhysicalSqlExtractor extractor =
|
||||
AiServices.create(PhysicalSqlExtractor.class, chatLanguageModel);
|
||||
Prompt prompt = generatePrompt(chatQueryContext.getRequest().getQueryText(),
|
||||
semanticParseInfo, chatApp.getPrompt());
|
||||
PhysicalSql physicalSql =
|
||||
extractor.generatePhysicalSql(prompt.toUserMessage().singleText());
|
||||
keyPipelineLog.info("LLMPhysicalSqlCorrector modelReq:\n{} \nmodelResp:\n{}", prompt.text(),
|
||||
physicalSql);
|
||||
if ("NEGATIVE".equalsIgnoreCase(physicalSql.getOpinion())
|
||||
&& StringUtils.isNotBlank(physicalSql.getSql())) {
|
||||
semanticParseInfo.getSqlInfo().setCorrectedQuerySQL(physicalSql.getSql());
|
||||
}
|
||||
}
|
||||
|
||||
private Prompt generatePrompt(String queryText, SemanticParseInfo semanticParseInfo,
|
||||
String promptTemplate) {
|
||||
Map<String, Object> variable = new HashMap<>();
|
||||
variable.put("question", queryText);
|
||||
variable.put("sql", semanticParseInfo.getSqlInfo().getQuerySQL());
|
||||
|
||||
return PromptTemplate.from(promptTemplate).apply(variable);
|
||||
}
|
||||
}
|
||||
@@ -8,43 +8,107 @@ import lombok.extern.slf4j.Slf4j;
|
||||
import org.springframework.stereotype.Service;
|
||||
|
||||
import java.util.ArrayList;
|
||||
import java.util.Collections;
|
||||
import java.util.HashMap;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
import java.util.Set;
|
||||
import java.util.concurrent.ConcurrentHashMap;
|
||||
import java.util.concurrent.locks.ReentrantReadWriteLock;
|
||||
import java.util.stream.Collectors;
|
||||
|
||||
@Service
|
||||
@Slf4j
|
||||
public class KnowledgeBaseService {
|
||||
private static volatile Map<Long, List<DictWord>> dimValueAliasMap = new HashMap<>();
|
||||
private static final Map<Long, List<DictWord>> dimValueAliasMap = new ConcurrentHashMap<>();
|
||||
private final ReentrantReadWriteLock lock = new ReentrantReadWriteLock();
|
||||
|
||||
/**
|
||||
* Get dimension value alias map (read-only).
|
||||
*
|
||||
* @return unmodifiable view of the map
|
||||
*/
|
||||
public static Map<Long, List<DictWord>> getDimValueAlias() {
|
||||
return dimValueAliasMap;
|
||||
return Collections.unmodifiableMap(dimValueAliasMap);
|
||||
}
|
||||
|
||||
/**
|
||||
* Add dimension value aliases with deduplication. Thread-safe implementation using
|
||||
* ConcurrentHashMap.
|
||||
*
|
||||
* @param dimId dimension ID
|
||||
* @param newWords new words to add
|
||||
* @return updated list of aliases for the dimension
|
||||
*/
|
||||
public static List<DictWord> addDimValueAlias(Long dimId, List<DictWord> newWords) {
|
||||
List<DictWord> dimValueAlias =
|
||||
dimValueAliasMap.containsKey(dimId) ? dimValueAliasMap.get(dimId)
|
||||
: new ArrayList<>();
|
||||
Set<String> wordSet =
|
||||
dimValueAlias
|
||||
.stream().map(word -> String.format("%s_%s_%s",
|
||||
word.getNatureWithFrequency(), word.getWord(), word.getAlias()))
|
||||
.collect(Collectors.toSet());
|
||||
for (DictWord dictWord : newWords) {
|
||||
String key = String.format("%s_%s_%s", dictWord.getNatureWithFrequency(),
|
||||
dictWord.getWord(), dictWord.getAlias());
|
||||
if (!wordSet.contains(key)) {
|
||||
dimValueAlias.add(dictWord);
|
||||
}
|
||||
if (dimId == null || CollectionUtils.isEmpty(newWords)) {
|
||||
return dimValueAliasMap.get(dimId);
|
||||
}
|
||||
|
||||
// Use computeIfAbsent and synchronized block for thread safety
|
||||
synchronized (dimValueAliasMap) {
|
||||
List<DictWord> dimValueAlias =
|
||||
dimValueAliasMap.computeIfAbsent(dimId, k -> new ArrayList<>());
|
||||
|
||||
// Build deduplication key set
|
||||
Set<String> existingKeys = dimValueAlias.stream().map(word -> buildDedupKey(word))
|
||||
.collect(Collectors.toSet());
|
||||
|
||||
// Add new words with deduplication
|
||||
for (DictWord dictWord : newWords) {
|
||||
String key = buildDedupKey(dictWord);
|
||||
if (!existingKeys.contains(key)) {
|
||||
dimValueAlias.add(dictWord);
|
||||
existingKeys.add(key);
|
||||
}
|
||||
}
|
||||
|
||||
return dimValueAlias;
|
||||
}
|
||||
dimValueAliasMap.put(dimId, dimValueAlias);
|
||||
return dimValueAlias;
|
||||
}
|
||||
|
||||
public void updateSemanticKnowledge(List<DictWord> natures) {
|
||||
/**
|
||||
* Remove dimension value aliases by dimension ID.
|
||||
*
|
||||
* @param dimId dimension ID to remove, or null to clear all
|
||||
*/
|
||||
public static void removeDimValueAlias(Long dimId) {
|
||||
if (dimId == null) {
|
||||
dimValueAliasMap.clear();
|
||||
log.info("Cleared all dimension value aliases");
|
||||
} else {
|
||||
dimValueAliasMap.remove(dimId);
|
||||
log.info("Removed dimension value alias for dimId: {}", dimId);
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Build deduplication key for DictWord.
|
||||
*
|
||||
* @param word the DictWord object
|
||||
* @return deduplication key string
|
||||
*/
|
||||
private static String buildDedupKey(DictWord word) {
|
||||
return String.format("%s_%s_%s", word.getNatureWithFrequency(), word.getWord(),
|
||||
word.getAlias());
|
||||
}
|
||||
|
||||
/**
|
||||
* Update semantic knowledge (incremental add, no clearing). Use this method to add new words
|
||||
* without removing existing data.
|
||||
*
|
||||
* @param natures the words to add
|
||||
*/
|
||||
public void updateSemanticKnowledge(List<DictWord> natures) {
|
||||
lock.writeLock().lock();
|
||||
try {
|
||||
updateSemanticKnowledgeInternal(natures);
|
||||
} finally {
|
||||
lock.writeLock().unlock();
|
||||
}
|
||||
}
|
||||
|
||||
private void updateSemanticKnowledgeInternal(List<DictWord> natures) {
|
||||
List<DictWord> prefixes = natures.stream().filter(
|
||||
entry -> !entry.getNatureWithFrequency().contains(DictWordType.SUFFIX.getType()))
|
||||
.collect(Collectors.toList());
|
||||
@@ -60,52 +124,82 @@ public class KnowledgeBaseService {
|
||||
SearchService.loadSuffix(suffixes);
|
||||
}
|
||||
|
||||
/**
|
||||
* Reload all knowledge (full replacement with clearing). Use this method to rebuild the entire
|
||||
* knowledge base.
|
||||
*
|
||||
* @param natures all words to load
|
||||
*/
|
||||
public void reloadAllData(List<DictWord> natures) {
|
||||
// 1. reload custom knowledge
|
||||
// 1. reload custom knowledge (executed outside lock to avoid long blocking)
|
||||
try {
|
||||
HanlpHelper.reloadCustomDictionary();
|
||||
} catch (Exception e) {
|
||||
log.error("reloadCustomDictionary error", e);
|
||||
}
|
||||
|
||||
// 2. update online knowledge
|
||||
if (CollectionUtils.isNotEmpty(dimValueAliasMap)) {
|
||||
for (Long dimId : dimValueAliasMap.keySet()) {
|
||||
natures.addAll(dimValueAliasMap.get(dimId));
|
||||
}
|
||||
}
|
||||
updateOnlineKnowledge(natures);
|
||||
}
|
||||
|
||||
private void updateOnlineKnowledge(List<DictWord> natures) {
|
||||
// 2. acquire write lock, clear trie and rebuild (short operation)
|
||||
lock.writeLock().lock();
|
||||
try {
|
||||
updateSemanticKnowledge(natures);
|
||||
} catch (Exception e) {
|
||||
log.error("updateSemanticKnowledge error", e);
|
||||
SearchService.clear();
|
||||
|
||||
if (CollectionUtils.isNotEmpty(dimValueAliasMap)) {
|
||||
for (Long dimId : dimValueAliasMap.keySet()) {
|
||||
natures.addAll(dimValueAliasMap.get(dimId));
|
||||
}
|
||||
}
|
||||
updateSemanticKnowledgeInternal(natures);
|
||||
} finally {
|
||||
lock.writeLock().unlock();
|
||||
}
|
||||
}
|
||||
|
||||
public List<S2Term> getTerms(String text, Map<Long, List<Long>> modelIdToDataSetIds) {
|
||||
return HanlpHelper.getTerms(text, modelIdToDataSetIds);
|
||||
lock.readLock().lock();
|
||||
try {
|
||||
return HanlpHelper.getTerms(text, modelIdToDataSetIds);
|
||||
} finally {
|
||||
lock.readLock().unlock();
|
||||
}
|
||||
}
|
||||
|
||||
public List<HanlpMapResult> prefixSearch(String key, int limit,
|
||||
Map<Long, List<Long>> modelIdToDataSetIds, Set<Long> detectDataSetIds) {
|
||||
return prefixSearchByModel(key, limit, modelIdToDataSetIds, detectDataSetIds);
|
||||
lock.readLock().lock();
|
||||
try {
|
||||
return prefixSearchByModel(key, limit, modelIdToDataSetIds, detectDataSetIds);
|
||||
} finally {
|
||||
lock.readLock().unlock();
|
||||
}
|
||||
}
|
||||
|
||||
public List<HanlpMapResult> prefixSearchByModel(String key, int limit,
|
||||
Map<Long, List<Long>> modelIdToDataSetIds, Set<Long> detectDataSetIds) {
|
||||
return SearchService.prefixSearch(key, limit, modelIdToDataSetIds, detectDataSetIds);
|
||||
lock.readLock().lock();
|
||||
try {
|
||||
return SearchService.prefixSearch(key, limit, modelIdToDataSetIds, detectDataSetIds);
|
||||
} finally {
|
||||
lock.readLock().unlock();
|
||||
}
|
||||
}
|
||||
|
||||
public List<HanlpMapResult> suffixSearch(String key, int limit,
|
||||
Map<Long, List<Long>> modelIdToDataSetIds, Set<Long> detectDataSetIds) {
|
||||
return suffixSearchByModel(key, limit, modelIdToDataSetIds, detectDataSetIds);
|
||||
lock.readLock().lock();
|
||||
try {
|
||||
return suffixSearchByModel(key, limit, modelIdToDataSetIds, detectDataSetIds);
|
||||
} finally {
|
||||
lock.readLock().unlock();
|
||||
}
|
||||
}
|
||||
|
||||
public List<HanlpMapResult> suffixSearchByModel(String key, int limit,
|
||||
Map<Long, List<Long>> modelIdToDataSetIds, Set<Long> detectDataSetIds) {
|
||||
return SearchService.suffixSearch(key, limit, modelIdToDataSetIds, detectDataSetIds);
|
||||
lock.readLock().lock();
|
||||
try {
|
||||
return SearchService.suffixSearch(key, limit, modelIdToDataSetIds, detectDataSetIds);
|
||||
} finally {
|
||||
lock.readLock().unlock();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -33,6 +33,7 @@ import java.util.Objects;
|
||||
import java.util.PriorityQueue;
|
||||
import java.util.TreeMap;
|
||||
import java.util.concurrent.ConcurrentHashMap;
|
||||
import java.util.concurrent.PriorityBlockingQueue;
|
||||
|
||||
import static com.hankcs.hanlp.utility.Predefine.logger;
|
||||
|
||||
@@ -40,7 +41,7 @@ public class MultiCustomDictionary extends DynamicCustomDictionary {
|
||||
|
||||
public static int MAX_SIZE = 10;
|
||||
public static Boolean removeDuplicates = true;
|
||||
public static ConcurrentHashMap<String, PriorityQueue<Term>> NATURE_TO_VALUES =
|
||||
public static ConcurrentHashMap<String, PriorityBlockingQueue<Term>> NATURE_TO_VALUES =
|
||||
new ConcurrentHashMap<>();
|
||||
private static boolean addToSuggesterTrie = true;
|
||||
|
||||
@@ -116,9 +117,17 @@ public class MultiCustomDictionary extends DynamicCustomDictionary {
|
||||
dictWord.setAlias(word.toLowerCase());
|
||||
String[] split = nature.split(DictWordType.NATURE_SPILT);
|
||||
if (split.length >= 2) {
|
||||
Long dimId = Long.parseLong(
|
||||
nature.split(DictWordType.NATURE_SPILT)[split.length - 1]);
|
||||
KnowledgeBaseService.addDimValueAlias(dimId, Arrays.asList(dictWord));
|
||||
try {
|
||||
Long dimId = Long.parseLong(
|
||||
nature.split(DictWordType.NATURE_SPILT)[split.length - 1]);
|
||||
KnowledgeBaseService.addDimValueAlias(dimId,
|
||||
Arrays.asList(dictWord));
|
||||
} catch (NumberFormatException e) {
|
||||
logger.warning(path + " : 非标准文件,不存入KnowledgeBaseService");
|
||||
return true;
|
||||
}
|
||||
|
||||
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -146,9 +155,10 @@ public class MultiCustomDictionary extends DynamicCustomDictionary {
|
||||
}
|
||||
for (int i = 0; i < attribute.nature.length; i++) {
|
||||
Nature nature = attribute.nature[i];
|
||||
PriorityQueue<Term> priorityQueue = NATURE_TO_VALUES.get(nature.toString());
|
||||
PriorityBlockingQueue<Term> priorityQueue =
|
||||
NATURE_TO_VALUES.get(nature.toString());
|
||||
if (Objects.isNull(priorityQueue)) {
|
||||
priorityQueue = new PriorityQueue<>(MAX_SIZE,
|
||||
priorityQueue = new PriorityBlockingQueue<>(MAX_SIZE,
|
||||
Comparator.comparingInt(Term::getFrequency).reversed());
|
||||
NATURE_TO_VALUES.put(nature.toString(), priorityQueue);
|
||||
}
|
||||
|
||||
@@ -24,14 +24,15 @@ import java.util.PriorityQueue;
|
||||
import java.util.Set;
|
||||
import java.util.TreeMap;
|
||||
import java.util.TreeSet;
|
||||
import java.util.concurrent.PriorityBlockingQueue;
|
||||
import java.util.stream.Collectors;
|
||||
|
||||
@Slf4j
|
||||
public class SearchService {
|
||||
|
||||
public static final int SEARCH_SIZE = 200;
|
||||
private static BinTrie<List<String>> trie;
|
||||
private static BinTrie<List<String>> suffixTrie;
|
||||
private static volatile BinTrie<List<String>> trie;
|
||||
private static volatile BinTrie<List<String>> suffixTrie;
|
||||
|
||||
static {
|
||||
trie = new BinTrie<>();
|
||||
@@ -200,7 +201,7 @@ public class SearchService {
|
||||
public static List<String> getDimensionValue(DimensionValueReq dimensionValueReq) {
|
||||
String nature = DictWordType.NATURE_SPILT + dimensionValueReq.getModelId()
|
||||
+ DictWordType.NATURE_SPILT + dimensionValueReq.getElementID();
|
||||
PriorityQueue<Term> terms = MultiCustomDictionary.NATURE_TO_VALUES.get(nature);
|
||||
PriorityBlockingQueue<Term> terms = MultiCustomDictionary.NATURE_TO_VALUES.get(nature);
|
||||
if (CollectionUtils.isEmpty(terms)) {
|
||||
return new ArrayList<>();
|
||||
}
|
||||
|
||||
@@ -175,6 +175,7 @@ public class FileHandlerImpl implements FileHandler {
|
||||
private DictValueResp convert2Resp(String lineStr) {
|
||||
DictValueResp dictValueResp = new DictValueResp();
|
||||
if (StringUtils.isNotEmpty(lineStr)) {
|
||||
lineStr = StringUtils.stripStart(lineStr, null);
|
||||
String[] itemArray = lineStr.split("\\s+");
|
||||
if (Objects.nonNull(itemArray) && itemArray.length >= 3) {
|
||||
dictValueResp.setValue(itemArray[0].replace("#", " "));
|
||||
|
||||
@@ -100,8 +100,6 @@ public class HanlpHelper {
|
||||
FileHelper.deleteCacheFile(HanLP.Config.CustomDictionaryPath);
|
||||
FileHelper.resetCustomPath(getDynamicCustomDictionary());
|
||||
}
|
||||
// 3.clear trie
|
||||
SearchService.clear();
|
||||
|
||||
boolean reload = getDynamicCustomDictionary().reload();
|
||||
if (reload) {
|
||||
|
||||
@@ -129,7 +129,7 @@ public abstract class BaseMapper implements SchemaMapper {
|
||||
Map<MatchText, List<T>> matchResult = matchStrategy.match(chatQueryContext, terms,
|
||||
chatQueryContext.getRequest().getDataSetIds());
|
||||
List<T> matches = new ArrayList<>();
|
||||
if (Objects.isNull(matchResult)) {
|
||||
if (Objects.isNull(matchResult) || matchResult.isEmpty()) {
|
||||
return matches;
|
||||
}
|
||||
Optional<List<T>> first = matchResult.entrySet().stream()
|
||||
|
||||
@@ -12,12 +12,17 @@ import org.springframework.beans.factory.annotation.Qualifier;
|
||||
import org.springframework.stereotype.Service;
|
||||
|
||||
import java.util.HashMap;
|
||||
import java.util.HashSet;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
import java.util.Objects;
|
||||
import java.util.Set;
|
||||
import java.util.concurrent.Callable;
|
||||
import java.util.concurrent.CompletableFuture;
|
||||
import java.util.concurrent.ExecutionException;
|
||||
import java.util.concurrent.ThreadPoolExecutor;
|
||||
import java.util.function.Function;
|
||||
import java.util.function.Supplier;
|
||||
import java.util.stream.Collectors;
|
||||
|
||||
@Service
|
||||
@Slf4j
|
||||
@@ -72,18 +77,39 @@ public abstract class BaseMatchStrategy<T extends MapResult> implements MatchStr
|
||||
}
|
||||
}
|
||||
|
||||
protected void executeTasks(List<Callable<Void>> tasks) {
|
||||
protected Set<T> executeTasks(List<Supplier<List<T>>> tasks) {
|
||||
|
||||
Function<Supplier<List<T>>, Supplier<List<T>>> decorator = taskDecorator();
|
||||
List<CompletableFuture<List<T>>> futures;
|
||||
if (decorator == null) {
|
||||
futures = tasks.stream().map(t -> CompletableFuture.supplyAsync(t, executor)).toList();
|
||||
} else {
|
||||
futures = tasks.stream()
|
||||
.map(t -> CompletableFuture.supplyAsync(decorator.apply(t), executor)).toList();
|
||||
}
|
||||
|
||||
CompletableFuture<List<T>> listCompletableFuture =
|
||||
CompletableFuture.allOf(futures.toArray(new CompletableFuture<?>[0]))
|
||||
.thenApply(v -> futures.stream()
|
||||
.flatMap(listFuture -> listFuture.join().stream())
|
||||
.collect(Collectors.toList()));
|
||||
try {
|
||||
executor.invokeAll(tasks);
|
||||
for (Callable<Void> future : tasks) {
|
||||
future.call();
|
||||
}
|
||||
} catch (Exception e) {
|
||||
List<T> ts = listCompletableFuture.get();
|
||||
Set<T> results = new HashSet<>();
|
||||
selectResultInOneRound(results, ts);
|
||||
return results;
|
||||
} catch (InterruptedException e) {
|
||||
Thread.currentThread().interrupt();
|
||||
throw new RuntimeException("Task execution interrupted", e);
|
||||
} catch (ExecutionException e) {
|
||||
throw new RuntimeException(e);
|
||||
}
|
||||
}
|
||||
|
||||
public Function<Supplier<List<T>>, Supplier<List<T>>> taskDecorator() {
|
||||
return null;
|
||||
}
|
||||
|
||||
public double getThreshold(Double threshold, Double minThreshold, MapModeEnum mapModeEnum) {
|
||||
if (MapModeEnum.STRICT.equals(mapModeEnum)) {
|
||||
return 1.0d;
|
||||
|
||||
@@ -17,6 +17,8 @@ import java.util.List;
|
||||
import java.util.Map;
|
||||
import java.util.Map.Entry;
|
||||
import java.util.Set;
|
||||
import java.util.function.Function;
|
||||
import java.util.function.Supplier;
|
||||
import java.util.stream.Collectors;
|
||||
|
||||
/**
|
||||
@@ -76,6 +78,22 @@ public class DatabaseMatchStrategy extends SingleMatchStrategy<DatabaseMapResult
|
||||
return allElements;
|
||||
}
|
||||
|
||||
@Override
|
||||
public Function<Supplier<List<DatabaseMapResult>>, Supplier<List<DatabaseMapResult>>> taskDecorator() {
|
||||
List<SchemaElement> schemaElements = allElements.get();
|
||||
if (CollectionUtils.isEmpty(schemaElements)) {
|
||||
return null;
|
||||
}
|
||||
return (t) -> (Supplier<List<DatabaseMapResult>>) () -> {
|
||||
try {
|
||||
allElements.set(schemaElements);
|
||||
return t.get();
|
||||
} finally {
|
||||
allElements.remove();
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
private Double getThreshold(ChatQueryContext chatQueryContext) {
|
||||
Double threshold =
|
||||
Double.valueOf(mapperConfig.getParameterValue(MapperConfig.MAPPER_NAME_THRESHOLD));
|
||||
|
||||
@@ -3,6 +3,7 @@ package com.tencent.supersonic.headless.chat.mapper;
|
||||
import com.alibaba.fastjson.JSONObject;
|
||||
import com.google.common.collect.Lists;
|
||||
import com.hankcs.hanlp.seg.common.Term;
|
||||
import com.tencent.supersonic.common.pojo.ChatModelConfig;
|
||||
import com.tencent.supersonic.headless.api.pojo.response.S2Term;
|
||||
import com.tencent.supersonic.headless.chat.ChatQueryContext;
|
||||
import com.tencent.supersonic.headless.chat.knowledge.EmbeddingResult;
|
||||
@@ -23,8 +24,7 @@ import org.springframework.beans.factory.annotation.Autowired;
|
||||
import org.springframework.stereotype.Service;
|
||||
|
||||
import java.util.*;
|
||||
import java.util.concurrent.Callable;
|
||||
import java.util.concurrent.ConcurrentHashMap;
|
||||
import java.util.function.Supplier;
|
||||
import java.util.stream.Collectors;
|
||||
|
||||
import static com.tencent.supersonic.headless.chat.mapper.MapperConfig.*;
|
||||
@@ -140,7 +140,6 @@ public class EmbeddingMatchStrategy extends BatchMatchStrategy<EmbeddingResult>
|
||||
*/
|
||||
public List<EmbeddingResult> detectByBatch(ChatQueryContext chatQueryContext,
|
||||
Set<Long> detectDataSetIds, Set<String> detectSegments, boolean useLlm) {
|
||||
Set<EmbeddingResult> results = ConcurrentHashMap.newKeySet();
|
||||
int embeddingMapperBatch = Integer
|
||||
.valueOf(mapperConfig.getParameterValue(MapperConfig.EMBEDDING_MAPPER_BATCH));
|
||||
|
||||
@@ -153,12 +152,11 @@ public class EmbeddingMatchStrategy extends BatchMatchStrategy<EmbeddingResult>
|
||||
Lists.partition(queryTextsList, embeddingMapperBatch);
|
||||
|
||||
// Create and execute tasks for each batch
|
||||
List<Callable<Void>> tasks = new ArrayList<>();
|
||||
List<Supplier<List<EmbeddingResult>>> tasks = new ArrayList<>();
|
||||
for (List<String> queryTextsSub : queryTextsSubList) {
|
||||
tasks.add(
|
||||
createTask(chatQueryContext, detectDataSetIds, queryTextsSub, results, useLlm));
|
||||
tasks.add(createTask(chatQueryContext, detectDataSetIds, queryTextsSub, useLlm));
|
||||
}
|
||||
executeTasks(tasks);
|
||||
Set<EmbeddingResult> results = executeTasks(tasks);
|
||||
|
||||
// Apply LLM filtering if enabled
|
||||
if (useLlm) {
|
||||
@@ -167,9 +165,13 @@ public class EmbeddingMatchStrategy extends BatchMatchStrategy<EmbeddingResult>
|
||||
variable.put("retrievedInfo", JSONObject.toJSONString(results));
|
||||
|
||||
Prompt prompt = PromptTemplate.from(LLM_FILTER_PROMPT).apply(variable);
|
||||
ChatLanguageModel chatLanguageModel =
|
||||
ModelProvider.getChatModel(chatQueryContext.getRequest().getChatAppConfig()
|
||||
.get("REWRITE_MULTI_TURN").getChatModelConfig());
|
||||
ChatModelConfig chatModelConfig = null;
|
||||
if (chatQueryContext.getRequest().getChatAppConfig() != null && chatQueryContext
|
||||
.getRequest().getChatAppConfig().containsKey("REWRITE_MULTI_TURN")) {
|
||||
chatModelConfig = chatQueryContext.getRequest().getChatAppConfig()
|
||||
.get("REWRITE_MULTI_TURN").getChatModelConfig();
|
||||
}
|
||||
ChatLanguageModel chatLanguageModel = ModelProvider.getChatModel(chatModelConfig);
|
||||
String response = chatLanguageModel.generate(prompt.toUserMessage().singleText());
|
||||
|
||||
if (StringUtils.isBlank(response)) {
|
||||
@@ -191,20 +193,13 @@ public class EmbeddingMatchStrategy extends BatchMatchStrategy<EmbeddingResult>
|
||||
* @param chatQueryContext The context of the chat query
|
||||
* @param detectDataSetIds Target dataset IDs
|
||||
* @param queryTextsSub Sub-list of query texts to process
|
||||
* @param results Shared result set for collecting results
|
||||
* @param useLlm Whether to use LLM
|
||||
* @return Callable task
|
||||
* @return Supplier task
|
||||
*/
|
||||
private Callable<Void> createTask(ChatQueryContext chatQueryContext, Set<Long> detectDataSetIds,
|
||||
List<String> queryTextsSub, Set<EmbeddingResult> results, boolean useLlm) {
|
||||
return () -> {
|
||||
List<EmbeddingResult> oneRoundResults = detectByQueryTextsSub(detectDataSetIds,
|
||||
queryTextsSub, chatQueryContext, useLlm);
|
||||
synchronized (results) {
|
||||
selectResultInOneRound(results, oneRoundResults);
|
||||
}
|
||||
return null;
|
||||
};
|
||||
private Supplier<List<EmbeddingResult>> createTask(ChatQueryContext chatQueryContext,
|
||||
Set<Long> detectDataSetIds, List<String> queryTextsSub, boolean useLlm) {
|
||||
return () -> detectByQueryTextsSub(detectDataSetIds, queryTextsSub, chatQueryContext,
|
||||
useLlm);
|
||||
}
|
||||
|
||||
/**
|
||||
|
||||
@@ -1,10 +1,7 @@
|
||||
package com.tencent.supersonic.headless.chat.mapper;
|
||||
|
||||
import com.tencent.supersonic.common.util.ContextUtils;
|
||||
import com.tencent.supersonic.headless.api.pojo.SchemaElement;
|
||||
import com.tencent.supersonic.headless.api.pojo.SchemaElementMatch;
|
||||
import com.tencent.supersonic.headless.api.pojo.SchemaElementType;
|
||||
import com.tencent.supersonic.headless.api.pojo.SchemaMapInfo;
|
||||
import com.tencent.supersonic.headless.api.pojo.*;
|
||||
import com.tencent.supersonic.headless.api.pojo.response.S2Term;
|
||||
import com.tencent.supersonic.headless.chat.ChatQueryContext;
|
||||
import com.tencent.supersonic.headless.chat.knowledge.DatabaseMapResult;
|
||||
@@ -90,13 +87,15 @@ public class KeywordMapper extends BaseMapper {
|
||||
.similarity(hanlpMapResult.getSimilarity())
|
||||
.detectWord(hanlpMapResult.getDetectWord()).build();
|
||||
// doDimValueAliasLogic 将维度值别名进行替换成真实维度值
|
||||
doDimValueAliasLogic(schemaElementMatch);
|
||||
doDimValueAliasLogic(schemaElementMatch,
|
||||
chatQueryContext.getSemanticSchema().getDimensionValues());
|
||||
addToSchemaMap(chatQueryContext.getMapInfo(), dataSetId, schemaElementMatch);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
private void doDimValueAliasLogic(SchemaElementMatch schemaElementMatch) {
|
||||
private void doDimValueAliasLogic(SchemaElementMatch schemaElementMatch,
|
||||
List<SchemaElement> dimensionValues) {
|
||||
SchemaElement element = schemaElementMatch.getElement();
|
||||
if (SchemaElementType.VALUE.equals(element.getType())) {
|
||||
Long dimId = element.getId();
|
||||
@@ -112,6 +111,18 @@ public class KeywordMapper extends BaseMapper {
|
||||
schemaElementMatch.setWord(wordTech);
|
||||
}
|
||||
}
|
||||
SchemaElement dimensionValue = dimensionValues.stream()
|
||||
.filter(dimValue -> dimId.equals(dimValue.getId())).findFirst().orElse(null);
|
||||
if (dimensionValue != null) {
|
||||
SchemaValueMap dimValue =
|
||||
dimensionValue.getSchemaValueMaps().stream().filter(schemaValueMap -> {
|
||||
return StringUtils.equals(schemaValueMap.getBizName(), word)
|
||||
|| schemaValueMap.getAlias().contains(word);
|
||||
}).findFirst().orElse(null);
|
||||
if (dimValue != null) {
|
||||
schemaElementMatch.setWord(dimValue.getTechName());
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -11,8 +11,7 @@ import java.util.ArrayList;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
import java.util.Set;
|
||||
import java.util.concurrent.Callable;
|
||||
import java.util.concurrent.ConcurrentHashMap;
|
||||
import java.util.function.Supplier;
|
||||
|
||||
@Service
|
||||
@Slf4j
|
||||
@@ -26,8 +25,7 @@ public abstract class SingleMatchStrategy<T extends MapResult> extends BaseMatch
|
||||
Set<Long> detectDataSetIds) {
|
||||
Map<Integer, Integer> regOffsetToLength = mapperHelper.getRegOffsetToLength(terms);
|
||||
String text = chatQueryContext.getRequest().getQueryText();
|
||||
Set<T> results = ConcurrentHashMap.newKeySet();
|
||||
List<Callable<Void>> tasks = new ArrayList<>();
|
||||
List<Supplier<List<T>>> tasks = new ArrayList<>();
|
||||
|
||||
for (int startIndex = 0; startIndex <= text.length() - 1;) {
|
||||
for (int index = startIndex; index <= text.length();) {
|
||||
@@ -35,27 +33,20 @@ public abstract class SingleMatchStrategy<T extends MapResult> extends BaseMatch
|
||||
index = mapperHelper.getStepIndex(regOffsetToLength, index);
|
||||
if (index <= text.length()) {
|
||||
String detectSegment = text.substring(startIndex, index).trim();
|
||||
Callable<Void> task = createTask(chatQueryContext, detectDataSetIds,
|
||||
detectSegment, offset, results);
|
||||
Supplier<List<T>> task =
|
||||
createTask(chatQueryContext, detectDataSetIds, detectSegment, offset);
|
||||
tasks.add(task);
|
||||
}
|
||||
}
|
||||
startIndex = mapperHelper.getStepIndex(regOffsetToLength, startIndex);
|
||||
}
|
||||
executeTasks(tasks);
|
||||
Set<T> results = executeTasks(tasks);
|
||||
return new ArrayList<>(results);
|
||||
}
|
||||
|
||||
private Callable<Void> createTask(ChatQueryContext chatQueryContext, Set<Long> detectDataSetIds,
|
||||
String detectSegment, int offset, Set<T> results) {
|
||||
return () -> {
|
||||
List<T> oneRoundResults =
|
||||
detectByStep(chatQueryContext, detectDataSetIds, detectSegment, offset);
|
||||
synchronized (results) {
|
||||
selectResultInOneRound(results, oneRoundResults);
|
||||
}
|
||||
return null;
|
||||
};
|
||||
private Supplier<List<T>> createTask(ChatQueryContext chatQueryContext,
|
||||
Set<Long> detectDataSetIds, String detectSegment, int offset) {
|
||||
return () -> detectByStep(chatQueryContext, detectDataSetIds, detectSegment, offset);
|
||||
}
|
||||
|
||||
public abstract List<T> detectByStep(ChatQueryContext chatQueryContext,
|
||||
|
||||
@@ -57,6 +57,10 @@ public class ParserConfig extends ParameterConfig {
|
||||
new Parameter("s2.parser.field.count.threshold", "0", "语义字段个数阈值",
|
||||
"如果映射字段小于该阈值,则将数据集所有字段输入LLM", "number", "语义解析配置");
|
||||
|
||||
public static final Parameter PARSER_FORMAT_JSON_TYPE =
|
||||
new Parameter("s2.parser.format.json-type", "", "请求llm返回json格式,默认不设置json格式",
|
||||
"选项:json_schema或者json_object", "string", "语义解析配置");
|
||||
|
||||
@Override
|
||||
public List<Parameter> getSysParameters() {
|
||||
return Lists.newArrayList(PARSER_LINKING_VALUE_ENABLE, PARSER_RULE_CORRECTOR_ENABLE,
|
||||
|
||||
@@ -2,9 +2,11 @@ package com.tencent.supersonic.headless.chat.parser.llm;
|
||||
|
||||
import com.google.common.collect.Lists;
|
||||
import com.tencent.supersonic.common.pojo.ChatApp;
|
||||
import com.tencent.supersonic.common.pojo.ChatModelConfig;
|
||||
import com.tencent.supersonic.common.pojo.Text2SQLExemplar;
|
||||
import com.tencent.supersonic.common.pojo.enums.AppModule;
|
||||
import com.tencent.supersonic.common.util.ChatAppManager;
|
||||
import com.tencent.supersonic.headless.chat.parser.ParserConfig;
|
||||
import com.tencent.supersonic.headless.chat.query.llm.s2sql.LLMReq;
|
||||
import com.tencent.supersonic.headless.chat.query.llm.s2sql.LLMResp;
|
||||
import dev.langchain4j.model.chat.ChatLanguageModel;
|
||||
@@ -14,9 +16,11 @@ import dev.langchain4j.model.output.structured.Description;
|
||||
import dev.langchain4j.service.AiServices;
|
||||
import lombok.Data;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.apache.commons.lang.StringUtils;
|
||||
import org.apache.commons.lang3.tuple.Pair;
|
||||
import org.slf4j.Logger;
|
||||
import org.slf4j.LoggerFactory;
|
||||
import org.springframework.beans.factory.annotation.Autowired;
|
||||
import org.springframework.stereotype.Service;
|
||||
|
||||
import java.util.HashMap;
|
||||
@@ -24,6 +28,8 @@ import java.util.List;
|
||||
import java.util.Map;
|
||||
import java.util.concurrent.ConcurrentHashMap;
|
||||
|
||||
import static com.tencent.supersonic.headless.chat.parser.ParserConfig.PARSER_FORMAT_JSON_TYPE;
|
||||
|
||||
@Service
|
||||
@Slf4j
|
||||
public class OnePassSCSqlGenStrategy extends SqlGenStrategy {
|
||||
@@ -31,6 +37,10 @@ public class OnePassSCSqlGenStrategy extends SqlGenStrategy {
|
||||
private static final Logger keyPipelineLog = LoggerFactory.getLogger("keyPipeline");
|
||||
|
||||
public static final String APP_KEY = "S2SQL_PARSER";
|
||||
|
||||
@Autowired
|
||||
private ParserConfig parserConfig;
|
||||
|
||||
public static final String INSTRUCTION =
|
||||
"#Role: You are a data analyst experienced in SQL languages."
|
||||
+ "\n#Task: You will be provided with a natural language question asked by users,"
|
||||
@@ -74,7 +84,13 @@ public class OnePassSCSqlGenStrategy extends SqlGenStrategy {
|
||||
|
||||
// 2.generate sql generation prompt for each self-consistency inference
|
||||
ChatApp chatApp = llmReq.getChatAppConfig().get(APP_KEY);
|
||||
ChatLanguageModel chatLanguageModel = getChatLanguageModel(chatApp.getChatModelConfig());
|
||||
ChatModelConfig chatModelConfig = chatApp.getChatModelConfig();
|
||||
if (!StringUtils.isBlank(parserConfig.getParameterValue(PARSER_FORMAT_JSON_TYPE))) {
|
||||
chatModelConfig.setJsonFormat(true);
|
||||
chatModelConfig
|
||||
.setJsonFormatType(parserConfig.getParameterValue(PARSER_FORMAT_JSON_TYPE));
|
||||
}
|
||||
ChatLanguageModel chatLanguageModel = getChatLanguageModel(chatModelConfig);
|
||||
SemanticSqlExtractor extractor =
|
||||
AiServices.create(SemanticSqlExtractor.class, chatLanguageModel);
|
||||
|
||||
|
||||
@@ -14,11 +14,10 @@ import org.springframework.beans.factory.annotation.Autowired;
|
||||
import org.springframework.stereotype.Component;
|
||||
import org.springframework.util.CollectionUtils;
|
||||
|
||||
import java.util.ArrayList;
|
||||
import java.util.Collections;
|
||||
import java.util.List;
|
||||
import java.util.Objects;
|
||||
import java.util.*;
|
||||
import java.util.stream.Collectors;
|
||||
|
||||
import static com.tencent.supersonic.common.pojo.DimensionConstants.*;
|
||||
import static com.tencent.supersonic.headless.chat.parser.ParserConfig.*;
|
||||
|
||||
@Component
|
||||
@@ -51,12 +50,33 @@ public class PromptHelper {
|
||||
// use random collection of exemplars for each self-consistency inference
|
||||
for (int i = 0; i < selfConsistencyNumber; i++) {
|
||||
List<Text2SQLExemplar> shuffledList = new ArrayList<>(exemplars);
|
||||
// only shuffle the exemplars from config
|
||||
List<Text2SQLExemplar> subList=shuffledList.subList(llmReq.getDynamicExemplars().size(),shuffledList.size());
|
||||
Collections.shuffle(subList);
|
||||
results.add(shuffledList.subList(0, Math.min(shuffledList.size(), fewShotNumber)));
|
||||
List<Text2SQLExemplar> same = shuffledList.stream() // 相似度极高的话,先找出来
|
||||
.filter(e -> e.getSimilarity() > 0.989).collect(Collectors.toList());
|
||||
List<Text2SQLExemplar> noSame = shuffledList.stream()
|
||||
.filter(e -> e.getSimilarity() <= 0.989).collect(Collectors.toList());
|
||||
if ((noSame.size() - same.size()) > fewShotNumber) {// 去除部分最低分
|
||||
noSame.sort(Comparator.comparingDouble(Text2SQLExemplar::getSimilarity));
|
||||
noSame = noSame.subList((noSame.size() - fewShotNumber) / 2, noSame.size());
|
||||
}
|
||||
Text2SQLExemplar mostSimilar = noSame.get(noSame.size() - 1);
|
||||
Collections.shuffle(noSame);
|
||||
List<Text2SQLExemplar> ts;
|
||||
if (same.size() > 0) {// 一样的话,必须作为提示语
|
||||
ts = new ArrayList<>();
|
||||
int needSize = Math.min(noSame.size() + same.size(), fewShotNumber);
|
||||
if (needSize > same.size()) {
|
||||
ts.addAll(noSame.subList(0, needSize - same.size()));
|
||||
}
|
||||
ts.addAll(same);
|
||||
} else { // 至少要一个最像的
|
||||
ts = noSame.subList(0, Math.min(noSame.size(), fewShotNumber));
|
||||
if (!ts.contains(mostSimilar)) {
|
||||
ts.remove(ts.size() - 1);
|
||||
ts.add(mostSimilar);
|
||||
}
|
||||
}
|
||||
results.add(ts);
|
||||
}
|
||||
|
||||
return results;
|
||||
}
|
||||
|
||||
@@ -123,6 +143,10 @@ public class PromptHelper {
|
||||
dimension.getAlias().forEach(a -> alias.append(a).append(";"));
|
||||
dimensionStr.append(" ALIAS '").append(alias).append("'");
|
||||
}
|
||||
if (Objects.nonNull(dimension.getExtInfo().get(DIMENSION_DATA_TYPE))) {
|
||||
dimensionStr.append(" DATATYPE '")
|
||||
.append(dimension.getExtInfo().get(DIMENSION_DATA_TYPE)).append("'");
|
||||
}
|
||||
if (StringUtils.isNotEmpty(dimension.getTimeFormat())) {
|
||||
dimensionStr.append(" FORMAT '").append(dimension.getTimeFormat()).append("'");
|
||||
}
|
||||
|
||||
@@ -32,7 +32,7 @@ public class TimeRangeParser implements SemanticParser {
|
||||
|
||||
private static final Pattern RECENT_PATTERN_CN = Pattern.compile(
|
||||
".*(?<periodStr>(近|过去)((?<enNum>\\d+)|(?<zhNum>[一二三四五六七八九十百千万亿]+))个?(?<zhPeriod>[天周月年])).*");
|
||||
private static final Pattern DATE_PATTERN_NUMBER = Pattern.compile("(\\d{8})");
|
||||
private static final Pattern DATE_PATTERN_NUMBER = Pattern.compile("\\b(\\d{8})\\b");
|
||||
private static final DateFormat DATE_FORMAT_NUMBER = new SimpleDateFormat("yyyyMMdd");
|
||||
private static final DateFormat DATE_FORMAT = new SimpleDateFormat("yyyy-MM-dd");
|
||||
|
||||
|
||||
@@ -5,7 +5,6 @@ import com.tencent.supersonic.headless.api.pojo.DBColumn;
|
||||
import com.tencent.supersonic.headless.api.pojo.enums.FieldType;
|
||||
import com.tencent.supersonic.headless.core.pojo.ConnectInfo;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.apache.commons.lang3.StringUtils;
|
||||
|
||||
import java.sql.*;
|
||||
import java.util.ArrayList;
|
||||
@@ -148,7 +147,8 @@ public abstract class BaseDbAdaptor implements DbAdaptor {
|
||||
String url = connectionInfo.getUrl().toLowerCase();
|
||||
|
||||
// 设置通用属性
|
||||
properties.setProperty("user", connectionInfo.getUserName());
|
||||
String userName = Optional.ofNullable(connectionInfo.getUserName()).orElse("");
|
||||
properties.setProperty("user", userName);
|
||||
|
||||
|
||||
String password = Optional.ofNullable(connectionInfo.getPassword()).orElse("");
|
||||
|
||||
@@ -10,6 +10,7 @@ import java.sql.DatabaseMetaData;
|
||||
import java.sql.ResultSet;
|
||||
import java.sql.SQLException;
|
||||
import java.util.List;
|
||||
import java.util.Properties;
|
||||
|
||||
@Slf4j
|
||||
public class DuckdbAdaptor extends DefaultDbAdaptor {
|
||||
@@ -23,7 +24,7 @@ public class DuckdbAdaptor extends DefaultDbAdaptor {
|
||||
String tableName) throws SQLException {
|
||||
List<DBColumn> dbColumns = Lists.newArrayList();
|
||||
DatabaseMetaData metaData = getDatabaseMetaData(connectInfo);
|
||||
ResultSet columns = metaData.getColumns(schemaName, null, tableName, null);
|
||||
ResultSet columns = metaData.getColumns(null, schemaName, tableName, null);
|
||||
while (columns.next()) {
|
||||
String columnName = columns.getString("COLUMN_NAME");
|
||||
String dataType = columns.getString("TYPE_NAME");
|
||||
@@ -42,4 +43,9 @@ public class DuckdbAdaptor extends DefaultDbAdaptor {
|
||||
return sql.replaceAll("`", "");
|
||||
}
|
||||
|
||||
@Override
|
||||
public Properties getProperties(ConnectInfo connectionInfo) {
|
||||
return new Properties();
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
package com.tencent.supersonic.headless.core.pojo;
|
||||
|
||||
import com.tencent.supersonic.common.pojo.User;
|
||||
import com.tencent.supersonic.headless.api.pojo.response.QueryState;
|
||||
import com.tencent.supersonic.headless.api.pojo.response.SemanticSchemaResp;
|
||||
import lombok.Data;
|
||||
@@ -24,6 +25,7 @@ public class QueryStatement {
|
||||
private SemanticSchemaResp semanticSchema;
|
||||
private Integer limit = 1000;
|
||||
private Boolean isTranslated = false;
|
||||
private User user;
|
||||
|
||||
public boolean isOk() {
|
||||
return StringUtils.isBlank(errMsg) && StringUtils.isNotBlank(sql);
|
||||
|
||||
@@ -47,24 +47,20 @@ public class SqlQueryParser implements QueryParser {
|
||||
SqlQuery sqlQuery = queryStatement.getSqlQuery();
|
||||
List<String> queryFields = SqlSelectHelper.getAllSelectFields(sqlQuery.getSql());
|
||||
Set<String> queryAliases = SqlSelectHelper.getAliasFields(sqlQuery.getSql());
|
||||
Set<String> ontologyMetricsDimensions = Collections.synchronizedSet(new HashSet<String>());
|
||||
Set<String> ontologyBizNameMetricsDimensions = Collections.synchronizedSet(new HashSet<>());
|
||||
List<Pair<String, String>> ontologyMetricsDimensionsAndBizName =
|
||||
Collections.synchronizedList(new ArrayList<>());
|
||||
queryFields.removeAll(queryAliases);
|
||||
Ontology ontology = queryStatement.getOntology();
|
||||
OntologyQuery ontologyQuery = buildOntologyQuery(ontology, queryFields);
|
||||
Set<String> queryFieldsSet = new HashSet<>(queryFields);
|
||||
ontologyQuery.getMetrics().forEach(m -> {
|
||||
ontologyMetricsDimensions.add(m.getName());
|
||||
ontologyBizNameMetricsDimensions.add(m.getBizName());
|
||||
ontologyMetricsDimensionsAndBizName.add(Pair.of(m.getName(), m.getBizName()));
|
||||
});
|
||||
ontologyQuery.getDimensions().forEach(d -> {
|
||||
ontologyMetricsDimensions.add(d.getName());
|
||||
ontologyBizNameMetricsDimensions.add(d.getBizName());
|
||||
ontologyMetricsDimensionsAndBizName.add(Pair.of(d.getName(), d.getBizName()));
|
||||
});
|
||||
// check if there are fields not matched with any metric or dimension
|
||||
|
||||
if (!(queryFieldsSet.containsAll(ontologyMetricsDimensions)
|
||||
|| queryFieldsSet.containsAll(ontologyBizNameMetricsDimensions))) {
|
||||
if (!allFieldMatched(queryFieldsSet, ontologyMetricsDimensionsAndBizName)) {
|
||||
List<String> semanticFields = Lists.newArrayList();
|
||||
ontologyQuery.getMetrics().forEach(m -> semanticFields.add(m.getName()));
|
||||
ontologyQuery.getDimensions().forEach(d -> semanticFields.add(d.getName()));
|
||||
@@ -103,6 +99,16 @@ public class SqlQueryParser implements QueryParser {
|
||||
log.info("parse sqlQuery [{}] ", sqlQuery);
|
||||
}
|
||||
|
||||
private boolean allFieldMatched(Set<String> queryFields,
|
||||
List<Pair<String, String>> ontologyMetricsDimensionsAndBizName) {
|
||||
for (Pair<String, String> pair : ontologyMetricsDimensionsAndBizName) {
|
||||
if (!(queryFields.contains(pair.getLeft()) || queryFields.contains(pair.getRight()))) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
private void aliasesWithBackticks(QueryStatement queryStatement) {
|
||||
String sql = queryStatement.getSqlQuery().getSql();
|
||||
sql = SqlReplaceHelper.replaceAliasWithBackticks(sql);
|
||||
|
||||
@@ -36,21 +36,14 @@ public class DataModelNode extends SemanticNode {
|
||||
&& !dataModel.getModelDetail().getSqlQuery().isEmpty()) {
|
||||
sqlTable = dataModel.getModelDetail().getSqlQuery();
|
||||
// if model has sqlVariables, parse sqlVariables
|
||||
if (Objects.nonNull(dataModel.getModelDetail().getSqlVariables()) &&
|
||||
!(CollectionUtils.isEmpty(dataModel.getModelDetail().getSqlVariables()))) {
|
||||
if (Objects.nonNull(dataModel.getModelDetail().getSqlVariables())
|
||||
&& !(CollectionUtils.isEmpty(dataModel.getModelDetail().getSqlVariables()))) {
|
||||
sqlTable = SqlVariableParseUtils.parse(sqlTable,
|
||||
dataModel.getModelDetail().getSqlVariables(), Lists.newArrayList());
|
||||
}
|
||||
} else if (dataModel.getModelDetail().getTableQuery() != null
|
||||
&& !dataModel.getModelDetail().getTableQuery().isEmpty()) {
|
||||
if (dataModel.getModelDetail().getDbType()
|
||||
.equalsIgnoreCase(EngineType.POSTGRESQL.getName())) {
|
||||
String fullTableName = String.join(".public.",
|
||||
dataModel.getModelDetail().getTableQuery().split("\\."));
|
||||
sqlTable = "SELECT * FROM " + fullTableName;
|
||||
} else {
|
||||
sqlTable = "SELECT * FROM " + dataModel.getModelDetail().getTableQuery();
|
||||
}
|
||||
sqlTable = "SELECT * FROM " + dataModel.getModelDetail().getTableQuery();
|
||||
}
|
||||
|
||||
// String filterSql = dataModel.getFilterSql();
|
||||
|
||||
@@ -88,7 +88,7 @@ public class SqlBuilder {
|
||||
GraphPath<String, DefaultEdge> selectedGraphPath = null;
|
||||
for (String fromModel : queryModels) {
|
||||
for (String toModel : queryModels) {
|
||||
if (fromModel != toModel) {
|
||||
if (!fromModel.equals(toModel)) {
|
||||
GraphPath<String, DefaultEdge> path = dijkstraAlg.getPath(fromModel, toModel);
|
||||
if (isGraphPathContainsAll(path, queryModels)) {
|
||||
selectedGraphPath = path;
|
||||
@@ -100,13 +100,13 @@ public class SqlBuilder {
|
||||
if (selectedGraphPath == null) {
|
||||
return dataModels;
|
||||
}
|
||||
Set<String> modelNames = Sets.newHashSet();
|
||||
Set<String> modelNames = Sets.newLinkedHashSet();
|
||||
for (DefaultEdge edge : selectedGraphPath.getEdgeList()) {
|
||||
modelNames.add(selectedGraphPath.getGraph().getEdgeSource(edge));
|
||||
modelNames.add(selectedGraphPath.getGraph().getEdgeTarget(edge));
|
||||
}
|
||||
return modelNames.stream().map(m -> ontology.getModelMap().get(m))
|
||||
.collect(Collectors.toSet());
|
||||
.collect(Collectors.toCollection(LinkedHashSet::new));
|
||||
}
|
||||
|
||||
private boolean isGraphPathContainsAll(GraphPath<String, DefaultEdge> graphPath,
|
||||
|
||||
@@ -102,7 +102,7 @@ public class DimValueAspect {
|
||||
continue;
|
||||
}
|
||||
for (DimensionResp dimension : dimensions) {
|
||||
if (!expression.getFieldName().equals(dimension.getName())
|
||||
if (!expression.getFieldName().equals(dimension.getBizName())
|
||||
|| CollectionUtils.isEmpty(dimension.getDimValueMaps())) {
|
||||
continue;
|
||||
}
|
||||
@@ -124,6 +124,10 @@ public class DimValueAspect {
|
||||
sql = SqlReplaceHelper.replaceValue(sql, filedNameToValueMap);
|
||||
log.debug("correctorSql after replacing:{}", sql);
|
||||
querySqlReq.setSql(sql);
|
||||
if (StringUtils.isEmpty(querySqlReq.getSqlInfo().getParsedS2SQL())
|
||||
&& StringUtils.isEmpty(querySqlReq.getSqlInfo().getCorrectedS2SQL())) {
|
||||
querySqlReq.getSqlInfo().setQuerySQL(sql);
|
||||
}
|
||||
Map<String, Map<String, String>> techNameToBizName = getTechNameToBizName(dimensions);
|
||||
|
||||
SemanticQueryResp queryResultWithColumns = (SemanticQueryResp) joinPoint.proceed();
|
||||
|
||||
@@ -123,7 +123,9 @@ public class S2SemanticLayerService implements SemanticLayerService {
|
||||
|
||||
// 3 translate query
|
||||
QueryStatement queryStatement = buildQueryStatement(queryReq, user);
|
||||
semanticTranslator.translate(queryStatement);
|
||||
if (!queryStatement.isTranslated()) {
|
||||
semanticTranslator.translate(queryStatement);
|
||||
}
|
||||
|
||||
// Check whether the dimensions of the metric drill-down are correct temporarily,
|
||||
// add the abstraction of a validator later.
|
||||
@@ -296,6 +298,9 @@ public class S2SemanticLayerService implements SemanticLayerService {
|
||||
queryStatement.setSql(semanticQueryReq.getSqlInfo().getQuerySQL());
|
||||
queryStatement.setIsTranslated(true);
|
||||
}
|
||||
if (queryStatement != null) {
|
||||
queryStatement.setUser(user);
|
||||
}
|
||||
return queryStatement;
|
||||
}
|
||||
|
||||
|
||||
@@ -13,5 +13,7 @@ public interface DimensionDOCustomMapper {
|
||||
|
||||
void batchUpdateStatus(List<DimensionDO> dimensionDOS);
|
||||
|
||||
void batchUpdate(List<DimensionDO> dimensionDOS);
|
||||
|
||||
List<DimensionDO> queryDimensions(DimensionsFilter dimensionsFilter);
|
||||
}
|
||||
|
||||
@@ -13,6 +13,8 @@ public interface MetricDOCustomMapper {
|
||||
|
||||
void batchUpdateStatus(List<MetricDO> metricDOS);
|
||||
|
||||
void batchUpdate(List<MetricDO> metricDOS);
|
||||
|
||||
void batchPublish(List<MetricDO> metricDOS);
|
||||
|
||||
void batchUnPublish(List<MetricDO> metricDOS);
|
||||
|
||||
@@ -16,6 +16,8 @@ public interface DimensionRepository {
|
||||
|
||||
void batchUpdateStatus(List<DimensionDO> dimensionDOS);
|
||||
|
||||
void batchUpdate(List<DimensionDO> dimensionDOS);
|
||||
|
||||
DimensionDO getDimensionById(Long id);
|
||||
|
||||
List<DimensionDO> getDimension(DimensionFilter dimensionFilter);
|
||||
|
||||
@@ -17,6 +17,8 @@ public interface MetricRepository {
|
||||
|
||||
void batchUpdateStatus(List<MetricDO> metricDOS);
|
||||
|
||||
void batchUpdateMetric(List<MetricDO> metricDOS);
|
||||
|
||||
void batchPublish(List<MetricDO> metricDOS);
|
||||
|
||||
void batchUnPublish(List<MetricDO> metricDOS);
|
||||
|
||||
@@ -17,12 +17,11 @@ import com.tencent.supersonic.headless.server.service.DimensionService;
|
||||
import com.tencent.supersonic.headless.server.utils.DictUtils;
|
||||
import com.xkzhangsan.time.utils.CollectionUtil;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.codehaus.plexus.util.StringUtils;
|
||||
import org.apache.commons.lang3.StringUtils;
|
||||
import org.springframework.beans.factory.annotation.Value;
|
||||
import org.springframework.stereotype.Repository;
|
||||
import org.springframework.util.CollectionUtils;
|
||||
|
||||
import java.util.Collections;
|
||||
import java.util.Comparator;
|
||||
import java.util.List;
|
||||
import java.util.Objects;
|
||||
|
||||
@@ -46,6 +46,11 @@ public class DimensionRepositoryImpl implements DimensionRepository {
|
||||
dimensionDOCustomMapper.batchUpdateStatus(dimensionDOS);
|
||||
}
|
||||
|
||||
@Override
|
||||
public void batchUpdate(List<DimensionDO> dimensionDOS) {
|
||||
dimensionDOCustomMapper.batchUpdate(dimensionDOS);
|
||||
}
|
||||
|
||||
@Override
|
||||
public DimensionDO getDimensionById(Long id) {
|
||||
return dimensionDOMapper.selectById(id);
|
||||
@@ -83,10 +88,10 @@ public class DimensionRepositoryImpl implements DimensionRepository {
|
||||
}
|
||||
if (StringUtils.isNotBlank(dimensionFilter.getKey())) {
|
||||
String key = dimensionFilter.getKey();
|
||||
queryWrapper.lambda().like(DimensionDO::getName, key).or()
|
||||
queryWrapper.and(qw -> qw.lambda().like(DimensionDO::getName, key).or()
|
||||
.like(DimensionDO::getBizName, key).or().like(DimensionDO::getDescription, key)
|
||||
.or().like(DimensionDO::getAlias, key).or()
|
||||
.like(DimensionDO::getCreatedBy, key);
|
||||
.like(DimensionDO::getCreatedBy, key));
|
||||
}
|
||||
|
||||
return dimensionDOMapper.selectList(queryWrapper);
|
||||
|
||||
@@ -53,6 +53,11 @@ public class MetricRepositoryImpl implements MetricRepository {
|
||||
metricDOCustomMapper.batchUpdateStatus(metricDOS);
|
||||
}
|
||||
|
||||
@Override
|
||||
public void batchUpdateMetric(List<MetricDO> metricDOS) {
|
||||
metricDOCustomMapper.batchUpdate(metricDOS);
|
||||
}
|
||||
|
||||
@Override
|
||||
public void batchPublish(List<MetricDO> metricDOS) {
|
||||
metricDOCustomMapper.batchPublish(metricDOS);
|
||||
@@ -104,14 +109,14 @@ public class MetricRepositoryImpl implements MetricRepository {
|
||||
if (StringUtils.isNotBlank(metricFilter.getCreatedBy())) {
|
||||
queryWrapper.lambda().eq(MetricDO::getCreatedBy, metricFilter.getCreatedBy());
|
||||
}
|
||||
if (Objects.nonNull(metricFilter.getIsPublish()) && metricFilter.getIsPublish() == 1) {
|
||||
if (Objects.nonNull(metricFilter.getIsPublish())) {
|
||||
queryWrapper.lambda().eq(MetricDO::getIsPublish, metricFilter.getIsPublish());
|
||||
}
|
||||
if (StringUtils.isNotBlank(metricFilter.getKey())) {
|
||||
String key = metricFilter.getKey();
|
||||
queryWrapper.lambda().like(MetricDO::getName, key).or().like(MetricDO::getBizName, key)
|
||||
.or().like(MetricDO::getDescription, key).or().like(MetricDO::getAlias, key)
|
||||
.or().like(MetricDO::getCreatedBy, key);
|
||||
queryWrapper.lambda().and(wrapper -> wrapper.like(MetricDO::getName, key).or()
|
||||
.like(MetricDO::getBizName, key).or().like(MetricDO::getDescription, key).or()
|
||||
.like(MetricDO::getAlias, key).or().like(MetricDO::getCreatedBy, key));
|
||||
}
|
||||
|
||||
return metricDOMapper.selectList(queryWrapper);
|
||||
|
||||
@@ -46,8 +46,10 @@ public class ChatModelController {
|
||||
}
|
||||
|
||||
@DeleteMapping("/{id}")
|
||||
public boolean deleteModel(@PathVariable("id") Integer id) {
|
||||
chatModelService.deleteChatModel(id);
|
||||
public boolean deleteModel(@PathVariable("id") Integer id,
|
||||
HttpServletRequest httpServletRequest, HttpServletResponse httpServletResponse) {
|
||||
User user = UserHolder.findUser(httpServletRequest, httpServletResponse);
|
||||
chatModelService.deleteChatModel(id, user);
|
||||
return true;
|
||||
}
|
||||
|
||||
|
||||
@@ -20,6 +20,7 @@ import org.springframework.web.bind.annotation.RequestMapping;
|
||||
import org.springframework.web.bind.annotation.RequestParam;
|
||||
import org.springframework.web.bind.annotation.RestController;
|
||||
|
||||
import java.util.Arrays;
|
||||
import java.util.List;
|
||||
|
||||
@RestController
|
||||
@@ -50,10 +51,9 @@ public class DataSetController {
|
||||
|
||||
@GetMapping("/getDataSetList")
|
||||
public List<DataSetResp> getDataSetList(@RequestParam("domainId") Long domainId) {
|
||||
MetaFilter metaFilter = new MetaFilter();
|
||||
metaFilter.setDomainId(domainId);
|
||||
metaFilter.setStatus(StatusEnum.ONLINE.getCode());
|
||||
return dataSetService.getDataSetList(metaFilter);
|
||||
List<Integer> statuCodeList =
|
||||
Arrays.asList(StatusEnum.ONLINE.getCode(), StatusEnum.OFFLINE.getCode());
|
||||
return dataSetService.getDataSetList(domainId, statuCodeList);
|
||||
}
|
||||
|
||||
@DeleteMapping("/{id}")
|
||||
|
||||
@@ -64,8 +64,10 @@ public class DatabaseController {
|
||||
}
|
||||
|
||||
@DeleteMapping("/{id}")
|
||||
public boolean deleteDatabase(@PathVariable("id") Long id) {
|
||||
databaseService.deleteDatabase(id);
|
||||
public boolean deleteDatabase(@PathVariable("id") Long id, HttpServletRequest request,
|
||||
HttpServletResponse response) {
|
||||
User user = UserHolder.findUser(request, response);
|
||||
databaseService.deleteDatabase(id, user);
|
||||
return true;
|
||||
}
|
||||
|
||||
|
||||
@@ -20,6 +20,8 @@ public interface DataSetService {
|
||||
|
||||
List<DataSetResp> getDataSetList(MetaFilter metaFilter);
|
||||
|
||||
List<DataSetResp> getDataSetList(Long domainId, List<Integer> statuCodesList);
|
||||
|
||||
void delete(Long id, User user);
|
||||
|
||||
Map<Long, List<Long>> getModelIdToDataSetIds(List<Long> dataSetIds, User user);
|
||||
|
||||
@@ -34,7 +34,7 @@ public interface DatabaseService {
|
||||
|
||||
List<DatabaseResp> getDatabaseList(User user);
|
||||
|
||||
void deleteDatabase(Long databaseId);
|
||||
void deleteDatabase(Long databaseId, User user);
|
||||
|
||||
List<String> getCatalogs(Long id) throws SQLException;
|
||||
|
||||
|
||||
@@ -27,10 +27,15 @@ public interface DimensionService {
|
||||
|
||||
DimensionResp createDimension(DimensionReq dimensionReq, User user) throws Exception;
|
||||
|
||||
void alterDimensionBatch(List<DimensionReq> dimensionReqs, Long modelId, User user)
|
||||
throws Exception;
|
||||
|
||||
void createDimensionBatch(List<DimensionReq> dimensionReqs, User user) throws Exception;
|
||||
|
||||
void updateDimension(DimensionReq dimensionReq, User user) throws Exception;
|
||||
|
||||
void updateDimensionBatch(List<DimensionReq> dimensionReqs, User user) throws Exception;
|
||||
|
||||
PageInfo<DimensionResp> queryDimension(PageDimensionReq pageDimensionReq);
|
||||
|
||||
List<DimensionResp> queryDimensions(DimensionsFilter dimensionsFilter);
|
||||
@@ -39,13 +44,15 @@ public interface DimensionService {
|
||||
|
||||
void deleteDimension(Long id, User user);
|
||||
|
||||
void deleteDimensionBatch(List<Long> idList, User user);
|
||||
|
||||
List<DimensionResp> getDimensionInModelCluster(Long modelId);
|
||||
|
||||
List<String> mockAlias(DimensionReq dimensionReq, String mockType, User user);
|
||||
|
||||
List<DimValueMap> mockDimensionValueAlias(DimensionReq dimensionReq, User user);
|
||||
|
||||
void sendDimensionEventBatch(List<Long> modelIds, EventType eventType);
|
||||
void sendDimensionEventBatch(List<Long> modelIds, EventType eventType, User user);
|
||||
|
||||
DataEvent getAllDataEvents();
|
||||
|
||||
|
||||
@@ -26,8 +26,12 @@ public interface MetricService {
|
||||
|
||||
void createMetricBatch(List<MetricReq> metricReqs, User user) throws Exception;
|
||||
|
||||
void alterMetricBatch(List<MetricReq> metricReqs, Long modelId, User user) throws Exception;
|
||||
|
||||
MetricResp updateMetric(MetricReq metricReq, User user) throws Exception;
|
||||
|
||||
void updateMetricBatch(List<MetricReq> metricReqs, User user) throws Exception;
|
||||
|
||||
void batchUpdateStatus(MetaBatchReq metaBatchReq, User user);
|
||||
|
||||
void batchPublish(List<Long> metricIds, User user);
|
||||
@@ -40,6 +44,8 @@ public interface MetricService {
|
||||
|
||||
void deleteMetric(Long id, User user) throws Exception;
|
||||
|
||||
void deleteMetricBatch(List<Long> idList, User user);
|
||||
|
||||
PageInfo<MetricResp> queryMetricMarket(PageMetricReq pageMetricReq, User user);
|
||||
|
||||
PageInfo<MetricResp> queryMetric(PageMetricReq pageMetricReq, User user);
|
||||
@@ -64,7 +70,7 @@ public interface MetricService {
|
||||
|
||||
MetricQueryDefaultConfig getMetricQueryDefaultConfig(Long metricId, User user);
|
||||
|
||||
void sendMetricEventBatch(List<Long> modelIds, EventType eventType);
|
||||
void sendMetricEventBatch(List<Long> modelIds, EventType eventType, User user);
|
||||
|
||||
List<MetricResp> queryMetrics(MetricsFilter metricsFilter);
|
||||
|
||||
|
||||
@@ -11,6 +11,8 @@ import com.tencent.supersonic.headless.api.pojo.request.*;
|
||||
import com.tencent.supersonic.headless.api.pojo.response.DatabaseResp;
|
||||
import com.tencent.supersonic.headless.api.pojo.response.ModelResp;
|
||||
import com.tencent.supersonic.headless.api.pojo.response.UnAvailableItemResp;
|
||||
import com.tencent.supersonic.headless.server.persistence.dataobject.DimensionDO;
|
||||
import com.tencent.supersonic.headless.server.persistence.dataobject.MetricDO;
|
||||
import com.tencent.supersonic.headless.server.pojo.ModelFilter;
|
||||
|
||||
import java.sql.SQLException;
|
||||
@@ -53,5 +55,9 @@ public interface ModelService {
|
||||
|
||||
void batchUpdateStatus(MetaBatchReq metaBatchReq, User user);
|
||||
|
||||
Dimension updateDimension(DimensionReq dimensionReq, User user);
|
||||
void updateModelByDimAndMetric(Long modelId, List<DimensionReq> dimensionReqList,
|
||||
List<MetricReq> metricReqList, User user);
|
||||
|
||||
void deleteModelDetailByDimAndMetric(Long modelId, List<DimensionDO> dimensionReqList,
|
||||
List<MetricDO> metricReqList);
|
||||
}
|
||||
|
||||
@@ -102,6 +102,20 @@ public class DataSetServiceImpl extends ServiceImpl<DataSetDOMapper, DataSetDO>
|
||||
return list(wrapper).stream().map(this::convert).collect(Collectors.toList());
|
||||
}
|
||||
|
||||
@Override
|
||||
public List<DataSetResp> getDataSetList(Long domainId, List<Integer> statuCodesList) {
|
||||
if (domainId == null || CollectionUtils.isEmpty(statuCodesList)) {
|
||||
return List.of();
|
||||
}
|
||||
QueryWrapper<DataSetDO> wrapper = new QueryWrapper<>();
|
||||
wrapper.lambda().eq(DataSetDO::getDomainId, domainId);
|
||||
wrapper.lambda().in(DataSetDO::getStatus, statuCodesList);
|
||||
wrapper.lambda().ne(DataSetDO::getStatus, StatusEnum.DELETED.getCode());
|
||||
|
||||
return list(wrapper).stream().map(this::convert).collect(Collectors.toList());
|
||||
|
||||
}
|
||||
|
||||
@Override
|
||||
public void delete(Long id, User user) {
|
||||
DataSetDO dataSetDO = getById(id);
|
||||
|
||||
@@ -138,7 +138,12 @@ public class DatabaseServiceImpl extends ServiceImpl<DatabaseDOMapper, DatabaseD
|
||||
}
|
||||
|
||||
@Override
|
||||
public void deleteDatabase(Long databaseId) {
|
||||
public void deleteDatabase(Long databaseId, User user) {
|
||||
DatabaseResp databaseResp = getDatabase(databaseId);
|
||||
if (!checkAdminPermission(user, databaseResp)) {
|
||||
throw new RuntimeException("没有权限删除该数据库");
|
||||
}
|
||||
|
||||
ModelFilter modelFilter = new ModelFilter();
|
||||
modelFilter.setDatabaseId(databaseId);
|
||||
modelFilter.setIncludesDetail(false);
|
||||
@@ -282,6 +287,7 @@ public class DatabaseServiceImpl extends ServiceImpl<DatabaseDOMapper, DatabaseD
|
||||
public List<DBColumn> getColumns(Long id, String catalog, String db, String table)
|
||||
throws SQLException {
|
||||
DatabaseResp databaseResp = getDatabase(id);
|
||||
catalog = StringUtils.isEmpty(catalog) ? db : catalog;
|
||||
return getColumns(databaseResp, catalog, db, table);
|
||||
}
|
||||
|
||||
|
||||
@@ -48,7 +48,7 @@ public class DictWordService {
|
||||
return;
|
||||
}
|
||||
setPreDictWords(dictWords);
|
||||
knowledgeBaseService.reloadAllData(getAllDictWords());
|
||||
knowledgeBaseService.reloadAllData(dictWords);
|
||||
long duration = System.currentTimeMillis() - startTime;
|
||||
log.info("Dictionary has been regularly reloaded in {} milliseconds", duration);
|
||||
}
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user