mirror of
https://github.com/tencentmusic/supersonic.git
synced 2025-12-10 11:07:06 +00:00
Compare commits
310 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
4d095f9676 | ||
|
|
7c7ccadcfd | ||
|
|
02b9dc6947 | ||
|
|
4222d7e2b5 | ||
|
|
b5daf04c96 | ||
|
|
d26089a249 | ||
|
|
627683d437 | ||
|
|
87e222eecc | ||
|
|
667272b103 | ||
|
|
3f08d95aaa | ||
|
|
27ebda3439 | ||
|
|
52eca178d3 | ||
|
|
f5a064aaad | ||
|
|
41e585324d | ||
|
|
c36082476f | ||
|
|
4bc1378285 | ||
|
|
23d01c4f83 | ||
|
|
49ebb70cb3 | ||
|
|
27bb1b322e | ||
|
|
0534053ff9 | ||
|
|
fe2a424718 | ||
|
|
d79e30cd7a | ||
|
|
aa433baa06 | ||
|
|
30bb9a1dc0 | ||
|
|
c168925f03 | ||
|
|
42c0bea8fc | ||
|
|
291c00749a | ||
|
|
6763ea0f7b | ||
|
|
f917defea8 | ||
|
|
91718592d4 | ||
|
|
6d9a8095eb | ||
|
|
5b8fdbc6fd | ||
|
|
e15e44e4a2 | ||
|
|
980d317152 | ||
|
|
2c23c2f574 | ||
|
|
80ad75503b | ||
|
|
0143b0a1b2 | ||
|
|
dd115f9d37 | ||
|
|
f198ce1ef8 | ||
|
|
18211a215d | ||
|
|
d6a386ad03 | ||
|
|
8f19584ad7 | ||
|
|
d9eaf79ab8 | ||
|
|
05b1a7ec3b | ||
|
|
11cdcb29fa | ||
|
|
46a9e5b097 | ||
|
|
8c65ac80b5 | ||
|
|
5b3a9ffba8 | ||
|
|
8688c8c2b3 | ||
|
|
13d8b9cff5 | ||
|
|
aa448b1ba3 | ||
|
|
7ef3d92f2c | ||
|
|
9f09598ccd | ||
|
|
36c8938ff7 | ||
|
|
3271db4ca6 | ||
|
|
400d8b34fd | ||
|
|
d4374f7074 | ||
|
|
438ee539d6 | ||
|
|
5ccde0206c | ||
|
|
74ed269544 | ||
|
|
1ad2c5402b | ||
|
|
805abeb261 | ||
|
|
551a376b00 | ||
|
|
47be92d5f6 | ||
|
|
5feac0c14e | ||
|
|
0f02e21eaa | ||
|
|
cdb84716b7 | ||
|
|
731238de08 | ||
|
|
cb1ad94086 | ||
|
|
24b0be7566 | ||
|
|
3241ef87a3 | ||
|
|
bd541e1199 | ||
|
|
f998f27c6f | ||
|
|
cf788316c3 | ||
|
|
8ed7e91221 | ||
|
|
e537b738e4 | ||
|
|
bf3a111e55 | ||
|
|
63a526709d | ||
|
|
e0088e8f8f | ||
|
|
7d33c49db8 | ||
|
|
acee0a36da | ||
|
|
a528ba6070 | ||
|
|
ba224ac335 | ||
|
|
18aa14118c | ||
|
|
4e139c837a | ||
|
|
6ad74bb206 | ||
|
|
16c3de44e4 | ||
|
|
608a4f7a2f | ||
|
|
cd972d0850 | ||
|
|
2aeeb1a14e | ||
|
|
41aa6ff8e4 | ||
|
|
67f658ced2 | ||
|
|
94fa86629d | ||
|
|
2cb0640a7b | ||
|
|
772d5bd3ae | ||
|
|
d6681ead60 | ||
|
|
d94fd4714f | ||
|
|
0365886270 | ||
|
|
aa6c658a9a | ||
|
|
6e3f871015 | ||
|
|
6c9983164e | ||
|
|
e00b935c1f | ||
|
|
f5f9c0314a | ||
|
|
910384d17f | ||
|
|
2fe56e7462 | ||
|
|
b8989e204f | ||
|
|
84b7c2c062 | ||
|
|
14373309aa | ||
|
|
9cd3e22721 | ||
|
|
2f812372d7 | ||
|
|
9f813ca1c0 | ||
|
|
70784598e1 | ||
|
|
ad20380283 | ||
|
|
f4e3922f47 | ||
|
|
bfac71a7d0 | ||
|
|
435e789fa4 | ||
|
|
b9f5e0a354 | ||
|
|
372e4acc2c | ||
|
|
8f37c3175f | ||
|
|
438e8463f5 | ||
|
|
ae9aa1ba0f | ||
|
|
688d26c457 | ||
|
|
8b99b46787 | ||
|
|
80cce47f58 | ||
|
|
c92184d89f | ||
|
|
a8fe575999 | ||
|
|
38099c8cc7 | ||
|
|
32e51257f6 | ||
|
|
e44e7ca8d5 | ||
|
|
d533496b2a | ||
|
|
eb9db28352 | ||
|
|
836ee5f3ed | ||
|
|
0fa31f84a3 | ||
|
|
9a3c71df4a | ||
|
|
e4e39e0496 | ||
|
|
cd901fbc68 | ||
|
|
8fde378534 | ||
|
|
d8f81aca65 | ||
|
|
b9895d541b | ||
|
|
4fbc3c8533 | ||
|
|
62e2bf7de6 | ||
|
|
c6f9ea2b20 | ||
|
|
166a3cfe09 | ||
|
|
cbf84876de | ||
|
|
8bd43f113b | ||
|
|
d710986923 | ||
|
|
dd63b78937 | ||
|
|
8d1a07585b | ||
|
|
f4638b48d5 | ||
|
|
a1d56fc7e4 | ||
|
|
9879c99873 | ||
|
|
1016efc646 | ||
|
|
156aa6822b | ||
|
|
34eb94320e | ||
|
|
ba1d14f40a | ||
|
|
dc4fbb1a14 | ||
|
|
7d770d2a6d | ||
|
|
7b861f563c | ||
|
|
65614ed3ba | ||
|
|
7acdf9cb3d | ||
|
|
2e4954a4e8 | ||
|
|
36052cb4f2 | ||
|
|
8d81f63e08 | ||
|
|
bf5be11549 | ||
|
|
36907ccac1 | ||
|
|
0f3e9e8754 | ||
|
|
883cdbefbe | ||
|
|
968d50e071 | ||
|
|
a9bb1c1f68 | ||
|
|
207d6cba43 | ||
|
|
40ba179703 | ||
|
|
5b8fde70ca | ||
|
|
e3232f0198 | ||
|
|
c5536aa25d | ||
|
|
37bb9ff767 | ||
|
|
f2e8207245 | ||
|
|
86bf40c8fb | ||
|
|
f90ab22119 | ||
|
|
8e7d224b7b | ||
|
|
29925a90ca | ||
|
|
fdf48d7bfd | ||
|
|
d9efe8f137 | ||
|
|
6afb2f0914 | ||
|
|
82ab9e3a6a | ||
|
|
f5ca33859c | ||
|
|
a0b4fb33c1 | ||
|
|
40705181a0 | ||
|
|
308178f299 | ||
|
|
410f2c93b9 | ||
|
|
767abc2b90 | ||
|
|
ab19b18169 | ||
|
|
406fe995c9 | ||
|
|
151963ea79 | ||
|
|
5583426115 | ||
|
|
9d1707eba1 | ||
|
|
f605cf0ef9 | ||
|
|
119e5b8c58 | ||
|
|
de764f3353 | ||
|
|
e4280e5516 | ||
|
|
886ee32e2f | ||
|
|
7544780ff7 | ||
|
|
26beff1080 | ||
|
|
88b8130d37 | ||
|
|
e7b8c68dba | ||
|
|
b753eda9b9 | ||
|
|
e6f2ce2598 | ||
|
|
a191bbbf6e | ||
|
|
65f48dd789 | ||
|
|
c14d4e59d4 | ||
|
|
6b2a14e589 | ||
|
|
d6cefaa6d2 | ||
|
|
278af3ce34 | ||
|
|
3b1cbd4fd7 | ||
|
|
500652da36 | ||
|
|
eee39f56a8 | ||
|
|
719b797037 | ||
|
|
7cb8208065 | ||
|
|
a03ababc80 | ||
|
|
a3565a0ae9 | ||
|
|
07a64375ce | ||
|
|
ec1e63e2f2 | ||
|
|
8487966888 | ||
|
|
4bbd2c7446 | ||
|
|
d9bab899fe | ||
|
|
e3b3e8861d | ||
|
|
69242f9f2d | ||
|
|
a21c7bce40 | ||
|
|
7379e3a833 | ||
|
|
b565b9c4e5 | ||
|
|
99ac17a5e4 | ||
|
|
4ccee8b107 | ||
|
|
eccd791a39 | ||
|
|
3d6878fe9f | ||
|
|
1e1803d148 | ||
|
|
343995fd8f | ||
|
|
71cb20eb4f | ||
|
|
741ed4191b | ||
|
|
2a6391a2ee | ||
|
|
405e846a0e | ||
|
|
155cf22841 | ||
|
|
e688422ec3 | ||
|
|
6047c787b3 | ||
|
|
c03166b622 | ||
|
|
617db611c3 | ||
|
|
f931951ad5 | ||
|
|
df7fea9ee3 | ||
|
|
24e8e756de | ||
|
|
ff5479f1a2 | ||
|
|
4ad3e1d9cf | ||
|
|
e4af83380b | ||
|
|
d30fe53ef3 | ||
|
|
9e0abc60be | ||
|
|
dc33cdce5a | ||
|
|
f5b8690ce0 | ||
|
|
fbb67f54ab | ||
|
|
5c4e80c8f8 | ||
|
|
0774c35589 | ||
|
|
553963a10a | ||
|
|
0bf171c8a6 | ||
|
|
f5549f7430 | ||
|
|
34816451c0 | ||
|
|
e1772c25c4 | ||
|
|
65653c0ee2 | ||
|
|
3addfb9a87 | ||
|
|
67be01f504 | ||
|
|
99aa7c9433 | ||
|
|
ccfdec8b45 | ||
|
|
dbd259adb0 | ||
|
|
ec151d7b53 | ||
|
|
5fbb1927a4 | ||
|
|
3bc13642af | ||
|
|
6b38a4f602 | ||
|
|
51f62438cf | ||
|
|
9d8b54072a | ||
|
|
0982c013d1 | ||
|
|
03a4719aed | ||
|
|
5c3fd75ed4 | ||
|
|
6dfc728b5b | ||
|
|
071ef8432e | ||
|
|
8ad5ffe20f | ||
|
|
49ba0e3f41 | ||
|
|
5a42ff4b78 | ||
|
|
20472dce88 | ||
|
|
057a7c9c6d | ||
|
|
63eff5c62a | ||
|
|
b824cd8ce7 | ||
|
|
eee82dea07 | ||
|
|
98656eb445 | ||
|
|
c8ff37e304 | ||
|
|
3fe726ac23 | ||
|
|
d5a253a781 | ||
|
|
6a5a95e543 | ||
|
|
a94a44826b | ||
|
|
31c8fea2dc | ||
|
|
13dcf0edb9 | ||
|
|
7bc64bc53b | ||
|
|
4991efe50c | ||
|
|
a87304b22b | ||
|
|
3701ade05f | ||
|
|
45ed5648c4 | ||
|
|
682d35b2b2 | ||
|
|
30f5fc9ab1 | ||
|
|
bc69d2221a | ||
|
|
592870f397 | ||
|
|
157c2999dc | ||
|
|
b6d984475c | ||
|
|
6a98ce9d28 | ||
|
|
c802c508fb | ||
|
|
545fb139ee | ||
|
|
c38507d50c |
4
.gitignore
vendored
4
.gitignore
vendored
@@ -15,4 +15,6 @@ assembly/runtime/*
|
||||
/assembly/deploy
|
||||
/runtime
|
||||
**/.flattened-pom.xml
|
||||
__pycache__/
|
||||
chm_db/
|
||||
__pycache__/
|
||||
/dict
|
||||
47
CHANGELOG.md
47
CHANGELOG.md
@@ -3,25 +3,40 @@
|
||||
- All notable changes to this project will be documented in this file.
|
||||
- "Breaking Changes" describes any changes that may break existing functionality or cause
|
||||
compatibility issues with previous versions.
|
||||
|
||||
## SuperSonic [0.7.5] - 2023-10-13
|
||||
|
||||
### Added
|
||||
- add SQL generation improvement optimization, support LLM SQL, Logic SQL, and Physical SQL display.
|
||||
- add showcase functionality to support recommending similar questions.
|
||||
- add frontend modification of filtering conditions and re-querying feature.
|
||||
- support nested query functionality in semantic.
|
||||
- support switching queries between multiple parsers in the frontend.
|
||||
|
||||
### Updated
|
||||
- optimizing the build and deployment of the project.
|
||||
- overall optimization of the SQL Corrector functionality.
|
||||
|
||||
### Fixed
|
||||
- fix execute error on mysql <=5.7
|
||||
|
||||
## SuperSonic [0.7.4] - 2023-09-10
|
||||
## SuperSonic [0.7.4] - 2023-09-10
|
||||
|
||||
### Added
|
||||
- add llm parser config
|
||||
- add datasource agg_time option
|
||||
- add function name adaptor in clickhouse
|
||||
- add dimension and metric show in dsl
|
||||
### Added
|
||||
- add llm parser config
|
||||
- add datasource agg_time option
|
||||
- add function name adaptor in clickhouse
|
||||
- add dimension and metric show in dsl
|
||||
|
||||
|
||||
### Updated
|
||||
- update user guide doc
|
||||
- update query building of plugin in default model
|
||||
- update some core API constructs to keep naming consistency
|
||||
- update ConfigureDemo config
|
||||
- update the association mechanism so that invisible dimensions and metrics will no longer be associated
|
||||
|
||||
### Fixed
|
||||
- fix hasAggregateFunction logic in SqlParserSelectHelper
|
||||
### Updated
|
||||
- update user guide doc
|
||||
- update query building of plugin in default model
|
||||
- update some core API constructs to keep naming consistency
|
||||
- update ConfigureDemo config
|
||||
- update the association mechanism so that invisible dimensions and metrics will no longer be associated
|
||||
|
||||
### Fixed
|
||||
- fix hasAggregateFunction logic in SqlParserSelectHelper
|
||||
|
||||
## SuperSonic [0.7.3] - 2023-08-29
|
||||
|
||||
|
||||
22
README.md
22
README.md
@@ -2,20 +2,20 @@
|
||||
|
||||
# SuperSonic (超音数)
|
||||
|
||||
**SuperSonic is an out-of-the-box yet highly extensible framework for building a data chatbot**. SuperSonic provides a chat interface that empowers users to query data using natural language and visualize the results with suitable charts. To enable such experience, the only thing necessary is to build logical semantic models (definition of metrics/dimensions/entities, along with their meaning, context and relationships) on top of physical data models, and no data modification or copying is required. Meanwhile, SuperSonic is designed to be pluggable, allowing new functionalities to be added through plugins and core components to be integrated with other systems.
|
||||
**SuperSonic is an out-of-the-box yet highly extensible framework for building ChatBI**. SuperSonic provides a chat interface that empowers users to query data using natural language and visualize the results with suitable charts. To enable such experience, the only thing necessary is to build logical semantic models (definition of metrics/dimensions/entities, along with their meaning, context and relationships) on top of physical data models, and no data modification or copying is required. Meanwhile, SuperSonic is designed to be pluggable, allowing new functionalities to be added through plugins and core components to be integrated with other systems.
|
||||
|
||||
<img src="./docs/images/supersonic_demo.gif" height="100%" width="100%" align="center"/>
|
||||
|
||||
## Motivation
|
||||
|
||||
The emergence of Large Language Model (LLM) like ChatGPT is reshaping the way information is retrieved. In the field of data analytics, both academia and industry are primarily focused on leveraging LLM to convert natural language queries into SQL queries. While some works show promising results, they are still not applicable to real-world scenarios.
|
||||
The emergence of Large Language Model (LLM) like ChatGPT is reshaping the way information is retrieved. In the field of data analytics, both academia and industry are primarily focused on leveraging LLM to convert natural language into SQL (so called text2sql or nl2sql). While some works exhibit promising results, their **reliability** is inadequate for real-world applications.
|
||||
|
||||
From our perspective, the key to filling the real-world gap lies in three aspects:
|
||||
1. Complement the LLM-based semantic parser with rule-based semantic parsers to improve **efficiency**(in terms of latency and cost).
|
||||
2. Augment semantic parsing with schema mappers(as a kind of preprocessor) and semantic correctors(as a kind of postprocessor) to improve **accuracy** and **stability**.
|
||||
3. Introduce a semantic layer encapsulating underlying data context(joins, formulas, etc) to reduce **complexity**.
|
||||
1. Introduce a semantic layer encapsulating underlying data context(joins, formulas, etc) to reduce **complexity**.
|
||||
2. Augment the LLM with schema mappers(as a kind of preprocessor) and semantic correctors(as a kind of postprocessor) to mitigate **hallucination**.
|
||||
3. Utilize heuristic rules when necessary to improve **efficiency**(in terms of latency and cost).
|
||||
|
||||
With these ideas in mind, we develop SuperSonic as a practical reference implementation and use it to power our real-world products. Additionally, to facilitate further development of data chatbot, we decide to open source SuperSonic as an extensible framework.
|
||||
With these ideas in mind, we develop SuperSonic as a practical reference implementation and use it to power our real-world products. Additionally, to facilitate further development of ChatBI, we decide to open source SuperSonic as an extensible framework.
|
||||
|
||||
## Out-of-the-box Features
|
||||
|
||||
@@ -40,7 +40,7 @@ The high-level architecture and main process flow is as follows:
|
||||
|
||||
- **Semantic Corrector:** checks validity of extracted semantic information and performs correction and optimization if needed.
|
||||
|
||||
- **Semantic Layer:** performs execution according to extracted semantic information. It generates SQL queries and executes them against physical data models.
|
||||
- **Semantic Interpreter:** performs execution according to extracted semantic information. It generates SQL statements and executes them against physical data models.
|
||||
|
||||
- **Chat Plugin:** extends functionality with third-party tools. The LLM is going to select the most suitable one, given all configured plugins with function description and sample questions.
|
||||
|
||||
@@ -49,15 +49,15 @@ The high-level architecture and main process flow is as follows:
|
||||
SuperSonic comes with sample semantic models as well as chat conversations that can be used as a starting point. Please follow the steps:
|
||||
|
||||
- Download the latest prebuilt binary from the [release page](https://github.com/tencentmusic/supersonic/releases)
|
||||
- Run script "bin/start-standalone.sh" to start services (one java process and one python process)
|
||||
- Run script "bin/supersonic-daemon.sh" to start services (one java process and one python process)
|
||||
- Visit http://localhost:9080 in the browser to start exploration
|
||||
|
||||
## Build and Delopment
|
||||
## Build and Development
|
||||
|
||||
Please refer to project [wiki](https://github.com/tencentmusic/supersonic/wiki).
|
||||
|
||||
## WeChat Contact
|
||||
|
||||
Please join the chat group to suggest feedbacks or ideas:
|
||||
Please follow SuperSonic wechat official account:
|
||||
|
||||
<img src="./docs/images/wechat_contact.jpeg" height="40%" width="40%" align="center"/>
|
||||
<img src="./docs/images/supersonic_wechat_oa.png" height="50%" width="50%" align="center"/>
|
||||
18
README_CN.md
18
README_CN.md
@@ -6,12 +6,12 @@
|
||||
|
||||
## 项目动机
|
||||
|
||||
大型语言模型(LLMs)如ChatGPT的出现正在重塑信息检索的方式。在数据分析领域,学术界和工业界主要关注利用深度学习模型将自然语言查询转换为SQL查询。虽然一些工作显示出有前景的结果,但它们还并不适用于实际场景。
|
||||
大型语言模型(LLMs)如ChatGPT的出现正在重塑信息检索的方式。在数据分析领域,学术界和工业界主要关注利用深度学习模型将自然语言查询转换为SQL查询。虽然一些工作显示出有前景的结果,但它们的可靠性还达不到生产可用的要求。
|
||||
|
||||
在我们看来,为了在实际场景发挥价值,有三个关键点:
|
||||
1. 在基于大模型语义解析器基础上,增加基于规则的解析器,提升语义解析的**效率**。
|
||||
2. 加入模式映射器和语义修正器,来增强语义解析能力,提升语义解析的**准确性**和**稳定性**。
|
||||
3. 引入语义模型层,封装底层数据的上下文(关联、公式等),降低语义解析的**复杂性**。
|
||||
1. 引入语义模型层,封装底层数据的上下文(关联、公式等),降低SQL生成的**复杂度**。
|
||||
2. 通过一前一后的模式映射器和语义修正器,来缓解LLM常见的**幻觉**现象。
|
||||
3. 设计启发式的规则,在一些特定场景提升语义解析的**效率**。
|
||||
|
||||
为了验证上述想法,我们开发了超音数项目,并将其应用在实际的内部产品中。与此同时,我们将超音数作为一个可扩展的框架开源,希望能够促进数据问答对话领域的进一步发展。
|
||||
|
||||
@@ -30,7 +30,7 @@
|
||||
|
||||
<img src="./docs/images/supersonic_components.png" height="65%" width="65%" align="center"/>
|
||||
|
||||
- **知识库(Knowledge Base):** 定期从语义模型中提取相关的模式信息,构建词典和索引,以便后续的模式映射。
|
||||
- **模型知识库(Knowledge Base):** 定期从语义模型中提取相关的模式信息,构建词典和索引,以便后续的模式映射。
|
||||
|
||||
- **模式映射器(Schema Mapper):** 将自然语言文本在知识库中进行匹配,为后续的语义解析提供相关信息。
|
||||
|
||||
@@ -38,7 +38,7 @@
|
||||
|
||||
- **语义修正器(Semantic Corrector):** 检查语义信息的合法性,对不合法的信息做修正和优化处理。
|
||||
|
||||
- **语义模型层(Semantic Layer):** 根据语义信息生成物理SQL执行查询。
|
||||
- **语义解释器(Semantic Interpreter):** 根据语义信息生成物理SQL执行查询。
|
||||
|
||||
- **问答插件(Chat Plugin):** 通过第三方工具扩展功能。给定所有配置的插件及其功能描述和示例问题,大语言模型将选择最合适的插件。
|
||||
|
||||
@@ -47,7 +47,7 @@
|
||||
超音数自带样例的语义模型和问答对话,只需以下三步即可快速体验:
|
||||
|
||||
- 从[release page](https://github.com/tencentmusic/supersonic/releases)下载预先构建好的发行包
|
||||
- 运行 "bin/start-standalone.sh"启动服务(一个Java进程和一个Python进程)
|
||||
- 运行 "bin/supersonic-daemon.sh"启动服务(一个Java进程和一个Python进程)
|
||||
- 在浏览器访问http://localhost:9080 开启探索
|
||||
|
||||
## 如何构建和部署
|
||||
@@ -56,6 +56,6 @@
|
||||
|
||||
## 微信联系方式
|
||||
|
||||
欢迎加入微信群反馈建议:
|
||||
欢迎关注微信公众号:
|
||||
|
||||
<img src="./docs/images/wechat_contact.jpeg" height="40%" width="40%" align="center"/>
|
||||
<img src="./docs/images/supersonic_wechat_oa.png" height="50%" width="50%" align="center"/>
|
||||
@@ -1,30 +1,53 @@
|
||||
@echo off
|
||||
setlocal
|
||||
|
||||
chcp 65001
|
||||
set "sbinDir=%~dp0"
|
||||
set "baseDir=%~dp0.."
|
||||
set "buildDir=%baseDir%\build"
|
||||
set "runtimeDir=%baseDir%\..\runtime"
|
||||
set "pip_path=pip3"
|
||||
|
||||
|
||||
rem 1. build semantic chat service
|
||||
rem 1. build backend java modules
|
||||
del /q "%buildDir%\*.tar.gz" 2>NUL
|
||||
|
||||
call mvn -f "%baseDir%\..\pom.xml" clean package -DskipTests
|
||||
|
||||
rem 2. move package to build
|
||||
echo f|xcopy "%baseDir%\..\launchers\standalone\target\*.tar.gz" "%buildDir%\supersonic-standalone.tar.gz"
|
||||
echo f|xcopy "%baseDir%\..\launchers\semantic\target\*.tar.gz" "%buildDir%\supersonic-semantic.tar.gz"
|
||||
echo f|xcopy "%baseDir%\..\launchers\chat\target\*.tar.gz" "%buildDir%\supersonic-chat.tar.gz"
|
||||
|
||||
rem 3. build webapp
|
||||
rem 3. build frontend webapp
|
||||
cd "%baseDir%\..\webapp"
|
||||
call start-fe-prod.bat
|
||||
copy /y "%baseDir%\..\webapp\supersonic-webapp.tar.gz" "%buildDir%\"
|
||||
|
||||
|
||||
rem 4. copy webapp to java classpath
|
||||
cd "%buildDir%"
|
||||
tar -zxvf supersonic-webapp.tar.gz
|
||||
move supersonic-webapp webapp
|
||||
move webapp ..\..\launchers\standalone\target\classes
|
||||
|
||||
rem 5. build backend python modules
|
||||
echo "start installing python modules with pip: ${pip_path}"
|
||||
set requirementPath="%baseDir%/../chat/python/requirements.txt"
|
||||
%pip_path% install -r %requirementPath%
|
||||
echo "install python modules success"
|
||||
|
||||
call :BUILD_RUNTIME
|
||||
|
||||
:BUILD_RUNTIME
|
||||
rem 6. reset runtime
|
||||
rd /s /q "%runtimeDir%"
|
||||
mkdir "%runtimeDir%"
|
||||
tar -zxvf "%buildDir%\supersonic-standalone.tar.gz" -C "%runtimeDir%"
|
||||
for /d %%f in ("%runtimeDir%\launchers-standalone-*") do (
|
||||
move "%%f" "%runtimeDir%\supersonic-standalone"
|
||||
)
|
||||
|
||||
rem 7. copy webapp to runtime
|
||||
tar -zxvf "%buildDir%\supersonic-webapp.tar.gz" -C "%buildDir%"
|
||||
if not exist "%runtimeDir%\supersonic-standalone\webapp" mkdir "%runtimeDir%\supersonic-standalone\webapp"
|
||||
xcopy /s /e /h /y "%buildDir%\supersonic-webapp\*" "%runtimeDir%\supersonic-standalone\webapp"
|
||||
if not exist "%runtimeDir%\supersonic-standalone\conf\webapp" mkdir "%runtimeDir%\supersonic-standalone\conf\webapp"
|
||||
xcopy /s /e /h /y "%runtimeDir%\supersonic-standalone\webapp\*" "%runtimeDir%\supersonic-standalone\conf\webapp"
|
||||
rd /s /q "%buildDir%\supersonic-webapp"
|
||||
|
||||
endlocal
|
||||
35
assembly/bin/supersonic-build.sh
Normal file → Executable file
35
assembly/bin/supersonic-build.sh
Normal file → Executable file
@@ -1,29 +1,48 @@
|
||||
#!/usr/bin/env bash
|
||||
|
||||
set -x
|
||||
sbinDir=$(cd "$(dirname "$0")"; pwd)
|
||||
baseDir=$(cd "$sbinDir/.." && pwd -P)
|
||||
runtimeDir=$baseDir/runtime
|
||||
buildDir=$baseDir/build
|
||||
chmod +x $sbinDir/supersonic-common.sh
|
||||
source $sbinDir/supersonic-common.sh
|
||||
|
||||
cd $baseDir
|
||||
|
||||
#1. build semantic chat service
|
||||
#1. build backend java modules
|
||||
rm -fr ${buildDir}/*.tar.gz
|
||||
rm -fr dist
|
||||
|
||||
set +x
|
||||
|
||||
mvn -f $baseDir/../ clean package -DskipTests
|
||||
|
||||
#2. move package to build
|
||||
cp $baseDir/../launchers/standalone/target/*.tar.gz ${buildDir}/supersonic.tar.gz
|
||||
cp $baseDir/../launchers/semantic/target/*.tar.gz ${buildDir}/supersonic-semantic.tar.gz
|
||||
cp $baseDir/../launchers/chat/target/*.tar.gz ${buildDir}/supersonic-chat.tar.gz
|
||||
cp $baseDir/../launchers/standalone/target/*.tar.gz ${buildDir}/supersonic-standalone.tar.gz
|
||||
|
||||
#3. build webapp
|
||||
#3. build frontend webapp
|
||||
chmod +x $baseDir/../webapp/start-fe-prod.sh
|
||||
cd ../webapp
|
||||
sh ./start-fe-prod.sh
|
||||
cp -fr ./supersonic-webapp.tar.gz ${buildDir}/
|
||||
|
||||
|
||||
#4. copy webapp to java classpath
|
||||
cd $buildDir
|
||||
tar xvf supersonic-webapp.tar.gz
|
||||
mv supersonic-webapp webapp
|
||||
mv webapp ../../launchers/standalone/target/classes
|
||||
cp -fr webapp ../../launchers/semantic/target/classes
|
||||
cp -fr webapp ../../launchers/chat/target/classes
|
||||
cp -fr webapp ../../launchers/standalone/target/classes
|
||||
rm -fr ${buildDir}/webapp
|
||||
|
||||
#5. build backend python modules
|
||||
echo "start installing python modules with pip: ${pip_path}"
|
||||
requirementPath=$baseDir/../chat/python/requirements.txt
|
||||
${pip_path} install -r ${requirementPath}
|
||||
echo "install python modules success"
|
||||
|
||||
#6. reset runtime
|
||||
rm -fr $runtimeDir/*
|
||||
moveAllToRuntime
|
||||
setEnvToWeb chat
|
||||
setEnvToWeb semantic
|
||||
|
||||
106
assembly/bin/supersonic-common.sh
Executable file
106
assembly/bin/supersonic-common.sh
Executable file
@@ -0,0 +1,106 @@
|
||||
#!/usr/bin/env bash
|
||||
|
||||
# environment parameters
|
||||
python_path=${PYTHON_PATH:-"python3"}
|
||||
pip_path=${PIP_PATH:-"pip3"}
|
||||
|
||||
sbinDir=$(cd "$(dirname "$0")"; pwd)
|
||||
baseDir=$(cd "$sbinDir/.." && pwd -P)
|
||||
runtimeDir=$baseDir/../runtime
|
||||
buildDir=$baseDir/build
|
||||
|
||||
readonly CHAT_APP_NAME="supersonic_chat"
|
||||
readonly SEMANTIC_APP_NAME="supersonic_semantic"
|
||||
readonly LLMPARSER_APP_NAME="supersonic_llmparser"
|
||||
readonly STANDALONE_APP_NAME="supersonic_standalone"
|
||||
readonly CHAT_SERVICE="chat"
|
||||
readonly SEMANTIC_SERVICE="semantic"
|
||||
readonly LLMPARSER_SERVICE="llmparser"
|
||||
readonly STANDALONE_SERVICE="standalone"
|
||||
readonly LLMPARSER_HOST="127.0.0.1"
|
||||
readonly LLMPARSER_PORT="9092"
|
||||
|
||||
function setEnvToWeb {
|
||||
model_name=$1
|
||||
json='{"env": "'$model_name'"}'
|
||||
echo $json > ${runtimeDir}/supersonic-${model_name}/webapp/supersonic.config.json
|
||||
echo $json > $baseDir/../launchers/${model_name}/target/classes/webapp/supersonic.config.json
|
||||
}
|
||||
|
||||
function moveToRuntime {
|
||||
model_name=$1
|
||||
tar -zxvf ${buildDir}/supersonic-${model_name}.tar.gz -C ${runtimeDir}
|
||||
mv ${runtimeDir}/launchers-${model_name}-* ${runtimeDir}/supersonic-${model_name}
|
||||
|
||||
mkdir -p ${runtimeDir}/supersonic-${model_name}/webapp
|
||||
cp -fr ${buildDir}/webapp/* ${runtimeDir}/supersonic-${model_name}/webapp
|
||||
}
|
||||
|
||||
function moveAllToRuntime {
|
||||
mkdir -p ${runtimeDir}
|
||||
tar xvf ${buildDir}/supersonic-webapp.tar.gz -C ${buildDir}
|
||||
mv ${buildDir}/supersonic-webapp ${buildDir}/webapp
|
||||
|
||||
moveToRuntime chat
|
||||
moveToRuntime semantic
|
||||
moveToRuntime standalone
|
||||
rm -fr ${buildDir}/webapp
|
||||
}
|
||||
|
||||
# run java service
|
||||
function runJavaService {
|
||||
javaRunDir=${runtimeDir}/supersonic-${model_name}
|
||||
local_app_name=$1
|
||||
libDir=$javaRunDir/lib
|
||||
confDir=$javaRunDir/conf
|
||||
|
||||
CLASSPATH=""
|
||||
CLASSPATH=$CLASSPATH:$confDir
|
||||
|
||||
for jarPath in $libDir/*.jar; do
|
||||
CLASSPATH=$CLASSPATH:$jarPath
|
||||
done
|
||||
|
||||
export CLASSPATH
|
||||
export LANG="zh_CN.UTF-8"
|
||||
|
||||
cd $javaRunDir
|
||||
if [[ "$JAVA_HOME" == "" ]]; then
|
||||
JAVA_HOME=$(ls /usr/jdk64/jdk* -d 2>/dev/null | xargs | awk '{print "'$local_app_name'"}')
|
||||
fi
|
||||
export PATH=$JAVA_HOME/bin:$PATH
|
||||
command="-Dfile.encoding="UTF-8" -Duser.language="Zh" -Duser.region="CN" -Duser.timezone="GMT+08" -Dapp_name=${local_app_name} -Xms1024m -Xmx2048m "$main_class
|
||||
|
||||
mkdir -p $javaRunDir/logs
|
||||
if [[ "$is_test" == "true" ]]; then
|
||||
java -Dspring.profiles.active="dev" $command >/dev/null 2>$javaRunDir/logs/error.log &
|
||||
else
|
||||
java $command $javaRunDir >/dev/null 2>$javaRunDir/logs/error.log &
|
||||
fi
|
||||
}
|
||||
|
||||
# run python service
|
||||
function runPythonService {
|
||||
pythonRunDir=${runtimeDir}/supersonic-${model_name}/llmparser
|
||||
cd $pythonRunDir
|
||||
nohup ${python_path} supersonic_llmparser.py > $pythonRunDir/llmparser.log 2>&1 &
|
||||
# add health check
|
||||
for i in {1..10}
|
||||
do
|
||||
echo "llmparser health check attempt $i..."
|
||||
response=$(curl -s http://${LLMPARSER_HOST}:${LLMPARSER_PORT}/health)
|
||||
echo "llmparser health check response: $response"
|
||||
status_ok="Healthy"
|
||||
if [[ $response == *$status_ok* ]] ; then
|
||||
echo "llmparser Health check passed."
|
||||
break
|
||||
else
|
||||
if [ "$i" -eq 10 ]; then
|
||||
echo "llmparser Health check failed after 10 attempts."
|
||||
echo "May still downloading model files. Please check llmparser.log in runtime directory."
|
||||
fi
|
||||
echo "Retrying after 5 seconds..."
|
||||
sleep 5
|
||||
fi
|
||||
done
|
||||
}
|
||||
@@ -1,120 +1,118 @@
|
||||
@echo off
|
||||
setlocal
|
||||
|
||||
chcp 65001
|
||||
set "sbinDir=%~dp0"
|
||||
set "baseDir=%~dp0.."
|
||||
set "runtimeDir=%baseDir%\..\runtime"
|
||||
set "buildDir=%baseDir%\build"
|
||||
set "main_class=com.tencent.supersonic.StandaloneLauncher"
|
||||
set "python_path=python"
|
||||
set "pip_path=pip3"
|
||||
set "standalone_service=standalone"
|
||||
set "llmparser_service=llmparser"
|
||||
|
||||
set "javaRunDir=%runtimeDir%\supersonic-standalone"
|
||||
set "pythonRunDir=%runtimeDir%\supersonic-standalone\llmparser"
|
||||
|
||||
set "command=%~1"
|
||||
set "module=%~2"
|
||||
set "service=%~2"
|
||||
|
||||
set "APP_NAME=standalone-service"
|
||||
set "MAIN_CLASS=com.tencent.supersonic.StandaloneLauncher"
|
||||
|
||||
set "python_path=python"
|
||||
set "pip_path=pip3.9"
|
||||
set "llm_host=127.0.0.1"
|
||||
set "llm_port=9092"
|
||||
set "start_name=api_service"
|
||||
|
||||
set "llm_path=%baseDir%\..\chat\core\src\main\python"
|
||||
|
||||
if "%module%"=="" (
|
||||
set "module=standalone"
|
||||
) else if "%module%"=="semantic" (
|
||||
set "APP_NAME=semantic-service"
|
||||
set "MAIN_CLASS=com.tencent.supersonic.SemanticLauncher"
|
||||
) else if "%module%"=="chat" (
|
||||
set "APP_NAME=chat-service"
|
||||
set "MAIN_CLASS=com.tencent.supersonic.ChatLauncher"
|
||||
if "%service%"=="" (
|
||||
set "service=%standalone_service%"
|
||||
)
|
||||
|
||||
if "%command%"=="" (
|
||||
set "command=restart"
|
||||
)
|
||||
|
||||
set "libDir=%runtimeDir%\supersonic-%module%\lib"
|
||||
set "confDir=%runtimeDir%\supersonic-%module%\conf"
|
||||
set "webDir=%runtimeDir%\supersonic-%module%\webapp"
|
||||
set "CLASSPATH=%confDir%;%webDir%;%libDir%\*"
|
||||
set "java-command=-Dfile.encoding=UTF-8 -Duser.language=Zh -Duser.region=CN -Duser.timezone=GMT+08 -Xms1024m -Xmx2048m -cp %CLASSPATH% %MAIN_CLASS%"
|
||||
|
||||
|
||||
if "%command%"=="stop" (
|
||||
call:STOP
|
||||
goto :EOF
|
||||
)
|
||||
call :BUILD_RUNTIME
|
||||
|
||||
if "%command%"=="restart" (
|
||||
call:STOP
|
||||
)
|
||||
|
||||
::1. clear file
|
||||
rd /s /q "%runtimeDir%"
|
||||
mkdir "%runtimeDir%"
|
||||
|
||||
if "%module%"=="llmparser" (
|
||||
tar -zxvf "%buildDir%\supersonic-standalone.tar.gz" -C "%runtimeDir%"
|
||||
for /d %%f in ("%runtimeDir%\launchers-standalone-*") do (
|
||||
move "%%f" "%runtimeDir%\supersonic-standalone"
|
||||
)
|
||||
cd "%runtimeDir%"
|
||||
"%pip_path%" install -r "%llm_path%\requirements.txt"
|
||||
"%python_path%" -c "import langchain,fastapi,chromadb,tiktoken,uvicorn" >nul 2>&1
|
||||
cd "%runtimeDir%\supersonic-standalone\llm\llm"
|
||||
start "" /B uvicorn %start_name%:app --port %llm_port% --host %llm_host% > "%runtimeDir%\supersonic-standalone\llm\llm.log" 2>&1
|
||||
echo "llm service started, see logs/error with logs/error command"
|
||||
call :STOP
|
||||
call :START
|
||||
goto :EOF
|
||||
) else if "%command%"=="start" (
|
||||
call :START
|
||||
goto :EOF
|
||||
) else if "%command%"=="stop" (
|
||||
call :STOP
|
||||
goto :EOF
|
||||
) else if "%command%"=="reload" (
|
||||
call :RELOAD_EXAMPLE
|
||||
goto :EOF
|
||||
) else (
|
||||
echo "Use command {start|stop|restart} to run."
|
||||
goto :EOF
|
||||
)
|
||||
|
||||
tar -zxvf "%buildDir%\supersonic-%module%.tar.gz" -C "%runtimeDir%"
|
||||
for /d %%f in ("%runtimeDir%\launchers-%module%-*") do (
|
||||
move "%%f" "%runtimeDir%\supersonic-%module%"
|
||||
)
|
||||
|
||||
if not exist "%runtimeDir%\supersonic-%module%\logs" mkdir "%runtimeDir%\supersonic-%module%\logs"
|
||||
|
||||
tar -zxvf "%buildDir%\supersonic-webapp.tar.gz" -C "%buildDir%"
|
||||
if not exist "%runtimeDir%\supersonic-%module%\webapp" mkdir "%runtimeDir%\supersonic-%module%\webapp"
|
||||
xcopy /s /e /h /y "%buildDir%\supersonic-webapp\*" "%runtimeDir%\supersonic-%module%\webapp"
|
||||
if not exist "%runtimeDir%\supersonic-%module%\conf\webapp" mkdir "%runtimeDir%\supersonic-%module%\conf\webapp"
|
||||
xcopy /s /e /h /y "%runtimeDir%\supersonic-%module%\webapp\*" "%runtimeDir%\supersonic-%module%\conf\webapp"
|
||||
rd /s /q "%buildDir%\supersonic-webapp"
|
||||
|
||||
::3. start service
|
||||
::start standalone service
|
||||
if "%command%"=="start" (
|
||||
call:START
|
||||
goto :EOF
|
||||
)
|
||||
|
||||
if "%command%"=="restart" (
|
||||
call:START
|
||||
goto :EOF
|
||||
)
|
||||
|
||||
:START
|
||||
if "%module%"=="standalone" (
|
||||
cd "%runtimeDir%"
|
||||
"%pip_path%" install -r "%llm_path%\requirements.txt"
|
||||
"%python_path%" -c "import langchain,fastapi,chromadb,tiktoken,uvicorn" >nul 2>&1
|
||||
cd "%runtimeDir%\supersonic-standalone\llm\llm"
|
||||
start "" /B uvicorn %start_name%:app --port %llm_port% --host %llm_host% > "%runtimeDir%\supersonic-standalone\llm\llm.log" 2>&1
|
||||
echo "llm service started, see logs/error with logs/error command"
|
||||
if "%service%"=="%llmparser_service%" (
|
||||
call :START_PYTHON
|
||||
goto :EOF
|
||||
)
|
||||
start "supersonic" /B java %java-command%>"%runtimeDir%\supersonic-%module%\logs\info-%module%.log" 2>&1
|
||||
echo "%module% service started, see logs/error with logs/error command"
|
||||
call :START_PYTHON
|
||||
call :START_JAVA
|
||||
goto :EOF
|
||||
|
||||
|
||||
:STOP
|
||||
for /f "tokens=2" %%i in ('tasklist ^| findstr /i "python"') do (
|
||||
taskkill /PID %%i /F
|
||||
echo "llm Process (PID = %%i) is killed."
|
||||
if "%service%"=="%llmparser_service%" (
|
||||
call :STOP_PYTHON
|
||||
goto :EOF
|
||||
)
|
||||
for /f "tokens=2" %%i in ('tasklist ^| findstr /i "java"') do (
|
||||
taskkill /PID %%i /F
|
||||
echo "%module% Process (PID = %%i) is killed."
|
||||
)
|
||||
goto :EOF
|
||||
call :STOP_PYTHON
|
||||
call :STOP_JAVA
|
||||
goto :EOF
|
||||
|
||||
:START_PYTHON
|
||||
echo 'python service starting, see logs in llmparser/llmparser.log'
|
||||
cd "%pythonRunDir%"
|
||||
start /B %python_path% supersonic_llmparser.py > %pythonRunDir%\llmparser.log 2>&1
|
||||
timeout /t 10 >nul
|
||||
echo 'python service started'
|
||||
goto :EOF
|
||||
|
||||
:START_JAVA
|
||||
echo 'java service starting, see logs in logs/'
|
||||
cd "%javaRunDir%"
|
||||
if not exist "%runtimeDir%\supersonic-standalone\logs" mkdir "%runtimeDir%\supersonic-standalone\logs"
|
||||
set "libDir=%runtimeDir%\supersonic-%service%\lib"
|
||||
set "confDir=%runtimeDir%\supersonic-%service%\conf"
|
||||
set "webDir=%runtimeDir%\supersonic-%service%\webapp"
|
||||
set "classpath=%confDir%;%webDir%;%libDir%\*"
|
||||
set "java-command=-Dfile.encoding=UTF-8 -Duser.language=Zh -Duser.region=CN -Duser.timezone=GMT+08 -Xms1024m -Xmx2048m -cp %CLASSPATH% %MAIN_CLASS%"
|
||||
start /B java %java-command% >nul 2>&1
|
||||
timeout /t 10 >nul
|
||||
echo 'java service started'
|
||||
goto :EOF
|
||||
|
||||
:STOP_PYTHON
|
||||
for /f "tokens=2" %%i in ('tasklist ^| findstr /i "python"') do (
|
||||
taskkill /PID %%i /F
|
||||
echo "python service (PID = %%i) is killed."
|
||||
)
|
||||
goto :EOF
|
||||
|
||||
:STOP_JAVA
|
||||
for /f "tokens=2" %%i in ('tasklist ^| findstr /i "java"') do (
|
||||
taskkill /PID %%i /F
|
||||
echo "java service (PID = %%i) is killed."
|
||||
)
|
||||
goto :EOF
|
||||
|
||||
:RELOAD_EXAMPLE
|
||||
cd "%runtimeDir%\supersonic-standalone\llmparser\sql"
|
||||
start %python_path% examples_reload_run.py
|
||||
goto :EOF
|
||||
|
||||
:BUILD_RUNTIME
|
||||
rem 6. reset runtime
|
||||
if exist "%runtimeDir%" goto :EOF
|
||||
mkdir "%runtimeDir%"
|
||||
tar -zxvf "%buildDir%\supersonic-standalone.tar.gz" -C "%runtimeDir%"
|
||||
for /d %%f in ("%runtimeDir%\launchers-standalone-*") do (
|
||||
move "%%f" "%runtimeDir%\supersonic-standalone"
|
||||
)
|
||||
|
||||
rem 7. copy webapp to runtime
|
||||
tar -zxvf "%buildDir%\supersonic-webapp.tar.gz" -C "%buildDir%"
|
||||
if not exist "%runtimeDir%\supersonic-standalone\webapp" mkdir "%runtimeDir%\supersonic-standalone\webapp"
|
||||
xcopy /s /e /h /y "%buildDir%\supersonic-webapp\*" "%runtimeDir%\supersonic-standalone\webapp"
|
||||
if not exist "%runtimeDir%\supersonic-standalone\conf\webapp" mkdir "%runtimeDir%\supersonic-standalone\conf\webapp"
|
||||
xcopy /s /e /h /y "%runtimeDir%\supersonic-standalone\webapp\*" "%runtimeDir%\supersonic-standalone\conf\webapp"
|
||||
rd /s /q "%buildDir%\supersonic-webapp"
|
||||
193
assembly/bin/supersonic-daemon.sh
Normal file → Executable file
193
assembly/bin/supersonic-daemon.sh
Normal file → Executable file
@@ -1,98 +1,147 @@
|
||||
#!/usr/bin/env bash
|
||||
|
||||
set -x
|
||||
sbinDir=$(cd "$(dirname "$0")"; pwd)
|
||||
baseDir=$(cd "$sbinDir/.." && pwd -P)
|
||||
runtimeDir=$baseDir/../runtime
|
||||
buildDir=$baseDir/build
|
||||
chmod +x $sbinDir/supersonic-common.sh
|
||||
source $sbinDir/supersonic-common.sh
|
||||
|
||||
# 1.init environment parameters
|
||||
if [ ! -d "$runtimeDir" ]; then
|
||||
echo "the runtime dir does not exist move all to runtime"
|
||||
moveAllToRuntime
|
||||
fi
|
||||
set +x
|
||||
|
||||
command=$1
|
||||
service=$2
|
||||
if [ -z "$service" ]; then
|
||||
service=${STANDALONE_SERVICE}
|
||||
fi
|
||||
|
||||
app_name=$STANDALONE_APP_NAME
|
||||
main_class="com.tencent.supersonic.StandaloneLauncher"
|
||||
model_name=$service
|
||||
|
||||
if [ "$service" == "llmparser" ]; then
|
||||
model_name=${STANDALONE_SERVICE}
|
||||
fi
|
||||
|
||||
cd $baseDir
|
||||
if [[ "$service" == "semantic" || -z "$service" ]] && [ "$command" != "stop" ]; then
|
||||
#1. clear file
|
||||
mkdir -p ${runtimeDir}
|
||||
rm -fr ${runtimeDir}/*
|
||||
|
||||
#2. package lib
|
||||
tar -zxvf ${buildDir}/supersonic.tar.gz -C ${runtimeDir}
|
||||
mv ${runtimeDir}/launchers-standalone-* ${runtimeDir}/supersonic-standalone
|
||||
tar -zxvf ${buildDir}/supersonic-webapp.tar.gz -C ${buildDir}
|
||||
mkdir -p ${runtimeDir}/supersonic-standalone/webapp
|
||||
cp -fr ${buildDir}/supersonic-webapp/* ${runtimeDir}/supersonic-standalone/webapp
|
||||
rm -fr ${buildDir}/supersonic-webapp
|
||||
fi
|
||||
if [[ "$service" == "semantic" ]]; then
|
||||
json=$(cat ${runtimeDir}/supersonic-semantic/webapp/supersonic.config.json)
|
||||
json=$(echo $json | jq '.env="semantic"')
|
||||
echo $json > ${runtimeDir}/supersonic-semantic/webapp/supersonic.config.json
|
||||
fi
|
||||
# 2.set main class
|
||||
function setMainClass {
|
||||
if [ "$service" == $CHAT_SERVICE ]; then
|
||||
main_class="com.tencent.supersonic.ChatLauncher"
|
||||
elif [ "$service" == $SEMANTIC_SERVICE ]; then
|
||||
main_class="com.tencent.supersonic.SemanticLauncher"
|
||||
fi
|
||||
}
|
||||
setMainClass
|
||||
# 3.set app name
|
||||
function setAppName {
|
||||
if [ "$service" == $CHAT_SERVICE ]; then
|
||||
app_name=$CHAT_APP_NAME
|
||||
elif [ "$service" == $SEMANTIC_SERVICE ]; then
|
||||
app_name=$SEMANTIC_APP_NAME
|
||||
elif [ "$service" == $LLMPARSER_SERVICE ]; then
|
||||
app_name=$LLMPARSER_APP_NAME
|
||||
fi
|
||||
}
|
||||
setAppName
|
||||
|
||||
if [[ "$service" == "chat" ]]; then
|
||||
json=$(cat ${runtimeDir}/supersonic-chat/webapp/supersonic.config.json)
|
||||
json=$(echo $json | jq '.env="chat"')
|
||||
echo $json > ${runtimeDir}/supersonic-chat/webapp/supersonic.config.json
|
||||
fi
|
||||
echo $command
|
||||
echo $service
|
||||
function reloadExamples {
|
||||
pythonRunDir=${runtimeDir}/supersonic-${model_name}/llmparser
|
||||
cd $pythonRunDir/sql
|
||||
${python_path} examples_reload_run.py
|
||||
}
|
||||
|
||||
|
||||
function start()
|
||||
{
|
||||
local_app_name=$1
|
||||
pid=$(ps aux |grep ${local_app_name} | grep -v grep | awk '{print $2}')
|
||||
if [[ "$pid" == "" ]]; then
|
||||
if [[ ${local_app_name} == $LLMPARSER_APP_NAME ]]; then
|
||||
runPythonService ${local_app_name}
|
||||
else
|
||||
runJavaService ${local_app_name}
|
||||
fi
|
||||
else
|
||||
echo "Process (PID = $pid) is running."
|
||||
return 1
|
||||
fi
|
||||
}
|
||||
|
||||
function stop()
|
||||
{
|
||||
pid=$(ps aux | grep $1 | grep -v grep | awk '{print $2}')
|
||||
if [[ "$pid" == "" ]]; then
|
||||
echo "Process $1 is not running !"
|
||||
return 1
|
||||
else
|
||||
kill -9 $pid
|
||||
echo "Process (PID = $pid) is killed !"
|
||||
return 0
|
||||
fi
|
||||
}
|
||||
|
||||
function reload()
|
||||
{
|
||||
if [[ $1 == $LLMPARSER_APP_NAME ]]; then
|
||||
reloadExamples
|
||||
fi
|
||||
}
|
||||
|
||||
# 4. execute command operation
|
||||
case "$command" in
|
||||
start)
|
||||
if [[ "$service" == "semantic" ]];then
|
||||
echo -e "Starting semantic"
|
||||
sh ${runtimeDir}/supersonic-semantic/bin/service.sh start
|
||||
elif [[ "$service" == "chat" ]];then
|
||||
echo -e "Starting chat"
|
||||
sh ${runtimeDir}/supersonic-chat/bin/service.sh start
|
||||
elif [[ "$service" == "llmparser" ]];then
|
||||
echo -e "Starting LLM"
|
||||
sh ${runtimeDir}/supersonic-standalone/llm/bin/service.sh start
|
||||
elif [[ -z "$service" ]]; then
|
||||
echo -e "Starting supersonic"
|
||||
sh ${runtimeDir}/supersonic-standalone/bin/service.sh start
|
||||
sh ${runtimeDir}/supersonic-standalone/llm/bin/service.sh start
|
||||
if [ "$service" == $STANDALONE_SERVICE ]; then
|
||||
echo "Starting $LLMPARSER_APP_NAME"
|
||||
start $LLMPARSER_APP_NAME
|
||||
echo "Starting $app_name"
|
||||
start $app_name
|
||||
else
|
||||
echo "Use command {semantic|semantic||} to run."
|
||||
echo "Starting $app_name"
|
||||
start $app_name
|
||||
fi
|
||||
echo "Start success"
|
||||
;;
|
||||
stop)
|
||||
if [[ "$service" == "semantic" ]];then
|
||||
echo -e "Stopping semantic"
|
||||
sh ${runtimeDir}/supersonic-semantic/bin/service.sh stop
|
||||
elif [[ "$service" == "chat" ]];then
|
||||
echo -e "Stopping chat"
|
||||
sh ${runtimeDir}/supersonic-chat/bin/service.sh stop
|
||||
elif [[ "$service" == "llmparser" ]];then
|
||||
echo -e "Stopping LLM"
|
||||
sh ${runtimeDir}/supersonic-standalone/llm/bin/service.sh stop
|
||||
elif [[ -z "$service" ]]; then
|
||||
echo -e "Stopping supersonic"
|
||||
sh ${runtimeDir}/supersonic-standalone/bin/service.sh stop
|
||||
sh ${runtimeDir}/supersonic-standalone/llm/bin/service.sh stop
|
||||
if [ "$service" == $STANDALONE_SERVICE ]; then
|
||||
echo "Stopping $LLMPARSER_APP_NAME"
|
||||
stop $LLMPARSER_APP_NAME
|
||||
echo "Stopping $app_name"
|
||||
stop $app_name
|
||||
else
|
||||
echo "Use command {semantic|semantic||} to run."
|
||||
echo "Stopping $app_name"
|
||||
stop ${app_name}
|
||||
fi
|
||||
echo "Stop success"
|
||||
;;
|
||||
reload)
|
||||
echo "Reloading ${app_name}"
|
||||
reload ${app_name}
|
||||
echo "Reload success"
|
||||
;;
|
||||
restart)
|
||||
if [[ "$service" == "semantic" ]];then
|
||||
echo -e "Restarting semantic"
|
||||
sh ${runtimeDir}/supersonic-semantic/bin/service.sh restart
|
||||
elif [[ "$service" == "chat" ]];then
|
||||
echo -e "Restarting chat"
|
||||
sh ${runtimeDir}/supersonic-chat/bin/service.sh restart
|
||||
elif [[ "$service" == "llmparser" ]];then
|
||||
echo -e "Restarting LLM"
|
||||
sh ${runtimeDir}/supersonic-standalone/llm/bin/service.sh restart
|
||||
elif [[ -z "$service" ]]; then
|
||||
echo -e "Restarting supersonic"
|
||||
sh ${runtimeDir}/supersonic-standalone/bin/service.sh restart
|
||||
sh ${runtimeDir}/supersonic-standalone/llm/bin/service.sh restart
|
||||
if [ "$service" == $STANDALONE_SERVICE ]; then
|
||||
echo "Stopping ${app_name}"
|
||||
stop ${app_name}
|
||||
echo "Stopping ${LLMPARSER_APP_NAME}"
|
||||
stop $LLMPARSER_APP_NAME
|
||||
echo "Starting ${LLMPARSER_APP_NAME}"
|
||||
start $LLMPARSER_APP_NAME
|
||||
echo "Starting ${app_name}"
|
||||
start ${app_name}
|
||||
else
|
||||
echo "Use command {semantic|semantic||} to run."
|
||||
echo "Stopping ${app_name}"
|
||||
stop ${app_name}
|
||||
echo "Starting ${app_name}"
|
||||
start ${app_name}
|
||||
fi
|
||||
echo "Restart success"
|
||||
;;
|
||||
*)
|
||||
echo "Use command {start|stop|status|restart} to run."
|
||||
echo "Use command {start|stop|restart} to run."
|
||||
exit 1
|
||||
esac
|
||||
|
||||
exit 0
|
||||
|
||||
@@ -6,14 +6,6 @@
|
||||
<format>tar.gz</format>
|
||||
</formats>
|
||||
<fileSets>
|
||||
|
||||
<fileSet>
|
||||
<directory>${project.basedir}/src/main/bin</directory>
|
||||
<outputDirectory>bin</outputDirectory>
|
||||
<fileMode>0777</fileMode>
|
||||
<directoryMode>0755</directoryMode>
|
||||
</fileSet>
|
||||
|
||||
<fileSet>
|
||||
<directory>${project.basedir}/src/main/resources</directory>
|
||||
<outputDirectory>conf</outputDirectory>
|
||||
@@ -29,8 +21,8 @@
|
||||
</includes>
|
||||
</fileSet>
|
||||
<fileSet>
|
||||
<directory>${project.basedir}/../../chat/core/src/main/python</directory>
|
||||
<outputDirectory>llm</outputDirectory>
|
||||
<directory>${project.basedir}/../../chat/python</directory>
|
||||
<outputDirectory>llmparser</outputDirectory>
|
||||
<fileMode>0777</fileMode>
|
||||
<directoryMode>0755</directoryMode>
|
||||
</fileSet>
|
||||
|
||||
@@ -12,6 +12,8 @@ public class UserConstants {
|
||||
|
||||
public static final String TOKEN_USER_EMAIL = "token_user_email";
|
||||
|
||||
public static final String TOKEN_IS_ADMIN = "token_is_admin";
|
||||
|
||||
public static final String TOKEN_ALGORITHM = "HS512";
|
||||
|
||||
public static final String TOKEN_CREATE_TIME = "token_create_time";
|
||||
|
||||
@@ -18,17 +18,22 @@ public class User {
|
||||
|
||||
private String email;
|
||||
|
||||
public static User get(Long id, String name, String displayName, String email) {
|
||||
return new User(id, name, displayName, email);
|
||||
private Integer isAdmin;
|
||||
|
||||
public static User get(Long id, String name, String displayName, String email, Integer isAdmin) {
|
||||
return new User(id, name, displayName, email, isAdmin);
|
||||
}
|
||||
|
||||
public static User getFakeUser() {
|
||||
return new User(1L, "admin", "admin", "admin@email");
|
||||
return new User(1L, "admin", "admin", "admin@email", 1);
|
||||
}
|
||||
|
||||
public String getDisplayName() {
|
||||
return StringUtils.isBlank(displayName) ? name : displayName;
|
||||
}
|
||||
|
||||
public boolean isSuperAdmin() {
|
||||
return isAdmin != null && isAdmin == 1;
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
@@ -9,13 +9,14 @@ public class UserWithPassword extends User {
|
||||
|
||||
private String password;
|
||||
|
||||
public UserWithPassword(Long id, String name, String displayName, String email, String password) {
|
||||
super(id, name, displayName, email);
|
||||
public UserWithPassword(Long id, String name, String displayName, String email, String password, Integer isAdmin) {
|
||||
super(id, name, displayName, email, isAdmin);
|
||||
this.password = password;
|
||||
}
|
||||
|
||||
public static UserWithPassword get(Long id, String name, String displayName, String email, String password) {
|
||||
return new UserWithPassword(id, name, displayName, email, password);
|
||||
public static UserWithPassword get(Long id, String name, String displayName,
|
||||
String email, String password, Integer isAdmin) {
|
||||
return new UserWithPassword(id, name, displayName, email, password, isAdmin);
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
@@ -6,7 +6,7 @@ import lombok.Data;
|
||||
@Data
|
||||
public class AuthGroup {
|
||||
|
||||
private String modelId;
|
||||
private Long modelId;
|
||||
private String name;
|
||||
private Integer groupId;
|
||||
private List<AuthRule> authRules;
|
||||
|
||||
@@ -7,13 +7,13 @@ import lombok.ToString;
|
||||
@ToString
|
||||
public class AuthRes {
|
||||
|
||||
private String modelId;
|
||||
private Long modelId;
|
||||
private String name;
|
||||
|
||||
public AuthRes() {
|
||||
}
|
||||
|
||||
public AuthRes(String modelId, String name) {
|
||||
public AuthRes(Long modelId, String name) {
|
||||
this.modelId = modelId;
|
||||
this.name = name;
|
||||
}
|
||||
|
||||
@@ -1,11 +1,13 @@
|
||||
package com.tencent.supersonic.auth.api.authorization.request;
|
||||
|
||||
import com.google.common.collect.Lists;
|
||||
import com.tencent.supersonic.auth.api.authorization.pojo.AuthRes;
|
||||
import java.util.ArrayList;
|
||||
import java.util.List;
|
||||
|
||||
import lombok.Data;
|
||||
import lombok.ToString;
|
||||
import org.springframework.util.CollectionUtils;
|
||||
|
||||
import java.util.ArrayList;
|
||||
import java.util.List;
|
||||
|
||||
@Data
|
||||
@ToString
|
||||
@@ -15,5 +17,17 @@ public class QueryAuthResReq {
|
||||
|
||||
private List<AuthRes> resources;
|
||||
|
||||
private String modelId;
|
||||
private Long modelId;
|
||||
|
||||
private List<Long> modelIds;
|
||||
|
||||
public List<Long> getModelIds() {
|
||||
if (!CollectionUtils.isEmpty(modelIds)) {
|
||||
return modelIds;
|
||||
}
|
||||
if (modelId != null) {
|
||||
return Lists.newArrayList(modelId);
|
||||
}
|
||||
return Lists.newArrayList();
|
||||
}
|
||||
}
|
||||
|
||||
@@ -10,7 +10,7 @@ public interface AuthService {
|
||||
|
||||
List<AuthGroup> queryAuthGroups(String domainId, Integer groupId);
|
||||
|
||||
void updateAuthGroup(AuthGroup group);
|
||||
void addOrUpdateAuthGroup(AuthGroup group);
|
||||
|
||||
void removeAuthGroup(AuthGroup group);
|
||||
|
||||
|
||||
@@ -33,12 +33,6 @@
|
||||
<artifactId>spring-boot-starter-jdbc</artifactId>
|
||||
</dependency>
|
||||
|
||||
<dependency>
|
||||
<groupId>org.mybatis</groupId>
|
||||
<artifactId>mybatis</artifactId>
|
||||
</dependency>
|
||||
|
||||
|
||||
<dependency>
|
||||
<groupId>com.alibaba</groupId>
|
||||
<artifactId>druid</artifactId>
|
||||
@@ -52,12 +46,7 @@
|
||||
<groupId>org.springframework.boot</groupId>
|
||||
<artifactId>spring-boot-starter-web</artifactId>
|
||||
</dependency>
|
||||
<dependency>
|
||||
<groupId>org.mybatis</groupId>
|
||||
<artifactId>mybatis-spring</artifactId>
|
||||
<version>${mybatis-spring.version}</version>
|
||||
<scope>compile</scope>
|
||||
</dependency>
|
||||
|
||||
<dependency>
|
||||
<groupId>com.github.pagehelper</groupId>
|
||||
<artifactId>pagehelper</artifactId>
|
||||
|
||||
@@ -71,7 +71,7 @@ public class DefaultUserAdaptor implements UserAdaptor {
|
||||
}
|
||||
if (userDO.getPassword().equals(userReq.getPassword())) {
|
||||
UserWithPassword user = UserWithPassword.get(userDO.getId(), userDO.getName(), userDO.getDisplayName(),
|
||||
userDO.getEmail(), userDO.getPassword());
|
||||
userDO.getEmail(), userDO.getPassword(), userDO.getIsAdmin());
|
||||
return userTokenUtils.generateToken(user);
|
||||
}
|
||||
throw new RuntimeException("password not correct, please try again");
|
||||
|
||||
@@ -1,99 +1,129 @@
|
||||
package com.tencent.supersonic.auth.authentication.persistence.dataobject;
|
||||
|
||||
public class UserDO {
|
||||
|
||||
/**
|
||||
*
|
||||
*
|
||||
*/
|
||||
private Long id;
|
||||
|
||||
/**
|
||||
*
|
||||
*
|
||||
*/
|
||||
private String name;
|
||||
|
||||
/**
|
||||
*
|
||||
*
|
||||
*/
|
||||
private String password;
|
||||
|
||||
/**
|
||||
*
|
||||
*
|
||||
*/
|
||||
private String displayName;
|
||||
|
||||
/**
|
||||
*
|
||||
*
|
||||
*/
|
||||
private String email;
|
||||
|
||||
/**
|
||||
* @return id
|
||||
*
|
||||
*/
|
||||
private Integer isAdmin;
|
||||
|
||||
/**
|
||||
*
|
||||
* @return id
|
||||
*/
|
||||
public Long getId() {
|
||||
return id;
|
||||
}
|
||||
|
||||
/**
|
||||
* @param id
|
||||
*
|
||||
* @param id
|
||||
*/
|
||||
public void setId(Long id) {
|
||||
this.id = id;
|
||||
}
|
||||
|
||||
/**
|
||||
* @return name
|
||||
*
|
||||
* @return name
|
||||
*/
|
||||
public String getName() {
|
||||
return name;
|
||||
}
|
||||
|
||||
/**
|
||||
* @param name
|
||||
*
|
||||
* @param name
|
||||
*/
|
||||
public void setName(String name) {
|
||||
this.name = name == null ? null : name.trim();
|
||||
}
|
||||
|
||||
/**
|
||||
* @return password
|
||||
*
|
||||
* @return password
|
||||
*/
|
||||
public String getPassword() {
|
||||
return password;
|
||||
}
|
||||
|
||||
/**
|
||||
* @param password
|
||||
*
|
||||
* @param password
|
||||
*/
|
||||
public void setPassword(String password) {
|
||||
this.password = password == null ? null : password.trim();
|
||||
}
|
||||
|
||||
/**
|
||||
* @return display_name
|
||||
*
|
||||
* @return display_name
|
||||
*/
|
||||
public String getDisplayName() {
|
||||
return displayName;
|
||||
}
|
||||
|
||||
/**
|
||||
* @param displayName
|
||||
*
|
||||
* @param displayName
|
||||
*/
|
||||
public void setDisplayName(String displayName) {
|
||||
this.displayName = displayName == null ? null : displayName.trim();
|
||||
}
|
||||
|
||||
/**
|
||||
* @return email
|
||||
*
|
||||
* @return email
|
||||
*/
|
||||
public String getEmail() {
|
||||
return email;
|
||||
}
|
||||
|
||||
/**
|
||||
* @param email
|
||||
*
|
||||
* @param email
|
||||
*/
|
||||
public void setEmail(String email) {
|
||||
this.email = email == null ? null : email.trim();
|
||||
}
|
||||
|
||||
/**
|
||||
*
|
||||
* @return is_admin
|
||||
*/
|
||||
public Integer getIsAdmin() {
|
||||
return isAdmin;
|
||||
}
|
||||
|
||||
/**
|
||||
*
|
||||
* @param isAdmin
|
||||
*/
|
||||
public void setIsAdmin(Integer isAdmin) {
|
||||
this.isAdmin = isAdmin;
|
||||
}
|
||||
}
|
||||
@@ -4,7 +4,6 @@ import java.util.ArrayList;
|
||||
import java.util.List;
|
||||
|
||||
public class UserDOExample {
|
||||
|
||||
/**
|
||||
* s2_user
|
||||
*/
|
||||
@@ -31,6 +30,7 @@ public class UserDOExample {
|
||||
protected Integer limitEnd;
|
||||
|
||||
/**
|
||||
*
|
||||
* @mbg.generated
|
||||
*/
|
||||
public UserDOExample() {
|
||||
@@ -38,13 +38,7 @@ public class UserDOExample {
|
||||
}
|
||||
|
||||
/**
|
||||
* @mbg.generated
|
||||
*/
|
||||
public String getOrderByClause() {
|
||||
return orderByClause;
|
||||
}
|
||||
|
||||
/**
|
||||
*
|
||||
* @mbg.generated
|
||||
*/
|
||||
public void setOrderByClause(String orderByClause) {
|
||||
@@ -52,13 +46,15 @@ public class UserDOExample {
|
||||
}
|
||||
|
||||
/**
|
||||
*
|
||||
* @mbg.generated
|
||||
*/
|
||||
public boolean isDistinct() {
|
||||
return distinct;
|
||||
public String getOrderByClause() {
|
||||
return orderByClause;
|
||||
}
|
||||
|
||||
/**
|
||||
*
|
||||
* @mbg.generated
|
||||
*/
|
||||
public void setDistinct(boolean distinct) {
|
||||
@@ -66,6 +62,15 @@ public class UserDOExample {
|
||||
}
|
||||
|
||||
/**
|
||||
*
|
||||
* @mbg.generated
|
||||
*/
|
||||
public boolean isDistinct() {
|
||||
return distinct;
|
||||
}
|
||||
|
||||
/**
|
||||
*
|
||||
* @mbg.generated
|
||||
*/
|
||||
public List<Criteria> getOredCriteria() {
|
||||
@@ -73,6 +78,7 @@ public class UserDOExample {
|
||||
}
|
||||
|
||||
/**
|
||||
*
|
||||
* @mbg.generated
|
||||
*/
|
||||
public void or(Criteria criteria) {
|
||||
@@ -80,6 +86,7 @@ public class UserDOExample {
|
||||
}
|
||||
|
||||
/**
|
||||
*
|
||||
* @mbg.generated
|
||||
*/
|
||||
public Criteria or() {
|
||||
@@ -89,6 +96,7 @@ public class UserDOExample {
|
||||
}
|
||||
|
||||
/**
|
||||
*
|
||||
* @mbg.generated
|
||||
*/
|
||||
public Criteria createCriteria() {
|
||||
@@ -100,6 +108,7 @@ public class UserDOExample {
|
||||
}
|
||||
|
||||
/**
|
||||
*
|
||||
* @mbg.generated
|
||||
*/
|
||||
protected Criteria createCriteriaInternal() {
|
||||
@@ -108,6 +117,7 @@ public class UserDOExample {
|
||||
}
|
||||
|
||||
/**
|
||||
*
|
||||
* @mbg.generated
|
||||
*/
|
||||
public void clear() {
|
||||
@@ -117,6 +127,15 @@ public class UserDOExample {
|
||||
}
|
||||
|
||||
/**
|
||||
*
|
||||
* @mbg.generated
|
||||
*/
|
||||
public void setLimitStart(Integer limitStart) {
|
||||
this.limitStart=limitStart;
|
||||
}
|
||||
|
||||
/**
|
||||
*
|
||||
* @mbg.generated
|
||||
*/
|
||||
public Integer getLimitStart() {
|
||||
@@ -124,31 +143,25 @@ public class UserDOExample {
|
||||
}
|
||||
|
||||
/**
|
||||
*
|
||||
* @mbg.generated
|
||||
*/
|
||||
public void setLimitStart(Integer limitStart) {
|
||||
this.limitStart = limitStart;
|
||||
public void setLimitEnd(Integer limitEnd) {
|
||||
this.limitEnd=limitEnd;
|
||||
}
|
||||
|
||||
/**
|
||||
*
|
||||
* @mbg.generated
|
||||
*/
|
||||
public Integer getLimitEnd() {
|
||||
return limitEnd;
|
||||
}
|
||||
|
||||
/**
|
||||
* @mbg.generated
|
||||
*/
|
||||
public void setLimitEnd(Integer limitEnd) {
|
||||
this.limitEnd = limitEnd;
|
||||
}
|
||||
|
||||
/**
|
||||
* s2_user null
|
||||
*/
|
||||
protected abstract static class GeneratedCriteria {
|
||||
|
||||
protected List<Criterion> criteria;
|
||||
|
||||
protected GeneratedCriteria() {
|
||||
@@ -528,6 +541,66 @@ public class UserDOExample {
|
||||
addCriterion("email not between", value1, value2, "email");
|
||||
return (Criteria) this;
|
||||
}
|
||||
|
||||
public Criteria andIsAdminIsNull() {
|
||||
addCriterion("is_admin is null");
|
||||
return (Criteria) this;
|
||||
}
|
||||
|
||||
public Criteria andIsAdminIsNotNull() {
|
||||
addCriterion("is_admin is not null");
|
||||
return (Criteria) this;
|
||||
}
|
||||
|
||||
public Criteria andIsAdminEqualTo(Integer value) {
|
||||
addCriterion("is_admin =", value, "isAdmin");
|
||||
return (Criteria) this;
|
||||
}
|
||||
|
||||
public Criteria andIsAdminNotEqualTo(Integer value) {
|
||||
addCriterion("is_admin <>", value, "isAdmin");
|
||||
return (Criteria) this;
|
||||
}
|
||||
|
||||
public Criteria andIsAdminGreaterThan(Integer value) {
|
||||
addCriterion("is_admin >", value, "isAdmin");
|
||||
return (Criteria) this;
|
||||
}
|
||||
|
||||
public Criteria andIsAdminGreaterThanOrEqualTo(Integer value) {
|
||||
addCriterion("is_admin >=", value, "isAdmin");
|
||||
return (Criteria) this;
|
||||
}
|
||||
|
||||
public Criteria andIsAdminLessThan(Integer value) {
|
||||
addCriterion("is_admin <", value, "isAdmin");
|
||||
return (Criteria) this;
|
||||
}
|
||||
|
||||
public Criteria andIsAdminLessThanOrEqualTo(Integer value) {
|
||||
addCriterion("is_admin <=", value, "isAdmin");
|
||||
return (Criteria) this;
|
||||
}
|
||||
|
||||
public Criteria andIsAdminIn(List<Integer> values) {
|
||||
addCriterion("is_admin in", values, "isAdmin");
|
||||
return (Criteria) this;
|
||||
}
|
||||
|
||||
public Criteria andIsAdminNotIn(List<Integer> values) {
|
||||
addCriterion("is_admin not in", values, "isAdmin");
|
||||
return (Criteria) this;
|
||||
}
|
||||
|
||||
public Criteria andIsAdminBetween(Integer value1, Integer value2) {
|
||||
addCriterion("is_admin between", value1, value2, "isAdmin");
|
||||
return (Criteria) this;
|
||||
}
|
||||
|
||||
public Criteria andIsAdminNotBetween(Integer value1, Integer value2) {
|
||||
addCriterion("is_admin not between", value1, value2, "isAdmin");
|
||||
return (Criteria) this;
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
@@ -544,7 +617,6 @@ public class UserDOExample {
|
||||
* s2_user null
|
||||
*/
|
||||
public static class Criterion {
|
||||
|
||||
private String condition;
|
||||
|
||||
private Object value;
|
||||
@@ -561,6 +633,38 @@ public class UserDOExample {
|
||||
|
||||
private String typeHandler;
|
||||
|
||||
public String getCondition() {
|
||||
return condition;
|
||||
}
|
||||
|
||||
public Object getValue() {
|
||||
return value;
|
||||
}
|
||||
|
||||
public Object getSecondValue() {
|
||||
return secondValue;
|
||||
}
|
||||
|
||||
public boolean isNoValue() {
|
||||
return noValue;
|
||||
}
|
||||
|
||||
public boolean isSingleValue() {
|
||||
return singleValue;
|
||||
}
|
||||
|
||||
public boolean isBetweenValue() {
|
||||
return betweenValue;
|
||||
}
|
||||
|
||||
public boolean isListValue() {
|
||||
return listValue;
|
||||
}
|
||||
|
||||
public String getTypeHandler() {
|
||||
return typeHandler;
|
||||
}
|
||||
|
||||
protected Criterion(String condition) {
|
||||
super();
|
||||
this.condition = condition;
|
||||
@@ -596,37 +700,5 @@ public class UserDOExample {
|
||||
protected Criterion(String condition, Object value, Object secondValue) {
|
||||
this(condition, value, secondValue, null);
|
||||
}
|
||||
|
||||
public String getCondition() {
|
||||
return condition;
|
||||
}
|
||||
|
||||
public Object getValue() {
|
||||
return value;
|
||||
}
|
||||
|
||||
public Object getSecondValue() {
|
||||
return secondValue;
|
||||
}
|
||||
|
||||
public boolean isNoValue() {
|
||||
return noValue;
|
||||
}
|
||||
|
||||
public boolean isSingleValue() {
|
||||
return singleValue;
|
||||
}
|
||||
|
||||
public boolean isBetweenValue() {
|
||||
return betweenValue;
|
||||
}
|
||||
|
||||
public boolean isListValue() {
|
||||
return listValue;
|
||||
}
|
||||
|
||||
public String getTypeHandler() {
|
||||
return typeHandler;
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -2,6 +2,7 @@ package com.tencent.supersonic.auth.authentication.utils;
|
||||
|
||||
import static com.tencent.supersonic.auth.api.authentication.constant.UserConstants.TOKEN_ALGORITHM;
|
||||
import static com.tencent.supersonic.auth.api.authentication.constant.UserConstants.TOKEN_CREATE_TIME;
|
||||
import static com.tencent.supersonic.auth.api.authentication.constant.UserConstants.TOKEN_IS_ADMIN;
|
||||
import static com.tencent.supersonic.auth.api.authentication.constant.UserConstants.TOKEN_PREFIX;
|
||||
import static com.tencent.supersonic.auth.api.authentication.constant.UserConstants.TOKEN_TIME_OUT;
|
||||
import static com.tencent.supersonic.auth.api.authentication.constant.UserConstants.TOKEN_USER_DISPLAY_NAME;
|
||||
@@ -42,6 +43,7 @@ public class UserTokenUtils {
|
||||
claims.put(TOKEN_USER_PASSWORD, StringUtils.isEmpty(user.getPassword()) ? "" : user.getPassword());
|
||||
claims.put(TOKEN_USER_DISPLAY_NAME, user.getDisplayName());
|
||||
claims.put(TOKEN_CREATE_TIME, System.currentTimeMillis());
|
||||
claims.put(TOKEN_IS_ADMIN, user.getIsAdmin());
|
||||
return generate(claims);
|
||||
}
|
||||
|
||||
@@ -52,6 +54,7 @@ public class UserTokenUtils {
|
||||
claims.put(TOKEN_USER_PASSWORD, "admin");
|
||||
claims.put(TOKEN_USER_DISPLAY_NAME, "admin");
|
||||
claims.put(TOKEN_CREATE_TIME, System.currentTimeMillis());
|
||||
claims.put(TOKEN_IS_ADMIN, 1);
|
||||
return generate(claims);
|
||||
}
|
||||
|
||||
@@ -63,7 +66,9 @@ public class UserTokenUtils {
|
||||
String userName = String.valueOf(claims.get(TOKEN_USER_NAME));
|
||||
String email = String.valueOf(claims.get(TOKEN_USER_EMAIL));
|
||||
String displayName = String.valueOf(claims.get(TOKEN_USER_DISPLAY_NAME));
|
||||
return User.get(userId, userName, displayName, email);
|
||||
Integer isAdmin = claims.get(TOKEN_IS_ADMIN) == null
|
||||
? 0 : Integer.parseInt(claims.get(TOKEN_IS_ADMIN).toString());
|
||||
return User.get(userId, userName, displayName, email, isAdmin);
|
||||
}
|
||||
|
||||
public UserWithPassword getUserWithPassword(HttpServletRequest request) {
|
||||
@@ -79,7 +84,9 @@ public class UserTokenUtils {
|
||||
String email = String.valueOf(claims.get(TOKEN_USER_EMAIL));
|
||||
String displayName = String.valueOf(claims.get(TOKEN_USER_DISPLAY_NAME));
|
||||
String password = String.valueOf(claims.get(TOKEN_USER_PASSWORD));
|
||||
return UserWithPassword.get(userId, userName, displayName, email, password);
|
||||
Integer isAdmin = claims.get(TOKEN_IS_ADMIN) == null
|
||||
? 0 : Integer.parseInt(claims.get(TOKEN_IS_ADMIN).toString());
|
||||
return UserWithPassword.get(userId, userName, displayName, email, password, isAdmin);
|
||||
}
|
||||
|
||||
private Claims getClaims(String token) {
|
||||
|
||||
@@ -2,11 +2,12 @@
|
||||
<!DOCTYPE mapper PUBLIC "-//mybatis.org//DTD Mapper 3.0//EN" "http://mybatis.org/dtd/mybatis-3-mapper.dtd">
|
||||
<mapper namespace="com.tencent.supersonic.auth.authentication.persistence.mapper.UserDOMapper">
|
||||
<resultMap id="BaseResultMap" type="com.tencent.supersonic.auth.authentication.persistence.dataobject.UserDO">
|
||||
<id column="id" jdbcType="BIGINT" property="id" />
|
||||
<result column="id" jdbcType="BIGINT" property="id" />
|
||||
<result column="name" jdbcType="VARCHAR" property="name" />
|
||||
<result column="password" jdbcType="VARCHAR" property="password" />
|
||||
<result column="display_name" jdbcType="VARCHAR" property="displayName" />
|
||||
<result column="email" jdbcType="VARCHAR" property="email" />
|
||||
<result column="is_admin" jdbcType="INTEGER" property="isAdmin" />
|
||||
</resultMap>
|
||||
<sql id="Example_Where_Clause">
|
||||
<where>
|
||||
@@ -38,7 +39,7 @@
|
||||
</where>
|
||||
</sql>
|
||||
<sql id="Base_Column_List">
|
||||
id, name, password, display_name, email
|
||||
id, name, password, display_name, email, is_admin
|
||||
</sql>
|
||||
<select id="selectByExample" parameterType="com.tencent.supersonic.auth.authentication.persistence.dataobject.UserDOExample" resultMap="BaseResultMap">
|
||||
select
|
||||
@@ -57,21 +58,13 @@
|
||||
limit #{limitStart} , #{limitEnd}
|
||||
</if>
|
||||
</select>
|
||||
<select id="selectByPrimaryKey" parameterType="java.lang.Long" resultMap="BaseResultMap">
|
||||
select
|
||||
<include refid="Base_Column_List" />
|
||||
from s2_user
|
||||
where id = #{id,jdbcType=BIGINT}
|
||||
</select>
|
||||
<delete id="deleteByPrimaryKey" parameterType="java.lang.Long">
|
||||
delete from s2_user
|
||||
where id = #{id,jdbcType=BIGINT}
|
||||
</delete>
|
||||
<insert id="insert" parameterType="com.tencent.supersonic.auth.authentication.persistence.dataobject.UserDO">
|
||||
insert into s2_user (id, name, password,
|
||||
display_name, email)
|
||||
display_name, email, is_admin
|
||||
)
|
||||
values (#{id,jdbcType=BIGINT}, #{name,jdbcType=VARCHAR}, #{password,jdbcType=VARCHAR},
|
||||
#{displayName,jdbcType=VARCHAR}, #{email,jdbcType=VARCHAR})
|
||||
#{displayName,jdbcType=VARCHAR}, #{email,jdbcType=VARCHAR}, #{isAdmin,jdbcType=INTEGER}
|
||||
)
|
||||
</insert>
|
||||
<insert id="insertSelective" parameterType="com.tencent.supersonic.auth.authentication.persistence.dataobject.UserDO">
|
||||
insert into s2_user
|
||||
@@ -91,6 +84,9 @@
|
||||
<if test="email != null">
|
||||
email,
|
||||
</if>
|
||||
<if test="isAdmin != null">
|
||||
is_admin,
|
||||
</if>
|
||||
</trim>
|
||||
<trim prefix="values (" suffix=")" suffixOverrides=",">
|
||||
<if test="id != null">
|
||||
@@ -108,6 +104,9 @@
|
||||
<if test="email != null">
|
||||
#{email,jdbcType=VARCHAR},
|
||||
</if>
|
||||
<if test="isAdmin != null">
|
||||
#{isAdmin,jdbcType=INTEGER},
|
||||
</if>
|
||||
</trim>
|
||||
</insert>
|
||||
<select id="countByExample" parameterType="com.tencent.supersonic.auth.authentication.persistence.dataobject.UserDOExample" resultType="java.lang.Long">
|
||||
@@ -116,30 +115,4 @@
|
||||
<include refid="Example_Where_Clause" />
|
||||
</if>
|
||||
</select>
|
||||
<update id="updateByPrimaryKeySelective" parameterType="com.tencent.supersonic.auth.authentication.persistence.dataobject.UserDO">
|
||||
update s2_user
|
||||
<set>
|
||||
<if test="name != null">
|
||||
name = #{name,jdbcType=VARCHAR},
|
||||
</if>
|
||||
<if test="password != null">
|
||||
password = #{password,jdbcType=VARCHAR},
|
||||
</if>
|
||||
<if test="displayName != null">
|
||||
display_name = #{displayName,jdbcType=VARCHAR},
|
||||
</if>
|
||||
<if test="email != null">
|
||||
email = #{email,jdbcType=VARCHAR},
|
||||
</if>
|
||||
</set>
|
||||
where id = #{id,jdbcType=BIGINT}
|
||||
</update>
|
||||
<update id="updateByPrimaryKey" parameterType="com.tencent.supersonic.auth.authentication.persistence.dataobject.UserDO">
|
||||
update s2_user
|
||||
set name = #{name,jdbcType=VARCHAR},
|
||||
password = #{password,jdbcType=VARCHAR},
|
||||
display_name = #{displayName,jdbcType=VARCHAR},
|
||||
email = #{email,jdbcType=VARCHAR}
|
||||
where id = #{id,jdbcType=BIGINT}
|
||||
</update>
|
||||
</mapper>
|
||||
@@ -13,7 +13,6 @@ import com.tencent.supersonic.auth.api.authorization.service.AuthService;
|
||||
import com.tencent.supersonic.auth.api.authorization.pojo.AuthGroup;
|
||||
import com.tencent.supersonic.auth.api.authorization.pojo.AuthRule;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.apache.commons.lang3.StringUtils;
|
||||
import org.springframework.jdbc.core.JdbcTemplate;
|
||||
import org.springframework.stereotype.Service;
|
||||
import org.springframework.util.CollectionUtils;
|
||||
@@ -48,12 +47,12 @@ public class AuthServiceImpl implements AuthService {
|
||||
public List<AuthGroup> queryAuthGroups(String modelId, Integer groupId) {
|
||||
return load().stream()
|
||||
.filter(group -> (Objects.isNull(groupId) || groupId.equals(group.getGroupId()))
|
||||
&& modelId.equals(group.getModelId()))
|
||||
&& modelId.equals(group.getModelId().toString()))
|
||||
.collect(Collectors.toList());
|
||||
}
|
||||
|
||||
@Override
|
||||
public void updateAuthGroup(AuthGroup group) {
|
||||
public void addOrUpdateAuthGroup(AuthGroup group) {
|
||||
Gson g = new Gson();
|
||||
if (group.getGroupId() == null) {
|
||||
int nextGroupId = 1;
|
||||
@@ -80,17 +79,14 @@ public class AuthServiceImpl implements AuthService {
|
||||
@Override
|
||||
public AuthorizedResourceResp queryAuthorizedResources(QueryAuthResReq req, User user) {
|
||||
Set<String> userOrgIds = userService.getUserAllOrgId(user.getName());
|
||||
if (!CollectionUtils.isEmpty(userOrgIds)) {
|
||||
req.setDepartmentIds(new ArrayList<>(userOrgIds));
|
||||
}
|
||||
List<AuthGroup> groups = getAuthGroups(req, user.getName());
|
||||
List<AuthGroup> groups = getAuthGroups(req.getModelIds(), user.getName(), new ArrayList<>(userOrgIds));
|
||||
AuthorizedResourceResp resource = new AuthorizedResourceResp();
|
||||
Map<String, List<AuthGroup>> authGroupsByModelId = groups.stream()
|
||||
Map<Long, List<AuthGroup>> authGroupsByModelId = groups.stream()
|
||||
.collect(Collectors.groupingBy(AuthGroup::getModelId));
|
||||
Map<String, List<AuthRes>> reqAuthRes = req.getResources().stream()
|
||||
Map<Long, List<AuthRes>> reqAuthRes = req.getResources().stream()
|
||||
.collect(Collectors.groupingBy(AuthRes::getModelId));
|
||||
|
||||
for (String modelId : reqAuthRes.keySet()) {
|
||||
for (Long modelId : reqAuthRes.keySet()) {
|
||||
List<AuthRes> reqResourcesList = reqAuthRes.get(modelId);
|
||||
AuthResGrp rg = new AuthResGrp();
|
||||
if (authGroupsByModelId.containsKey(modelId)) {
|
||||
@@ -113,7 +109,7 @@ public class AuthServiceImpl implements AuthService {
|
||||
}
|
||||
}
|
||||
|
||||
if (StringUtils.isNotEmpty(req.getModelId())) {
|
||||
if (req.getModelId() != null) {
|
||||
List<AuthGroup> authGroups = authGroupsByModelId.get(req.getModelId());
|
||||
if (!CollectionUtils.isEmpty(authGroups)) {
|
||||
for (AuthGroup group : authGroups) {
|
||||
@@ -130,17 +126,17 @@ public class AuthServiceImpl implements AuthService {
|
||||
return resource;
|
||||
}
|
||||
|
||||
private List<AuthGroup> getAuthGroups(QueryAuthResReq req, String userName) {
|
||||
private List<AuthGroup> getAuthGroups(List<Long> modelIds, String userName, List<String> departmentIds) {
|
||||
List<AuthGroup> groups = load().stream()
|
||||
.filter(group -> {
|
||||
if (!Objects.equals(group.getModelId(), req.getModelId())) {
|
||||
if (CollectionUtils.isEmpty(modelIds) || !modelIds.contains(group.getModelId())) {
|
||||
return false;
|
||||
}
|
||||
if (!CollectionUtils.isEmpty(group.getAuthorizedUsers()) && group.getAuthorizedUsers()
|
||||
.contains(userName)) {
|
||||
return true;
|
||||
}
|
||||
for (String departmentId : req.getDepartmentIds()) {
|
||||
for (String departmentId : departmentIds) {
|
||||
if (!CollectionUtils.isEmpty(group.getAuthorizedDepartmentIds())
|
||||
&& group.getAuthorizedDepartmentIds().contains(departmentId)) {
|
||||
return true;
|
||||
@@ -148,7 +144,7 @@ public class AuthServiceImpl implements AuthService {
|
||||
}
|
||||
return false;
|
||||
}).collect(Collectors.toList());
|
||||
log.info("user:{} department:{} authGroups:{}", userName, req.getDepartmentIds(), groups);
|
||||
log.info("user:{} department:{} authGroups:{}", userName, departmentIds, groups);
|
||||
return groups;
|
||||
}
|
||||
|
||||
|
||||
@@ -40,7 +40,7 @@ public class AuthController {
|
||||
@PostMapping("/createGroup")
|
||||
public void newAuthGroup(@RequestBody AuthGroup group) {
|
||||
group.setGroupId(null);
|
||||
authService.updateAuthGroup(group);
|
||||
authService.addOrUpdateAuthGroup(group);
|
||||
}
|
||||
|
||||
@PostMapping("/removeGroup")
|
||||
@@ -58,7 +58,7 @@ public class AuthController {
|
||||
if (group.getGroupId() == null || group.getGroupId() == 0) {
|
||||
throw new RuntimeException("groupId is empty");
|
||||
}
|
||||
authService.updateAuthGroup(group);
|
||||
authService.addOrUpdateAuthGroup(group);
|
||||
}
|
||||
|
||||
/**
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
package com.tencent.supersonic.chat.api.component;
|
||||
|
||||
import com.tencent.supersonic.chat.api.pojo.SemanticCorrectInfo;
|
||||
import net.sf.jsqlparser.JSQLParserException;
|
||||
import com.tencent.supersonic.chat.api.pojo.SemanticParseInfo;
|
||||
import com.tencent.supersonic.chat.api.pojo.request.QueryReq;
|
||||
|
||||
/**
|
||||
* A semantic corrector checks validity of extracted semantic information and
|
||||
@@ -9,5 +9,5 @@ import net.sf.jsqlparser.JSQLParserException;
|
||||
*/
|
||||
public interface SemanticCorrector {
|
||||
|
||||
void correct(SemanticCorrectInfo semanticCorrectInfo) throws JSQLParserException;
|
||||
void correct(QueryReq queryReq, SemanticParseInfo semanticParseInfo);
|
||||
}
|
||||
|
||||
@@ -8,10 +8,14 @@ import com.tencent.supersonic.semantic.api.model.request.PageDimensionReq;
|
||||
import com.tencent.supersonic.semantic.api.model.request.PageMetricReq;
|
||||
import com.tencent.supersonic.semantic.api.model.response.DomainResp;
|
||||
import com.tencent.supersonic.semantic.api.model.response.DimensionResp;
|
||||
import com.tencent.supersonic.semantic.api.model.response.ModelResp;
|
||||
import com.tencent.supersonic.semantic.api.model.response.ExplainResp;
|
||||
import com.tencent.supersonic.semantic.api.model.response.MetricResp;
|
||||
import com.tencent.supersonic.semantic.api.model.response.ModelResp;
|
||||
import com.tencent.supersonic.semantic.api.model.response.ModelSchemaResp;
|
||||
import com.tencent.supersonic.semantic.api.model.response.QueryResultWithSchemaResp;
|
||||
import com.tencent.supersonic.semantic.api.query.request.QueryDslReq;
|
||||
import com.tencent.supersonic.semantic.api.query.request.ExplainSqlReq;
|
||||
import com.tencent.supersonic.semantic.api.query.request.QueryDimValueReq;
|
||||
import com.tencent.supersonic.semantic.api.query.request.QueryS2SQLReq;
|
||||
import com.tencent.supersonic.semantic.api.query.request.QueryMultiStructReq;
|
||||
import com.tencent.supersonic.semantic.api.query.request.QueryStructReq;
|
||||
|
||||
@@ -28,16 +32,32 @@ import java.util.List;
|
||||
* as proxy to a remote semantic service.
|
||||
* </p>
|
||||
*/
|
||||
public interface SemanticLayer {
|
||||
public interface SemanticInterpreter {
|
||||
|
||||
QueryResultWithSchemaResp queryByStruct(QueryStructReq queryStructReq, User user);
|
||||
|
||||
QueryResultWithSchemaResp queryByMultiStruct(QueryMultiStructReq queryMultiStructReq, User user);
|
||||
QueryResultWithSchemaResp queryByDsl(QueryDslReq queryDslReq, User user);
|
||||
|
||||
QueryResultWithSchemaResp queryByS2SQL(QueryS2SQLReq queryS2SQLReq, User user);
|
||||
|
||||
QueryResultWithSchemaResp queryDimValue(QueryDimValueReq queryDimValueReq, User user);
|
||||
|
||||
List<ModelSchema> getModelSchema();
|
||||
|
||||
List<ModelSchema> getModelSchema(List<Long> ids);
|
||||
|
||||
ModelSchema getModelSchema(Long model, Boolean cacheEnable);
|
||||
PageInfo<DimensionResp> getDimensionPage(PageDimensionReq pageDimensionCmd);
|
||||
PageInfo<MetricResp> getMetricPage(PageMetricReq pageMetricCmd);
|
||||
|
||||
PageInfo<DimensionResp> getDimensionPage(PageDimensionReq pageDimensionReq);
|
||||
|
||||
PageInfo<MetricResp> getMetricPage(PageMetricReq pageDimensionReq, User user);
|
||||
|
||||
List<DomainResp> getDomainList(User user);
|
||||
|
||||
List<ModelResp> getModelList(AuthType authType, Long domainId, User user);
|
||||
|
||||
<T> ExplainResp explain(ExplainSqlReq<T> explainSqlReq, User user) throws Exception;
|
||||
|
||||
List<ModelSchemaResp> fetchModelSchema(List<Long> ids, Boolean cacheEnable);
|
||||
|
||||
}
|
||||
@@ -14,6 +14,10 @@ public interface SemanticQuery {
|
||||
|
||||
QueryResult execute(User user) throws SqlParseException;
|
||||
|
||||
void initS2Sql(User user);
|
||||
|
||||
String explain(User user);
|
||||
|
||||
SemanticParseInfo getParseInfo();
|
||||
|
||||
void setParseInfo(SemanticParseInfo parseInfo);
|
||||
|
||||
@@ -1,8 +1,13 @@
|
||||
package com.tencent.supersonic.chat.api.pojo;
|
||||
|
||||
import com.google.common.collect.Sets;
|
||||
import com.tencent.supersonic.common.pojo.ModelRela;
|
||||
import lombok.Data;
|
||||
import org.springframework.util.CollectionUtils;
|
||||
|
||||
import java.util.ArrayList;
|
||||
import java.util.HashSet;
|
||||
import java.util.List;
|
||||
import java.util.Optional;
|
||||
import java.util.Set;
|
||||
|
||||
@@ -13,7 +18,9 @@ public class ModelSchema {
|
||||
private Set<SchemaElement> metrics = new HashSet<>();
|
||||
private Set<SchemaElement> dimensions = new HashSet<>();
|
||||
private Set<SchemaElement> dimensionValues = new HashSet<>();
|
||||
private Set<SchemaElement> tags = new HashSet<>();
|
||||
private SchemaElement entity = new SchemaElement();
|
||||
private List<ModelRela> modelRelas = new ArrayList<>();
|
||||
|
||||
public SchemaElement getElement(SchemaElementType elementType, long elementID) {
|
||||
Optional<SchemaElement> element = Optional.empty();
|
||||
@@ -34,6 +41,9 @@ public class ModelSchema {
|
||||
case VALUE:
|
||||
element = dimensionValues.stream().filter(e -> e.getId() == elementID).findFirst();
|
||||
break;
|
||||
case TAG:
|
||||
element = tags.stream().filter(e -> e.getId() == elementID).findFirst();
|
||||
break;
|
||||
default:
|
||||
}
|
||||
|
||||
@@ -44,4 +54,45 @@ public class ModelSchema {
|
||||
}
|
||||
}
|
||||
|
||||
public SchemaElement getElement(SchemaElementType elementType, String name) {
|
||||
Optional<SchemaElement> element = Optional.empty();
|
||||
|
||||
switch (elementType) {
|
||||
case ENTITY:
|
||||
element = Optional.ofNullable(entity);
|
||||
break;
|
||||
case MODEL:
|
||||
element = Optional.of(model);
|
||||
break;
|
||||
case METRIC:
|
||||
element = metrics.stream().filter(e -> name.equals(e.getName())).findFirst();
|
||||
break;
|
||||
case DIMENSION:
|
||||
element = dimensions.stream().filter(e -> name.equals(e.getName())).findFirst();
|
||||
break;
|
||||
case VALUE:
|
||||
element = dimensionValues.stream().filter(e -> name.equals(e.getName())).findFirst();
|
||||
break;
|
||||
default:
|
||||
}
|
||||
|
||||
if (element.isPresent()) {
|
||||
return element.get();
|
||||
} else {
|
||||
return null;
|
||||
}
|
||||
}
|
||||
|
||||
public Set<Long> getModelClusterSet() {
|
||||
if (CollectionUtils.isEmpty(modelRelas)) {
|
||||
return Sets.newHashSet();
|
||||
}
|
||||
Set<Long> modelClusterSet = new HashSet<>();
|
||||
modelRelas.forEach(modelRela -> {
|
||||
modelClusterSet.add(modelRela.getToModelId());
|
||||
modelClusterSet.add(modelRela.getFromModelId());
|
||||
});
|
||||
return modelClusterSet;
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
@@ -15,6 +15,7 @@ public class QueryContext {
|
||||
private QueryReq request;
|
||||
private List<SemanticQuery> candidateQueries = new ArrayList<>();
|
||||
private SchemaMapInfo mapInfo = new SchemaMapInfo();
|
||||
private SchemaModelClusterMapInfo modelClusterMapInfo = new SchemaModelClusterMapInfo();
|
||||
|
||||
public QueryContext(QueryReq request) {
|
||||
this.request = request;
|
||||
|
||||
@@ -0,0 +1,19 @@
|
||||
package com.tencent.supersonic.chat.api.pojo;
|
||||
|
||||
|
||||
import lombok.AllArgsConstructor;
|
||||
import lombok.Builder;
|
||||
import lombok.Data;
|
||||
import lombok.NoArgsConstructor;
|
||||
|
||||
@Data
|
||||
@Builder
|
||||
@NoArgsConstructor
|
||||
@AllArgsConstructor
|
||||
public class RelatedSchemaElement {
|
||||
|
||||
private Long dimensionId;
|
||||
|
||||
private boolean isNecessary;
|
||||
|
||||
}
|
||||
@@ -1,30 +1,35 @@
|
||||
package com.tencent.supersonic.chat.api.pojo;
|
||||
|
||||
import com.google.common.base.Objects;
|
||||
import java.io.Serializable;
|
||||
import java.util.List;
|
||||
import lombok.AllArgsConstructor;
|
||||
import lombok.Builder;
|
||||
import lombok.Data;
|
||||
import lombok.Getter;
|
||||
import lombok.NoArgsConstructor;
|
||||
|
||||
import java.io.Serializable;
|
||||
import java.util.List;
|
||||
|
||||
@Data
|
||||
@Getter
|
||||
@Builder
|
||||
@AllArgsConstructor
|
||||
@NoArgsConstructor
|
||||
public class SchemaElement implements Serializable {
|
||||
|
||||
private Long model;
|
||||
private Long id;
|
||||
private String name;
|
||||
private String bizName;
|
||||
private Long useCnt;
|
||||
private SchemaElementType type;
|
||||
|
||||
private List<String> alias;
|
||||
|
||||
private List<SchemaValueMap> schemaValueMaps;
|
||||
private List<RelatedSchemaElement> relatedSchemaElements;
|
||||
|
||||
private String defaultAgg;
|
||||
|
||||
private double order;
|
||||
|
||||
@Override
|
||||
public boolean equals(Object o) {
|
||||
@@ -37,13 +42,13 @@ public class SchemaElement implements Serializable {
|
||||
SchemaElement schemaElement = (SchemaElement) o;
|
||||
return Objects.equal(model, schemaElement.model) && Objects.equal(id,
|
||||
schemaElement.id) && Objects.equal(name, schemaElement.name)
|
||||
&& Objects.equal(bizName, schemaElement.bizName) && Objects.equal(
|
||||
useCnt, schemaElement.useCnt) && Objects.equal(type, schemaElement.type);
|
||||
&& Objects.equal(bizName, schemaElement.bizName)
|
||||
&& Objects.equal(type, schemaElement.type);
|
||||
}
|
||||
|
||||
@Override
|
||||
public int hashCode() {
|
||||
return Objects.hashCode(model, id, name, bizName, useCnt, type);
|
||||
return Objects.hashCode(model, id, name, bizName, type);
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
@@ -6,6 +6,7 @@ public enum SchemaElementType {
|
||||
DIMENSION,
|
||||
VALUE,
|
||||
ENTITY,
|
||||
TAG,
|
||||
ID,
|
||||
DATE
|
||||
}
|
||||
|
||||
@@ -0,0 +1,61 @@
|
||||
package com.tencent.supersonic.chat.api.pojo;
|
||||
|
||||
import com.clickhouse.client.internal.apache.commons.compress.utils.Lists;
|
||||
import com.tencent.supersonic.common.pojo.ModelCluster;
|
||||
import lombok.Data;
|
||||
import org.springframework.util.CollectionUtils;
|
||||
|
||||
import java.util.HashMap;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
import java.util.Set;
|
||||
|
||||
@Data
|
||||
public class SchemaModelClusterMapInfo {
|
||||
|
||||
private Map<String, List<SchemaElementMatch>> modelElementMatches = new HashMap<>();
|
||||
|
||||
public Set<String> getMatchedModelClusters() {
|
||||
return modelElementMatches.keySet();
|
||||
}
|
||||
|
||||
public List<SchemaElementMatch> getMatchedElements(Long modelId) {
|
||||
for (String key : modelElementMatches.keySet()) {
|
||||
if (ModelCluster.getModelIdFromKey(key).contains(modelId)) {
|
||||
return modelElementMatches.get(key);
|
||||
}
|
||||
}
|
||||
return Lists.newArrayList();
|
||||
}
|
||||
|
||||
public List<SchemaElementMatch> getMatchedElements(String modelCluster) {
|
||||
return modelElementMatches.get(modelCluster);
|
||||
}
|
||||
|
||||
public Map<String, List<SchemaElementMatch>> getModelElementMatches() {
|
||||
return modelElementMatches;
|
||||
}
|
||||
|
||||
public Map<String, List<SchemaElementMatch>> getElementMatchesByModelIds(Set<Long> modelIds) {
|
||||
if (CollectionUtils.isEmpty(modelIds)) {
|
||||
return modelElementMatches;
|
||||
}
|
||||
Map<String, List<SchemaElementMatch>> modelElementMatchesFiltered = new HashMap<>();
|
||||
for (String key : modelElementMatches.keySet()) {
|
||||
for (Long modelId : modelIds) {
|
||||
if (ModelCluster.getModelIdFromKey(key).contains(modelId)) {
|
||||
modelElementMatchesFiltered.put(key, modelElementMatches.get(key));
|
||||
}
|
||||
}
|
||||
}
|
||||
return modelElementMatchesFiltered;
|
||||
}
|
||||
|
||||
public void setModelElementMatches(Map<String, List<SchemaElementMatch>> modelElementMatches) {
|
||||
this.modelElementMatches = modelElementMatches;
|
||||
}
|
||||
|
||||
public void setMatchedElements(String modelCluster, List<SchemaElementMatch> elementMatches) {
|
||||
modelElementMatches.put(modelCluster, elementMatches);
|
||||
}
|
||||
}
|
||||
@@ -1,53 +1,74 @@
|
||||
package com.tencent.supersonic.chat.api.pojo;
|
||||
|
||||
|
||||
import java.util.List;
|
||||
import java.util.Set;
|
||||
import java.util.TreeSet;
|
||||
import java.util.LinkedHashSet;
|
||||
import java.util.ArrayList;
|
||||
import java.util.Map;
|
||||
import java.util.HashMap;
|
||||
import java.util.Comparator;
|
||||
|
||||
import com.tencent.supersonic.chat.api.pojo.request.QueryFilter;
|
||||
import com.tencent.supersonic.chat.api.pojo.response.EntityInfo;
|
||||
import com.tencent.supersonic.chat.api.pojo.response.SqlInfo;
|
||||
import com.tencent.supersonic.common.pojo.DateConf;
|
||||
import com.tencent.supersonic.common.pojo.ModelCluster;
|
||||
import com.tencent.supersonic.common.pojo.Order;
|
||||
import com.tencent.supersonic.common.pojo.QueryType;
|
||||
import com.tencent.supersonic.common.pojo.enums.AggregateTypeEnum;
|
||||
import com.tencent.supersonic.common.pojo.enums.FilterType;
|
||||
import lombok.Data;
|
||||
|
||||
import java.util.ArrayList;
|
||||
import java.util.Comparator;
|
||||
import java.util.HashMap;
|
||||
import java.util.LinkedHashSet;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
import java.util.Set;
|
||||
import java.util.TreeSet;
|
||||
|
||||
@Data
|
||||
public class SemanticParseInfo {
|
||||
|
||||
private Integer id;
|
||||
private String queryMode;
|
||||
private SchemaElement model;
|
||||
private ModelCluster model = new ModelCluster();
|
||||
private Set<SchemaElement> metrics = new TreeSet<>(new SchemaNameLengthComparator());
|
||||
private Set<SchemaElement> dimensions = new LinkedHashSet();
|
||||
private SchemaElement entity;
|
||||
private AggregateTypeEnum aggType = AggregateTypeEnum.NONE;
|
||||
private FilterType filterType = FilterType.UNION;
|
||||
private Set<QueryFilter> dimensionFilters = new LinkedHashSet();
|
||||
private Set<QueryFilter> metricFilters = new LinkedHashSet();
|
||||
private Set<Order> orders = new LinkedHashSet();
|
||||
private DateConf dateInfo;
|
||||
private Long limit;
|
||||
private Boolean nativeQuery = false;
|
||||
private double score;
|
||||
private List<SchemaElementMatch> elementMatches = new ArrayList<>();
|
||||
private Map<String, Object> properties = new HashMap<>();
|
||||
private EntityInfo entityInfo;
|
||||
public Long getModelId() {
|
||||
return model != null ? model.getId() : 0L;
|
||||
private SqlInfo sqlInfo = new SqlInfo();
|
||||
private QueryType queryType = QueryType.OTHER;
|
||||
|
||||
public String getModelClusterKey() {
|
||||
if (model == null) {
|
||||
return "";
|
||||
}
|
||||
return model.getKey();
|
||||
}
|
||||
|
||||
public String getModelName() {
|
||||
return model != null ? model.getName() : "null";
|
||||
if (model == null) {
|
||||
return "";
|
||||
}
|
||||
return model.getName();
|
||||
}
|
||||
|
||||
private static class SchemaNameLengthComparator implements Comparator<SchemaElement> {
|
||||
|
||||
@Override
|
||||
public int compare(SchemaElement o1, SchemaElement o2) {
|
||||
if (o1.getOrder() != o2.getOrder()) {
|
||||
if (o1.getOrder() < o2.getOrder()) {
|
||||
return -1;
|
||||
} else {
|
||||
return 1;
|
||||
}
|
||||
}
|
||||
int len1 = o1.getName().length();
|
||||
int len2 = o2.getName().length();
|
||||
if (len1 != len2) {
|
||||
@@ -65,4 +86,26 @@ public class SemanticParseInfo {
|
||||
return metrics;
|
||||
}
|
||||
|
||||
private Map<Long, Integer> getModelElementCountMap() {
|
||||
Map<Long, Integer> elementCountMap = new HashMap<>();
|
||||
elementMatches.forEach(element -> {
|
||||
int count = elementCountMap.getOrDefault(element.getElement().getModel(), 0);
|
||||
elementCountMap.put(element.getElement().getModel(), count + 1);
|
||||
});
|
||||
return elementCountMap;
|
||||
}
|
||||
|
||||
public Long getModelId() {
|
||||
Map<Long, Integer> elementCountMap = getModelElementCountMap();
|
||||
Long modelId = -1L;
|
||||
int maxCnt = 0;
|
||||
for (Long model : elementCountMap.keySet()) {
|
||||
if (elementCountMap.get(model) > maxCnt) {
|
||||
maxCnt = elementCountMap.get(model);
|
||||
modelId = model;
|
||||
}
|
||||
}
|
||||
return modelId;
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
@@ -1,12 +1,18 @@
|
||||
package com.tencent.supersonic.chat.api.pojo;
|
||||
|
||||
import org.springframework.util.CollectionUtils;
|
||||
|
||||
import java.io.Serializable;
|
||||
import java.util.ArrayList;
|
||||
import java.util.HashMap;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
import java.util.Optional;
|
||||
import java.util.Set;
|
||||
import java.util.stream.Collectors;
|
||||
|
||||
public class SemanticSchema implements Serializable {
|
||||
|
||||
private List<ModelSchema> modelSchemaList;
|
||||
|
||||
public SemanticSchema(List<ModelSchema> modelSchemaList) {
|
||||
@@ -17,6 +23,64 @@ public class SemanticSchema implements Serializable {
|
||||
modelSchemaList.add(schema);
|
||||
}
|
||||
|
||||
public SchemaElement getElement(SchemaElementType elementType, long elementID) {
|
||||
Optional<SchemaElement> element = Optional.empty();
|
||||
|
||||
switch (elementType) {
|
||||
case ENTITY:
|
||||
element = getElementsById(elementID, getEntities());
|
||||
break;
|
||||
case MODEL:
|
||||
element = getElementsById(elementID, getModels());
|
||||
break;
|
||||
case METRIC:
|
||||
element = getElementsById(elementID, getMetrics());
|
||||
break;
|
||||
case DIMENSION:
|
||||
element = getElementsById(elementID, getDimensions());
|
||||
break;
|
||||
case VALUE:
|
||||
element = getElementsById(elementID, getDimensionValues());
|
||||
break;
|
||||
default:
|
||||
}
|
||||
|
||||
if (element.isPresent()) {
|
||||
return element.get();
|
||||
} else {
|
||||
return null;
|
||||
}
|
||||
}
|
||||
|
||||
public SchemaElement getElementByName(SchemaElementType elementType, String name) {
|
||||
Optional<SchemaElement> element = Optional.empty();
|
||||
|
||||
switch (elementType) {
|
||||
case ENTITY:
|
||||
element = getElementsByName(name, getEntities());
|
||||
break;
|
||||
case MODEL:
|
||||
element = getElementsByName(name, getModels());
|
||||
break;
|
||||
case METRIC:
|
||||
element = getElementsByName(name, getMetrics());
|
||||
break;
|
||||
case DIMENSION:
|
||||
element = getElementsByName(name, getDimensions());
|
||||
break;
|
||||
case VALUE:
|
||||
element = getElementsByName(name, getDimensionValues());
|
||||
break;
|
||||
default:
|
||||
}
|
||||
|
||||
if (element.isPresent()) {
|
||||
return element.get();
|
||||
} else {
|
||||
return null;
|
||||
}
|
||||
}
|
||||
|
||||
public Map<Long, String> getModelIdToName() {
|
||||
return modelSchemaList.stream()
|
||||
.collect(Collectors.toMap(a -> a.getModel().getId(), a -> a.getModel().getName(), (k1, k2) -> k1));
|
||||
@@ -34,21 +98,84 @@ public class SemanticSchema implements Serializable {
|
||||
return dimensions;
|
||||
}
|
||||
|
||||
public List<SchemaElement> getDimensions(Set<Long> modelIds) {
|
||||
List<SchemaElement> dimensions = getDimensions();
|
||||
return getElementsByModelId(modelIds, dimensions);
|
||||
}
|
||||
|
||||
public SchemaElement getDimensions(Long id) {
|
||||
List<SchemaElement> dimensions = getDimensions();
|
||||
Optional<SchemaElement> dimension = getElementsById(id, dimensions);
|
||||
return dimension.orElse(null);
|
||||
}
|
||||
|
||||
public List<SchemaElement> getTags() {
|
||||
List<SchemaElement> tags = new ArrayList<>();
|
||||
modelSchemaList.stream().forEach(d -> tags.addAll(d.getTags()));
|
||||
return tags;
|
||||
}
|
||||
|
||||
public List<SchemaElement> getTags(Set<Long> modelIds) {
|
||||
List<SchemaElement> tags = new ArrayList<>();
|
||||
modelSchemaList.stream().filter(schemaElement -> modelIds.contains(schemaElement.getModel()))
|
||||
.forEach(d -> tags.addAll(d.getTags()));
|
||||
return tags;
|
||||
}
|
||||
|
||||
public List<SchemaElement> getMetrics() {
|
||||
List<SchemaElement> metrics = new ArrayList<>();
|
||||
modelSchemaList.stream().forEach(d -> metrics.addAll(d.getMetrics()));
|
||||
return metrics;
|
||||
}
|
||||
|
||||
public List<SchemaElement> getMetrics(Set<Long> modelIds) {
|
||||
List<SchemaElement> metrics = getMetrics();
|
||||
return getElementsByModelId(modelIds, metrics);
|
||||
}
|
||||
|
||||
public List<SchemaElement> getEntities() {
|
||||
List<SchemaElement> entities = new ArrayList<>();
|
||||
modelSchemaList.stream().forEach(d -> entities.add(d.getEntity()));
|
||||
return entities;
|
||||
}
|
||||
|
||||
private List<SchemaElement> getElementsByModelId(Set<Long> modelIds, List<SchemaElement> elements) {
|
||||
return elements.stream()
|
||||
.filter(schemaElement -> modelIds.contains(schemaElement.getModel()))
|
||||
.collect(Collectors.toList());
|
||||
}
|
||||
|
||||
private Optional<SchemaElement> getElementsById(Long id, List<SchemaElement> elements) {
|
||||
return elements.stream()
|
||||
.filter(schemaElement -> id.equals(schemaElement.getId()))
|
||||
.findFirst();
|
||||
}
|
||||
|
||||
private Optional<SchemaElement> getElementsByName(String name, List<SchemaElement> elements) {
|
||||
return elements.stream()
|
||||
.filter(schemaElement -> name.equals(schemaElement.getName()))
|
||||
.findFirst();
|
||||
}
|
||||
|
||||
public List<SchemaElement> getModels() {
|
||||
List<SchemaElement> models = new ArrayList<>();
|
||||
modelSchemaList.stream().forEach(d -> models.add(d.getModel()));
|
||||
return models;
|
||||
}
|
||||
|
||||
public List<SchemaElement> getEntities() {
|
||||
List<SchemaElement> entities = new ArrayList<>();
|
||||
modelSchemaList.stream().forEach(d -> entities.add(d.getEntity()));
|
||||
return entities;
|
||||
public Map<String, String> getBizNameToName(Set<Long> modelIds) {
|
||||
List<SchemaElement> allElements = new ArrayList<>();
|
||||
allElements.addAll(getDimensions(modelIds));
|
||||
allElements.addAll(getMetrics(modelIds));
|
||||
return allElements.stream()
|
||||
.collect(Collectors.toMap(SchemaElement::getBizName, SchemaElement::getName, (k1, k2) -> k1));
|
||||
}
|
||||
|
||||
public Map<Long, ModelSchema> getModelSchemaMap() {
|
||||
if (CollectionUtils.isEmpty(modelSchemaList)) {
|
||||
return new HashMap<>();
|
||||
}
|
||||
return modelSchemaList.stream().collect(Collectors.toMap(modelSchema
|
||||
-> modelSchema.getModel().getModel(), modelSchema -> modelSchema));
|
||||
}
|
||||
}
|
||||
|
||||
@@ -32,6 +32,11 @@ public class ChatConfigBaseReq {
|
||||
*/
|
||||
private List<RecommendedQuestionReq> recommendedQuestions;
|
||||
|
||||
/**
|
||||
* the llm examples about the model
|
||||
*/
|
||||
private String llmExamples;
|
||||
|
||||
/**
|
||||
* available status
|
||||
*/
|
||||
|
||||
@@ -0,0 +1,23 @@
|
||||
package com.tencent.supersonic.chat.api.pojo.request;
|
||||
|
||||
import lombok.Data;
|
||||
import lombok.NoArgsConstructor;
|
||||
import lombok.ToString;
|
||||
|
||||
import javax.validation.constraints.NotNull;
|
||||
import java.util.List;
|
||||
|
||||
import static java.time.LocalDate.now;
|
||||
|
||||
@ToString
|
||||
@Data
|
||||
@NoArgsConstructor
|
||||
public class DictLatestTaskReq {
|
||||
|
||||
@NotNull
|
||||
private Long modelId;
|
||||
|
||||
private List<Long> dimIds;
|
||||
|
||||
private String createdAt = now().plusDays(-4).toString();
|
||||
}
|
||||
@@ -0,0 +1,20 @@
|
||||
package com.tencent.supersonic.chat.api.pojo.request;
|
||||
|
||||
import com.tencent.supersonic.common.pojo.enums.TaskStatusEnum;
|
||||
import lombok.Data;
|
||||
import lombok.ToString;
|
||||
|
||||
@ToString
|
||||
@Data
|
||||
public class DictTaskFilterReq {
|
||||
|
||||
private Long id;
|
||||
|
||||
private String name;
|
||||
|
||||
private String createdBy;
|
||||
|
||||
private String createdAt;
|
||||
|
||||
private TaskStatusEnum status;
|
||||
}
|
||||
@@ -1,12 +1,21 @@
|
||||
package com.tencent.supersonic.chat.api.pojo.request;
|
||||
|
||||
import javax.validation.constraints.NotNull;
|
||||
import lombok.Data;
|
||||
|
||||
@Data
|
||||
public class DimensionValueReq {
|
||||
|
||||
private Integer agentId;
|
||||
|
||||
@NotNull
|
||||
private Long elementID;
|
||||
|
||||
@NotNull
|
||||
private Long modelId;
|
||||
|
||||
private String bizName;
|
||||
|
||||
private Object value;
|
||||
@NotNull
|
||||
private String value;
|
||||
}
|
||||
|
||||
@@ -3,16 +3,18 @@ package com.tencent.supersonic.chat.api.pojo.request;
|
||||
|
||||
import com.tencent.supersonic.auth.api.authentication.pojo.User;
|
||||
import com.tencent.supersonic.chat.api.pojo.SemanticParseInfo;
|
||||
import lombok.Builder;
|
||||
import lombok.Data;
|
||||
|
||||
@Builder
|
||||
@Data
|
||||
public class ExecuteQueryReq {
|
||||
private User user;
|
||||
private Integer agentId;
|
||||
private Integer chatId;
|
||||
private String queryText;
|
||||
private Long queryId = 7L;
|
||||
private Integer parseId = 2;
|
||||
private Long queryId;
|
||||
private Integer parseId;
|
||||
private SemanticParseInfo parseInfo;
|
||||
private boolean saveAnswer = true;
|
||||
private boolean saveAnswer;
|
||||
}
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
package com.tencent.supersonic.chat.api.pojo.request;
|
||||
|
||||
import lombok.Data;
|
||||
import java.util.List;
|
||||
|
||||
@Data
|
||||
public class PageQueryInfoReq {
|
||||
@@ -11,27 +12,9 @@ public class PageQueryInfoReq {
|
||||
|
||||
private String userName;
|
||||
|
||||
public int getPageSize() {
|
||||
return pageSize;
|
||||
}
|
||||
private List<Long> ids;
|
||||
|
||||
public void setPageSize(int pageSize) {
|
||||
this.pageSize = pageSize;
|
||||
}
|
||||
|
||||
public int getCurrent() {
|
||||
return current;
|
||||
}
|
||||
|
||||
public void setCurrent(int current) {
|
||||
this.current = current;
|
||||
}
|
||||
|
||||
public String getUserName() {
|
||||
return userName;
|
||||
}
|
||||
|
||||
public void setUserName(String userName) {
|
||||
this.userName = userName;
|
||||
public Integer getLimitStart() {
|
||||
return this.pageSize * (this.current - 1);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,25 +1,21 @@
|
||||
package com.tencent.supersonic.chat.api.pojo.request;
|
||||
|
||||
|
||||
import com.tencent.supersonic.auth.api.authentication.pojo.User;
|
||||
import com.tencent.supersonic.chat.api.pojo.SchemaElement;
|
||||
import com.tencent.supersonic.common.pojo.DateConf;
|
||||
import com.tencent.supersonic.common.pojo.Order;
|
||||
import com.tencent.supersonic.common.pojo.enums.AggregateTypeEnum;
|
||||
import java.util.HashSet;
|
||||
import java.util.Set;
|
||||
import lombok.Data;
|
||||
|
||||
@Data
|
||||
public class QueryDataReq {
|
||||
String queryMode;
|
||||
SchemaElement model;
|
||||
Set<SchemaElement> metrics = new HashSet<>();
|
||||
Set<SchemaElement> dimensions = new HashSet<>();
|
||||
Set<QueryFilter> dimensionFilters = new HashSet<>();
|
||||
Set<QueryFilter> metricFilters = new HashSet<>();
|
||||
private AggregateTypeEnum aggType = AggregateTypeEnum.NONE;
|
||||
private Set<Order> orders = new HashSet<>();
|
||||
private User user;
|
||||
private Set<SchemaElement> metrics = new HashSet<>();
|
||||
private Set<SchemaElement> dimensions = new HashSet<>();
|
||||
private Set<QueryFilter> dimensionFilters = new HashSet<>();
|
||||
private Set<QueryFilter> metricFilters = new HashSet<>();
|
||||
private DateConf dateInfo;
|
||||
private Long limit;
|
||||
private Boolean nativeQuery = false;
|
||||
private Long queryId;
|
||||
private Integer parseId;
|
||||
}
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
package com.tencent.supersonic.chat.api.pojo.request;
|
||||
|
||||
import com.tencent.supersonic.semantic.api.query.enums.FilterOperatorEnum;
|
||||
import java.util.Objects;
|
||||
import com.google.common.base.Objects;
|
||||
import com.tencent.supersonic.common.pojo.enums.FilterOperatorEnum;
|
||||
import lombok.Data;
|
||||
import lombok.ToString;
|
||||
|
||||
@@ -19,6 +19,8 @@ public class QueryFilter {
|
||||
|
||||
private Long elementID;
|
||||
|
||||
private String function;
|
||||
|
||||
@Override
|
||||
public boolean equals(Object o) {
|
||||
if (this == o) {
|
||||
@@ -27,14 +29,15 @@ public class QueryFilter {
|
||||
if (o == null || getClass() != o.getClass()) {
|
||||
return false;
|
||||
}
|
||||
QueryFilter filter = (QueryFilter) o;
|
||||
return Objects.equals(bizName, filter.bizName) && Objects.equals(name, filter.name)
|
||||
&& operator == filter.operator && Objects.equals(value, filter.value) && Objects.equals(
|
||||
elementID, filter.elementID);
|
||||
QueryFilter that = (QueryFilter) o;
|
||||
return Objects.equal(bizName, that.bizName) && Objects.equal(name,
|
||||
that.name) && operator == that.operator && Objects.equal(value, that.value)
|
||||
&& Objects.equal(elementID, that.elementID) && Objects.equal(
|
||||
function, that.function);
|
||||
}
|
||||
|
||||
@Override
|
||||
public int hashCode() {
|
||||
return Objects.hash(bizName, name, operator, value, elementID);
|
||||
return Objects.hashCode(bizName, name, operator, value, elementID, function);
|
||||
}
|
||||
}
|
||||
@@ -7,7 +7,7 @@ import lombok.Data;
|
||||
public class QueryReq {
|
||||
private String queryText;
|
||||
private Integer chatId;
|
||||
private Long modelId = 0L;
|
||||
private Long modelId;
|
||||
private User user;
|
||||
private QueryFilters queryFilters;
|
||||
private boolean saveAnswer = true;
|
||||
|
||||
@@ -0,0 +1,13 @@
|
||||
package com.tencent.supersonic.chat.api.pojo.request;
|
||||
|
||||
|
||||
import lombok.Data;
|
||||
|
||||
@Data
|
||||
public class RecommendReq {
|
||||
|
||||
private Long modelId;
|
||||
|
||||
private Long metricId;
|
||||
|
||||
}
|
||||
@@ -0,0 +1,25 @@
|
||||
package com.tencent.supersonic.chat.api.pojo.request;
|
||||
|
||||
|
||||
import lombok.AllArgsConstructor;
|
||||
import lombok.Builder;
|
||||
import lombok.Data;
|
||||
import lombok.NoArgsConstructor;
|
||||
|
||||
@Data
|
||||
@NoArgsConstructor
|
||||
@AllArgsConstructor
|
||||
@Builder
|
||||
public class SolvedQueryReq {
|
||||
|
||||
private Long queryId;
|
||||
|
||||
private Integer parseId;
|
||||
|
||||
private String queryText;
|
||||
|
||||
private String modelId;
|
||||
|
||||
private Integer agentId;
|
||||
|
||||
}
|
||||
@@ -23,6 +23,8 @@ public class ChatConfigResp {
|
||||
|
||||
private List<RecommendedQuestionReq> recommendedQuestions;
|
||||
|
||||
private String llmExamples;
|
||||
|
||||
/**
|
||||
* available status
|
||||
*/
|
||||
|
||||
@@ -0,0 +1,30 @@
|
||||
package com.tencent.supersonic.chat.api.pojo.response;
|
||||
|
||||
import com.tencent.supersonic.common.pojo.enums.TaskStatusEnum;
|
||||
import lombok.Data;
|
||||
import lombok.ToString;
|
||||
|
||||
import java.util.Date;
|
||||
|
||||
@ToString
|
||||
@Data
|
||||
public class DictLatestTaskResp {
|
||||
|
||||
private Long dimId;
|
||||
|
||||
private Long id;
|
||||
|
||||
private String name;
|
||||
|
||||
private String description;
|
||||
|
||||
private String command;
|
||||
|
||||
private TaskStatusEnum status;
|
||||
|
||||
private String createdBy;
|
||||
|
||||
private Date createdAt;
|
||||
|
||||
private Long elapsedMs;
|
||||
}
|
||||
@@ -1,12 +1,13 @@
|
||||
package com.tencent.supersonic.chat.api.pojo.response;
|
||||
|
||||
import lombok.Data;
|
||||
|
||||
import java.io.Serializable;
|
||||
import java.util.List;
|
||||
import lombok.Data;
|
||||
|
||||
@Data
|
||||
public class ModelInfo extends DataInfo implements Serializable {
|
||||
|
||||
private List<String> words;
|
||||
private String primaryEntityBizName;
|
||||
private String primaryKey;
|
||||
}
|
||||
|
||||
@@ -1,30 +1,24 @@
|
||||
package com.tencent.supersonic.chat.api.pojo.response;
|
||||
|
||||
import com.google.common.collect.Lists;
|
||||
import com.tencent.supersonic.chat.api.pojo.SemanticParseInfo;
|
||||
import lombok.Data;
|
||||
import lombok.Getter;
|
||||
import lombok.Builder;
|
||||
import lombok.NoArgsConstructor;
|
||||
import lombok.AllArgsConstructor;
|
||||
|
||||
import java.util.List;
|
||||
|
||||
@Data
|
||||
@Getter
|
||||
@Builder
|
||||
@NoArgsConstructor
|
||||
@AllArgsConstructor
|
||||
public class ParseResp {
|
||||
private Integer chatId;
|
||||
private String queryText;
|
||||
private Long queryId;
|
||||
private ParseState state;
|
||||
private List<SemanticParseInfo> selectedParses;
|
||||
private List<SemanticParseInfo> candidateParses;
|
||||
private List<SemanticParseInfo> selectedParses = Lists.newArrayList();
|
||||
private List<SemanticParseInfo> candidateParses = Lists.newArrayList();
|
||||
private ParseTimeCostDO parseTimeCost = new ParseTimeCostDO();
|
||||
|
||||
public enum ParseState {
|
||||
COMPLETED,
|
||||
PENDING,
|
||||
FAILED
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
@@ -0,0 +1,15 @@
|
||||
package com.tencent.supersonic.chat.api.pojo.response;
|
||||
|
||||
import lombok.Data;
|
||||
|
||||
@Data
|
||||
public class ParseTimeCostDO {
|
||||
|
||||
private long parseStartTime;
|
||||
private long parseTime;
|
||||
private long sqlTime;
|
||||
|
||||
public ParseTimeCostDO() {
|
||||
this.parseStartTime = System.currentTimeMillis();
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,11 @@
|
||||
package com.tencent.supersonic.chat.api.pojo.response;
|
||||
|
||||
import lombok.Data;
|
||||
|
||||
import java.util.List;
|
||||
|
||||
@Data
|
||||
public class QueryRecallResp {
|
||||
private List<SolvedQueryRecallResp> solvedQueryRecallRespList;
|
||||
private Long queryTimeCost;
|
||||
}
|
||||
@@ -1,6 +1,8 @@
|
||||
package com.tencent.supersonic.chat.api.pojo.response;
|
||||
|
||||
import java.util.Date;
|
||||
import java.util.List;
|
||||
import com.tencent.supersonic.chat.api.pojo.SemanticParseInfo;
|
||||
import lombok.Data;
|
||||
|
||||
@Data
|
||||
@@ -13,4 +15,5 @@ public class QueryResp {
|
||||
private String feedback;
|
||||
private String queryText;
|
||||
private QueryResult queryResult;
|
||||
private List<SemanticParseInfo> parseInfos;
|
||||
}
|
||||
@@ -21,4 +21,5 @@ public class QueryResult {
|
||||
private SemanticParseInfo chatContext;
|
||||
private Object response;
|
||||
private List<Map<String, Object>> queryResults;
|
||||
private Long queryTimeCost;
|
||||
}
|
||||
|
||||
@@ -0,0 +1,17 @@
|
||||
package com.tencent.supersonic.chat.api.pojo.response;
|
||||
|
||||
|
||||
import lombok.Builder;
|
||||
import lombok.Data;
|
||||
|
||||
@Data
|
||||
@Builder
|
||||
public class SolvedQueryRecallResp {
|
||||
|
||||
private Long queryId;
|
||||
|
||||
private Integer parseId;
|
||||
|
||||
private String queryText;
|
||||
|
||||
}
|
||||
@@ -0,0 +1,11 @@
|
||||
package com.tencent.supersonic.chat.api.pojo.response;
|
||||
|
||||
import lombok.Data;
|
||||
|
||||
@Data
|
||||
public class SqlInfo {
|
||||
|
||||
private String s2SQL;
|
||||
private String correctS2SQL;
|
||||
private String querySQL;
|
||||
}
|
||||
@@ -59,16 +59,7 @@
|
||||
<groupId>org.springframework.boot</groupId>
|
||||
<artifactId>spring-boot-starter-web</artifactId>
|
||||
</dependency>
|
||||
<dependency>
|
||||
<groupId>org.mybatis</groupId>
|
||||
<artifactId>mybatis</artifactId>
|
||||
</dependency>
|
||||
|
||||
<dependency>
|
||||
<groupId>org.mybatis.spring.boot</groupId>
|
||||
<artifactId>mybatis-spring-boot-starter</artifactId>
|
||||
<version>${mybatis-spring.version}</version>
|
||||
</dependency>
|
||||
<dependency>
|
||||
<groupId>com.alibaba</groupId>
|
||||
<artifactId>druid</artifactId>
|
||||
@@ -78,24 +69,6 @@
|
||||
<groupId>mysql</groupId>
|
||||
<artifactId>mysql-connector-java</artifactId>
|
||||
</dependency>
|
||||
<dependency>
|
||||
<groupId>org.mybatis</groupId>
|
||||
<artifactId>mybatis-spring</artifactId>
|
||||
<scope>test</scope>
|
||||
</dependency>
|
||||
|
||||
<dependency>
|
||||
<groupId>org.mybatis.spring.boot</groupId>
|
||||
<artifactId>mybatis-spring-boot-starter-test</artifactId>
|
||||
<version>${mybatis.test.version}</version>
|
||||
<scope>test</scope>
|
||||
</dependency>
|
||||
|
||||
<dependency>
|
||||
<groupId>com.github.pagehelper</groupId>
|
||||
<artifactId>pagehelper-spring-boot-starter</artifactId>
|
||||
<version>${pagehelper.spring.version}</version>
|
||||
</dependency>
|
||||
|
||||
<dependency>
|
||||
<groupId>com.h2database</groupId>
|
||||
|
||||
@@ -2,7 +2,7 @@ package com.tencent.supersonic.chat.agent.tool;
|
||||
|
||||
public enum AgentToolType {
|
||||
RULE,
|
||||
DSL,
|
||||
LLM_S2SQL,
|
||||
PLUGIN,
|
||||
INTERPRET
|
||||
}
|
||||
|
||||
@@ -0,0 +1,16 @@
|
||||
package com.tencent.supersonic.chat.agent.tool;
|
||||
|
||||
|
||||
import java.util.List;
|
||||
import lombok.AllArgsConstructor;
|
||||
import lombok.Data;
|
||||
import lombok.NoArgsConstructor;
|
||||
|
||||
@Data
|
||||
@NoArgsConstructor
|
||||
@AllArgsConstructor
|
||||
public class CommonAgentTool extends AgentTool {
|
||||
|
||||
protected List<Long> modelIds;
|
||||
|
||||
}
|
||||
@@ -5,9 +5,7 @@ import lombok.Data;
|
||||
import java.util.List;
|
||||
|
||||
@Data
|
||||
public class DslTool extends AgentTool {
|
||||
|
||||
private List<Long> modelIds;
|
||||
public class LLMParserTool extends CommonAgentTool {
|
||||
|
||||
private List<String> exampleQuestions;
|
||||
|
||||
@@ -7,12 +7,13 @@ import org.apache.commons.collections.CollectionUtils;
|
||||
import java.util.List;
|
||||
|
||||
@Data
|
||||
public class RuleQueryTool extends AgentTool {
|
||||
public class RuleQueryTool extends CommonAgentTool {
|
||||
|
||||
private List<Long> modelIds;
|
||||
|
||||
private List<String> queryModes;
|
||||
|
||||
private List<String> queryTypes;
|
||||
|
||||
public boolean isContainsAllModel() {
|
||||
return CollectionUtils.isNotEmpty(modelIds) && modelIds.contains(-1L);
|
||||
}
|
||||
|
||||
@@ -7,13 +7,19 @@ import org.springframework.context.annotation.Configuration;
|
||||
|
||||
@Configuration
|
||||
@Data
|
||||
public class LLMConfig {
|
||||
public class LLMParserConfig {
|
||||
|
||||
|
||||
@Value("${llm.url:}")
|
||||
@Value("${llm.parser.url:}")
|
||||
private String url;
|
||||
|
||||
@Value("${query2sql.path:/query2sql}")
|
||||
private String queryToSqlPath;
|
||||
|
||||
@Value("${dimension.topn:5}")
|
||||
private Integer dimensionTopN;
|
||||
|
||||
@Value("${metric.topn:5}")
|
||||
private Integer metricTopN;
|
||||
|
||||
}
|
||||
@@ -1,43 +1,161 @@
|
||||
package com.tencent.supersonic.chat.config;
|
||||
|
||||
import com.tencent.supersonic.common.service.SysParameterService;
|
||||
import lombok.Data;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.apache.commons.lang3.StringUtils;
|
||||
import org.springframework.beans.factory.annotation.Autowired;
|
||||
import org.springframework.beans.factory.annotation.Value;
|
||||
import org.springframework.context.annotation.Configuration;
|
||||
import org.springframework.context.annotation.PropertySource;
|
||||
|
||||
@Configuration
|
||||
@Data
|
||||
@PropertySource("classpath:optimization.properties")
|
||||
//@ComponentScan(basePackages = "com.tencent.supersonic.chat")
|
||||
@Slf4j
|
||||
public class OptimizationConfig {
|
||||
|
||||
@Value("${one.detection.size}")
|
||||
@Value("${one.detection.size:8}")
|
||||
private Integer oneDetectionSize;
|
||||
@Value("${one.detection.max.size}")
|
||||
|
||||
@Value("${one.detection.max.size:20}")
|
||||
private Integer oneDetectionMaxSize;
|
||||
|
||||
@Value("${metric.dimension.min.threshold}")
|
||||
@Value("${metric.dimension.min.threshold:0.3}")
|
||||
private Double metricDimensionMinThresholdConfig;
|
||||
|
||||
@Value("${metric.dimension.threshold}")
|
||||
@Value("${metric.dimension.threshold:0.3}")
|
||||
private Double metricDimensionThresholdConfig;
|
||||
|
||||
@Value("${dimension.value.threshold}")
|
||||
@Value("${dimension.value.threshold:0.5}")
|
||||
private Double dimensionValueThresholdConfig;
|
||||
|
||||
@Value("${function.bonus.threshold}")
|
||||
private Double functionBonusThreshold;
|
||||
|
||||
@Value("${long.text.threshold}")
|
||||
@Value("${long.text.threshold:0.8}")
|
||||
private Double longTextThreshold;
|
||||
|
||||
@Value("${short.text.threshold}")
|
||||
@Value("${short.text.threshold:0.5}")
|
||||
private Double shortTextThreshold;
|
||||
|
||||
@Value("${query.text.length.threshold}")
|
||||
@Value("${query.text.length.threshold:10}")
|
||||
private Integer queryTextLengthThreshold;
|
||||
@Value("${embedding.mapper.word.min:4}")
|
||||
private int embeddingMapperWordMin;
|
||||
|
||||
@Value("${candidate.threshold}")
|
||||
private Double candidateThreshold;
|
||||
@Value("${embedding.mapper.word.max:5}")
|
||||
private int embeddingMapperWordMax;
|
||||
|
||||
@Value("${embedding.mapper.batch:50}")
|
||||
private int embeddingMapperBatch;
|
||||
|
||||
@Value("${embedding.mapper.number:5}")
|
||||
private int embeddingMapperNumber;
|
||||
|
||||
@Value("${embedding.mapper.round.number:10}")
|
||||
private int embeddingMapperRoundNumber;
|
||||
|
||||
@Value("${embedding.mapper.distance.threshold:0.58}")
|
||||
private Double embeddingMapperDistanceThreshold;
|
||||
|
||||
@Value("${s2SQL.linking.value.switch:true}")
|
||||
private boolean useLinkingValueSwitch;
|
||||
|
||||
@Value("${s2SQL.use.switch:true}")
|
||||
private boolean useS2SqlSwitch;
|
||||
|
||||
@Value("${text2sql.example.num:10}")
|
||||
private int text2sqlExampleNum;
|
||||
|
||||
@Value("${text2sql.fewShots.num:10}")
|
||||
private int text2sqlFewShotsNum;
|
||||
|
||||
@Value("${text2sql.self.consistency.num:5}")
|
||||
private int text2sqlSelfConsistencyNum;
|
||||
|
||||
@Value("${text2sql.collection.name:text2dsl_agent_collection}")
|
||||
private String text2sqlCollectionName;
|
||||
|
||||
@Autowired
|
||||
private SysParameterService sysParameterService;
|
||||
|
||||
public Integer getOneDetectionSize() {
|
||||
return convertValue("one.detection.size", Integer.class, oneDetectionSize);
|
||||
}
|
||||
|
||||
public Integer getOneDetectionMaxSize() {
|
||||
return convertValue("one.detection.max.size", Integer.class, oneDetectionMaxSize);
|
||||
}
|
||||
|
||||
public Double getMetricDimensionMinThresholdConfig() {
|
||||
return convertValue("metric.dimension.min.threshold", Double.class, metricDimensionMinThresholdConfig);
|
||||
}
|
||||
|
||||
public Double getMetricDimensionThresholdConfig() {
|
||||
return convertValue("metric.dimension.threshold", Double.class, metricDimensionThresholdConfig);
|
||||
}
|
||||
|
||||
public Double getDimensionValueThresholdConfig() {
|
||||
return convertValue("dimension.value.threshold", Double.class, dimensionValueThresholdConfig);
|
||||
}
|
||||
|
||||
public Double getLongTextThreshold() {
|
||||
return convertValue("long.text.threshold", Double.class, longTextThreshold);
|
||||
}
|
||||
|
||||
public Double getShortTextThreshold() {
|
||||
return convertValue("short.text.threshold", Double.class, shortTextThreshold);
|
||||
}
|
||||
|
||||
public Integer getQueryTextLengthThreshold() {
|
||||
return convertValue("query.text.length.threshold", Integer.class, queryTextLengthThreshold);
|
||||
}
|
||||
|
||||
public boolean isUseS2SqlSwitch() {
|
||||
return convertValue("use.s2SQL.switch", Boolean.class, useS2SqlSwitch);
|
||||
}
|
||||
|
||||
public Integer getEmbeddingMapperWordMin() {
|
||||
return convertValue("embedding.mapper.word.min", Integer.class, embeddingMapperWordMin);
|
||||
}
|
||||
|
||||
public Integer getEmbeddingMapperWordMax() {
|
||||
return convertValue("embedding.mapper.word.max", Integer.class, embeddingMapperWordMax);
|
||||
}
|
||||
|
||||
public Integer getEmbeddingMapperBatch() {
|
||||
return convertValue("embedding.mapper.batch", Integer.class, embeddingMapperBatch);
|
||||
}
|
||||
|
||||
public Integer getEmbeddingMapperNumber() {
|
||||
return convertValue("embedding.mapper.number", Integer.class, embeddingMapperNumber);
|
||||
}
|
||||
|
||||
public Integer getEmbeddingMapperRoundNumber() {
|
||||
return convertValue("embedding.mapper.round.number", Integer.class, embeddingMapperRoundNumber);
|
||||
}
|
||||
|
||||
public Double getEmbeddingMapperDistanceThreshold() {
|
||||
return convertValue("embedding.mapper.distance.threshold", Double.class, embeddingMapperDistanceThreshold);
|
||||
}
|
||||
|
||||
public boolean isUseLinkingValueSwitch() {
|
||||
return convertValue("s2SQL.linking.value.switch", Boolean.class, useLinkingValueSwitch);
|
||||
}
|
||||
|
||||
public <T> T convertValue(String paramName, Class<T> targetType, T defaultValue) {
|
||||
try {
|
||||
String value = sysParameterService.getSysParameter().getParameterByName(paramName);
|
||||
if (StringUtils.isBlank(value)) {
|
||||
return defaultValue;
|
||||
}
|
||||
if (targetType == Double.class) {
|
||||
return targetType.cast(Double.parseDouble(value));
|
||||
} else if (targetType == Integer.class) {
|
||||
return targetType.cast(Integer.parseInt(value));
|
||||
} else if (targetType == Boolean.class) {
|
||||
return targetType.cast(Boolean.parseBoolean(value));
|
||||
}
|
||||
} catch (Exception e) {
|
||||
log.error("convertValue", e);
|
||||
}
|
||||
return defaultValue;
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
@@ -2,20 +2,51 @@ package com.tencent.supersonic.chat.corrector;
|
||||
|
||||
import com.tencent.supersonic.chat.api.component.SemanticCorrector;
|
||||
import com.tencent.supersonic.chat.api.pojo.SchemaElement;
|
||||
import com.tencent.supersonic.chat.api.pojo.SemanticParseInfo;
|
||||
import com.tencent.supersonic.chat.api.pojo.SemanticSchema;
|
||||
import com.tencent.supersonic.chat.api.pojo.request.QueryReq;
|
||||
import com.tencent.supersonic.common.pojo.enums.AggregateTypeEnum;
|
||||
import com.tencent.supersonic.common.pojo.enums.TimeDimensionEnum;
|
||||
import com.tencent.supersonic.common.util.ContextUtils;
|
||||
import com.tencent.supersonic.common.util.jsqlparser.SqlParserAddHelper;
|
||||
import com.tencent.supersonic.common.util.jsqlparser.SqlParserSelectFunctionHelper;
|
||||
import com.tencent.supersonic.common.util.jsqlparser.SqlParserSelectHelper;
|
||||
import com.tencent.supersonic.knowledge.service.SchemaService;
|
||||
import com.tencent.supersonic.semantic.api.model.enums.TimeDimensionEnum;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.apache.commons.lang3.StringUtils;
|
||||
import org.springframework.util.CollectionUtils;
|
||||
|
||||
import java.util.ArrayList;
|
||||
import java.util.HashSet;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
import java.util.Objects;
|
||||
import java.util.Set;
|
||||
import java.util.stream.Collectors;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
|
||||
/**
|
||||
* basic semantic correction functionality, offering common methods and an
|
||||
* abstract method called doCorrect
|
||||
*/
|
||||
@Slf4j
|
||||
public abstract class BaseSemanticCorrector implements SemanticCorrector {
|
||||
public static final String DATE_FIELD = "数据日期";
|
||||
protected Map<String, String> getFieldToBizName(Long modelId) {
|
||||
|
||||
public void correct(QueryReq queryReq, SemanticParseInfo semanticParseInfo) {
|
||||
try {
|
||||
if (StringUtils.isBlank(semanticParseInfo.getSqlInfo().getCorrectS2SQL())) {
|
||||
return;
|
||||
}
|
||||
doCorrect(queryReq, semanticParseInfo);
|
||||
log.info("sqlCorrection:{} sql:{}", this.getClass().getSimpleName(), semanticParseInfo.getSqlInfo());
|
||||
} catch (Exception e) {
|
||||
log.error(String.format("correct error,sqlInfo:%s", semanticParseInfo.getSqlInfo()), e);
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
public abstract void doCorrect(QueryReq queryReq, SemanticParseInfo semanticParseInfo);
|
||||
|
||||
protected Map<String, String> getFieldNameMap(Set<Long> modelIds) {
|
||||
|
||||
SemanticSchema semanticSchema = ContextUtils.getBean(SchemaService.class).getSemanticSchema();
|
||||
|
||||
@@ -23,11 +54,78 @@ public abstract class BaseSemanticCorrector implements SemanticCorrector {
|
||||
dbAllFields.addAll(semanticSchema.getMetrics());
|
||||
dbAllFields.addAll(semanticSchema.getDimensions());
|
||||
|
||||
// support fieldName and field alias
|
||||
Map<String, String> result = dbAllFields.stream()
|
||||
.filter(entry -> entry.getModel().equals(modelId))
|
||||
.collect(Collectors.toMap(SchemaElement::getName, a -> a.getBizName(), (k1, k2) -> k1));
|
||||
result.put(DATE_FIELD, TimeDimensionEnum.DAY.getName());
|
||||
.filter(entry -> modelIds.contains(entry.getModel()))
|
||||
.flatMap(schemaElement -> {
|
||||
Set<String> elements = new HashSet<>();
|
||||
elements.add(schemaElement.getName());
|
||||
if (!CollectionUtils.isEmpty(schemaElement.getAlias())) {
|
||||
elements.addAll(schemaElement.getAlias());
|
||||
}
|
||||
return elements.stream();
|
||||
})
|
||||
.collect(Collectors.toMap(a -> a, a -> a, (k1, k2) -> k1));
|
||||
result.put(TimeDimensionEnum.DAY.getChName(), TimeDimensionEnum.DAY.getChName());
|
||||
result.put(TimeDimensionEnum.MONTH.getChName(), TimeDimensionEnum.MONTH.getChName());
|
||||
result.put(TimeDimensionEnum.WEEK.getChName(), TimeDimensionEnum.WEEK.getChName());
|
||||
|
||||
result.put(TimeDimensionEnum.DAY.getName(), TimeDimensionEnum.DAY.getChName());
|
||||
result.put(TimeDimensionEnum.MONTH.getName(), TimeDimensionEnum.MONTH.getChName());
|
||||
result.put(TimeDimensionEnum.WEEK.getName(), TimeDimensionEnum.WEEK.getChName());
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
protected void addFieldsToSelect(SemanticParseInfo semanticParseInfo, String correctS2SQL) {
|
||||
Set<String> selectFields = new HashSet<>(SqlParserSelectHelper.getSelectFields(correctS2SQL));
|
||||
Set<String> needAddFields = new HashSet<>(SqlParserSelectHelper.getGroupByFields(correctS2SQL));
|
||||
needAddFields.addAll(SqlParserSelectHelper.getOrderByFields(correctS2SQL));
|
||||
|
||||
// If there is no aggregate function in the S2SQL statement and
|
||||
// there is a data field in 'WHERE' statement, add the field to the 'SELECT' statement.
|
||||
if (!SqlParserSelectFunctionHelper.hasAggregateFunction(correctS2SQL)) {
|
||||
List<String> whereFields = SqlParserSelectHelper.getWhereFields(correctS2SQL);
|
||||
List<String> timeChNameList = TimeDimensionEnum.getChNameList();
|
||||
Set<String> timeFields = whereFields.stream().filter(field -> timeChNameList.contains(field))
|
||||
.collect(Collectors.toSet());
|
||||
needAddFields.addAll(timeFields);
|
||||
}
|
||||
|
||||
if (CollectionUtils.isEmpty(selectFields) || CollectionUtils.isEmpty(needAddFields)) {
|
||||
return;
|
||||
}
|
||||
|
||||
needAddFields.removeAll(selectFields);
|
||||
String replaceFields = SqlParserAddHelper.addFieldsToSelect(correctS2SQL, new ArrayList<>(needAddFields));
|
||||
semanticParseInfo.getSqlInfo().setCorrectS2SQL(replaceFields);
|
||||
}
|
||||
|
||||
protected void addAggregateToMetric(SemanticParseInfo semanticParseInfo) {
|
||||
//add aggregate to all metric
|
||||
String correctS2SQL = semanticParseInfo.getSqlInfo().getCorrectS2SQL();
|
||||
Set<Long> modelIds = semanticParseInfo.getModel().getModelIds();
|
||||
|
||||
List<SchemaElement> metrics = getMetricElements(modelIds);
|
||||
|
||||
Map<String, String> metricToAggregate = metrics.stream()
|
||||
.map(schemaElement -> {
|
||||
if (Objects.isNull(schemaElement.getDefaultAgg())) {
|
||||
schemaElement.setDefaultAgg(AggregateTypeEnum.SUM.name());
|
||||
}
|
||||
return schemaElement;
|
||||
}).collect(Collectors.toMap(a -> a.getName(), a -> a.getDefaultAgg(), (k1, k2) -> k1));
|
||||
|
||||
if (CollectionUtils.isEmpty(metricToAggregate)) {
|
||||
return;
|
||||
}
|
||||
String aggregateSql = SqlParserAddHelper.addAggregateToField(correctS2SQL, metricToAggregate);
|
||||
semanticParseInfo.getSqlInfo().setCorrectS2SQL(aggregateSql);
|
||||
}
|
||||
|
||||
protected List<SchemaElement> getMetricElements(Set<Long> modelIds) {
|
||||
SemanticSchema semanticSchema = ContextUtils.getBean(SchemaService.class).getSemanticSchema();
|
||||
return semanticSchema.getMetrics(modelIds);
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
@@ -1,27 +0,0 @@
|
||||
package com.tencent.supersonic.chat.corrector;
|
||||
|
||||
import com.tencent.supersonic.chat.api.pojo.SemanticCorrectInfo;
|
||||
import com.tencent.supersonic.chat.parser.llm.dsl.DSLDateHelper;
|
||||
import com.tencent.supersonic.common.util.jsqlparser.SqlParserSelectHelper;
|
||||
import com.tencent.supersonic.common.util.jsqlparser.SqlParserUpdateHelper;
|
||||
import java.util.List;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.springframework.util.CollectionUtils;
|
||||
|
||||
@Slf4j
|
||||
public class DateFieldCorrector extends BaseSemanticCorrector {
|
||||
|
||||
@Override
|
||||
public void correct(SemanticCorrectInfo semanticCorrectInfo) {
|
||||
|
||||
String sql = semanticCorrectInfo.getSql();
|
||||
List<String> whereFields = SqlParserSelectHelper.getWhereFields(sql);
|
||||
if (CollectionUtils.isEmpty(whereFields) || !whereFields.contains(DATE_FIELD)) {
|
||||
String currentDate = DSLDateHelper.getReferenceDate(semanticCorrectInfo.getParseInfo().getModelId());
|
||||
sql = SqlParserUpdateHelper.addWhere(sql, DATE_FIELD, currentDate);
|
||||
}
|
||||
semanticCorrectInfo.setPreSql(semanticCorrectInfo.getSql());
|
||||
semanticCorrectInfo.setSql(sql);
|
||||
}
|
||||
|
||||
}
|
||||
@@ -1,18 +0,0 @@
|
||||
package com.tencent.supersonic.chat.corrector;
|
||||
|
||||
import com.tencent.supersonic.chat.api.pojo.SemanticCorrectInfo;
|
||||
import com.tencent.supersonic.common.util.jsqlparser.SqlParserUpdateHelper;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
|
||||
@Slf4j
|
||||
public class FieldCorrector extends BaseSemanticCorrector {
|
||||
|
||||
@Override
|
||||
public void correct(SemanticCorrectInfo semanticCorrectInfo) {
|
||||
String preSql = semanticCorrectInfo.getSql();
|
||||
semanticCorrectInfo.setPreSql(preSql);
|
||||
String sql = SqlParserUpdateHelper.replaceFields(preSql,
|
||||
getFieldToBizName(semanticCorrectInfo.getParseInfo().getModelId()));
|
||||
semanticCorrectInfo.setSql(sql);
|
||||
}
|
||||
}
|
||||
@@ -1,49 +0,0 @@
|
||||
package com.tencent.supersonic.chat.corrector;
|
||||
|
||||
import com.tencent.supersonic.chat.api.pojo.SemanticCorrectInfo;
|
||||
import com.tencent.supersonic.chat.parser.llm.dsl.DSLParseResult;
|
||||
import com.tencent.supersonic.chat.query.llm.dsl.LLMReq;
|
||||
import com.tencent.supersonic.chat.query.llm.dsl.LLMReq.ElementValue;
|
||||
import com.tencent.supersonic.common.pojo.Constants;
|
||||
import com.tencent.supersonic.common.util.JsonUtil;
|
||||
import com.tencent.supersonic.common.util.jsqlparser.SqlParserUpdateHelper;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
import java.util.Objects;
|
||||
import java.util.Set;
|
||||
import java.util.stream.Collectors;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.springframework.util.CollectionUtils;
|
||||
|
||||
@Slf4j
|
||||
public class FieldNameCorrector extends BaseSemanticCorrector {
|
||||
|
||||
@Override
|
||||
public void correct(SemanticCorrectInfo semanticCorrectInfo) {
|
||||
|
||||
Object context = semanticCorrectInfo.getParseInfo().getProperties().get(Constants.CONTEXT);
|
||||
if (Objects.isNull(context)) {
|
||||
return;
|
||||
}
|
||||
|
||||
DSLParseResult dslParseResult = JsonUtil.toObject(JsonUtil.toString(context), DSLParseResult.class);
|
||||
if (Objects.isNull(dslParseResult) || Objects.isNull(dslParseResult.getLlmReq())) {
|
||||
return;
|
||||
}
|
||||
LLMReq llmReq = dslParseResult.getLlmReq();
|
||||
List<ElementValue> linking = llmReq.getLinking();
|
||||
if (CollectionUtils.isEmpty(linking)) {
|
||||
return;
|
||||
}
|
||||
|
||||
Map<String, Set<String>> fieldValueToFieldNames = linking.stream().collect(
|
||||
Collectors.groupingBy(ElementValue::getFieldValue,
|
||||
Collectors.mapping(ElementValue::getFieldName, Collectors.toSet())));
|
||||
|
||||
String preSql = semanticCorrectInfo.getSql();
|
||||
semanticCorrectInfo.setPreSql(preSql);
|
||||
String sql = SqlParserUpdateHelper.replaceFieldNameByValue(preSql, fieldValueToFieldNames);
|
||||
semanticCorrectInfo.setSql(sql);
|
||||
}
|
||||
|
||||
}
|
||||
@@ -1,81 +0,0 @@
|
||||
package com.tencent.supersonic.chat.corrector;
|
||||
|
||||
import com.tencent.supersonic.chat.api.pojo.SemanticCorrectInfo;
|
||||
import com.tencent.supersonic.chat.api.pojo.SchemaElement;
|
||||
import com.tencent.supersonic.chat.api.pojo.SchemaValueMap;
|
||||
import com.tencent.supersonic.chat.api.pojo.SemanticSchema;
|
||||
import com.tencent.supersonic.common.util.ContextUtils;
|
||||
import com.tencent.supersonic.common.util.jsqlparser.SqlParserUpdateHelper;
|
||||
import com.tencent.supersonic.knowledge.service.SchemaService;
|
||||
import java.util.HashMap;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
import java.util.Objects;
|
||||
import java.util.stream.Collectors;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.apache.logging.log4j.util.Strings;
|
||||
import org.springframework.util.CollectionUtils;
|
||||
|
||||
@Slf4j
|
||||
public class FieldValueCorrector extends BaseSemanticCorrector {
|
||||
|
||||
@Override
|
||||
public void correct(SemanticCorrectInfo semanticCorrectInfo) {
|
||||
SemanticSchema semanticSchema = ContextUtils.getBean(SchemaService.class).getSemanticSchema();
|
||||
Long modelId = semanticCorrectInfo.getParseInfo().getModel().getId();
|
||||
List<SchemaElement> dimensions = semanticSchema.getDimensions().stream()
|
||||
.filter(schemaElement -> modelId.equals(schemaElement.getModel()))
|
||||
.collect(Collectors.toList());
|
||||
|
||||
if (CollectionUtils.isEmpty(dimensions)) {
|
||||
return;
|
||||
}
|
||||
|
||||
Map<String, Map<String, String>> aliasAndBizNameToTechName = getAliasAndBizNameToTechName(dimensions);
|
||||
String preSql = semanticCorrectInfo.getSql();
|
||||
semanticCorrectInfo.setPreSql(preSql);
|
||||
String sql = SqlParserUpdateHelper.replaceValue(preSql, aliasAndBizNameToTechName);
|
||||
semanticCorrectInfo.setSql(sql);
|
||||
return;
|
||||
}
|
||||
|
||||
|
||||
private Map<String, Map<String, String>> getAliasAndBizNameToTechName(List<SchemaElement> dimensions) {
|
||||
if (CollectionUtils.isEmpty(dimensions)) {
|
||||
return new HashMap<>();
|
||||
}
|
||||
|
||||
Map<String, Map<String, String>> result = new HashMap<>();
|
||||
|
||||
for (SchemaElement dimension : dimensions) {
|
||||
if (Objects.isNull(dimension)
|
||||
|| Strings.isEmpty(dimension.getBizName())
|
||||
|| CollectionUtils.isEmpty(dimension.getSchemaValueMaps())) {
|
||||
continue;
|
||||
}
|
||||
String bizName = dimension.getBizName();
|
||||
|
||||
Map<String, String> aliasAndBizNameToTechName = new HashMap<>();
|
||||
|
||||
for (SchemaValueMap valueMap : dimension.getSchemaValueMaps()) {
|
||||
if (Objects.isNull(valueMap) || Strings.isEmpty(valueMap.getTechName())) {
|
||||
continue;
|
||||
}
|
||||
if (Strings.isNotEmpty(valueMap.getBizName())) {
|
||||
aliasAndBizNameToTechName.put(valueMap.getBizName(), valueMap.getTechName());
|
||||
}
|
||||
if (!CollectionUtils.isEmpty(valueMap.getAlias())) {
|
||||
valueMap.getAlias().stream().forEach(alias -> {
|
||||
if (Strings.isNotEmpty(alias)) {
|
||||
aliasAndBizNameToTechName.put(alias, valueMap.getTechName());
|
||||
}
|
||||
});
|
||||
}
|
||||
}
|
||||
if (!CollectionUtils.isEmpty(aliasAndBizNameToTechName)) {
|
||||
result.put(bizName, aliasAndBizNameToTechName);
|
||||
}
|
||||
}
|
||||
return result;
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,17 @@
|
||||
package com.tencent.supersonic.chat.corrector;
|
||||
|
||||
import com.tencent.supersonic.chat.api.pojo.SemanticParseInfo;
|
||||
import com.tencent.supersonic.chat.api.pojo.request.QueryReq;
|
||||
import com.tencent.supersonic.common.util.jsqlparser.SqlParserReplaceHelper;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
|
||||
@Slf4j
|
||||
public class FromCorrector extends BaseSemanticCorrector {
|
||||
|
||||
@Override
|
||||
public void doCorrect(QueryReq queryReq, SemanticParseInfo semanticParseInfo) {
|
||||
String modelName = semanticParseInfo.getModel().getName();
|
||||
SqlParserReplaceHelper.replaceTable(semanticParseInfo.getSqlInfo().getCorrectS2SQL(), modelName);
|
||||
}
|
||||
|
||||
}
|
||||
@@ -1,16 +0,0 @@
|
||||
package com.tencent.supersonic.chat.corrector;
|
||||
|
||||
import com.tencent.supersonic.chat.api.pojo.SemanticCorrectInfo;
|
||||
import com.tencent.supersonic.common.util.jsqlparser.SqlParserUpdateHelper;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
|
||||
@Slf4j
|
||||
public class FunctionAliasCorrector extends BaseSemanticCorrector {
|
||||
|
||||
@Override
|
||||
public void correct(SemanticCorrectInfo semanticCorrectInfo) {
|
||||
String replaceAlias = SqlParserUpdateHelper.replaceAlias(semanticCorrectInfo.getSql());
|
||||
semanticCorrectInfo.setSql(replaceAlias);
|
||||
}
|
||||
|
||||
}
|
||||
@@ -1,17 +0,0 @@
|
||||
package com.tencent.supersonic.chat.corrector;
|
||||
|
||||
import com.tencent.supersonic.chat.api.pojo.SemanticCorrectInfo;
|
||||
import com.tencent.supersonic.common.util.jsqlparser.SqlParserUpdateHelper;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
|
||||
@Slf4j
|
||||
public class FunctionCorrector extends BaseSemanticCorrector {
|
||||
|
||||
@Override
|
||||
public void correct(SemanticCorrectInfo semanticCorrectInfo) {
|
||||
String preSql = semanticCorrectInfo.getSql();
|
||||
semanticCorrectInfo.setPreSql(preSql);
|
||||
String sql = SqlParserUpdateHelper.replaceFunction(preSql);
|
||||
semanticCorrectInfo.setSql(sql);
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,91 @@
|
||||
package com.tencent.supersonic.chat.corrector;
|
||||
|
||||
import com.tencent.supersonic.chat.api.pojo.SemanticParseInfo;
|
||||
import com.tencent.supersonic.chat.api.pojo.SemanticSchema;
|
||||
import com.tencent.supersonic.chat.api.pojo.request.QueryReq;
|
||||
import com.tencent.supersonic.chat.api.pojo.response.SqlInfo;
|
||||
import com.tencent.supersonic.common.pojo.enums.TimeDimensionEnum;
|
||||
import com.tencent.supersonic.common.util.ContextUtils;
|
||||
import com.tencent.supersonic.common.util.jsqlparser.SqlParserAddHelper;
|
||||
import com.tencent.supersonic.common.util.jsqlparser.SqlParserSelectHelper;
|
||||
import com.tencent.supersonic.knowledge.service.SchemaService;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.springframework.util.CollectionUtils;
|
||||
|
||||
import java.util.HashSet;
|
||||
import java.util.List;
|
||||
import java.util.Set;
|
||||
import java.util.stream.Collectors;
|
||||
|
||||
/**
|
||||
* Perform SQL corrections on the "group by" section in S2SQL.
|
||||
*/
|
||||
@Slf4j
|
||||
public class GroupByCorrector extends BaseSemanticCorrector {
|
||||
|
||||
@Override
|
||||
public void doCorrect(QueryReq queryReq, SemanticParseInfo semanticParseInfo) {
|
||||
|
||||
addGroupByFields(semanticParseInfo);
|
||||
|
||||
}
|
||||
|
||||
private void addGroupByFields(SemanticParseInfo semanticParseInfo) {
|
||||
Set<Long> modelIds = semanticParseInfo.getModel().getModelIds();
|
||||
|
||||
//add dimension group by
|
||||
SqlInfo sqlInfo = semanticParseInfo.getSqlInfo();
|
||||
String correctS2SQL = sqlInfo.getCorrectS2SQL();
|
||||
SemanticSchema semanticSchema = ContextUtils.getBean(SchemaService.class).getSemanticSchema();
|
||||
//add alias field name
|
||||
Set<String> dimensions = semanticSchema.getDimensions(modelIds).stream()
|
||||
.flatMap(
|
||||
schemaElement -> {
|
||||
Set<String> elements = new HashSet<>();
|
||||
elements.add(schemaElement.getName());
|
||||
if (!CollectionUtils.isEmpty(schemaElement.getAlias())) {
|
||||
elements.addAll(schemaElement.getAlias());
|
||||
}
|
||||
return elements.stream();
|
||||
}
|
||||
).collect(Collectors.toSet());
|
||||
dimensions.add(TimeDimensionEnum.DAY.getChName());
|
||||
|
||||
List<String> selectFields = SqlParserSelectHelper.getSelectFields(correctS2SQL);
|
||||
|
||||
if (CollectionUtils.isEmpty(selectFields) || CollectionUtils.isEmpty(dimensions)) {
|
||||
return;
|
||||
}
|
||||
// if only date in select not add group by.
|
||||
if (selectFields.size() == 1 && selectFields.contains(TimeDimensionEnum.DAY.getChName())) {
|
||||
return;
|
||||
}
|
||||
if (SqlParserSelectHelper.hasGroupBy(correctS2SQL)) {
|
||||
log.info("not add group by ,exist group by in correctS2SQL:{}", correctS2SQL);
|
||||
return;
|
||||
}
|
||||
|
||||
List<String> aggregateFields = SqlParserSelectHelper.getAggregateFields(correctS2SQL);
|
||||
Set<String> groupByFields = selectFields.stream()
|
||||
.filter(field -> dimensions.contains(field))
|
||||
.filter(field -> {
|
||||
if (!CollectionUtils.isEmpty(aggregateFields) && aggregateFields.contains(field)) {
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
})
|
||||
.collect(Collectors.toSet());
|
||||
semanticParseInfo.getSqlInfo().setCorrectS2SQL(SqlParserAddHelper.addGroupBy(correctS2SQL, groupByFields));
|
||||
|
||||
addAggregate(semanticParseInfo);
|
||||
}
|
||||
|
||||
private void addAggregate(SemanticParseInfo semanticParseInfo) {
|
||||
List<String> sqlGroupByFields = SqlParserSelectHelper.getGroupByFields(
|
||||
semanticParseInfo.getSqlInfo().getCorrectS2SQL());
|
||||
if (CollectionUtils.isEmpty(sqlGroupByFields)) {
|
||||
return;
|
||||
}
|
||||
addAggregateToMetric(semanticParseInfo);
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,66 @@
|
||||
package com.tencent.supersonic.chat.corrector;
|
||||
|
||||
import com.tencent.supersonic.chat.api.pojo.SemanticParseInfo;
|
||||
import com.tencent.supersonic.chat.api.pojo.SemanticSchema;
|
||||
import com.tencent.supersonic.chat.api.pojo.request.QueryReq;
|
||||
import com.tencent.supersonic.common.util.ContextUtils;
|
||||
import com.tencent.supersonic.common.util.jsqlparser.SqlParserAddHelper;
|
||||
import com.tencent.supersonic.common.util.jsqlparser.SqlParserSelectFunctionHelper;
|
||||
import com.tencent.supersonic.common.util.jsqlparser.SqlParserSelectHelper;
|
||||
import com.tencent.supersonic.knowledge.service.SchemaService;
|
||||
|
||||
import java.util.Set;
|
||||
import java.util.List;
|
||||
import java.util.stream.Collectors;
|
||||
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import net.sf.jsqlparser.expression.Expression;
|
||||
import org.springframework.util.CollectionUtils;
|
||||
|
||||
/**
|
||||
* Perform SQL corrections on the "Having" section in S2SQL.
|
||||
*/
|
||||
@Slf4j
|
||||
public class HavingCorrector extends BaseSemanticCorrector {
|
||||
|
||||
@Override
|
||||
public void doCorrect(QueryReq queryReq, SemanticParseInfo semanticParseInfo) {
|
||||
|
||||
//add aggregate to all metric
|
||||
addHaving(semanticParseInfo);
|
||||
|
||||
//add having expression filed to select
|
||||
addHavingToSelect(semanticParseInfo);
|
||||
|
||||
}
|
||||
|
||||
private void addHaving(SemanticParseInfo semanticParseInfo) {
|
||||
Set<Long> modelIds = semanticParseInfo.getModel().getModelIds();
|
||||
|
||||
SemanticSchema semanticSchema = ContextUtils.getBean(SchemaService.class).getSemanticSchema();
|
||||
|
||||
Set<String> metrics = semanticSchema.getMetrics(modelIds).stream()
|
||||
.map(schemaElement -> schemaElement.getName()).collect(Collectors.toSet());
|
||||
|
||||
if (CollectionUtils.isEmpty(metrics)) {
|
||||
return;
|
||||
}
|
||||
String havingSql = SqlParserAddHelper.addHaving(semanticParseInfo.getSqlInfo().getCorrectS2SQL(), metrics);
|
||||
semanticParseInfo.getSqlInfo().setCorrectS2SQL(havingSql);
|
||||
}
|
||||
|
||||
private void addHavingToSelect(SemanticParseInfo semanticParseInfo) {
|
||||
String correctS2SQL = semanticParseInfo.getSqlInfo().getCorrectS2SQL();
|
||||
if (!SqlParserSelectFunctionHelper.hasAggregateFunction(correctS2SQL)) {
|
||||
return;
|
||||
}
|
||||
List<Expression> havingExpressionList = SqlParserSelectHelper.getHavingExpression(correctS2SQL);
|
||||
if (!CollectionUtils.isEmpty(havingExpressionList)) {
|
||||
String replaceSql = SqlParserAddHelper.addFunctionToSelect(correctS2SQL, havingExpressionList);
|
||||
semanticParseInfo.getSqlInfo().setCorrectS2SQL(replaceSql);
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
|
||||
}
|
||||
@@ -1,48 +0,0 @@
|
||||
package com.tencent.supersonic.chat.corrector;
|
||||
|
||||
import com.tencent.supersonic.chat.api.pojo.SemanticCorrectInfo;
|
||||
import com.tencent.supersonic.chat.api.pojo.request.QueryFilters;
|
||||
import com.tencent.supersonic.common.pojo.Constants;
|
||||
import com.tencent.supersonic.common.util.StringUtil;
|
||||
import com.tencent.supersonic.common.util.jsqlparser.SqlParserUpdateHelper;
|
||||
import java.util.Objects;
|
||||
import java.util.stream.Collectors;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import net.sf.jsqlparser.JSQLParserException;
|
||||
import net.sf.jsqlparser.expression.Expression;
|
||||
import net.sf.jsqlparser.parser.CCJSqlParserUtil;
|
||||
import org.apache.commons.collections.CollectionUtils;
|
||||
import org.apache.commons.lang3.StringUtils;
|
||||
|
||||
@Slf4j
|
||||
public class QueryFilterAppend extends BaseSemanticCorrector {
|
||||
|
||||
@Override
|
||||
public void correct(SemanticCorrectInfo semanticCorrectInfo) throws JSQLParserException {
|
||||
String queryFilter = getQueryFilter(semanticCorrectInfo.getQueryFilters());
|
||||
String preSql = semanticCorrectInfo.getSql();
|
||||
|
||||
if (StringUtils.isNotEmpty(queryFilter)) {
|
||||
log.info("add queryFilter to preSql :{}", queryFilter);
|
||||
Expression expression = CCJSqlParserUtil.parseCondExpression(queryFilter);
|
||||
String sql = SqlParserUpdateHelper.addWhere(preSql, expression);
|
||||
semanticCorrectInfo.setPreSql(preSql);
|
||||
semanticCorrectInfo.setSql(sql);
|
||||
}
|
||||
}
|
||||
|
||||
private String getQueryFilter(QueryFilters queryFilters) {
|
||||
if (Objects.isNull(queryFilters) || CollectionUtils.isEmpty(queryFilters.getFilters())) {
|
||||
return null;
|
||||
}
|
||||
return queryFilters.getFilters().stream()
|
||||
.map(filter -> {
|
||||
String bizNameWrap = StringUtil.getSpaceWrap(filter.getBizName());
|
||||
String operatorWrap = StringUtil.getSpaceWrap(filter.getOperator().getValue());
|
||||
String valueWrap = StringUtil.getCommaWrap(filter.getValue().toString());
|
||||
return bizNameWrap + operatorWrap + valueWrap;
|
||||
})
|
||||
.collect(Collectors.joining(Constants.AND_UPPER));
|
||||
}
|
||||
|
||||
}
|
||||
@@ -0,0 +1,108 @@
|
||||
package com.tencent.supersonic.chat.corrector;
|
||||
|
||||
import com.tencent.supersonic.chat.api.pojo.SemanticParseInfo;
|
||||
import com.tencent.supersonic.chat.api.pojo.request.QueryReq;
|
||||
import com.tencent.supersonic.chat.api.pojo.response.SqlInfo;
|
||||
import com.tencent.supersonic.chat.parser.llm.s2sql.ParseResult;
|
||||
import com.tencent.supersonic.chat.query.llm.s2sql.LLMReq.ElementValue;
|
||||
import com.tencent.supersonic.common.pojo.Constants;
|
||||
import com.tencent.supersonic.common.util.JsonUtil;
|
||||
import com.tencent.supersonic.common.util.jsqlparser.AggregateEnum;
|
||||
import com.tencent.supersonic.common.util.jsqlparser.SqlParserReplaceHelper;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
import java.util.Objects;
|
||||
import java.util.Set;
|
||||
import java.util.stream.Collectors;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.springframework.util.CollectionUtils;
|
||||
|
||||
/**
|
||||
* Perform schema corrections on the Schema information in S2QL.
|
||||
*/
|
||||
@Slf4j
|
||||
public class SchemaCorrector extends BaseSemanticCorrector {
|
||||
|
||||
@Override
|
||||
public void doCorrect(QueryReq queryReq, SemanticParseInfo semanticParseInfo) {
|
||||
|
||||
correctAggFunction(semanticParseInfo);
|
||||
|
||||
replaceAlias(semanticParseInfo);
|
||||
|
||||
updateFieldNameByLinkingValue(semanticParseInfo);
|
||||
|
||||
updateFieldValueByLinkingValue(semanticParseInfo);
|
||||
|
||||
correctFieldName(semanticParseInfo);
|
||||
}
|
||||
|
||||
private void correctAggFunction(SemanticParseInfo semanticParseInfo) {
|
||||
Map<String, String> aggregateEnum = AggregateEnum.getAggregateEnum();
|
||||
SqlInfo sqlInfo = semanticParseInfo.getSqlInfo();
|
||||
String sql = SqlParserReplaceHelper.replaceFunction(sqlInfo.getCorrectS2SQL(), aggregateEnum);
|
||||
sqlInfo.setCorrectS2SQL(sql);
|
||||
}
|
||||
|
||||
private void replaceAlias(SemanticParseInfo semanticParseInfo) {
|
||||
SqlInfo sqlInfo = semanticParseInfo.getSqlInfo();
|
||||
String replaceAlias = SqlParserReplaceHelper.replaceAlias(sqlInfo.getCorrectS2SQL());
|
||||
sqlInfo.setCorrectS2SQL(replaceAlias);
|
||||
}
|
||||
|
||||
private void correctFieldName(SemanticParseInfo semanticParseInfo) {
|
||||
Map<String, String> fieldNameMap = getFieldNameMap(semanticParseInfo.getModel().getModelIds());
|
||||
SqlInfo sqlInfo = semanticParseInfo.getSqlInfo();
|
||||
String sql = SqlParserReplaceHelper.replaceFields(sqlInfo.getCorrectS2SQL(), fieldNameMap);
|
||||
sqlInfo.setCorrectS2SQL(sql);
|
||||
}
|
||||
|
||||
private void updateFieldNameByLinkingValue(SemanticParseInfo semanticParseInfo) {
|
||||
List<ElementValue> linking = getLinkingValues(semanticParseInfo);
|
||||
if (CollectionUtils.isEmpty(linking)) {
|
||||
return;
|
||||
}
|
||||
|
||||
Map<String, Set<String>> fieldValueToFieldNames = linking.stream().collect(
|
||||
Collectors.groupingBy(ElementValue::getFieldValue,
|
||||
Collectors.mapping(ElementValue::getFieldName, Collectors.toSet())));
|
||||
|
||||
SqlInfo sqlInfo = semanticParseInfo.getSqlInfo();
|
||||
|
||||
String sql = SqlParserReplaceHelper.replaceFieldNameByValue(sqlInfo.getCorrectS2SQL(), fieldValueToFieldNames);
|
||||
sqlInfo.setCorrectS2SQL(sql);
|
||||
}
|
||||
|
||||
private List<ElementValue> getLinkingValues(SemanticParseInfo semanticParseInfo) {
|
||||
Object context = semanticParseInfo.getProperties().get(Constants.CONTEXT);
|
||||
if (Objects.isNull(context)) {
|
||||
return null;
|
||||
}
|
||||
|
||||
ParseResult parseResult = JsonUtil.toObject(JsonUtil.toString(context), ParseResult.class);
|
||||
if (Objects.isNull(parseResult) || Objects.isNull(parseResult.getLlmReq())) {
|
||||
return null;
|
||||
}
|
||||
return parseResult.getLinkingValues();
|
||||
}
|
||||
|
||||
|
||||
private void updateFieldValueByLinkingValue(SemanticParseInfo semanticParseInfo) {
|
||||
List<ElementValue> linking = getLinkingValues(semanticParseInfo);
|
||||
if (CollectionUtils.isEmpty(linking)) {
|
||||
return;
|
||||
}
|
||||
|
||||
Map<String, Map<String, String>> filedNameToValueMap = linking.stream().collect(
|
||||
Collectors.groupingBy(ElementValue::getFieldName,
|
||||
Collectors.mapping(ElementValue::getFieldValue, Collectors.toMap(
|
||||
oldValue -> oldValue,
|
||||
newValue -> newValue,
|
||||
(existingValue, newValue) -> newValue)
|
||||
)));
|
||||
|
||||
SqlInfo sqlInfo = semanticParseInfo.getSqlInfo();
|
||||
String sql = SqlParserReplaceHelper.replaceValue(sqlInfo.getCorrectS2SQL(), filedNameToValueMap, false);
|
||||
sqlInfo.setCorrectS2SQL(sql);
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,29 @@
|
||||
package com.tencent.supersonic.chat.corrector;
|
||||
|
||||
import com.tencent.supersonic.chat.api.pojo.SemanticParseInfo;
|
||||
import com.tencent.supersonic.chat.api.pojo.request.QueryReq;
|
||||
import com.tencent.supersonic.common.util.jsqlparser.SqlParserSelectHelper;
|
||||
import java.util.List;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.springframework.util.CollectionUtils;
|
||||
|
||||
/**
|
||||
* Perform SQL corrections on the "Select" section in S2SQL.
|
||||
*/
|
||||
@Slf4j
|
||||
public class SelectCorrector extends BaseSemanticCorrector {
|
||||
|
||||
@Override
|
||||
public void doCorrect(QueryReq queryReq, SemanticParseInfo semanticParseInfo) {
|
||||
String correctS2SQL = semanticParseInfo.getSqlInfo().getCorrectS2SQL();
|
||||
List<String> aggregateFields = SqlParserSelectHelper.getAggregateFields(correctS2SQL);
|
||||
List<String> selectFields = SqlParserSelectHelper.getSelectFields(correctS2SQL);
|
||||
// If the number of aggregated fields is equal to the number of queried fields, do not add fields to select.
|
||||
if (!CollectionUtils.isEmpty(aggregateFields)
|
||||
&& !CollectionUtils.isEmpty(selectFields)
|
||||
&& aggregateFields.size() == selectFields.size()) {
|
||||
return;
|
||||
}
|
||||
addFieldsToSelect(semanticParseInfo, correctS2SQL);
|
||||
}
|
||||
}
|
||||
@@ -1,38 +0,0 @@
|
||||
package com.tencent.supersonic.chat.corrector;
|
||||
|
||||
import com.tencent.supersonic.chat.api.pojo.SemanticCorrectInfo;
|
||||
import com.tencent.supersonic.common.util.jsqlparser.SqlParserSelectHelper;
|
||||
import com.tencent.supersonic.common.util.jsqlparser.SqlParserUpdateHelper;
|
||||
import com.tencent.supersonic.semantic.api.model.enums.TimeDimensionEnum;
|
||||
import java.util.ArrayList;
|
||||
import java.util.HashSet;
|
||||
import java.util.Set;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.springframework.util.CollectionUtils;
|
||||
|
||||
@Slf4j
|
||||
public class SelectFieldAppendCorrector extends BaseSemanticCorrector {
|
||||
|
||||
@Override
|
||||
public void correct(SemanticCorrectInfo semanticCorrectInfo) {
|
||||
String preSql = semanticCorrectInfo.getSql();
|
||||
if (SqlParserSelectHelper.hasAggregateFunction(preSql)) {
|
||||
return;
|
||||
}
|
||||
Set<String> selectFields = new HashSet<>(SqlParserSelectHelper.getSelectFields(preSql));
|
||||
Set<String> whereFields = new HashSet<>(SqlParserSelectHelper.getWhereFields(preSql));
|
||||
|
||||
if (CollectionUtils.isEmpty(selectFields) || CollectionUtils.isEmpty(whereFields)) {
|
||||
return;
|
||||
}
|
||||
|
||||
whereFields.addAll(SqlParserSelectHelper.getOrderByFields(preSql));
|
||||
whereFields.removeAll(selectFields);
|
||||
whereFields.remove(TimeDimensionEnum.DAY.getName());
|
||||
whereFields.remove(TimeDimensionEnum.WEEK.getName());
|
||||
whereFields.remove(TimeDimensionEnum.MONTH.getName());
|
||||
String replaceFields = SqlParserUpdateHelper.addFieldsToSelect(preSql, new ArrayList<>(whereFields));
|
||||
semanticCorrectInfo.setPreSql(preSql);
|
||||
semanticCorrectInfo.setSql(replaceFields);
|
||||
}
|
||||
}
|
||||
@@ -1,21 +0,0 @@
|
||||
package com.tencent.supersonic.chat.corrector;
|
||||
|
||||
import com.tencent.supersonic.chat.api.pojo.SemanticCorrectInfo;
|
||||
import com.tencent.supersonic.common.util.jsqlparser.SqlParserUpdateHelper;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
|
||||
@Slf4j
|
||||
public class TableNameCorrector extends BaseSemanticCorrector {
|
||||
|
||||
public static final String TABLE_PREFIX = "t_";
|
||||
|
||||
@Override
|
||||
public void correct(SemanticCorrectInfo semanticCorrectInfo) {
|
||||
Long modelId = semanticCorrectInfo.getParseInfo().getModelId();
|
||||
String preSql = semanticCorrectInfo.getSql();
|
||||
semanticCorrectInfo.setPreSql(preSql);
|
||||
String sql = SqlParserUpdateHelper.replaceTable(preSql, TABLE_PREFIX + modelId);
|
||||
semanticCorrectInfo.setSql(sql);
|
||||
}
|
||||
|
||||
}
|
||||
@@ -0,0 +1,156 @@
|
||||
package com.tencent.supersonic.chat.corrector;
|
||||
|
||||
import com.tencent.supersonic.chat.api.pojo.SchemaElement;
|
||||
import com.tencent.supersonic.chat.api.pojo.SchemaValueMap;
|
||||
import com.tencent.supersonic.chat.api.pojo.SemanticParseInfo;
|
||||
import com.tencent.supersonic.chat.api.pojo.SemanticSchema;
|
||||
import com.tencent.supersonic.chat.api.pojo.request.QueryFilters;
|
||||
import com.tencent.supersonic.chat.api.pojo.request.QueryReq;
|
||||
import com.tencent.supersonic.chat.parser.llm.s2sql.S2SQLDateHelper;
|
||||
import com.tencent.supersonic.common.pojo.Constants;
|
||||
import com.tencent.supersonic.common.pojo.enums.TimeDimensionEnum;
|
||||
import com.tencent.supersonic.common.util.ContextUtils;
|
||||
import com.tencent.supersonic.common.util.StringUtil;
|
||||
import com.tencent.supersonic.common.util.jsqlparser.SqlParserAddHelper;
|
||||
import com.tencent.supersonic.common.util.jsqlparser.SqlParserReplaceHelper;
|
||||
import com.tencent.supersonic.common.util.jsqlparser.SqlParserSelectHelper;
|
||||
import com.tencent.supersonic.knowledge.service.SchemaService;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import net.sf.jsqlparser.JSQLParserException;
|
||||
import net.sf.jsqlparser.expression.Expression;
|
||||
import net.sf.jsqlparser.parser.CCJSqlParserUtil;
|
||||
import org.apache.commons.lang3.StringUtils;
|
||||
import org.apache.logging.log4j.util.Strings;
|
||||
import org.springframework.util.CollectionUtils;
|
||||
|
||||
import java.util.HashMap;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
import java.util.Objects;
|
||||
import java.util.Set;
|
||||
import java.util.stream.Collectors;
|
||||
|
||||
/**
|
||||
* Perform SQL corrections on the "Where" section in S2SQL.
|
||||
*/
|
||||
@Slf4j
|
||||
public class WhereCorrector extends BaseSemanticCorrector {
|
||||
|
||||
@Override
|
||||
public void doCorrect(QueryReq queryReq, SemanticParseInfo semanticParseInfo) {
|
||||
|
||||
addDateIfNotExist(semanticParseInfo);
|
||||
|
||||
parserDateDiffFunction(semanticParseInfo);
|
||||
|
||||
addQueryFilter(queryReq, semanticParseInfo);
|
||||
|
||||
updateFieldValueByTechName(semanticParseInfo);
|
||||
}
|
||||
|
||||
private void addQueryFilter(QueryReq queryReq, SemanticParseInfo semanticParseInfo) {
|
||||
String queryFilter = getQueryFilter(queryReq.getQueryFilters());
|
||||
|
||||
String correctS2SQL = semanticParseInfo.getSqlInfo().getCorrectS2SQL();
|
||||
|
||||
if (StringUtils.isNotEmpty(queryFilter)) {
|
||||
log.info("add queryFilter to correctS2SQL :{}", queryFilter);
|
||||
Expression expression = null;
|
||||
try {
|
||||
expression = CCJSqlParserUtil.parseCondExpression(queryFilter);
|
||||
} catch (JSQLParserException e) {
|
||||
log.error("parseCondExpression", e);
|
||||
}
|
||||
correctS2SQL = SqlParserAddHelper.addWhere(correctS2SQL, expression);
|
||||
semanticParseInfo.getSqlInfo().setCorrectS2SQL(correctS2SQL);
|
||||
}
|
||||
}
|
||||
|
||||
private void parserDateDiffFunction(SemanticParseInfo semanticParseInfo) {
|
||||
String correctS2SQL = semanticParseInfo.getSqlInfo().getCorrectS2SQL();
|
||||
correctS2SQL = SqlParserReplaceHelper.replaceFunction(correctS2SQL);
|
||||
semanticParseInfo.getSqlInfo().setCorrectS2SQL(correctS2SQL);
|
||||
}
|
||||
|
||||
private void addDateIfNotExist(SemanticParseInfo semanticParseInfo) {
|
||||
String correctS2SQL = semanticParseInfo.getSqlInfo().getCorrectS2SQL();
|
||||
List<String> whereFields = SqlParserSelectHelper.getWhereFields(correctS2SQL);
|
||||
if (CollectionUtils.isEmpty(whereFields) || !TimeDimensionEnum.containsZhTimeDimension(whereFields)) {
|
||||
String currentDate = S2SQLDateHelper.getReferenceDate(semanticParseInfo.getModelId());
|
||||
if (StringUtils.isNotBlank(currentDate)) {
|
||||
correctS2SQL = SqlParserAddHelper.addParenthesisToWhere(correctS2SQL);
|
||||
correctS2SQL = SqlParserAddHelper.addWhere(
|
||||
correctS2SQL, TimeDimensionEnum.DAY.getChName(), currentDate);
|
||||
}
|
||||
}
|
||||
semanticParseInfo.getSqlInfo().setCorrectS2SQL(correctS2SQL);
|
||||
}
|
||||
|
||||
private String getQueryFilter(QueryFilters queryFilters) {
|
||||
if (Objects.isNull(queryFilters) || CollectionUtils.isEmpty(queryFilters.getFilters())) {
|
||||
return null;
|
||||
}
|
||||
return queryFilters.getFilters().stream()
|
||||
.map(filter -> {
|
||||
String bizNameWrap = StringUtil.getSpaceWrap(filter.getName());
|
||||
String operatorWrap = StringUtil.getSpaceWrap(filter.getOperator().getValue());
|
||||
String valueWrap = StringUtil.getCommaWrap(filter.getValue().toString());
|
||||
return bizNameWrap + operatorWrap + valueWrap;
|
||||
})
|
||||
.collect(Collectors.joining(Constants.AND_UPPER));
|
||||
}
|
||||
|
||||
private void updateFieldValueByTechName(SemanticParseInfo semanticParseInfo) {
|
||||
SemanticSchema semanticSchema = ContextUtils.getBean(SchemaService.class).getSemanticSchema();
|
||||
Set<Long> modelIds = semanticParseInfo.getModel().getModelIds();
|
||||
List<SchemaElement> dimensions = semanticSchema.getDimensions(modelIds);
|
||||
|
||||
if (CollectionUtils.isEmpty(dimensions)) {
|
||||
return;
|
||||
}
|
||||
|
||||
Map<String, Map<String, String>> aliasAndBizNameToTechName = getAliasAndBizNameToTechName(dimensions);
|
||||
String correctS2SQL = SqlParserReplaceHelper.replaceValue(semanticParseInfo.getSqlInfo().getCorrectS2SQL(),
|
||||
aliasAndBizNameToTechName);
|
||||
semanticParseInfo.getSqlInfo().setCorrectS2SQL(correctS2SQL);
|
||||
}
|
||||
|
||||
private Map<String, Map<String, String>> getAliasAndBizNameToTechName(List<SchemaElement> dimensions) {
|
||||
if (CollectionUtils.isEmpty(dimensions)) {
|
||||
return new HashMap<>();
|
||||
}
|
||||
|
||||
Map<String, Map<String, String>> result = new HashMap<>();
|
||||
|
||||
for (SchemaElement dimension : dimensions) {
|
||||
if (Objects.isNull(dimension)
|
||||
|| Strings.isEmpty(dimension.getName())
|
||||
|| CollectionUtils.isEmpty(dimension.getSchemaValueMaps())) {
|
||||
continue;
|
||||
}
|
||||
String name = dimension.getName();
|
||||
|
||||
Map<String, String> aliasAndBizNameToTechName = new HashMap<>();
|
||||
|
||||
for (SchemaValueMap valueMap : dimension.getSchemaValueMaps()) {
|
||||
if (Objects.isNull(valueMap) || Strings.isEmpty(valueMap.getTechName())) {
|
||||
continue;
|
||||
}
|
||||
if (Strings.isNotEmpty(valueMap.getBizName())) {
|
||||
aliasAndBizNameToTechName.put(valueMap.getBizName(), valueMap.getTechName());
|
||||
}
|
||||
if (!CollectionUtils.isEmpty(valueMap.getAlias())) {
|
||||
valueMap.getAlias().stream().forEach(alias -> {
|
||||
if (Strings.isNotEmpty(alias)) {
|
||||
aliasAndBizNameToTechName.put(alias, valueMap.getTechName());
|
||||
}
|
||||
});
|
||||
}
|
||||
}
|
||||
if (!CollectionUtils.isEmpty(aliasAndBizNameToTechName)) {
|
||||
result.put(name, aliasAndBizNameToTechName);
|
||||
}
|
||||
}
|
||||
return result;
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,104 @@
|
||||
package com.tencent.supersonic.chat.mapper;
|
||||
|
||||
import com.tencent.supersonic.chat.api.component.SchemaMapper;
|
||||
import com.tencent.supersonic.chat.api.pojo.ModelSchema;
|
||||
import com.tencent.supersonic.chat.api.pojo.QueryContext;
|
||||
import com.tencent.supersonic.chat.api.pojo.SchemaElement;
|
||||
import com.tencent.supersonic.chat.api.pojo.SchemaElementMatch;
|
||||
import com.tencent.supersonic.chat.api.pojo.SchemaElementType;
|
||||
import com.tencent.supersonic.chat.api.pojo.SchemaMapInfo;
|
||||
import com.tencent.supersonic.chat.service.SemanticService;
|
||||
import com.tencent.supersonic.common.util.ContextUtils;
|
||||
import java.util.ArrayList;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
import java.util.Objects;
|
||||
import java.util.concurrent.atomic.AtomicBoolean;
|
||||
import java.util.stream.Collectors;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.apache.commons.lang3.StringUtils;
|
||||
import org.springframework.beans.BeanUtils;
|
||||
|
||||
/**
|
||||
* base Mapper
|
||||
*/
|
||||
@Slf4j
|
||||
public abstract class BaseMapper implements SchemaMapper {
|
||||
|
||||
@Override
|
||||
public void map(QueryContext queryContext) {
|
||||
|
||||
String simpleName = this.getClass().getSimpleName();
|
||||
long startTime = System.currentTimeMillis();
|
||||
log.info("before {},mapInfo:{}", simpleName, queryContext.getMapInfo().getModelElementMatches());
|
||||
|
||||
try {
|
||||
doMap(queryContext);
|
||||
} catch (Exception e) {
|
||||
log.error("work error", e);
|
||||
}
|
||||
|
||||
long cost = System.currentTimeMillis() - startTime;
|
||||
log.info("after {},cost:{},mapInfo:{}", simpleName, cost, queryContext.getMapInfo().getModelElementMatches());
|
||||
}
|
||||
|
||||
public abstract void doMap(QueryContext queryContext);
|
||||
|
||||
|
||||
public void addToSchemaMap(SchemaMapInfo schemaMap, Long modelId, SchemaElementMatch newElementMatch) {
|
||||
Map<Long, List<SchemaElementMatch>> modelElementMatches = schemaMap.getModelElementMatches();
|
||||
List<SchemaElementMatch> schemaElementMatches = modelElementMatches.putIfAbsent(modelId, new ArrayList<>());
|
||||
if (schemaElementMatches == null) {
|
||||
schemaElementMatches = modelElementMatches.get(modelId);
|
||||
}
|
||||
//remove duplication
|
||||
AtomicBoolean needAddNew = new AtomicBoolean(true);
|
||||
schemaElementMatches.removeIf(
|
||||
existElementMatch -> {
|
||||
SchemaElement existElement = existElementMatch.getElement();
|
||||
SchemaElement newElement = newElementMatch.getElement();
|
||||
if (existElement.equals(newElement)) {
|
||||
if (newElementMatch.getSimilarity() > existElementMatch.getSimilarity()) {
|
||||
return true;
|
||||
} else {
|
||||
needAddNew.set(false);
|
||||
}
|
||||
}
|
||||
return false;
|
||||
}
|
||||
);
|
||||
if (needAddNew.get()) {
|
||||
schemaElementMatches.add(newElementMatch);
|
||||
}
|
||||
}
|
||||
|
||||
public SchemaElement getSchemaElement(Long modelId, SchemaElementType elementType, Long elementID) {
|
||||
SchemaElement element = new SchemaElement();
|
||||
SemanticService schemaService = ContextUtils.getBean(SemanticService.class);
|
||||
ModelSchema modelSchema = schemaService.getModelSchema(modelId);
|
||||
if (Objects.isNull(modelSchema)) {
|
||||
return null;
|
||||
}
|
||||
SchemaElement elementDb = modelSchema.getElement(elementType, elementID);
|
||||
if (Objects.isNull(elementDb)) {
|
||||
log.info("element is null, elementType:{},elementID:{}", elementType, elementID);
|
||||
return null;
|
||||
}
|
||||
BeanUtils.copyProperties(elementDb, element);
|
||||
element.setAlias(getAlias(elementDb));
|
||||
return element;
|
||||
}
|
||||
|
||||
public List<String> getAlias(SchemaElement element) {
|
||||
if (!SchemaElementType.VALUE.equals(element.getType())) {
|
||||
return element.getAlias();
|
||||
}
|
||||
if (org.apache.commons.collections.CollectionUtils.isNotEmpty(element.getAlias()) && StringUtils.isNotEmpty(
|
||||
element.getName())) {
|
||||
return element.getAlias().stream()
|
||||
.filter(aliasItem -> aliasItem.contains(element.getName()))
|
||||
.collect(Collectors.toList());
|
||||
}
|
||||
return element.getAlias();
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,158 @@
|
||||
package com.tencent.supersonic.chat.mapper;
|
||||
|
||||
import com.hankcs.hanlp.seg.common.Term;
|
||||
import com.tencent.supersonic.chat.api.pojo.QueryContext;
|
||||
import com.tencent.supersonic.knowledge.utils.NatureHelper;
|
||||
import java.util.ArrayList;
|
||||
import java.util.Comparator;
|
||||
import java.util.HashMap;
|
||||
import java.util.HashSet;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
import java.util.Objects;
|
||||
import java.util.Optional;
|
||||
import java.util.Set;
|
||||
import java.util.stream.Collectors;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.apache.commons.collections4.CollectionUtils;
|
||||
import org.apache.commons.lang3.StringUtils;
|
||||
import org.springframework.beans.factory.annotation.Autowired;
|
||||
import org.springframework.stereotype.Service;
|
||||
|
||||
/**
|
||||
* Base Match Strategy
|
||||
*/
|
||||
@Service
|
||||
@Slf4j
|
||||
public abstract class BaseMatchStrategy<T> implements MatchStrategy<T> {
|
||||
|
||||
@Autowired
|
||||
private MapperHelper mapperHelper;
|
||||
|
||||
@Override
|
||||
public Map<MatchText, List<T>> match(QueryContext queryContext, List<Term> terms, Set<Long> detectModelIds) {
|
||||
String text = queryContext.getRequest().getQueryText();
|
||||
if (Objects.isNull(terms) || StringUtils.isEmpty(text)) {
|
||||
return null;
|
||||
}
|
||||
|
||||
log.debug("terms:{},,detectModelIds:{}", terms, detectModelIds);
|
||||
|
||||
List<T> detects = detect(queryContext, terms, detectModelIds);
|
||||
Map<MatchText, List<T>> result = new HashMap<>();
|
||||
|
||||
result.put(MatchText.builder().regText(text).detectSegment(text).build(), detects);
|
||||
return result;
|
||||
}
|
||||
|
||||
public List<T> detect(QueryContext queryContext, List<Term> terms, Set<Long> detectModelIds) {
|
||||
Map<Integer, Integer> regOffsetToLength = getRegOffsetToLength(terms);
|
||||
String text = queryContext.getRequest().getQueryText();
|
||||
Set<T> results = new HashSet<>();
|
||||
|
||||
Set<String> detectSegments = new HashSet<>();
|
||||
|
||||
for (Integer startIndex = 0; startIndex <= text.length() - 1; ) {
|
||||
|
||||
for (Integer index = startIndex; index <= text.length(); ) {
|
||||
int offset = mapperHelper.getStepOffset(terms, startIndex);
|
||||
index = mapperHelper.getStepIndex(regOffsetToLength, index);
|
||||
if (index <= text.length()) {
|
||||
String detectSegment = text.substring(startIndex, index);
|
||||
detectSegments.add(detectSegment);
|
||||
detectByStep(queryContext, results, detectModelIds, startIndex, index, offset);
|
||||
}
|
||||
}
|
||||
startIndex = mapperHelper.getStepIndex(regOffsetToLength, startIndex);
|
||||
}
|
||||
detectByBatch(queryContext, results, detectModelIds, detectSegments);
|
||||
return new ArrayList<>(results);
|
||||
}
|
||||
|
||||
protected void detectByBatch(QueryContext queryContext, Set<T> results, Set<Long> detectModelIds,
|
||||
Set<String> detectSegments) {
|
||||
return;
|
||||
}
|
||||
|
||||
public Map<Integer, Integer> getRegOffsetToLength(List<Term> terms) {
|
||||
return terms.stream().sorted(Comparator.comparing(Term::length))
|
||||
.collect(Collectors.toMap(Term::getOffset, term -> term.word.length(), (value1, value2) -> value2));
|
||||
}
|
||||
|
||||
public void selectResultInOneRound(Set<T> existResults, List<T> oneRoundResults) {
|
||||
if (CollectionUtils.isEmpty(oneRoundResults)) {
|
||||
return;
|
||||
}
|
||||
for (T oneRoundResult : oneRoundResults) {
|
||||
if (existResults.contains(oneRoundResult)) {
|
||||
boolean isDeleted = existResults.removeIf(
|
||||
existResult -> {
|
||||
boolean delete = needDelete(oneRoundResult, existResult);
|
||||
if (delete) {
|
||||
log.info("deleted existResult:{}", existResult);
|
||||
}
|
||||
return delete;
|
||||
}
|
||||
);
|
||||
if (isDeleted) {
|
||||
log.info("deleted, add oneRoundResult:{}", oneRoundResult);
|
||||
existResults.add(oneRoundResult);
|
||||
}
|
||||
} else {
|
||||
existResults.add(oneRoundResult);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
public List<T> getMatches(QueryContext queryContext, List<Term> terms) {
|
||||
Set<Long> detectModelIds = mapperHelper.getModelIds(queryContext.getRequest());
|
||||
terms = filterByModelIds(terms, detectModelIds);
|
||||
Map<MatchText, List<T>> matchResult = match(queryContext, terms, detectModelIds);
|
||||
List<T> matches = new ArrayList<>();
|
||||
if (Objects.isNull(matchResult)) {
|
||||
return matches;
|
||||
}
|
||||
Optional<List<T>> first = matchResult.entrySet().stream()
|
||||
.filter(entry -> CollectionUtils.isNotEmpty(entry.getValue()))
|
||||
.map(entry -> entry.getValue()).findFirst();
|
||||
|
||||
if (first.isPresent()) {
|
||||
matches = first.get();
|
||||
}
|
||||
return matches;
|
||||
}
|
||||
|
||||
public List<Term> filterByModelIds(List<Term> terms, Set<Long> detectModelIds) {
|
||||
logTerms(terms);
|
||||
if (CollectionUtils.isNotEmpty(detectModelIds)) {
|
||||
terms = terms.stream().filter(term -> {
|
||||
Long modelId = NatureHelper.getModelId(term.getNature().toString());
|
||||
if (Objects.nonNull(modelId)) {
|
||||
return detectModelIds.contains(modelId);
|
||||
}
|
||||
return false;
|
||||
}).collect(Collectors.toList());
|
||||
log.info("terms filter by modelIds:{}", detectModelIds);
|
||||
logTerms(terms);
|
||||
}
|
||||
return terms;
|
||||
}
|
||||
|
||||
public void logTerms(List<Term> terms) {
|
||||
if (CollectionUtils.isEmpty(terms)) {
|
||||
return;
|
||||
}
|
||||
for (Term term : terms) {
|
||||
log.debug("word:{},nature:{},frequency:{}", term.word, term.nature.toString(), term.getFrequency());
|
||||
}
|
||||
}
|
||||
|
||||
public abstract boolean needDelete(T oneRoundResult, T existResult);
|
||||
|
||||
public abstract String getMapKey(T a);
|
||||
|
||||
public abstract void detectByStep(QueryContext queryContext, Set<T> results,
|
||||
Set<Long> detectModelIds, Integer startIndex, Integer index, int offset);
|
||||
|
||||
|
||||
}
|
||||
@@ -0,0 +1,62 @@
|
||||
package com.tencent.supersonic.chat.mapper;
|
||||
|
||||
import com.alibaba.fastjson.JSONObject;
|
||||
import com.hankcs.hanlp.seg.common.Term;
|
||||
import com.tencent.supersonic.chat.api.pojo.QueryContext;
|
||||
import com.tencent.supersonic.chat.api.pojo.SchemaElement;
|
||||
import com.tencent.supersonic.chat.api.pojo.SchemaElementMatch;
|
||||
import com.tencent.supersonic.common.util.ContextUtils;
|
||||
import com.tencent.supersonic.common.util.embedding.Retrieval;
|
||||
import com.tencent.supersonic.knowledge.dictionary.EmbeddingResult;
|
||||
import com.tencent.supersonic.knowledge.dictionary.builder.BaseWordBuilder;
|
||||
import com.tencent.supersonic.knowledge.utils.HanlpHelper;
|
||||
import java.util.List;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.apache.commons.lang3.StringUtils;
|
||||
|
||||
/***
|
||||
* A mapper that is capable of semantic understanding of text.
|
||||
*/
|
||||
@Slf4j
|
||||
public class EmbeddingMapper extends BaseMapper {
|
||||
|
||||
@Override
|
||||
public void doMap(QueryContext queryContext) {
|
||||
//1. query from embedding by queryText
|
||||
|
||||
String queryText = queryContext.getRequest().getQueryText();
|
||||
List<Term> terms = HanlpHelper.getTerms(queryText);
|
||||
|
||||
EmbeddingMatchStrategy matchStrategy = ContextUtils.getBean(EmbeddingMatchStrategy.class);
|
||||
List<EmbeddingResult> matchResults = matchStrategy.getMatches(queryContext, terms);
|
||||
|
||||
HanlpHelper.transLetterOriginal(matchResults);
|
||||
|
||||
//2. build SchemaElementMatch by info
|
||||
for (EmbeddingResult matchResult : matchResults) {
|
||||
Long elementId = Retrieval.getLongId(matchResult.getId());
|
||||
|
||||
SchemaElement schemaElement = JSONObject.parseObject(JSONObject.toJSONString(matchResult.getMetadata()),
|
||||
SchemaElement.class);
|
||||
|
||||
if (StringUtils.isBlank(matchResult.getMetadata().get("modelId"))) {
|
||||
continue;
|
||||
}
|
||||
long modelId = Long.parseLong(matchResult.getMetadata().get("modelId"));
|
||||
|
||||
schemaElement = getSchemaElement(modelId, schemaElement.getType(), elementId);
|
||||
if (schemaElement == null) {
|
||||
continue;
|
||||
}
|
||||
SchemaElementMatch schemaElementMatch = SchemaElementMatch.builder()
|
||||
.element(schemaElement)
|
||||
.frequency(BaseWordBuilder.DEFAULT_FREQUENCY)
|
||||
.word(matchResult.getName())
|
||||
.similarity(1 - matchResult.getDistance())
|
||||
.detectWord(matchResult.getDetectWord())
|
||||
.build();
|
||||
//3. add to mapInfo
|
||||
addToSchemaMap(queryContext.getMapInfo(), modelId, schemaElementMatch);
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,136 @@
|
||||
package com.tencent.supersonic.chat.mapper;
|
||||
|
||||
import com.google.common.collect.Lists;
|
||||
import com.tencent.supersonic.chat.api.pojo.QueryContext;
|
||||
import com.tencent.supersonic.chat.config.OptimizationConfig;
|
||||
import com.tencent.supersonic.common.pojo.Constants;
|
||||
import com.tencent.supersonic.common.util.embedding.EmbeddingUtils;
|
||||
import com.tencent.supersonic.common.util.embedding.Retrieval;
|
||||
import com.tencent.supersonic.common.util.embedding.RetrieveQuery;
|
||||
import com.tencent.supersonic.common.util.embedding.RetrieveQueryResult;
|
||||
import com.tencent.supersonic.knowledge.dictionary.EmbeddingResult;
|
||||
import com.tencent.supersonic.semantic.model.domain.listener.MetaEmbeddingListener;
|
||||
import java.util.Comparator;
|
||||
import java.util.HashMap;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
import java.util.Set;
|
||||
import java.util.stream.Collectors;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.apache.commons.collections.CollectionUtils;
|
||||
import org.apache.commons.lang3.StringUtils;
|
||||
import org.springframework.beans.BeanUtils;
|
||||
import org.springframework.beans.factory.annotation.Autowired;
|
||||
import org.springframework.stereotype.Service;
|
||||
|
||||
/**
|
||||
* match strategy implement
|
||||
*/
|
||||
@Service
|
||||
@Slf4j
|
||||
public class EmbeddingMatchStrategy extends BaseMatchStrategy<EmbeddingResult> {
|
||||
|
||||
@Autowired
|
||||
private OptimizationConfig optimizationConfig;
|
||||
@Autowired
|
||||
private EmbeddingUtils embeddingUtils;
|
||||
|
||||
@Override
|
||||
public boolean needDelete(EmbeddingResult oneRoundResult, EmbeddingResult existResult) {
|
||||
return getMapKey(oneRoundResult).equals(getMapKey(existResult))
|
||||
&& existResult.getDistance() > oneRoundResult.getDistance();
|
||||
}
|
||||
|
||||
@Override
|
||||
public String getMapKey(EmbeddingResult a) {
|
||||
return a.getName() + Constants.UNDERLINE + a.getId();
|
||||
}
|
||||
|
||||
|
||||
@Override
|
||||
protected void detectByBatch(QueryContext queryContext, Set<EmbeddingResult> results, Set<Long> detectModelIds,
|
||||
Set<String> detectSegments) {
|
||||
|
||||
List<String> queryTextsList = detectSegments.stream()
|
||||
.map(detectSegment -> detectSegment.trim())
|
||||
.filter(detectSegment -> StringUtils.isNotBlank(detectSegment)
|
||||
&& detectSegment.length() >= optimizationConfig.getEmbeddingMapperWordMin()
|
||||
&& detectSegment.length() <= optimizationConfig.getEmbeddingMapperWordMax())
|
||||
.collect(Collectors.toList());
|
||||
|
||||
List<List<String>> queryTextsSubList = Lists.partition(queryTextsList,
|
||||
optimizationConfig.getEmbeddingMapperBatch());
|
||||
|
||||
for (List<String> queryTextsSub : queryTextsSubList) {
|
||||
detectByQueryTextsSub(results, detectModelIds, queryTextsSub);
|
||||
}
|
||||
}
|
||||
|
||||
private void detectByQueryTextsSub(Set<EmbeddingResult> results, Set<Long> detectModelIds,
|
||||
List<String> queryTextsSub) {
|
||||
int embeddingNumber = optimizationConfig.getEmbeddingMapperNumber();
|
||||
Double distance = optimizationConfig.getEmbeddingMapperDistanceThreshold();
|
||||
Map<String, String> filterCondition = null;
|
||||
// step1. build query params
|
||||
// if only one modelId, add to filterCondition
|
||||
if (CollectionUtils.isNotEmpty(detectModelIds) && detectModelIds.size() == 1) {
|
||||
filterCondition = new HashMap<>();
|
||||
filterCondition.put("modelId", detectModelIds.stream().findFirst().get().toString());
|
||||
}
|
||||
|
||||
RetrieveQuery retrieveQuery = RetrieveQuery.builder()
|
||||
.queryTextsList(queryTextsSub)
|
||||
.filterCondition(filterCondition)
|
||||
.queryEmbeddings(null)
|
||||
.build();
|
||||
// step2. retrieveQuery by detectSegment
|
||||
List<RetrieveQueryResult> retrieveQueryResults = embeddingUtils.retrieveQuery(
|
||||
MetaEmbeddingListener.COLLECTION_NAME, retrieveQuery, embeddingNumber);
|
||||
|
||||
if (CollectionUtils.isEmpty(retrieveQueryResults)) {
|
||||
return;
|
||||
}
|
||||
// step3. build EmbeddingResults. filter by modelId
|
||||
List<EmbeddingResult> collect = retrieveQueryResults.stream()
|
||||
.map(retrieveQueryResult -> {
|
||||
List<Retrieval> retrievals = retrieveQueryResult.getRetrieval();
|
||||
if (CollectionUtils.isNotEmpty(retrievals)) {
|
||||
retrievals.removeIf(retrieval -> retrieval.getDistance() > distance.doubleValue());
|
||||
if (CollectionUtils.isNotEmpty(detectModelIds)) {
|
||||
retrievals.removeIf(retrieval -> {
|
||||
String modelIdStr = retrieval.getMetadata().get("modelId");
|
||||
if (StringUtils.isBlank(modelIdStr)) {
|
||||
return true;
|
||||
}
|
||||
return detectModelIds.contains(Long.parseLong(modelIdStr));
|
||||
});
|
||||
}
|
||||
}
|
||||
return retrieveQueryResult;
|
||||
})
|
||||
.filter(retrieveQueryResult -> CollectionUtils.isNotEmpty(retrieveQueryResult.getRetrieval()))
|
||||
.flatMap(retrieveQueryResult -> retrieveQueryResult.getRetrieval().stream()
|
||||
.map(retrieval -> {
|
||||
EmbeddingResult embeddingResult = new EmbeddingResult();
|
||||
BeanUtils.copyProperties(retrieval, embeddingResult);
|
||||
embeddingResult.setDetectWord(retrieveQueryResult.getQuery());
|
||||
embeddingResult.setName(retrieval.getQuery());
|
||||
return embeddingResult;
|
||||
}))
|
||||
.collect(Collectors.toList());
|
||||
|
||||
// step4. select mapResul in one round
|
||||
int roundNumber = optimizationConfig.getEmbeddingMapperRoundNumber() * queryTextsSub.size();
|
||||
List<EmbeddingResult> oneRoundResults = collect.stream()
|
||||
.sorted(Comparator.comparingDouble(EmbeddingResult::getDistance))
|
||||
.limit(roundNumber)
|
||||
.collect(Collectors.toList());
|
||||
selectResultInOneRound(results, oneRoundResults);
|
||||
}
|
||||
|
||||
@Override
|
||||
public void detectByStep(QueryContext queryContext, Set<EmbeddingResult> existResults, Set<Long> detectModelIds,
|
||||
Integer startIndex, Integer index, int offset) {
|
||||
return;
|
||||
}
|
||||
}
|
||||
@@ -1,27 +1,29 @@
|
||||
package com.tencent.supersonic.chat.mapper;
|
||||
|
||||
|
||||
import com.tencent.supersonic.chat.api.component.SchemaMapper;
|
||||
import com.tencent.supersonic.chat.api.pojo.QueryContext;
|
||||
import com.tencent.supersonic.chat.api.pojo.SchemaMapInfo;
|
||||
import com.tencent.supersonic.chat.api.pojo.SchemaElementMatch;
|
||||
import com.tencent.supersonic.chat.api.pojo.SchemaElement;
|
||||
import com.tencent.supersonic.chat.api.pojo.SchemaElementType;
|
||||
import com.tencent.supersonic.chat.api.pojo.ModelSchema;
|
||||
import com.tencent.supersonic.chat.api.pojo.QueryContext;
|
||||
import com.tencent.supersonic.chat.api.pojo.SchemaElement;
|
||||
import com.tencent.supersonic.chat.api.pojo.SchemaElementMatch;
|
||||
import com.tencent.supersonic.chat.api.pojo.SchemaElementType;
|
||||
import com.tencent.supersonic.chat.api.pojo.SchemaMapInfo;
|
||||
import com.tencent.supersonic.chat.service.SemanticService;
|
||||
import com.tencent.supersonic.common.util.ContextUtils;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.springframework.beans.BeanUtils;
|
||||
import org.springframework.util.CollectionUtils;
|
||||
|
||||
import java.util.List;
|
||||
import java.util.stream.Collectors;
|
||||
|
||||
|
||||
/**
|
||||
* A mapper capable of converting the VALUE of entity dimension values into ID types.
|
||||
*/
|
||||
@Slf4j
|
||||
public class EntityMapper implements SchemaMapper {
|
||||
public class EntityMapper extends BaseMapper {
|
||||
|
||||
@Override
|
||||
public void map(QueryContext queryContext) {
|
||||
public void doMap(QueryContext queryContext) {
|
||||
SchemaMapInfo schemaMapInfo = queryContext.getMapInfo();
|
||||
for (Long modelId : schemaMapInfo.getMatchedModels()) {
|
||||
List<SchemaElementMatch> schemaElementMatchList = schemaMapInfo.getMatchedElements(modelId);
|
||||
@@ -32,8 +34,9 @@ public class EntityMapper implements SchemaMapper {
|
||||
if (entity == null || entity.getId() == null) {
|
||||
continue;
|
||||
}
|
||||
List<SchemaElementMatch> valueSchemaElements = schemaElementMatchList.stream().filter(schemaElementMatch ->
|
||||
SchemaElementType.VALUE.equals(schemaElementMatch.getElement().getType()))
|
||||
List<SchemaElementMatch> valueSchemaElements = schemaElementMatchList.stream()
|
||||
.filter(schemaElementMatch ->
|
||||
SchemaElementType.VALUE.equals(schemaElementMatch.getElement().getType()))
|
||||
.collect(Collectors.toList());
|
||||
for (SchemaElementMatch schemaElementMatch : valueSchemaElements) {
|
||||
if (!entity.getId().equals(schemaElementMatch.getElement().getId())) {
|
||||
@@ -51,7 +54,7 @@ public class EntityMapper implements SchemaMapper {
|
||||
}
|
||||
|
||||
private boolean checkExistSameEntitySchemaElements(SchemaElementMatch valueSchemaElementMatch,
|
||||
List<SchemaElementMatch> schemaElementMatchList) {
|
||||
List<SchemaElementMatch> schemaElementMatchList) {
|
||||
List<SchemaElementMatch> entitySchemaElements = schemaElementMatchList.stream().filter(schemaElementMatch ->
|
||||
SchemaElementType.ENTITY.equals(schemaElementMatch.getElement().getType()))
|
||||
.collect(Collectors.toList());
|
||||
|
||||
@@ -1,179 +1,67 @@
|
||||
package com.tencent.supersonic.chat.mapper;
|
||||
|
||||
import com.hankcs.hanlp.seg.common.Term;
|
||||
import com.tencent.supersonic.chat.api.component.SchemaMapper;
|
||||
import com.tencent.supersonic.chat.api.pojo.SchemaMapInfo;
|
||||
import com.tencent.supersonic.chat.api.pojo.SchemaElement;
|
||||
import com.tencent.supersonic.chat.api.pojo.SchemaElementType;
|
||||
import com.tencent.supersonic.chat.api.pojo.SchemaElementMatch;
|
||||
import com.tencent.supersonic.chat.api.pojo.SemanticSchema;
|
||||
import com.tencent.supersonic.chat.api.pojo.QueryContext;
|
||||
import com.tencent.supersonic.chat.config.OptimizationConfig;
|
||||
import com.tencent.supersonic.chat.api.pojo.SchemaElement;
|
||||
import com.tencent.supersonic.chat.api.pojo.SchemaElementMatch;
|
||||
import com.tencent.supersonic.chat.api.pojo.SchemaElementType;
|
||||
import com.tencent.supersonic.chat.api.pojo.SchemaMapInfo;
|
||||
import com.tencent.supersonic.common.util.ContextUtils;
|
||||
import com.tencent.supersonic.knowledge.service.SchemaService;
|
||||
import com.tencent.supersonic.knowledge.dictionary.FuzzyResult;
|
||||
import com.tencent.supersonic.knowledge.utils.HanlpHelper;
|
||||
import java.util.ArrayList;
|
||||
import java.util.Comparator;
|
||||
import java.util.HashMap;
|
||||
import java.util.HashSet;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
import java.util.Map.Entry;
|
||||
import java.util.Objects;
|
||||
import java.util.Set;
|
||||
import java.util.stream.Collectors;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.springframework.util.CollectionUtils;
|
||||
|
||||
/***
|
||||
* A mapper capable of fuzzy parsing of metric names and dimension names.
|
||||
*/
|
||||
@Slf4j
|
||||
public class FuzzyNameMapper implements SchemaMapper {
|
||||
public class FuzzyNameMapper extends BaseMapper {
|
||||
|
||||
@Override
|
||||
public void map(QueryContext queryContext) {
|
||||
|
||||
log.debug("before db mapper,mapInfo:{}", queryContext.getMapInfo());
|
||||
public void doMap(QueryContext queryContext) {
|
||||
|
||||
List<Term> terms = HanlpHelper.getTerms(queryContext.getRequest().getQueryText());
|
||||
|
||||
SemanticSchema semanticSchema = ContextUtils.getBean(SchemaService.class).getSemanticSchema();
|
||||
|
||||
detectAndAddToSchema(queryContext, terms, semanticSchema.getDimensions(), SchemaElementType.DIMENSION);
|
||||
|
||||
detectAndAddToSchema(queryContext, terms, semanticSchema.getMetrics(), SchemaElementType.METRIC);
|
||||
|
||||
log.debug("after db mapper,mapInfo:{}", queryContext.getMapInfo());
|
||||
}
|
||||
|
||||
private void detectAndAddToSchema(QueryContext queryContext, List<Term> terms, List<SchemaElement> models,
|
||||
SchemaElementType schemaElementType) {
|
||||
try {
|
||||
|
||||
Map<String, Set<SchemaElement>> modelResultSet = getResultSet(queryContext, terms, models);
|
||||
|
||||
addToSchemaMapInfo(modelResultSet, queryContext.getMapInfo(), schemaElementType);
|
||||
|
||||
} catch (Exception e) {
|
||||
log.error("detectAndAddToSchema error", e);
|
||||
}
|
||||
}
|
||||
|
||||
private Map<String, Set<SchemaElement>> getResultSet(QueryContext queryContext, List<Term> terms,
|
||||
List<SchemaElement> models) {
|
||||
|
||||
String queryText = queryContext.getRequest().getQueryText();
|
||||
FuzzyNameMatchStrategy fuzzyNameMatchStrategy = ContextUtils.getBean(FuzzyNameMatchStrategy.class);
|
||||
|
||||
MapperHelper mapperHelper = ContextUtils.getBean(MapperHelper.class);
|
||||
Set<Long> modelIds = mapperHelper.getModelIds(queryContext.getRequest());
|
||||
|
||||
Double metricDimensionThresholdConfig = getThreshold(queryContext, mapperHelper);
|
||||
List<FuzzyResult> matches = fuzzyNameMatchStrategy.getMatches(queryContext, terms);
|
||||
|
||||
Map<String, Set<SchemaElement>> nameToItems = getNameToItems(models);
|
||||
|
||||
Map<Integer, Integer> regOffsetToLength = terms.stream().sorted(Comparator.comparing(Term::length))
|
||||
.collect(Collectors.toMap(Term::getOffset, term -> term.word.length(), (value1, value2) -> value2));
|
||||
|
||||
Map<String, Set<SchemaElement>> modelResultSet = new HashMap<>();
|
||||
for (Integer startIndex = 0; startIndex <= queryText.length() - 1; ) {
|
||||
for (Integer endIndex = startIndex; endIndex <= queryText.length(); ) {
|
||||
endIndex = mapperHelper.getStepIndex(regOffsetToLength, endIndex);
|
||||
if (endIndex > queryText.length()) {
|
||||
continue;
|
||||
}
|
||||
String detectSegment = queryText.substring(startIndex, endIndex);
|
||||
|
||||
for (Entry<String, Set<SchemaElement>> entry : nameToItems.entrySet()) {
|
||||
String name = entry.getKey();
|
||||
Set<SchemaElement> schemaElements = entry.getValue();
|
||||
if (!name.contains(detectSegment)
|
||||
|| mapperHelper.getSimilarity(detectSegment, name) < metricDimensionThresholdConfig) {
|
||||
continue;
|
||||
}
|
||||
if (!CollectionUtils.isEmpty(modelIds)) {
|
||||
schemaElements = schemaElements.stream()
|
||||
.filter(schemaElement -> modelIds.contains(schemaElement.getModel()))
|
||||
.collect(Collectors.toSet());
|
||||
}
|
||||
Set<SchemaElement> preSchemaElements = modelResultSet.putIfAbsent(detectSegment, schemaElements);
|
||||
if (Objects.nonNull(preSchemaElements)) {
|
||||
preSchemaElements.addAll(schemaElements);
|
||||
}
|
||||
}
|
||||
for (FuzzyResult match : matches) {
|
||||
SchemaElement schemaElement = match.getSchemaElement();
|
||||
Set<Long> regElementSet = getRegElementSet(queryContext.getMapInfo(), schemaElement);
|
||||
if (regElementSet.contains(schemaElement.getId())) {
|
||||
continue;
|
||||
}
|
||||
startIndex = mapperHelper.getStepIndex(regOffsetToLength, startIndex);
|
||||
SchemaElementMatch schemaElementMatch = SchemaElementMatch.builder()
|
||||
.element(schemaElement)
|
||||
.word(schemaElement.getName())
|
||||
.detectWord(match.getDetectWord())
|
||||
.frequency(10000L)
|
||||
.similarity(mapperHelper.getSimilarity(match.getDetectWord(), schemaElement.getName()))
|
||||
.build();
|
||||
log.info("add to schema, elementMatch {}", schemaElementMatch);
|
||||
addToSchemaMap(queryContext.getMapInfo(), schemaElement.getModel(), schemaElementMatch);
|
||||
}
|
||||
return modelResultSet;
|
||||
}
|
||||
|
||||
private Double getThreshold(QueryContext queryContext, MapperHelper mapperHelper) {
|
||||
|
||||
OptimizationConfig optimizationConfig = ContextUtils.getBean(OptimizationConfig.class);
|
||||
Double metricDimensionThresholdConfig = optimizationConfig.getMetricDimensionThresholdConfig();
|
||||
Double metricDimensionMinThresholdConfig = optimizationConfig.getMetricDimensionMinThresholdConfig();
|
||||
|
||||
Map<Long, List<SchemaElementMatch>> modelElementMatches = queryContext.getMapInfo()
|
||||
.getModelElementMatches();
|
||||
boolean existElement = modelElementMatches.entrySet().stream()
|
||||
.anyMatch(entry -> entry.getValue().size() >= 1);
|
||||
|
||||
if (!existElement) {
|
||||
double halfThreshold = metricDimensionThresholdConfig / 2;
|
||||
|
||||
metricDimensionThresholdConfig = halfThreshold >= metricDimensionMinThresholdConfig ? halfThreshold
|
||||
: metricDimensionMinThresholdConfig;
|
||||
log.info("ModelElementMatches:{} , not exist Element metricDimensionThresholdConfig reduce by half:{}",
|
||||
modelElementMatches, metricDimensionThresholdConfig);
|
||||
}
|
||||
return metricDimensionThresholdConfig;
|
||||
}
|
||||
|
||||
private Map<String, Set<SchemaElement>> getNameToItems(List<SchemaElement> models) {
|
||||
return models.stream().collect(
|
||||
Collectors.toMap(SchemaElement::getName, a -> {
|
||||
Set<SchemaElement> result = new HashSet<>();
|
||||
result.add(a);
|
||||
return result;
|
||||
}, (k1, k2) -> {
|
||||
k1.addAll(k2);
|
||||
return k1;
|
||||
}));
|
||||
}
|
||||
|
||||
private void addToSchemaMapInfo(Map<String, Set<SchemaElement>> mapResultRowSet, SchemaMapInfo schemaMap,
|
||||
SchemaElementType schemaElementType) {
|
||||
if (Objects.isNull(mapResultRowSet) || mapResultRowSet.size() <= 0) {
|
||||
return;
|
||||
}
|
||||
MapperHelper mapperHelper = ContextUtils.getBean(MapperHelper.class);
|
||||
|
||||
for (Map.Entry<String, Set<SchemaElement>> entry : mapResultRowSet.entrySet()) {
|
||||
String detectWord = entry.getKey();
|
||||
Set<SchemaElement> schemaElements = entry.getValue();
|
||||
for (SchemaElement schemaElement : schemaElements) {
|
||||
|
||||
List<SchemaElementMatch> elements = schemaMap.getMatchedElements(schemaElement.getModel());
|
||||
if (CollectionUtils.isEmpty(elements)) {
|
||||
elements = new ArrayList<>();
|
||||
schemaMap.setMatchedElements(schemaElement.getModel(), elements);
|
||||
}
|
||||
Set<Long> regElementSet = elements.stream()
|
||||
.filter(elementMatch -> schemaElementType.equals(elementMatch.getElement().getType()))
|
||||
.map(elementMatch -> elementMatch.getElement().getId())
|
||||
.collect(Collectors.toSet());
|
||||
|
||||
if (regElementSet.contains(schemaElement.getId())) {
|
||||
continue;
|
||||
}
|
||||
SchemaElementMatch schemaElementMatch = SchemaElementMatch.builder()
|
||||
.element(schemaElement)
|
||||
.word(schemaElement.getName())
|
||||
.detectWord(detectWord)
|
||||
.frequency(10000L)
|
||||
.similarity(mapperHelper.getSimilarity(detectWord, schemaElement.getName()))
|
||||
.build();
|
||||
log.info("schemaElementType:{},add to schema, elementMatch {}", schemaElementType, schemaElementMatch);
|
||||
elements.add(schemaElementMatch);
|
||||
}
|
||||
private Set<Long> getRegElementSet(SchemaMapInfo schemaMap, SchemaElement schemaElement) {
|
||||
List<SchemaElementMatch> elements = schemaMap.getMatchedElements(schemaElement.getModel());
|
||||
if (CollectionUtils.isEmpty(elements)) {
|
||||
return new HashSet<>();
|
||||
}
|
||||
return elements.stream()
|
||||
.filter(elementMatch ->
|
||||
SchemaElementType.METRIC.equals(elementMatch.getElement().getType())
|
||||
|| SchemaElementType.DIMENSION.equals(elementMatch.getElement().getType()))
|
||||
.map(elementMatch -> elementMatch.getElement().getId())
|
||||
.collect(Collectors.toSet());
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
@@ -0,0 +1,131 @@
|
||||
package com.tencent.supersonic.chat.mapper;
|
||||
|
||||
import com.hankcs.hanlp.seg.common.Term;
|
||||
import com.tencent.supersonic.chat.api.pojo.QueryContext;
|
||||
import com.tencent.supersonic.chat.api.pojo.SchemaElement;
|
||||
import com.tencent.supersonic.chat.api.pojo.SchemaElementMatch;
|
||||
import com.tencent.supersonic.chat.config.OptimizationConfig;
|
||||
import com.tencent.supersonic.common.pojo.Constants;
|
||||
import com.tencent.supersonic.knowledge.dictionary.FuzzyResult;
|
||||
import com.tencent.supersonic.knowledge.service.SchemaService;
|
||||
import java.util.ArrayList;
|
||||
import java.util.HashSet;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
import java.util.Map.Entry;
|
||||
import java.util.Set;
|
||||
import java.util.stream.Collectors;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.apache.commons.lang3.StringUtils;
|
||||
import org.springframework.beans.factory.annotation.Autowired;
|
||||
import org.springframework.stereotype.Service;
|
||||
import org.springframework.util.CollectionUtils;
|
||||
|
||||
/**
|
||||
* Fuzzy Name Match Strategy
|
||||
*/
|
||||
@Service
|
||||
@Slf4j
|
||||
public class FuzzyNameMatchStrategy extends BaseMatchStrategy<FuzzyResult> {
|
||||
|
||||
@Autowired
|
||||
private OptimizationConfig optimizationConfig;
|
||||
@Autowired
|
||||
private MapperHelper mapperHelper;
|
||||
@Autowired
|
||||
private SchemaService schemaService;
|
||||
private List<SchemaElement> allElements;
|
||||
|
||||
|
||||
@Override
|
||||
public Map<MatchText, List<FuzzyResult>> match(QueryContext queryContext, List<Term> terms,
|
||||
Set<Long> detectModelIds) {
|
||||
this.allElements = getSchemaElements();
|
||||
return super.match(queryContext, terms, detectModelIds);
|
||||
}
|
||||
|
||||
@Override
|
||||
public boolean needDelete(FuzzyResult oneRoundResult, FuzzyResult existResult) {
|
||||
return getMapKey(oneRoundResult).equals(getMapKey(existResult))
|
||||
&& existResult.getDetectWord().length() < oneRoundResult.getDetectWord().length();
|
||||
}
|
||||
|
||||
@Override
|
||||
public String getMapKey(FuzzyResult a) {
|
||||
return a.getName() + Constants.UNDERLINE + a.getSchemaElement().getId()
|
||||
+ Constants.UNDERLINE + a.getSchemaElement().getName();
|
||||
}
|
||||
|
||||
public void detectByStep(QueryContext queryContext, Set<FuzzyResult> existResults, Set<Long> detectModelIds,
|
||||
Integer startIndex, Integer index, int offset) {
|
||||
String detectSegment = queryContext.getRequest().getQueryText().substring(startIndex, index);
|
||||
if (StringUtils.isBlank(detectSegment)) {
|
||||
return;
|
||||
}
|
||||
Set<Long> modelIds = mapperHelper.getModelIds(queryContext.getRequest());
|
||||
|
||||
Double metricDimensionThresholdConfig = getThreshold(queryContext);
|
||||
|
||||
Map<String, Set<SchemaElement>> nameToItems = getNameToItems(allElements);
|
||||
|
||||
for (Entry<String, Set<SchemaElement>> entry : nameToItems.entrySet()) {
|
||||
String name = entry.getKey();
|
||||
if (!name.contains(detectSegment)
|
||||
|| mapperHelper.getSimilarity(detectSegment, name) < metricDimensionThresholdConfig) {
|
||||
continue;
|
||||
}
|
||||
Set<SchemaElement> schemaElements = entry.getValue();
|
||||
if (!CollectionUtils.isEmpty(modelIds)) {
|
||||
schemaElements = schemaElements.stream()
|
||||
.filter(schemaElement -> modelIds.contains(schemaElement.getModel()))
|
||||
.collect(Collectors.toSet());
|
||||
}
|
||||
for (SchemaElement schemaElement : schemaElements) {
|
||||
FuzzyResult fuzzyResult = new FuzzyResult();
|
||||
fuzzyResult.setDetectWord(detectSegment);
|
||||
fuzzyResult.setName(schemaElement.getName());
|
||||
fuzzyResult.setSchemaElement(schemaElement);
|
||||
existResults.add(fuzzyResult);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
private List<SchemaElement> getSchemaElements() {
|
||||
List<SchemaElement> allElements = new ArrayList<>();
|
||||
allElements.addAll(schemaService.getSemanticSchema().getDimensions());
|
||||
allElements.addAll(schemaService.getSemanticSchema().getMetrics());
|
||||
return allElements;
|
||||
}
|
||||
|
||||
|
||||
private Double getThreshold(QueryContext queryContext) {
|
||||
Double metricDimensionThresholdConfig = optimizationConfig.getMetricDimensionThresholdConfig();
|
||||
Double metricDimensionMinThresholdConfig = optimizationConfig.getMetricDimensionMinThresholdConfig();
|
||||
|
||||
Map<Long, List<SchemaElementMatch>> modelElementMatches = queryContext.getMapInfo().getModelElementMatches();
|
||||
|
||||
boolean existElement = modelElementMatches.entrySet().stream().anyMatch(entry -> entry.getValue().size() >= 1);
|
||||
|
||||
if (!existElement) {
|
||||
double halfThreshold = metricDimensionThresholdConfig / 2;
|
||||
|
||||
metricDimensionThresholdConfig = halfThreshold >= metricDimensionMinThresholdConfig ? halfThreshold
|
||||
: metricDimensionMinThresholdConfig;
|
||||
log.info("ModelElementMatches:{} , not exist Element metricDimensionThresholdConfig reduce by half:{}",
|
||||
modelElementMatches, metricDimensionThresholdConfig);
|
||||
}
|
||||
return metricDimensionThresholdConfig;
|
||||
}
|
||||
|
||||
private Map<String, Set<SchemaElement>> getNameToItems(List<SchemaElement> models) {
|
||||
return models.stream().collect(
|
||||
Collectors.toMap(SchemaElement::getName, a -> {
|
||||
Set<SchemaElement> result = new HashSet<>();
|
||||
result.add(a);
|
||||
return result;
|
||||
}, (k1, k2) -> {
|
||||
k1.addAll(k2);
|
||||
return k1;
|
||||
}));
|
||||
}
|
||||
}
|
||||
@@ -1,65 +1,48 @@
|
||||
package com.tencent.supersonic.chat.mapper;
|
||||
|
||||
import com.hankcs.hanlp.seg.common.Term;
|
||||
import com.tencent.supersonic.chat.api.component.SchemaMapper;
|
||||
import com.tencent.supersonic.chat.api.pojo.ModelSchema;
|
||||
import com.tencent.supersonic.chat.api.pojo.QueryContext;
|
||||
import com.tencent.supersonic.chat.api.pojo.SchemaElement;
|
||||
import com.tencent.supersonic.chat.api.pojo.SchemaElementMatch;
|
||||
import com.tencent.supersonic.chat.api.pojo.SchemaElementType;
|
||||
import com.tencent.supersonic.chat.api.pojo.SchemaMapInfo;
|
||||
import com.tencent.supersonic.chat.service.SemanticService;
|
||||
import com.tencent.supersonic.common.util.ContextUtils;
|
||||
import com.tencent.supersonic.common.pojo.enums.DictWordType;
|
||||
import com.tencent.supersonic.knowledge.dictionary.MapResult;
|
||||
import com.tencent.supersonic.knowledge.dictionary.builder.BaseWordBuilder;
|
||||
import com.tencent.supersonic.knowledge.dictionary.builder.WordBuilderFactory;
|
||||
import com.tencent.supersonic.knowledge.dictionary.HanlpMapResult;
|
||||
import com.tencent.supersonic.knowledge.utils.HanlpHelper;
|
||||
import com.tencent.supersonic.knowledge.utils.NatureHelper;
|
||||
import java.util.ArrayList;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
import java.util.Objects;
|
||||
import java.util.Optional;
|
||||
import java.util.Set;
|
||||
import java.util.stream.Collectors;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.apache.commons.collections.CollectionUtils;
|
||||
import org.apache.commons.lang3.StringUtils;
|
||||
import org.springframework.beans.BeanUtils;
|
||||
|
||||
/***
|
||||
* A mapper capable of prefix and suffix similarity parsing for
|
||||
* domain names, dimension values, metric names, and dimension names.
|
||||
*/
|
||||
@Slf4j
|
||||
public class HanlpDictMapper implements SchemaMapper {
|
||||
public class HanlpDictMapper extends BaseMapper {
|
||||
|
||||
@Override
|
||||
public void map(QueryContext queryContext) {
|
||||
public void doMap(QueryContext queryContext) {
|
||||
|
||||
String queryText = queryContext.getRequest().getQueryText();
|
||||
List<Term> terms = HanlpHelper.getTerms(queryText);
|
||||
|
||||
for (Term term : terms) {
|
||||
log.info("word:{},nature:{},frequency:{}", term.word, term.nature.toString(), term.getFrequency());
|
||||
}
|
||||
HanlpDictMatchStrategy matchStrategy = ContextUtils.getBean(HanlpDictMatchStrategy.class);
|
||||
|
||||
QueryMatchStrategy matchStrategy = ContextUtils.getBean(QueryMatchStrategy.class);
|
||||
MapperHelper mapperHelper = ContextUtils.getBean(MapperHelper.class);
|
||||
Set<Long> detectModelIds = mapperHelper.getModelIds(queryContext.getRequest());
|
||||
|
||||
Map<MatchText, List<MapResult>> matchResult = matchStrategy.match(queryContext.getRequest(), terms,
|
||||
detectModelIds);
|
||||
|
||||
List<MapResult> matches = getMatches(matchResult);
|
||||
List<HanlpMapResult> matches = matchStrategy.getMatches(queryContext, terms);
|
||||
|
||||
HanlpHelper.transLetterOriginal(matches);
|
||||
|
||||
log.info("queryContext:{},matches:{}", queryContext, matches);
|
||||
|
||||
convertTermsToSchemaMapInfo(matches, queryContext.getMapInfo(), terms);
|
||||
}
|
||||
|
||||
|
||||
private void convertTermsToSchemaMapInfo(List<MapResult> mapResults, SchemaMapInfo schemaMap, List<Term> terms) {
|
||||
if (CollectionUtils.isEmpty(mapResults)) {
|
||||
private void convertTermsToSchemaMapInfo(List<HanlpMapResult> hanlpMapResults, SchemaMapInfo schemaMap,
|
||||
List<Term> terms) {
|
||||
if (CollectionUtils.isEmpty(hanlpMapResults)) {
|
||||
return;
|
||||
}
|
||||
|
||||
@@ -67,8 +50,8 @@ public class HanlpDictMapper implements SchemaMapper {
|
||||
Collectors.toMap(entry -> entry.getWord() + entry.getNature(),
|
||||
term -> Long.valueOf(term.getFrequency()), (value1, value2) -> value2));
|
||||
|
||||
for (MapResult mapResult : mapResults) {
|
||||
for (String nature : mapResult.getNatures()) {
|
||||
for (HanlpMapResult hanlpMapResult : hanlpMapResults) {
|
||||
for (String nature : hanlpMapResult.getNatures()) {
|
||||
Long modelId = NatureHelper.getModelId(nature);
|
||||
if (Objects.isNull(modelId)) {
|
||||
continue;
|
||||
@@ -77,68 +60,27 @@ public class HanlpDictMapper implements SchemaMapper {
|
||||
if (Objects.isNull(elementType)) {
|
||||
continue;
|
||||
}
|
||||
|
||||
SemanticService schemaService = ContextUtils.getBean(SemanticService.class);
|
||||
ModelSchema modelSchema = schemaService.getModelSchema(modelId);
|
||||
|
||||
BaseWordBuilder baseWordBuilder = WordBuilderFactory.get(DictWordType.getNatureType(nature));
|
||||
Long elementID = baseWordBuilder.getElementID(nature);
|
||||
Long frequency = wordNatureToFrequency.get(mapResult.getName() + nature);
|
||||
|
||||
SchemaElement elementDb = modelSchema.getElement(elementType, elementID);
|
||||
if (Objects.isNull(elementDb)) {
|
||||
log.info("element is null, elementType:{},elementID:{}", elementType, elementID);
|
||||
Long elementID = NatureHelper.getElementID(nature);
|
||||
SchemaElement element = getSchemaElement(modelId, elementType, elementID);
|
||||
if (element == null) {
|
||||
continue;
|
||||
}
|
||||
SchemaElement element = new SchemaElement();
|
||||
BeanUtils.copyProperties(elementDb, element);
|
||||
element.setAlias(getAlias(elementDb));
|
||||
if (element.getType().equals(SchemaElementType.VALUE)) {
|
||||
element.setName(mapResult.getName());
|
||||
element.setName(hanlpMapResult.getName());
|
||||
}
|
||||
Long frequency = wordNatureToFrequency.get(hanlpMapResult.getName() + nature);
|
||||
SchemaElementMatch schemaElementMatch = SchemaElementMatch.builder()
|
||||
.element(element)
|
||||
.frequency(frequency)
|
||||
.word(mapResult.getName())
|
||||
.similarity(mapResult.getSimilarity())
|
||||
.detectWord(mapResult.getDetectWord())
|
||||
.word(hanlpMapResult.getName())
|
||||
.similarity(hanlpMapResult.getSimilarity())
|
||||
.detectWord(hanlpMapResult.getDetectWord())
|
||||
.build();
|
||||
|
||||
Map<Long, List<SchemaElementMatch>> modelElementMatches = schemaMap.getModelElementMatches();
|
||||
List<SchemaElementMatch> schemaElementMatches = modelElementMatches.putIfAbsent(modelId,
|
||||
new ArrayList<>());
|
||||
if (schemaElementMatches == null) {
|
||||
schemaElementMatches = modelElementMatches.get(modelId);
|
||||
}
|
||||
schemaElementMatches.add(schemaElementMatch);
|
||||
addToSchemaMap(schemaMap, modelId, schemaElementMatch);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
private List<MapResult> getMatches(Map<MatchText, List<MapResult>> matchResult) {
|
||||
List<MapResult> matches = new ArrayList<>();
|
||||
if (Objects.isNull(matchResult)) {
|
||||
return matches;
|
||||
}
|
||||
Optional<List<MapResult>> first = matchResult.entrySet().stream()
|
||||
.filter(entry -> CollectionUtils.isNotEmpty(entry.getValue()))
|
||||
.map(entry -> entry.getValue()).findFirst();
|
||||
|
||||
if (first.isPresent()) {
|
||||
matches = first.get();
|
||||
}
|
||||
return matches;
|
||||
}
|
||||
|
||||
public List<String> getAlias(SchemaElement element) {
|
||||
if (!SchemaElementType.VALUE.equals(element.getType())) {
|
||||
return element.getAlias();
|
||||
}
|
||||
if (CollectionUtils.isNotEmpty(element.getAlias()) && StringUtils.isNotEmpty(element.getName())) {
|
||||
return element.getAlias().stream()
|
||||
.filter(aliasItem -> aliasItem.contains(element.getName()))
|
||||
.collect(Collectors.toList());
|
||||
}
|
||||
return element.getAlias();
|
||||
}
|
||||
}
|
||||
|
||||
@@ -0,0 +1,123 @@
|
||||
package com.tencent.supersonic.chat.mapper;
|
||||
|
||||
import com.hankcs.hanlp.seg.common.Term;
|
||||
import com.tencent.supersonic.chat.api.pojo.QueryContext;
|
||||
import com.tencent.supersonic.chat.api.pojo.request.QueryReq;
|
||||
import com.tencent.supersonic.chat.config.OptimizationConfig;
|
||||
import com.tencent.supersonic.common.pojo.Constants;
|
||||
import com.tencent.supersonic.knowledge.dictionary.HanlpMapResult;
|
||||
import com.tencent.supersonic.knowledge.service.SearchService;
|
||||
import java.util.HashMap;
|
||||
import java.util.LinkedHashSet;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
import java.util.Objects;
|
||||
import java.util.Set;
|
||||
import java.util.stream.Collectors;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.apache.commons.collections.CollectionUtils;
|
||||
import org.apache.commons.lang3.StringUtils;
|
||||
import org.springframework.beans.factory.annotation.Autowired;
|
||||
import org.springframework.stereotype.Service;
|
||||
|
||||
/**
|
||||
* match strategy implement
|
||||
*/
|
||||
@Service
|
||||
@Slf4j
|
||||
public class HanlpDictMatchStrategy extends BaseMatchStrategy<HanlpMapResult> {
|
||||
|
||||
@Autowired
|
||||
private MapperHelper mapperHelper;
|
||||
|
||||
@Autowired
|
||||
private OptimizationConfig optimizationConfig;
|
||||
|
||||
@Override
|
||||
public Map<MatchText, List<HanlpMapResult>> match(QueryContext queryContext, List<Term> terms,
|
||||
Set<Long> detectModelIds) {
|
||||
QueryReq queryReq = queryContext.getRequest();
|
||||
String text = queryReq.getQueryText();
|
||||
if (Objects.isNull(terms) || StringUtils.isEmpty(text)) {
|
||||
return null;
|
||||
}
|
||||
|
||||
log.debug("retryCount:{},terms:{},,detectModelIds:{}", terms, detectModelIds);
|
||||
|
||||
List<HanlpMapResult> detects = detect(queryContext, terms, detectModelIds);
|
||||
Map<MatchText, List<HanlpMapResult>> result = new HashMap<>();
|
||||
|
||||
result.put(MatchText.builder().regText(text).detectSegment(text).build(), detects);
|
||||
return result;
|
||||
}
|
||||
|
||||
@Override
|
||||
public boolean needDelete(HanlpMapResult oneRoundResult, HanlpMapResult existResult) {
|
||||
return getMapKey(oneRoundResult).equals(getMapKey(existResult))
|
||||
&& existResult.getDetectWord().length() < oneRoundResult.getDetectWord().length();
|
||||
}
|
||||
|
||||
public void detectByStep(QueryContext queryContext, Set<HanlpMapResult> existResults, Set<Long> detectModelIds,
|
||||
Integer startIndex, Integer index, int offset) {
|
||||
QueryReq queryReq = queryContext.getRequest();
|
||||
String text = queryReq.getQueryText();
|
||||
Integer agentId = queryReq.getAgentId();
|
||||
String detectSegment = text.substring(startIndex, index);
|
||||
|
||||
// step1. pre search
|
||||
Integer oneDetectionMaxSize = optimizationConfig.getOneDetectionMaxSize();
|
||||
LinkedHashSet<HanlpMapResult> hanlpMapResults = SearchService.prefixSearch(detectSegment, oneDetectionMaxSize,
|
||||
agentId,
|
||||
detectModelIds).stream().collect(Collectors.toCollection(LinkedHashSet::new));
|
||||
// step2. suffix search
|
||||
LinkedHashSet<HanlpMapResult> suffixHanlpMapResults = SearchService.suffixSearch(detectSegment,
|
||||
oneDetectionMaxSize, agentId, detectModelIds).stream()
|
||||
.collect(Collectors.toCollection(LinkedHashSet::new));
|
||||
|
||||
hanlpMapResults.addAll(suffixHanlpMapResults);
|
||||
|
||||
if (CollectionUtils.isEmpty(hanlpMapResults)) {
|
||||
return;
|
||||
}
|
||||
// step3. merge pre/suffix result
|
||||
hanlpMapResults = hanlpMapResults.stream().sorted((a, b) -> -(b.getName().length() - a.getName().length()))
|
||||
.collect(Collectors.toCollection(LinkedHashSet::new));
|
||||
|
||||
// step4. filter by similarity
|
||||
hanlpMapResults = hanlpMapResults.stream()
|
||||
.filter(term -> mapperHelper.getSimilarity(detectSegment, term.getName())
|
||||
>= mapperHelper.getThresholdMatch(term.getNatures()))
|
||||
.filter(term -> CollectionUtils.isNotEmpty(term.getNatures()))
|
||||
.collect(Collectors.toCollection(LinkedHashSet::new));
|
||||
|
||||
log.info("after isSimilarity parseResults:{}", hanlpMapResults);
|
||||
|
||||
hanlpMapResults = hanlpMapResults.stream().map(parseResult -> {
|
||||
parseResult.setOffset(offset);
|
||||
parseResult.setSimilarity(mapperHelper.getSimilarity(detectSegment, parseResult.getName()));
|
||||
return parseResult;
|
||||
}).collect(Collectors.toCollection(LinkedHashSet::new));
|
||||
|
||||
// step5. take only one dimension or 10 metric/dimension value per rond.
|
||||
List<HanlpMapResult> dimensionMetrics = hanlpMapResults.stream()
|
||||
.filter(entry -> mapperHelper.existDimensionValues(entry.getNatures()))
|
||||
.collect(Collectors.toList())
|
||||
.stream()
|
||||
.limit(1)
|
||||
.collect(Collectors.toList());
|
||||
|
||||
Integer oneDetectionSize = optimizationConfig.getOneDetectionSize();
|
||||
List<HanlpMapResult> oneRoundResults = hanlpMapResults.stream().limit(oneDetectionSize)
|
||||
.collect(Collectors.toList());
|
||||
|
||||
if (CollectionUtils.isNotEmpty(dimensionMetrics)) {
|
||||
oneRoundResults = dimensionMetrics;
|
||||
}
|
||||
// step6. select mapResul in one round
|
||||
selectResultInOneRound(existResults, oneRoundResults);
|
||||
}
|
||||
|
||||
public String getMapKey(HanlpMapResult a) {
|
||||
return a.getName() + Constants.UNDERLINE + String.join(Constants.UNDERLINE, a.getNatures());
|
||||
}
|
||||
}
|
||||
@@ -1,11 +1,13 @@
|
||||
package com.tencent.supersonic.chat.mapper;
|
||||
|
||||
import com.hankcs.hanlp.algorithm.EditDistance;
|
||||
import com.hankcs.hanlp.seg.common.Term;
|
||||
import com.tencent.supersonic.chat.api.pojo.request.QueryReq;
|
||||
import com.tencent.supersonic.chat.config.OptimizationConfig;
|
||||
import com.tencent.supersonic.chat.service.AgentService;
|
||||
import com.tencent.supersonic.common.util.ContextUtils;
|
||||
import com.tencent.supersonic.knowledge.utils.NatureHelper;
|
||||
import java.util.Comparator;
|
||||
import java.util.HashSet;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
@@ -39,10 +41,14 @@ public class MapperHelper {
|
||||
return index;
|
||||
}
|
||||
|
||||
public Integer getStepOffset(List<Integer> termList, Integer index) {
|
||||
|
||||
public Integer getStepOffset(List<Term> termList, Integer index) {
|
||||
List<Integer> offsetList = termList.stream().sorted(Comparator.comparing(Term::getOffset))
|
||||
.map(term -> term.getOffset()).collect(Collectors.toList());
|
||||
|
||||
for (int j = 0; j < termList.size() - 1; j++) {
|
||||
if (termList.get(j) <= index && termList.get(j + 1) > index) {
|
||||
return termList.get(j);
|
||||
if (offsetList.get(j) <= index && offsetList.get(j + 1) > index) {
|
||||
return offsetList.get(j);
|
||||
}
|
||||
}
|
||||
return index;
|
||||
@@ -88,7 +94,7 @@ public class MapperHelper {
|
||||
|
||||
AgentService agentService = ContextUtils.getBean(AgentService.class);
|
||||
|
||||
Set<Long> detectModelIds = agentService.getDslToolsModelIds(request.getAgentId(), null);
|
||||
Set<Long> detectModelIds = agentService.getModelIds(request.getAgentId(), null);
|
||||
//contains all
|
||||
if (agentService.containsAllModel(detectModelIds)) {
|
||||
if (Objects.nonNull(modelId) && modelId > 0) {
|
||||
|
||||
@@ -1,8 +1,7 @@
|
||||
package com.tencent.supersonic.chat.mapper;
|
||||
|
||||
import com.hankcs.hanlp.seg.common.Term;
|
||||
import com.tencent.supersonic.chat.api.pojo.request.QueryReq;
|
||||
import com.tencent.supersonic.knowledge.dictionary.MapResult;
|
||||
import com.tencent.supersonic.chat.api.pojo.QueryContext;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
import java.util.Set;
|
||||
@@ -10,8 +9,8 @@ import java.util.Set;
|
||||
/**
|
||||
* match strategy
|
||||
*/
|
||||
public interface MatchStrategy {
|
||||
public interface MatchStrategy<T> {
|
||||
|
||||
Map<MatchText, List<MapResult>> match(QueryReq queryReq, List<Term> terms, Set<Long> detectModelId);
|
||||
Map<MatchText, List<T>> match(QueryContext queryContext, List<Term> terms, Set<Long> detectModelId);
|
||||
|
||||
}
|
||||
@@ -0,0 +1,52 @@
|
||||
package com.tencent.supersonic.chat.mapper;
|
||||
|
||||
import com.tencent.supersonic.chat.api.component.SchemaMapper;
|
||||
import com.tencent.supersonic.chat.api.pojo.QueryContext;
|
||||
import com.tencent.supersonic.chat.api.pojo.SchemaElementMatch;
|
||||
import com.tencent.supersonic.chat.api.pojo.SchemaMapInfo;
|
||||
import com.tencent.supersonic.chat.api.pojo.SchemaModelClusterMapInfo;
|
||||
import com.tencent.supersonic.chat.api.pojo.SemanticSchema;
|
||||
import com.tencent.supersonic.chat.utils.ModelClusterBuilder;
|
||||
import com.tencent.supersonic.common.pojo.ModelCluster;
|
||||
import com.tencent.supersonic.common.util.ContextUtils;
|
||||
import com.tencent.supersonic.knowledge.service.SchemaService;
|
||||
|
||||
import java.util.ArrayList;
|
||||
import java.util.HashMap;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
import java.util.Set;
|
||||
import java.util.stream.Collectors;
|
||||
|
||||
public class ModelClusterMapper implements SchemaMapper {
|
||||
|
||||
@Override
|
||||
public void map(QueryContext queryContext) {
|
||||
SchemaService schemaService = ContextUtils.getBean(SchemaService.class);
|
||||
SemanticSchema semanticSchema = schemaService.getSemanticSchema();
|
||||
SchemaMapInfo schemaMapInfo = queryContext.getMapInfo();
|
||||
List<ModelCluster> modelClusters = buildModelClusterMatched(schemaMapInfo, semanticSchema);
|
||||
Map<String, List<SchemaElementMatch>> modelClusterElementMatches = new HashMap<>();
|
||||
for (ModelCluster modelCluster : modelClusters) {
|
||||
for (Long modelId : schemaMapInfo.getMatchedModels()) {
|
||||
if (modelCluster.getModelIds().contains(modelId)) {
|
||||
modelClusterElementMatches.computeIfAbsent(modelCluster.getKey(), k -> new ArrayList<>())
|
||||
.addAll(schemaMapInfo.getMatchedElements(modelId));
|
||||
}
|
||||
}
|
||||
}
|
||||
SchemaModelClusterMapInfo modelClusterMapInfo = new SchemaModelClusterMapInfo();
|
||||
modelClusterMapInfo.setModelElementMatches(modelClusterElementMatches);
|
||||
queryContext.setModelClusterMapInfo(modelClusterMapInfo);
|
||||
}
|
||||
|
||||
private List<ModelCluster> buildModelClusterMatched(SchemaMapInfo schemaMapInfo,
|
||||
SemanticSchema semanticSchema) {
|
||||
Set<Long> matchedModels = schemaMapInfo.getMatchedModels();
|
||||
List<ModelCluster> modelClusters = ModelClusterBuilder.buildModelClusters(semanticSchema);
|
||||
return modelClusters.stream().map(ModelCluster::getModelIds).peek(modelCluster -> {
|
||||
modelCluster.removeIf(model -> !matchedModels.contains(model));
|
||||
}).filter(modelCluster -> modelCluster.size() > 0).map(ModelCluster::build).collect(Collectors.toList());
|
||||
}
|
||||
|
||||
}
|
||||
@@ -1,163 +0,0 @@
|
||||
package com.tencent.supersonic.chat.mapper;
|
||||
|
||||
import com.hankcs.hanlp.seg.common.Term;
|
||||
import com.tencent.supersonic.chat.api.pojo.request.QueryReq;
|
||||
import com.tencent.supersonic.chat.config.OptimizationConfig;
|
||||
import com.tencent.supersonic.common.pojo.Constants;
|
||||
import com.tencent.supersonic.knowledge.dictionary.MapResult;
|
||||
import com.tencent.supersonic.knowledge.service.SearchService;
|
||||
import java.util.ArrayList;
|
||||
import java.util.Comparator;
|
||||
import java.util.HashMap;
|
||||
import java.util.LinkedHashSet;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
import java.util.Objects;
|
||||
import java.util.Set;
|
||||
import java.util.stream.Collectors;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.apache.commons.collections.CollectionUtils;
|
||||
import org.apache.commons.compress.utils.Lists;
|
||||
import org.apache.commons.lang3.StringUtils;
|
||||
import org.springframework.beans.factory.annotation.Autowired;
|
||||
import org.springframework.stereotype.Service;
|
||||
|
||||
/**
|
||||
* match strategy implement
|
||||
*/
|
||||
@Service
|
||||
@Slf4j
|
||||
public class QueryMatchStrategy implements MatchStrategy {
|
||||
|
||||
@Autowired
|
||||
private MapperHelper mapperHelper;
|
||||
|
||||
@Autowired
|
||||
private OptimizationConfig optimizationConfig;
|
||||
|
||||
@Override
|
||||
public Map<MatchText, List<MapResult>> match(QueryReq queryReq, List<Term> terms, Set<Long> detectModelIds) {
|
||||
String text = queryReq.getQueryText();
|
||||
if (Objects.isNull(terms) || StringUtils.isEmpty(text)) {
|
||||
return null;
|
||||
}
|
||||
|
||||
Map<Integer, Integer> regOffsetToLength = terms.stream().sorted(Comparator.comparing(Term::length))
|
||||
.collect(Collectors.toMap(Term::getOffset, term -> term.word.length(), (value1, value2) -> value2));
|
||||
|
||||
List<Integer> offsetList = terms.stream().sorted(Comparator.comparing(Term::getOffset))
|
||||
.map(term -> term.getOffset()).collect(Collectors.toList());
|
||||
|
||||
log.debug("retryCount:{},terms:{},regOffsetToLength:{},offsetList:{},detectModelIds:{}", terms,
|
||||
regOffsetToLength, offsetList, detectModelIds);
|
||||
|
||||
List<MapResult> detects = detect(queryReq, regOffsetToLength, offsetList, detectModelIds);
|
||||
Map<MatchText, List<MapResult>> result = new HashMap<>();
|
||||
|
||||
result.put(MatchText.builder().regText(text).detectSegment(text).build(), detects);
|
||||
return result;
|
||||
}
|
||||
|
||||
private List<MapResult> detect(QueryReq queryReq, Map<Integer, Integer> regOffsetToLength, List<Integer> offsetList,
|
||||
Set<Long> detectModelIds) {
|
||||
String text = queryReq.getQueryText();
|
||||
List<MapResult> results = Lists.newArrayList();
|
||||
|
||||
for (Integer index = 0; index <= text.length() - 1; ) {
|
||||
|
||||
Set<MapResult> mapResultRowSet = new LinkedHashSet();
|
||||
|
||||
for (Integer i = index; i <= text.length(); ) {
|
||||
int offset = mapperHelper.getStepOffset(offsetList, index);
|
||||
i = mapperHelper.getStepIndex(regOffsetToLength, i);
|
||||
if (i <= text.length()) {
|
||||
List<MapResult> mapResults = detectByStep(queryReq, detectModelIds, index, i, offset);
|
||||
selectMapResultInOneRound(mapResultRowSet, mapResults);
|
||||
}
|
||||
}
|
||||
index = mapperHelper.getStepIndex(regOffsetToLength, index);
|
||||
results.addAll(mapResultRowSet);
|
||||
}
|
||||
return results;
|
||||
}
|
||||
|
||||
private void selectMapResultInOneRound(Set<MapResult> mapResultRowSet, List<MapResult> mapResults) {
|
||||
for (MapResult mapResult : mapResults) {
|
||||
if (mapResultRowSet.contains(mapResult)) {
|
||||
boolean isDeleted = mapResultRowSet.removeIf(
|
||||
entry -> {
|
||||
boolean deleted = getMapKey(mapResult).equals(getMapKey(entry))
|
||||
&& entry.getDetectWord().length() < mapResult.getDetectWord().length();
|
||||
if (deleted) {
|
||||
log.info("deleted entry:{}", entry);
|
||||
}
|
||||
return deleted;
|
||||
}
|
||||
);
|
||||
if (isDeleted) {
|
||||
log.info("deleted, add mapResult:{}", mapResult);
|
||||
mapResultRowSet.add(mapResult);
|
||||
}
|
||||
} else {
|
||||
mapResultRowSet.add(mapResult);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
private String getMapKey(MapResult a) {
|
||||
return a.getName() + Constants.UNDERLINE + String.join(Constants.UNDERLINE, a.getNatures());
|
||||
}
|
||||
|
||||
private List<MapResult> detectByStep(QueryReq queryReq, Set<Long> detectModelIds, Integer index, Integer i,
|
||||
int offset) {
|
||||
String text = queryReq.getQueryText();
|
||||
Integer agentId = queryReq.getAgentId();
|
||||
String detectSegment = text.substring(index, i);
|
||||
|
||||
// step1. pre search
|
||||
Integer oneDetectionMaxSize = optimizationConfig.getOneDetectionMaxSize();
|
||||
LinkedHashSet<MapResult> mapResults = SearchService.prefixSearch(detectSegment, oneDetectionMaxSize, agentId,
|
||||
detectModelIds).stream().collect(Collectors.toCollection(LinkedHashSet::new));
|
||||
// step2. suffix search
|
||||
LinkedHashSet<MapResult> suffixMapResults = SearchService.suffixSearch(detectSegment, oneDetectionMaxSize,
|
||||
agentId, detectModelIds).stream().collect(Collectors.toCollection(LinkedHashSet::new));
|
||||
|
||||
mapResults.addAll(suffixMapResults);
|
||||
|
||||
if (CollectionUtils.isEmpty(mapResults)) {
|
||||
return new ArrayList<>();
|
||||
}
|
||||
// step3. merge pre/suffix result
|
||||
mapResults = mapResults.stream().sorted((a, b) -> -(b.getName().length() - a.getName().length()))
|
||||
.collect(Collectors.toCollection(LinkedHashSet::new));
|
||||
|
||||
// step4. filter by similarity
|
||||
mapResults = mapResults.stream()
|
||||
.filter(term -> mapperHelper.getSimilarity(detectSegment, term.getName())
|
||||
>= mapperHelper.getThresholdMatch(term.getNatures()))
|
||||
.filter(term -> CollectionUtils.isNotEmpty(term.getNatures()))
|
||||
.collect(Collectors.toCollection(LinkedHashSet::new));
|
||||
|
||||
log.info("after isSimilarity parseResults:{}", mapResults);
|
||||
|
||||
mapResults = mapResults.stream().map(parseResult -> {
|
||||
parseResult.setOffset(offset);
|
||||
parseResult.setSimilarity(mapperHelper.getSimilarity(detectSegment, parseResult.getName()));
|
||||
return parseResult;
|
||||
}).collect(Collectors.toCollection(LinkedHashSet::new));
|
||||
|
||||
// step5. take only one dimension or 10 metric/dimension value per rond.
|
||||
List<MapResult> dimensionMetrics = mapResults.stream()
|
||||
.filter(entry -> mapperHelper.existDimensionValues(entry.getNatures()))
|
||||
.collect(Collectors.toList())
|
||||
.stream()
|
||||
.limit(1)
|
||||
.collect(Collectors.toList());
|
||||
|
||||
if (CollectionUtils.isNotEmpty(dimensionMetrics)) {
|
||||
return dimensionMetrics;
|
||||
} else {
|
||||
return mapResults.stream().limit(optimizationConfig.getOneDetectionSize()).collect(Collectors.toList());
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -2,9 +2,10 @@ package com.tencent.supersonic.chat.mapper;
|
||||
|
||||
import com.google.common.collect.Lists;
|
||||
import com.hankcs.hanlp.seg.common.Term;
|
||||
import com.tencent.supersonic.chat.api.pojo.QueryContext;
|
||||
import com.tencent.supersonic.chat.api.pojo.request.QueryReq;
|
||||
import com.tencent.supersonic.common.pojo.enums.DictWordType;
|
||||
import com.tencent.supersonic.knowledge.dictionary.MapResult;
|
||||
import com.tencent.supersonic.knowledge.dictionary.HanlpMapResult;
|
||||
import com.tencent.supersonic.knowledge.service.SearchService;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
@@ -20,17 +21,16 @@ import org.springframework.stereotype.Service;
|
||||
* match strategy implement
|
||||
*/
|
||||
@Service
|
||||
public class SearchMatchStrategy implements MatchStrategy {
|
||||
public class SearchMatchStrategy extends BaseMatchStrategy<HanlpMapResult> {
|
||||
|
||||
private static final int SEARCH_SIZE = 3;
|
||||
|
||||
@Override
|
||||
public Map<MatchText, List<MapResult>> match(QueryReq queryReq, List<Term> originals, Set<Long> detectModelIds) {
|
||||
public Map<MatchText, List<HanlpMapResult>> match(QueryContext queryContext, List<Term> originals,
|
||||
Set<Long> detectModelIds) {
|
||||
QueryReq queryReq = queryContext.getRequest();
|
||||
String text = queryReq.getQueryText();
|
||||
Map<Integer, Integer> regOffsetToLength = originals.stream()
|
||||
.filter(entry -> !entry.nature.toString().startsWith(DictWordType.NATURE_SPILT))
|
||||
.collect(Collectors.toMap(Term::getOffset, value -> value.word.length(),
|
||||
(value1, value2) -> value2));
|
||||
Map<Integer, Integer> regOffsetToLength = getRegOffsetToLength(originals);
|
||||
|
||||
List<Integer> detectIndexList = Lists.newArrayList();
|
||||
|
||||
@@ -46,19 +46,19 @@ public class SearchMatchStrategy implements MatchStrategy {
|
||||
index++;
|
||||
}
|
||||
}
|
||||
Map<MatchText, List<MapResult>> regTextMap = new ConcurrentHashMap<>();
|
||||
Map<MatchText, List<HanlpMapResult>> regTextMap = new ConcurrentHashMap<>();
|
||||
detectIndexList.stream().parallel().forEach(detectIndex -> {
|
||||
String regText = text.substring(0, detectIndex);
|
||||
String detectSegment = text.substring(detectIndex);
|
||||
|
||||
if (StringUtils.isNotEmpty(detectSegment)) {
|
||||
List<MapResult> mapResults = SearchService.prefixSearch(detectSegment,
|
||||
List<HanlpMapResult> hanlpMapResults = SearchService.prefixSearch(detectSegment,
|
||||
SearchService.SEARCH_SIZE, queryReq.getAgentId(), detectModelIds);
|
||||
List<MapResult> suffixMapResults = SearchService.suffixSearch(detectSegment, SEARCH_SIZE,
|
||||
queryReq.getAgentId(), detectModelIds);
|
||||
mapResults.addAll(suffixMapResults);
|
||||
List<HanlpMapResult> suffixHanlpMapResults = SearchService.suffixSearch(
|
||||
detectSegment, SEARCH_SIZE, queryReq.getAgentId(), detectModelIds);
|
||||
hanlpMapResults.addAll(suffixHanlpMapResults);
|
||||
// remove entity name where search
|
||||
mapResults = mapResults.stream().filter(entry -> {
|
||||
hanlpMapResults = hanlpMapResults.stream().filter(entry -> {
|
||||
List<String> natures = entry.getNatures().stream()
|
||||
.filter(nature -> !nature.endsWith(DictWordType.ENTITY.getType()))
|
||||
.collect(Collectors.toList());
|
||||
@@ -71,10 +71,27 @@ public class SearchMatchStrategy implements MatchStrategy {
|
||||
.regText(regText)
|
||||
.detectSegment(detectSegment)
|
||||
.build();
|
||||
regTextMap.put(matchText, mapResults);
|
||||
regTextMap.put(matchText, hanlpMapResults);
|
||||
}
|
||||
}
|
||||
);
|
||||
return regTextMap;
|
||||
}
|
||||
|
||||
@Override
|
||||
public boolean needDelete(HanlpMapResult oneRoundResult, HanlpMapResult existResult) {
|
||||
return false;
|
||||
}
|
||||
|
||||
@Override
|
||||
public String getMapKey(HanlpMapResult a) {
|
||||
return null;
|
||||
}
|
||||
|
||||
@Override
|
||||
public void detectByStep(QueryContext queryContext, Set<HanlpMapResult> results, Set<Long> detectModelIds,
|
||||
Integer startIndex,
|
||||
Integer i, int offset) {
|
||||
|
||||
}
|
||||
}
|
||||
|
||||
@@ -0,0 +1,71 @@
|
||||
package com.tencent.supersonic.chat.parser;
|
||||
|
||||
import com.alibaba.fastjson.JSON;
|
||||
import com.tencent.supersonic.chat.config.LLMParserConfig;
|
||||
import com.tencent.supersonic.chat.parser.plugin.function.FunctionCallConfig;
|
||||
import com.tencent.supersonic.chat.parser.plugin.function.FunctionReq;
|
||||
import com.tencent.supersonic.chat.parser.plugin.function.FunctionResp;
|
||||
import com.tencent.supersonic.chat.query.llm.s2sql.LLMReq;
|
||||
import com.tencent.supersonic.chat.query.llm.s2sql.LLMResp;
|
||||
import com.tencent.supersonic.common.util.ContextUtils;
|
||||
import com.tencent.supersonic.common.util.JsonUtil;
|
||||
import java.net.URI;
|
||||
import java.net.URL;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.springframework.http.HttpEntity;
|
||||
import org.springframework.http.HttpHeaders;
|
||||
import org.springframework.http.HttpMethod;
|
||||
import org.springframework.http.MediaType;
|
||||
import org.springframework.http.ResponseEntity;
|
||||
import org.springframework.web.client.RestTemplate;
|
||||
import org.springframework.web.util.UriComponentsBuilder;
|
||||
|
||||
@Slf4j
|
||||
public class HttpLLMInterpreter implements LLMInterpreter {
|
||||
|
||||
public LLMResp query2sql(LLMReq llmReq, String modelClusterKey) {
|
||||
|
||||
long startTime = System.currentTimeMillis();
|
||||
log.info("requestLLM request, modelId:{},llmReq:{}", modelClusterKey, llmReq);
|
||||
try {
|
||||
LLMParserConfig llmParserConfig = ContextUtils.getBean(LLMParserConfig.class);
|
||||
|
||||
URL url = new URL(new URL(llmParserConfig.getUrl()), llmParserConfig.getQueryToSqlPath());
|
||||
HttpHeaders headers = new HttpHeaders();
|
||||
headers.setContentType(MediaType.APPLICATION_JSON);
|
||||
HttpEntity<String> entity = new HttpEntity<>(JsonUtil.toString(llmReq), headers);
|
||||
RestTemplate restTemplate = ContextUtils.getBean(RestTemplate.class);
|
||||
ResponseEntity<LLMResp> responseEntity = restTemplate.exchange(url.toString(), HttpMethod.POST, entity,
|
||||
LLMResp.class);
|
||||
|
||||
log.info("requestLLM response,cost:{}, questUrl:{} \n entity:{} \n body:{}",
|
||||
System.currentTimeMillis() - startTime, url, entity, responseEntity.getBody());
|
||||
return responseEntity.getBody();
|
||||
} catch (Exception e) {
|
||||
log.error("requestLLM error", e);
|
||||
}
|
||||
return null;
|
||||
}
|
||||
|
||||
public FunctionResp requestFunction(FunctionReq functionReq) {
|
||||
FunctionCallConfig functionCallInfoConfig = ContextUtils.getBean(FunctionCallConfig.class);
|
||||
String url = functionCallInfoConfig.getUrl() + functionCallInfoConfig.getPluginSelectPath();
|
||||
HttpHeaders headers = new HttpHeaders();
|
||||
long startTime = System.currentTimeMillis();
|
||||
headers.setContentType(MediaType.APPLICATION_JSON);
|
||||
HttpEntity<String> entity = new HttpEntity<>(JSON.toJSONString(functionReq), headers);
|
||||
URI requestUrl = UriComponentsBuilder.fromHttpUrl(url).build().encode().toUri();
|
||||
RestTemplate restTemplate = ContextUtils.getBean(RestTemplate.class);
|
||||
try {
|
||||
log.info("requestFunction functionReq:{}", JsonUtil.toString(functionReq));
|
||||
ResponseEntity<FunctionResp> responseEntity = restTemplate.exchange(requestUrl, HttpMethod.POST, entity,
|
||||
FunctionResp.class);
|
||||
log.info("requestFunction responseEntity:{},cost:{}", responseEntity,
|
||||
System.currentTimeMillis() - startTime);
|
||||
return responseEntity.getBody();
|
||||
} catch (Exception e) {
|
||||
log.error("requestFunction error", e);
|
||||
}
|
||||
return null;
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,18 @@
|
||||
package com.tencent.supersonic.chat.parser;
|
||||
|
||||
import com.tencent.supersonic.chat.parser.plugin.function.FunctionReq;
|
||||
import com.tencent.supersonic.chat.parser.plugin.function.FunctionResp;
|
||||
import com.tencent.supersonic.chat.query.llm.s2sql.LLMReq;
|
||||
import com.tencent.supersonic.chat.query.llm.s2sql.LLMResp;
|
||||
|
||||
/**
|
||||
* Unified interpreter for invoking the llm layer.
|
||||
*/
|
||||
public interface LLMInterpreter {
|
||||
|
||||
|
||||
LLMResp query2sql(LLMReq llmReq, String modelClusterKey);
|
||||
|
||||
FunctionResp requestFunction(FunctionReq functionReq);
|
||||
|
||||
}
|
||||
@@ -0,0 +1,85 @@
|
||||
package com.tencent.supersonic.chat.parser;
|
||||
|
||||
import com.tencent.supersonic.auth.api.authentication.pojo.User;
|
||||
import com.tencent.supersonic.chat.api.component.SemanticParser;
|
||||
import com.tencent.supersonic.chat.api.component.SemanticQuery;
|
||||
import com.tencent.supersonic.chat.api.pojo.ChatContext;
|
||||
import com.tencent.supersonic.chat.api.pojo.QueryContext;
|
||||
import com.tencent.supersonic.chat.api.pojo.SchemaElement;
|
||||
import com.tencent.supersonic.chat.api.pojo.SemanticParseInfo;
|
||||
import com.tencent.supersonic.chat.api.pojo.SemanticSchema;
|
||||
import com.tencent.supersonic.chat.api.pojo.response.SqlInfo;
|
||||
import com.tencent.supersonic.chat.query.llm.s2sql.S2SQLQuery;
|
||||
import com.tencent.supersonic.chat.query.rule.RuleSemanticQuery;
|
||||
import com.tencent.supersonic.chat.service.SemanticService;
|
||||
import com.tencent.supersonic.common.pojo.QueryType;
|
||||
import com.tencent.supersonic.common.util.ContextUtils;
|
||||
import com.tencent.supersonic.common.util.jsqlparser.SqlParserSelectHelper;
|
||||
import com.tencent.supersonic.knowledge.service.SchemaService;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.apache.commons.collections4.CollectionUtils;
|
||||
import org.apache.commons.lang3.StringUtils;
|
||||
|
||||
import java.util.List;
|
||||
import java.util.Objects;
|
||||
import java.util.Set;
|
||||
import java.util.stream.Collectors;
|
||||
|
||||
/**
|
||||
* Query type parser, determine if the query is a metric query, an entity query,
|
||||
* or another type of query.
|
||||
*/
|
||||
@Slf4j
|
||||
public class QueryTypeParser implements SemanticParser {
|
||||
|
||||
@Override
|
||||
public void parse(QueryContext queryContext, ChatContext chatContext) {
|
||||
|
||||
List<SemanticQuery> candidateQueries = queryContext.getCandidateQueries();
|
||||
User user = queryContext.getRequest().getUser();
|
||||
|
||||
for (SemanticQuery semanticQuery : candidateQueries) {
|
||||
// 1.init S2SQL
|
||||
semanticQuery.initS2Sql(user);
|
||||
// 2.set queryType
|
||||
QueryType queryType = getQueryType(semanticQuery);
|
||||
semanticQuery.getParseInfo().setQueryType(queryType);
|
||||
}
|
||||
}
|
||||
|
||||
private QueryType getQueryType(SemanticQuery semanticQuery) {
|
||||
SemanticParseInfo parseInfo = semanticQuery.getParseInfo();
|
||||
SqlInfo sqlInfo = parseInfo.getSqlInfo();
|
||||
if (Objects.isNull(sqlInfo) || StringUtils.isBlank(sqlInfo.getS2SQL())) {
|
||||
return QueryType.OTHER;
|
||||
}
|
||||
//1. entity queryType
|
||||
Set<Long> modelIds = parseInfo.getModel().getModelIds();
|
||||
if (semanticQuery instanceof RuleSemanticQuery || semanticQuery instanceof S2SQLQuery) {
|
||||
//If all the fields in the SELECT statement are of tag type.
|
||||
List<String> selectFields = SqlParserSelectHelper.getSelectFields(sqlInfo.getS2SQL());
|
||||
SemanticService semanticService = ContextUtils.getBean(SemanticService.class);
|
||||
SemanticSchema semanticSchema = semanticService.getSemanticSchema();
|
||||
if (CollectionUtils.isNotEmpty(selectFields)) {
|
||||
Set<String> tags = semanticSchema.getTags(modelIds).stream().map(SchemaElement::getName)
|
||||
.collect(Collectors.toSet());
|
||||
if (CollectionUtils.isNotEmpty(tags) && tags.containsAll(selectFields)) {
|
||||
return QueryType.TAG;
|
||||
}
|
||||
}
|
||||
}
|
||||
//2. metric queryType
|
||||
List<String> selectFields = SqlParserSelectHelper.getSelectFields(sqlInfo.getS2SQL());
|
||||
SemanticSchema semanticSchema = ContextUtils.getBean(SchemaService.class).getSemanticSchema();
|
||||
List<SchemaElement> metrics = semanticSchema.getMetrics(modelIds);
|
||||
if (CollectionUtils.isNotEmpty(metrics)) {
|
||||
Set<String> metricNameSet = metrics.stream().map(SchemaElement::getName).collect(Collectors.toSet());
|
||||
boolean containMetric = selectFields.stream().anyMatch(metricNameSet::contains);
|
||||
if (containMetric) {
|
||||
return QueryType.METRIC;
|
||||
}
|
||||
}
|
||||
return QueryType.OTHER;
|
||||
}
|
||||
|
||||
}
|
||||
@@ -5,7 +5,7 @@ import com.tencent.supersonic.chat.api.component.SemanticQuery;
|
||||
import com.tencent.supersonic.chat.api.pojo.QueryContext;
|
||||
import com.tencent.supersonic.chat.api.pojo.SemanticParseInfo;
|
||||
import com.tencent.supersonic.chat.config.OptimizationConfig;
|
||||
import com.tencent.supersonic.chat.query.llm.dsl.DslQuery;
|
||||
import com.tencent.supersonic.chat.query.llm.s2sql.S2SQLQuery;
|
||||
import com.tencent.supersonic.common.util.ContextUtils;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
|
||||
@@ -20,7 +20,7 @@ public class SatisfactionChecker {
|
||||
// check all the parse info in candidate
|
||||
public static boolean check(QueryContext queryContext) {
|
||||
for (SemanticQuery query : queryContext.getCandidateQueries()) {
|
||||
if (query.getQueryMode().equals(DslQuery.QUERY_MODE)) {
|
||||
if (query.getQueryMode().equals(S2SQLQuery.QUERY_MODE)) {
|
||||
continue;
|
||||
}
|
||||
if (checkThreshold(queryContext.getRequest().getQueryText(), query.getParseInfo())) {
|
||||
|
||||
@@ -1,25 +0,0 @@
|
||||
package com.tencent.supersonic.chat.parser.llm.dsl;
|
||||
|
||||
import com.tencent.supersonic.chat.agent.tool.DslTool;
|
||||
import com.tencent.supersonic.chat.api.pojo.request.QueryReq;
|
||||
import com.tencent.supersonic.chat.query.llm.dsl.LLMReq;
|
||||
import com.tencent.supersonic.chat.query.llm.dsl.LLMResp;
|
||||
import lombok.AllArgsConstructor;
|
||||
import lombok.Builder;
|
||||
import lombok.Data;
|
||||
import lombok.NoArgsConstructor;
|
||||
|
||||
@Data
|
||||
@Builder
|
||||
@AllArgsConstructor
|
||||
@NoArgsConstructor
|
||||
public class DSLParseResult {
|
||||
|
||||
private LLMReq llmReq;
|
||||
|
||||
private LLMResp llmResp;
|
||||
|
||||
private QueryReq request;
|
||||
|
||||
private DslTool dslTool;
|
||||
}
|
||||
@@ -1,435 +0,0 @@
|
||||
package com.tencent.supersonic.chat.parser.llm.dsl;
|
||||
|
||||
import com.google.common.collect.Lists;
|
||||
import com.tencent.supersonic.chat.agent.tool.AgentToolType;
|
||||
import com.tencent.supersonic.chat.agent.tool.DslTool;
|
||||
import com.tencent.supersonic.chat.api.component.SemanticCorrector;
|
||||
import com.tencent.supersonic.chat.api.component.SemanticParser;
|
||||
import com.tencent.supersonic.chat.api.pojo.ChatContext;
|
||||
import com.tencent.supersonic.chat.api.pojo.QueryContext;
|
||||
import com.tencent.supersonic.chat.api.pojo.SchemaElement;
|
||||
import com.tencent.supersonic.chat.api.pojo.SchemaElementMatch;
|
||||
import com.tencent.supersonic.chat.api.pojo.SchemaElementType;
|
||||
import com.tencent.supersonic.chat.api.pojo.SemanticCorrectInfo;
|
||||
import com.tencent.supersonic.chat.api.pojo.SemanticParseInfo;
|
||||
import com.tencent.supersonic.chat.api.pojo.SemanticSchema;
|
||||
import com.tencent.supersonic.chat.api.pojo.request.QueryFilter;
|
||||
import com.tencent.supersonic.chat.api.pojo.request.QueryReq;
|
||||
import com.tencent.supersonic.chat.config.LLMConfig;
|
||||
import com.tencent.supersonic.chat.corrector.BaseSemanticCorrector;
|
||||
import com.tencent.supersonic.chat.parser.SatisfactionChecker;
|
||||
import com.tencent.supersonic.chat.parser.plugin.function.ModelResolver;
|
||||
import com.tencent.supersonic.chat.query.QueryManager;
|
||||
import com.tencent.supersonic.chat.query.llm.dsl.DslQuery;
|
||||
import com.tencent.supersonic.chat.query.llm.dsl.LLMReq;
|
||||
import com.tencent.supersonic.chat.query.llm.dsl.LLMReq.ElementValue;
|
||||
import com.tencent.supersonic.chat.query.llm.dsl.LLMResp;
|
||||
import com.tencent.supersonic.chat.query.plugin.PluginSemanticQuery;
|
||||
import com.tencent.supersonic.chat.service.AgentService;
|
||||
import com.tencent.supersonic.chat.utils.ComponentFactory;
|
||||
import com.tencent.supersonic.common.pojo.Constants;
|
||||
import com.tencent.supersonic.common.pojo.DateConf;
|
||||
import com.tencent.supersonic.common.pojo.DateConf.DateMode;
|
||||
import com.tencent.supersonic.common.util.ContextUtils;
|
||||
import com.tencent.supersonic.common.util.JsonUtil;
|
||||
import com.tencent.supersonic.common.util.jsqlparser.FilterExpression;
|
||||
import com.tencent.supersonic.common.util.jsqlparser.SqlParserSelectHelper;
|
||||
import com.tencent.supersonic.knowledge.service.SchemaService;
|
||||
import com.tencent.supersonic.semantic.api.model.enums.TimeDimensionEnum;
|
||||
import com.tencent.supersonic.semantic.api.query.enums.FilterOperatorEnum;
|
||||
import java.util.ArrayList;
|
||||
import java.util.Arrays;
|
||||
import java.util.HashMap;
|
||||
import java.util.HashSet;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
import java.util.Objects;
|
||||
import java.util.Optional;
|
||||
import java.util.Set;
|
||||
import java.util.function.Function;
|
||||
import java.util.stream.Collectors;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.apache.commons.lang3.StringUtils;
|
||||
import org.springframework.http.HttpEntity;
|
||||
import org.springframework.http.HttpHeaders;
|
||||
import org.springframework.http.HttpMethod;
|
||||
import org.springframework.http.MediaType;
|
||||
import org.springframework.http.ResponseEntity;
|
||||
import org.springframework.util.CollectionUtils;
|
||||
import org.springframework.web.client.RestTemplate;
|
||||
|
||||
@Slf4j
|
||||
public class LLMDslParser implements SemanticParser {
|
||||
|
||||
public static final double function_bonus_threshold = 201;
|
||||
|
||||
@Override
|
||||
public void parse(QueryContext queryCtx, ChatContext chatCtx) {
|
||||
QueryReq request = queryCtx.getRequest();
|
||||
LLMConfig llmConfig = ContextUtils.getBean(LLMConfig.class);
|
||||
if (StringUtils.isEmpty(llmConfig.getUrl())) {
|
||||
log.info("llm url is empty, skip dsl parser, llmConfig:{}", llmConfig);
|
||||
return;
|
||||
}
|
||||
if (SatisfactionChecker.check(queryCtx)) {
|
||||
log.info("skip dsl parser, queryText:{}", request.getQueryText());
|
||||
return;
|
||||
}
|
||||
try {
|
||||
Long modelId = getModelId(queryCtx, chatCtx, request.getAgentId());
|
||||
if (Objects.isNull(modelId) || modelId <= 0) {
|
||||
return;
|
||||
}
|
||||
|
||||
DslTool dslTool = getDslTool(request, modelId);
|
||||
if (Objects.isNull(dslTool)) {
|
||||
log.info("no dsl tool in this agent, skip dsl parser");
|
||||
return;
|
||||
}
|
||||
|
||||
LLMReq llmReq = getLlmReq(queryCtx, modelId);
|
||||
LLMResp llmResp = requestLLM(llmReq, modelId, llmConfig);
|
||||
|
||||
if (Objects.isNull(llmResp)) {
|
||||
return;
|
||||
}
|
||||
DSLParseResult dslParseResult = DSLParseResult.builder().request(request)
|
||||
.dslTool(dslTool).llmReq(llmReq).llmResp(llmResp).build();
|
||||
|
||||
SemanticParseInfo parseInfo = getParseInfo(queryCtx, modelId, dslTool, dslParseResult);
|
||||
|
||||
SemanticCorrectInfo semanticCorrectInfo = getCorrectorSql(queryCtx, parseInfo, llmResp.getSqlOutput());
|
||||
|
||||
llmResp.setCorrectorSql(semanticCorrectInfo.getSql());
|
||||
|
||||
updateParseInfo(semanticCorrectInfo, modelId, parseInfo);
|
||||
|
||||
} catch (Exception e) {
|
||||
log.error("LLMDSLParser error", e);
|
||||
}
|
||||
}
|
||||
|
||||
private Set<SchemaElement> getElements(Long modelId, List<String> allFields, List<SchemaElement> elements) {
|
||||
return elements.stream()
|
||||
.filter(schemaElement -> modelId.equals(schemaElement.getModel())
|
||||
&& allFields.contains(schemaElement.getBizName())
|
||||
).collect(Collectors.toSet());
|
||||
}
|
||||
|
||||
private List<String> getFieldsExceptDate(List<String> allFields) {
|
||||
if (CollectionUtils.isEmpty(allFields)) {
|
||||
return new ArrayList<>();
|
||||
}
|
||||
return allFields.stream()
|
||||
.filter(entry -> !TimeDimensionEnum.getNameList().contains(entry))
|
||||
.collect(Collectors.toList());
|
||||
}
|
||||
|
||||
public void updateParseInfo(SemanticCorrectInfo semanticCorrectInfo, Long modelId, SemanticParseInfo parseInfo) {
|
||||
|
||||
String correctorSql = semanticCorrectInfo.getPreSql();
|
||||
if (StringUtils.isEmpty(correctorSql)) {
|
||||
correctorSql = semanticCorrectInfo.getSql();
|
||||
}
|
||||
List<FilterExpression> expressions = SqlParserSelectHelper.getFilterExpression(correctorSql);
|
||||
//set dataInfo
|
||||
try {
|
||||
if (!CollectionUtils.isEmpty(expressions)) {
|
||||
DateConf dateInfo = getDateInfo(expressions);
|
||||
parseInfo.setDateInfo(dateInfo);
|
||||
}
|
||||
} catch (Exception e) {
|
||||
log.error("set dateInfo error :", e);
|
||||
}
|
||||
|
||||
//set filter
|
||||
try {
|
||||
Map<String, SchemaElement> bizNameToElement = getBizNameToElement(modelId);
|
||||
List<QueryFilter> result = getDimensionFilter(bizNameToElement, expressions);
|
||||
parseInfo.getDimensionFilters().addAll(result);
|
||||
} catch (Exception e) {
|
||||
log.error("set dimensionFilter error :", e);
|
||||
}
|
||||
|
||||
SemanticSchema semanticSchema = ContextUtils.getBean(SchemaService.class).getSemanticSchema();
|
||||
|
||||
if (Objects.isNull(semanticSchema)) {
|
||||
return;
|
||||
}
|
||||
List<String> allFields = getFieldsExceptDate(SqlParserSelectHelper.getAllFields(semanticCorrectInfo.getSql()));
|
||||
|
||||
Set<SchemaElement> metrics = getElements(modelId, allFields, semanticSchema.getMetrics());
|
||||
parseInfo.setMetrics(metrics);
|
||||
|
||||
if (SqlParserSelectHelper.hasAggregateFunction(semanticCorrectInfo.getSql())) {
|
||||
parseInfo.setNativeQuery(false);
|
||||
List<String> groupByFields = SqlParserSelectHelper.getGroupByFields(semanticCorrectInfo.getSql());
|
||||
List<String> groupByDimensions = getFieldsExceptDate(groupByFields);
|
||||
parseInfo.setDimensions(getElements(modelId, groupByDimensions, semanticSchema.getDimensions()));
|
||||
} else {
|
||||
parseInfo.setNativeQuery(true);
|
||||
List<String> selectFields = SqlParserSelectHelper.getSelectFields(semanticCorrectInfo.getSql());
|
||||
List<String> selectDimensions = getFieldsExceptDate(selectFields);
|
||||
parseInfo.setDimensions(getElements(modelId, selectDimensions, semanticSchema.getDimensions()));
|
||||
}
|
||||
}
|
||||
|
||||
private List<QueryFilter> getDimensionFilter(Map<String, SchemaElement> bizNameToElement,
|
||||
List<FilterExpression> filterExpressions) {
|
||||
List<QueryFilter> result = Lists.newArrayList();
|
||||
for (FilterExpression expression : filterExpressions) {
|
||||
QueryFilter dimensionFilter = new QueryFilter();
|
||||
dimensionFilter.setValue(expression.getFieldValue());
|
||||
String bizName = expression.getFieldName();
|
||||
SchemaElement schemaElement = bizNameToElement.get(bizName);
|
||||
if (Objects.isNull(schemaElement)) {
|
||||
continue;
|
||||
}
|
||||
String fieldName = schemaElement.getName();
|
||||
dimensionFilter.setName(fieldName);
|
||||
dimensionFilter.setBizName(bizName);
|
||||
dimensionFilter.setElementID(schemaElement.getId());
|
||||
|
||||
FilterOperatorEnum operatorEnum = FilterOperatorEnum.getSqlOperator(expression.getOperator());
|
||||
dimensionFilter.setOperator(operatorEnum);
|
||||
result.add(dimensionFilter);
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
private DateConf getDateInfo(List<FilterExpression> filterExpressions) {
|
||||
List<FilterExpression> dateExpressions = filterExpressions.stream()
|
||||
.filter(expression -> {
|
||||
List<String> nameList = TimeDimensionEnum.getNameList();
|
||||
if (StringUtils.isEmpty(expression.getFieldName())) {
|
||||
return false;
|
||||
}
|
||||
return nameList.contains(expression.getFieldName().toLowerCase());
|
||||
}).collect(Collectors.toList());
|
||||
if (CollectionUtils.isEmpty(dateExpressions)) {
|
||||
return new DateConf();
|
||||
}
|
||||
DateConf dateInfo = new DateConf();
|
||||
dateInfo.setDateMode(DateMode.BETWEEN);
|
||||
FilterExpression firstExpression = dateExpressions.get(0);
|
||||
|
||||
FilterOperatorEnum firstOperator = FilterOperatorEnum.getSqlOperator(firstExpression.getOperator());
|
||||
if (FilterOperatorEnum.EQUALS.equals(firstOperator) && Objects.nonNull(firstExpression.getFieldValue())) {
|
||||
dateInfo.setStartDate(firstExpression.getFieldValue().toString());
|
||||
dateInfo.setEndDate(firstExpression.getFieldValue().toString());
|
||||
dateInfo.setDateMode(DateMode.BETWEEN);
|
||||
return dateInfo;
|
||||
}
|
||||
if (containOperators(firstExpression, firstOperator, FilterOperatorEnum.GREATER_THAN,
|
||||
FilterOperatorEnum.GREATER_THAN_EQUALS)) {
|
||||
dateInfo.setStartDate(firstExpression.getFieldValue().toString());
|
||||
if (hasSecondDate(dateExpressions)) {
|
||||
dateInfo.setEndDate(dateExpressions.get(1).getFieldValue().toString());
|
||||
}
|
||||
}
|
||||
if (containOperators(firstExpression, firstOperator, FilterOperatorEnum.MINOR_THAN,
|
||||
FilterOperatorEnum.MINOR_THAN_EQUALS)) {
|
||||
dateInfo.setEndDate(firstExpression.getFieldValue().toString());
|
||||
if (hasSecondDate(dateExpressions)) {
|
||||
dateInfo.setStartDate(dateExpressions.get(1).getFieldValue().toString());
|
||||
}
|
||||
}
|
||||
return dateInfo;
|
||||
}
|
||||
|
||||
private boolean containOperators(FilterExpression expression, FilterOperatorEnum firstOperator,
|
||||
FilterOperatorEnum... operatorEnums) {
|
||||
return (Arrays.asList(operatorEnums).contains(firstOperator) && Objects.nonNull(expression.getFieldValue()));
|
||||
}
|
||||
|
||||
private boolean hasSecondDate(List<FilterExpression> dateExpressions) {
|
||||
return dateExpressions.size() > 1 && Objects.nonNull(dateExpressions.get(1).getFieldValue());
|
||||
}
|
||||
|
||||
private SemanticCorrectInfo getCorrectorSql(QueryContext queryCtx, SemanticParseInfo parseInfo, String sql) {
|
||||
|
||||
SemanticCorrectInfo correctInfo = SemanticCorrectInfo.builder()
|
||||
.queryFilters(queryCtx.getRequest().getQueryFilters()).sql(sql)
|
||||
.parseInfo(parseInfo).build();
|
||||
|
||||
List<SemanticCorrector> dslCorrections = ComponentFactory.getSqlCorrections();
|
||||
|
||||
dslCorrections.forEach(dslCorrection -> {
|
||||
try {
|
||||
dslCorrection.correct(correctInfo);
|
||||
log.info("sqlCorrection:{} sql:{}", dslCorrection.getClass().getSimpleName(), correctInfo.getSql());
|
||||
} catch (Exception e) {
|
||||
log.error("sqlCorrection:{} correct error,correctInfo:{}", dslCorrection, correctInfo, e);
|
||||
}
|
||||
});
|
||||
return correctInfo;
|
||||
}
|
||||
|
||||
private SemanticParseInfo getParseInfo(QueryContext queryCtx, Long modelId, DslTool dslTool,
|
||||
DSLParseResult dslParseResult) {
|
||||
PluginSemanticQuery semanticQuery = QueryManager.createPluginQuery(DslQuery.QUERY_MODE);
|
||||
SemanticParseInfo parseInfo = semanticQuery.getParseInfo();
|
||||
parseInfo.getElementMatches().addAll(queryCtx.getMapInfo().getMatchedElements(modelId));
|
||||
|
||||
Map<String, Object> properties = new HashMap<>();
|
||||
properties.put(Constants.CONTEXT, dslParseResult);
|
||||
properties.put("type", "internal");
|
||||
properties.put("name", dslTool.getName());
|
||||
|
||||
parseInfo.setProperties(properties);
|
||||
parseInfo.setScore(function_bonus_threshold);
|
||||
parseInfo.setQueryMode(semanticQuery.getQueryMode());
|
||||
|
||||
SemanticSchema semanticSchema = ContextUtils.getBean(SchemaService.class).getSemanticSchema();
|
||||
Map<Long, String> modelIdToName = semanticSchema.getModelIdToName();
|
||||
|
||||
SchemaElement model = new SchemaElement();
|
||||
model.setModel(modelId);
|
||||
model.setId(modelId);
|
||||
model.setName(modelIdToName.get(modelId));
|
||||
parseInfo.setModel(model);
|
||||
queryCtx.getCandidateQueries().add(semanticQuery);
|
||||
return parseInfo;
|
||||
}
|
||||
|
||||
private DslTool getDslTool(QueryReq request, Long modelId) {
|
||||
AgentService agentService = ContextUtils.getBean(AgentService.class);
|
||||
List<DslTool> dslTools = agentService.getDslTools(request.getAgentId(), AgentToolType.DSL);
|
||||
Optional<DslTool> dslToolOptional = dslTools.stream()
|
||||
.filter(tool -> {
|
||||
List<Long> modelIds = tool.getModelIds();
|
||||
if (agentService.containsAllModel(new HashSet<>(modelIds))) {
|
||||
return true;
|
||||
}
|
||||
return modelIds.contains(modelId);
|
||||
})
|
||||
.findFirst();
|
||||
return dslToolOptional.orElse(null);
|
||||
}
|
||||
|
||||
private Long getModelId(QueryContext queryCtx, ChatContext chatCtx, Integer agentId) {
|
||||
AgentService agentService = ContextUtils.getBean(AgentService.class);
|
||||
Set<Long> distinctModelIds = agentService.getDslToolsModelIds(agentId, AgentToolType.DSL);
|
||||
if (agentService.containsAllModel(distinctModelIds)) {
|
||||
distinctModelIds = new HashSet<>();
|
||||
}
|
||||
ModelResolver modelResolver = ComponentFactory.getModelResolver();
|
||||
Long modelId = modelResolver.resolve(queryCtx, chatCtx, distinctModelIds);
|
||||
log.info("resolve modelId:{},dslModels:{}", modelId, distinctModelIds);
|
||||
return modelId;
|
||||
}
|
||||
|
||||
private LLMResp requestLLM(LLMReq llmReq, Long modelId, LLMConfig llmConfig) {
|
||||
String questUrl = llmConfig.getUrl() + llmConfig.getQueryToSqlPath();
|
||||
long startTime = System.currentTimeMillis();
|
||||
log.info("requestLLM request, modelId:{},llmReq:{}", modelId, llmReq);
|
||||
RestTemplate restTemplate = ContextUtils.getBean(RestTemplate.class);
|
||||
try {
|
||||
HttpHeaders headers = new HttpHeaders();
|
||||
headers.setContentType(MediaType.APPLICATION_JSON);
|
||||
HttpEntity<String> entity = new HttpEntity<>(JsonUtil.toString(llmReq), headers);
|
||||
ResponseEntity<LLMResp> responseEntity = restTemplate.exchange(questUrl, HttpMethod.POST, entity,
|
||||
LLMResp.class);
|
||||
|
||||
log.info("requestLLM response,cost:{}, questUrl:{} \n entity:{} \n body:{}",
|
||||
System.currentTimeMillis() - startTime, questUrl, entity, responseEntity.getBody());
|
||||
return responseEntity.getBody();
|
||||
} catch (Exception e) {
|
||||
log.error("requestLLM error", e);
|
||||
}
|
||||
return null;
|
||||
}
|
||||
|
||||
private LLMReq getLlmReq(QueryContext queryCtx, Long modelId) {
|
||||
SemanticSchema semanticSchema = ContextUtils.getBean(SchemaService.class).getSemanticSchema();
|
||||
Map<Long, String> modelIdToName = semanticSchema.getModelIdToName();
|
||||
String queryText = queryCtx.getRequest().getQueryText();
|
||||
LLMReq llmReq = new LLMReq();
|
||||
llmReq.setQueryText(queryText);
|
||||
LLMReq.LLMSchema llmSchema = new LLMReq.LLMSchema();
|
||||
llmSchema.setModelName(modelIdToName.get(modelId));
|
||||
llmSchema.setDomainName(modelIdToName.get(modelId));
|
||||
List<String> fieldNameList = getFieldNameList(queryCtx, modelId, semanticSchema);
|
||||
fieldNameList.add(BaseSemanticCorrector.DATE_FIELD);
|
||||
llmSchema.setFieldNameList(fieldNameList);
|
||||
llmReq.setSchema(llmSchema);
|
||||
List<ElementValue> linking = new ArrayList<>();
|
||||
linking.addAll(getValueList(queryCtx, modelId, semanticSchema));
|
||||
llmReq.setLinking(linking);
|
||||
String currentDate = DSLDateHelper.getReferenceDate(modelId);
|
||||
llmReq.setCurrentDate(currentDate);
|
||||
return llmReq;
|
||||
}
|
||||
|
||||
protected List<ElementValue> getValueList(QueryContext queryCtx, Long modelId, SemanticSchema semanticSchema) {
|
||||
Map<Long, String> itemIdToName = getItemIdToName(modelId, semanticSchema);
|
||||
|
||||
List<SchemaElementMatch> matchedElements = queryCtx.getMapInfo().getMatchedElements(modelId);
|
||||
if (CollectionUtils.isEmpty(matchedElements)) {
|
||||
return new ArrayList<>();
|
||||
}
|
||||
Set<ElementValue> valueMatches = matchedElements
|
||||
.stream()
|
||||
.filter(elementMatch -> !elementMatch.isInherited())
|
||||
.filter(schemaElementMatch -> {
|
||||
SchemaElementType type = schemaElementMatch.getElement().getType();
|
||||
return SchemaElementType.VALUE.equals(type) || SchemaElementType.ID.equals(type);
|
||||
})
|
||||
.map(elementMatch -> {
|
||||
ElementValue elementValue = new ElementValue();
|
||||
elementValue.setFieldName(itemIdToName.get(elementMatch.getElement().getId()));
|
||||
elementValue.setFieldValue(elementMatch.getWord());
|
||||
return elementValue;
|
||||
}).collect(Collectors.toSet());
|
||||
return new ArrayList<>(valueMatches);
|
||||
}
|
||||
|
||||
|
||||
protected Map<String, SchemaElement> getBizNameToElement(Long modelId) {
|
||||
SemanticSchema semanticSchema = ContextUtils.getBean(SchemaService.class).getSemanticSchema();
|
||||
List<SchemaElement> dimensions = semanticSchema.getDimensions();
|
||||
List<SchemaElement> metrics = semanticSchema.getMetrics();
|
||||
|
||||
List<SchemaElement> allElements = Lists.newArrayList();
|
||||
allElements.addAll(dimensions);
|
||||
allElements.addAll(metrics);
|
||||
return allElements.stream()
|
||||
.filter(schemaElement -> schemaElement.getModel().equals(modelId))
|
||||
.collect(Collectors.toMap(SchemaElement::getBizName, Function.identity(), (value1, value2) -> value2));
|
||||
}
|
||||
|
||||
|
||||
protected List<String> getFieldNameList(QueryContext queryCtx, Long modelId, SemanticSchema semanticSchema) {
|
||||
Map<Long, String> itemIdToName = getItemIdToName(modelId, semanticSchema);
|
||||
|
||||
List<SchemaElementMatch> matchedElements = queryCtx.getMapInfo().getMatchedElements(modelId);
|
||||
if (CollectionUtils.isEmpty(matchedElements)) {
|
||||
return new ArrayList<>();
|
||||
}
|
||||
Set<String> fieldNameList = matchedElements.stream()
|
||||
.filter(schemaElementMatch -> {
|
||||
SchemaElementType elementType = schemaElementMatch.getElement().getType();
|
||||
return SchemaElementType.METRIC.equals(elementType)
|
||||
|| SchemaElementType.DIMENSION.equals(elementType)
|
||||
|| SchemaElementType.VALUE.equals(elementType);
|
||||
})
|
||||
.map(schemaElementMatch -> {
|
||||
SchemaElementType elementType = schemaElementMatch.getElement().getType();
|
||||
|
||||
if (!SchemaElementType.VALUE.equals(elementType)) {
|
||||
return schemaElementMatch.getWord();
|
||||
}
|
||||
return itemIdToName.get(schemaElementMatch.getElement().getId());
|
||||
})
|
||||
.filter(name -> StringUtils.isNotEmpty(name) && !name.contains("%"))
|
||||
.collect(Collectors.toSet());
|
||||
return new ArrayList<>(fieldNameList);
|
||||
}
|
||||
|
||||
protected Map<Long, String> getItemIdToName(Long modelId, SemanticSchema semanticSchema) {
|
||||
return semanticSchema.getDimensions().stream()
|
||||
.filter(entry -> modelId.equals(entry.getModel()))
|
||||
.collect(Collectors.toMap(SchemaElement::getId, SchemaElement::getName, (value1, value2) -> value2));
|
||||
}
|
||||
|
||||
}
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user