mirror of
https://github.com/tencentmusic/supersonic.git
synced 2025-12-11 20:25:12 +00:00
Compare commits
414 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
f60c1675cd | ||
|
|
1d9b6d6877 | ||
|
|
d8930e8906 | ||
|
|
c68df24375 | ||
|
|
bb1001677d | ||
|
|
7a1cfbcef8 | ||
|
|
67b9c4bf79 | ||
|
|
7cb7697353 | ||
|
|
3e18655c69 | ||
|
|
e7d52f87f0 | ||
|
|
2cd8f8022b | ||
|
|
e08435902a | ||
|
|
b44fa2bf3c | ||
|
|
d7f1f06daf | ||
|
|
4c26e0c972 | ||
|
|
d7fafa361d | ||
|
|
0c69651ef3 | ||
|
|
b5fdbfbbf6 | ||
|
|
33a2688e77 | ||
|
|
6bd97cd8af | ||
|
|
64615cbef9 | ||
|
|
dfb3b59984 | ||
|
|
61641ecb00 | ||
|
|
5016881ce3 | ||
|
|
fe75b3e393 | ||
|
|
3db443f9b1 | ||
|
|
59c21ea19a | ||
|
|
95334441b1 | ||
|
|
276b224c13 | ||
|
|
f03da53d6f | ||
|
|
9201550027 | ||
|
|
c86cd9f901 | ||
|
|
ef8caea9d2 | ||
|
|
6daaff8c30 | ||
|
|
4b00c16eb7 | ||
|
|
4dae84034e | ||
|
|
e6eac03ec6 | ||
|
|
e9a479e2df | ||
|
|
7db1cc270e | ||
|
|
3bf5b86535 | ||
|
|
3d30632b41 | ||
|
|
287a6561ff | ||
|
|
169262cc62 | ||
|
|
fda5a577d6 | ||
|
|
f89be48e98 | ||
|
|
2c7afd0d55 | ||
|
|
2ad0553f6c | ||
|
|
340cb2c835 | ||
|
|
caefa501f2 | ||
|
|
5c96d75d39 | ||
|
|
86c2f96942 | ||
|
|
73899e3174 | ||
|
|
49bb2c6d8b | ||
|
|
9223a4f856 | ||
|
|
f3f60af231 | ||
|
|
3cdfcae01c | ||
|
|
0c6efada43 | ||
|
|
d79f73eab6 | ||
|
|
3ae720ef30 | ||
|
|
221e88de0f | ||
|
|
23d926f195 | ||
|
|
97b11ec244 | ||
|
|
899047dbd1 | ||
|
|
cb4b91878f | ||
|
|
6af661459c | ||
|
|
0e0ba51750 | ||
|
|
a5c32ac064 | ||
|
|
abbe8c84a1 | ||
|
|
6c0f88d8b5 | ||
|
|
68ada561ac | ||
|
|
18b52ec742 | ||
|
|
ca8d7d89c1 | ||
|
|
e6ab7cb5ff | ||
|
|
9679169e6f | ||
|
|
ed0f856438 | ||
|
|
9aa5c93d9d | ||
|
|
b45592c009 | ||
|
|
6e0fa95e6f | ||
|
|
94f310d17f | ||
|
|
2bc29d64a4 | ||
|
|
c220ca69c2 | ||
|
|
4280aad0a7 | ||
|
|
c98d15059b | ||
|
|
a862a83272 | ||
|
|
c6d59701db | ||
|
|
39a85dc4ed | ||
|
|
507c02a8fd | ||
|
|
380597f0c3 | ||
|
|
e469c449b4 | ||
|
|
f8bdb8a4b4 | ||
|
|
d76216a2ec | ||
|
|
82cfb3050d | ||
|
|
57f7d0c67d | ||
|
|
c11a242f34 | ||
|
|
576fad5fb1 | ||
|
|
8171d754e0 | ||
|
|
6be0f02c75 | ||
|
|
95e3138ab2 | ||
|
|
3a30a1a317 | ||
|
|
46733d1728 | ||
|
|
b6734d99e1 | ||
|
|
9cb01149f8 | ||
|
|
db88127da9 | ||
|
|
0e492ef402 | ||
|
|
4d095f9676 | ||
|
|
7c7ccadcfd | ||
|
|
02b9dc6947 | ||
|
|
4222d7e2b5 | ||
|
|
b5daf04c96 | ||
|
|
d26089a249 | ||
|
|
627683d437 | ||
|
|
87e222eecc | ||
|
|
667272b103 | ||
|
|
3f08d95aaa | ||
|
|
27ebda3439 | ||
|
|
52eca178d3 | ||
|
|
f5a064aaad | ||
|
|
41e585324d | ||
|
|
c36082476f | ||
|
|
4bc1378285 | ||
|
|
23d01c4f83 | ||
|
|
49ebb70cb3 | ||
|
|
27bb1b322e | ||
|
|
0534053ff9 | ||
|
|
fe2a424718 | ||
|
|
d79e30cd7a | ||
|
|
aa433baa06 | ||
|
|
30bb9a1dc0 | ||
|
|
c168925f03 | ||
|
|
42c0bea8fc | ||
|
|
291c00749a | ||
|
|
6763ea0f7b | ||
|
|
f917defea8 | ||
|
|
91718592d4 | ||
|
|
6d9a8095eb | ||
|
|
5b8fdbc6fd | ||
|
|
e15e44e4a2 | ||
|
|
980d317152 | ||
|
|
2c23c2f574 | ||
|
|
80ad75503b | ||
|
|
0143b0a1b2 | ||
|
|
dd115f9d37 | ||
|
|
f198ce1ef8 | ||
|
|
18211a215d | ||
|
|
d6a386ad03 | ||
|
|
8f19584ad7 | ||
|
|
d9eaf79ab8 | ||
|
|
05b1a7ec3b | ||
|
|
11cdcb29fa | ||
|
|
46a9e5b097 | ||
|
|
8c65ac80b5 | ||
|
|
5b3a9ffba8 | ||
|
|
8688c8c2b3 | ||
|
|
13d8b9cff5 | ||
|
|
aa448b1ba3 | ||
|
|
7ef3d92f2c | ||
|
|
9f09598ccd | ||
|
|
36c8938ff7 | ||
|
|
3271db4ca6 | ||
|
|
400d8b34fd | ||
|
|
d4374f7074 | ||
|
|
438ee539d6 | ||
|
|
5ccde0206c | ||
|
|
74ed269544 | ||
|
|
1ad2c5402b | ||
|
|
805abeb261 | ||
|
|
551a376b00 | ||
|
|
47be92d5f6 | ||
|
|
5feac0c14e | ||
|
|
0f02e21eaa | ||
|
|
cdb84716b7 | ||
|
|
731238de08 | ||
|
|
cb1ad94086 | ||
|
|
24b0be7566 | ||
|
|
3241ef87a3 | ||
|
|
bd541e1199 | ||
|
|
f998f27c6f | ||
|
|
cf788316c3 | ||
|
|
8ed7e91221 | ||
|
|
e537b738e4 | ||
|
|
bf3a111e55 | ||
|
|
63a526709d | ||
|
|
e0088e8f8f | ||
|
|
7d33c49db8 | ||
|
|
acee0a36da | ||
|
|
a528ba6070 | ||
|
|
ba224ac335 | ||
|
|
18aa14118c | ||
|
|
4e139c837a | ||
|
|
6ad74bb206 | ||
|
|
16c3de44e4 | ||
|
|
608a4f7a2f | ||
|
|
cd972d0850 | ||
|
|
2aeeb1a14e | ||
|
|
41aa6ff8e4 | ||
|
|
67f658ced2 | ||
|
|
94fa86629d | ||
|
|
2cb0640a7b | ||
|
|
772d5bd3ae | ||
|
|
d6681ead60 | ||
|
|
d94fd4714f | ||
|
|
0365886270 | ||
|
|
aa6c658a9a | ||
|
|
6e3f871015 | ||
|
|
6c9983164e | ||
|
|
e00b935c1f | ||
|
|
f5f9c0314a | ||
|
|
910384d17f | ||
|
|
2fe56e7462 | ||
|
|
b8989e204f | ||
|
|
84b7c2c062 | ||
|
|
14373309aa | ||
|
|
9cd3e22721 | ||
|
|
2f812372d7 | ||
|
|
9f813ca1c0 | ||
|
|
70784598e1 | ||
|
|
ad20380283 | ||
|
|
f4e3922f47 | ||
|
|
bfac71a7d0 | ||
|
|
435e789fa4 | ||
|
|
b9f5e0a354 | ||
|
|
372e4acc2c | ||
|
|
8f37c3175f | ||
|
|
438e8463f5 | ||
|
|
ae9aa1ba0f | ||
|
|
688d26c457 | ||
|
|
8b99b46787 | ||
|
|
80cce47f58 | ||
|
|
c92184d89f | ||
|
|
a8fe575999 | ||
|
|
38099c8cc7 | ||
|
|
32e51257f6 | ||
|
|
e44e7ca8d5 | ||
|
|
d533496b2a | ||
|
|
eb9db28352 | ||
|
|
836ee5f3ed | ||
|
|
0fa31f84a3 | ||
|
|
9a3c71df4a | ||
|
|
e4e39e0496 | ||
|
|
cd901fbc68 | ||
|
|
8fde378534 | ||
|
|
d8f81aca65 | ||
|
|
b9895d541b | ||
|
|
4fbc3c8533 | ||
|
|
62e2bf7de6 | ||
|
|
c6f9ea2b20 | ||
|
|
166a3cfe09 | ||
|
|
cbf84876de | ||
|
|
8bd43f113b | ||
|
|
d710986923 | ||
|
|
dd63b78937 | ||
|
|
8d1a07585b | ||
|
|
f4638b48d5 | ||
|
|
a1d56fc7e4 | ||
|
|
9879c99873 | ||
|
|
1016efc646 | ||
|
|
156aa6822b | ||
|
|
34eb94320e | ||
|
|
ba1d14f40a | ||
|
|
dc4fbb1a14 | ||
|
|
7d770d2a6d | ||
|
|
7b861f563c | ||
|
|
65614ed3ba | ||
|
|
7acdf9cb3d | ||
|
|
2e4954a4e8 | ||
|
|
36052cb4f2 | ||
|
|
8d81f63e08 | ||
|
|
bf5be11549 | ||
|
|
36907ccac1 | ||
|
|
0f3e9e8754 | ||
|
|
883cdbefbe | ||
|
|
968d50e071 | ||
|
|
a9bb1c1f68 | ||
|
|
207d6cba43 | ||
|
|
40ba179703 | ||
|
|
5b8fde70ca | ||
|
|
e3232f0198 | ||
|
|
c5536aa25d | ||
|
|
37bb9ff767 | ||
|
|
f2e8207245 | ||
|
|
86bf40c8fb | ||
|
|
f90ab22119 | ||
|
|
8e7d224b7b | ||
|
|
29925a90ca | ||
|
|
fdf48d7bfd | ||
|
|
d9efe8f137 | ||
|
|
6afb2f0914 | ||
|
|
82ab9e3a6a | ||
|
|
f5ca33859c | ||
|
|
a0b4fb33c1 | ||
|
|
40705181a0 | ||
|
|
308178f299 | ||
|
|
410f2c93b9 | ||
|
|
767abc2b90 | ||
|
|
ab19b18169 | ||
|
|
406fe995c9 | ||
|
|
151963ea79 | ||
|
|
5583426115 | ||
|
|
9d1707eba1 | ||
|
|
f605cf0ef9 | ||
|
|
119e5b8c58 | ||
|
|
de764f3353 | ||
|
|
e4280e5516 | ||
|
|
886ee32e2f | ||
|
|
7544780ff7 | ||
|
|
26beff1080 | ||
|
|
88b8130d37 | ||
|
|
e7b8c68dba | ||
|
|
b753eda9b9 | ||
|
|
e6f2ce2598 | ||
|
|
a191bbbf6e | ||
|
|
65f48dd789 | ||
|
|
c14d4e59d4 | ||
|
|
6b2a14e589 | ||
|
|
d6cefaa6d2 | ||
|
|
278af3ce34 | ||
|
|
3b1cbd4fd7 | ||
|
|
500652da36 | ||
|
|
eee39f56a8 | ||
|
|
719b797037 | ||
|
|
7cb8208065 | ||
|
|
a03ababc80 | ||
|
|
a3565a0ae9 | ||
|
|
07a64375ce | ||
|
|
ec1e63e2f2 | ||
|
|
8487966888 | ||
|
|
4bbd2c7446 | ||
|
|
d9bab899fe | ||
|
|
e3b3e8861d | ||
|
|
69242f9f2d | ||
|
|
a21c7bce40 | ||
|
|
7379e3a833 | ||
|
|
b565b9c4e5 | ||
|
|
99ac17a5e4 | ||
|
|
4ccee8b107 | ||
|
|
eccd791a39 | ||
|
|
3d6878fe9f | ||
|
|
1e1803d148 | ||
|
|
343995fd8f | ||
|
|
71cb20eb4f | ||
|
|
741ed4191b | ||
|
|
2a6391a2ee | ||
|
|
405e846a0e | ||
|
|
155cf22841 | ||
|
|
e688422ec3 | ||
|
|
6047c787b3 | ||
|
|
c03166b622 | ||
|
|
617db611c3 | ||
|
|
f931951ad5 | ||
|
|
df7fea9ee3 | ||
|
|
24e8e756de | ||
|
|
ff5479f1a2 | ||
|
|
4ad3e1d9cf | ||
|
|
e4af83380b | ||
|
|
d30fe53ef3 | ||
|
|
9e0abc60be | ||
|
|
dc33cdce5a | ||
|
|
f5b8690ce0 | ||
|
|
fbb67f54ab | ||
|
|
5c4e80c8f8 | ||
|
|
0774c35589 | ||
|
|
553963a10a | ||
|
|
0bf171c8a6 | ||
|
|
f5549f7430 | ||
|
|
34816451c0 | ||
|
|
e1772c25c4 | ||
|
|
65653c0ee2 | ||
|
|
3addfb9a87 | ||
|
|
67be01f504 | ||
|
|
99aa7c9433 | ||
|
|
ccfdec8b45 | ||
|
|
dbd259adb0 | ||
|
|
ec151d7b53 | ||
|
|
5fbb1927a4 | ||
|
|
3bc13642af | ||
|
|
6b38a4f602 | ||
|
|
51f62438cf | ||
|
|
9d8b54072a | ||
|
|
0982c013d1 | ||
|
|
03a4719aed | ||
|
|
5c3fd75ed4 | ||
|
|
6dfc728b5b | ||
|
|
071ef8432e | ||
|
|
8ad5ffe20f | ||
|
|
49ba0e3f41 | ||
|
|
5a42ff4b78 | ||
|
|
20472dce88 | ||
|
|
057a7c9c6d | ||
|
|
63eff5c62a | ||
|
|
b824cd8ce7 | ||
|
|
eee82dea07 | ||
|
|
98656eb445 | ||
|
|
c8ff37e304 | ||
|
|
3fe726ac23 | ||
|
|
d5a253a781 | ||
|
|
6a5a95e543 | ||
|
|
a94a44826b | ||
|
|
31c8fea2dc | ||
|
|
13dcf0edb9 | ||
|
|
7bc64bc53b | ||
|
|
4991efe50c | ||
|
|
a87304b22b | ||
|
|
3701ade05f | ||
|
|
45ed5648c4 | ||
|
|
682d35b2b2 | ||
|
|
30f5fc9ab1 | ||
|
|
bc69d2221a | ||
|
|
592870f397 | ||
|
|
157c2999dc | ||
|
|
b6d984475c | ||
|
|
6a98ce9d28 | ||
|
|
c802c508fb | ||
|
|
545fb139ee | ||
|
|
c38507d50c |
2
.gitignore
vendored
2
.gitignore
vendored
@@ -15,4 +15,6 @@ assembly/runtime/*
|
|||||||
/assembly/deploy
|
/assembly/deploy
|
||||||
/runtime
|
/runtime
|
||||||
**/.flattened-pom.xml
|
**/.flattened-pom.xml
|
||||||
|
chm_db/
|
||||||
__pycache__/
|
__pycache__/
|
||||||
|
/dict
|
||||||
52
CHANGELOG.md
52
CHANGELOG.md
@@ -4,24 +4,48 @@
|
|||||||
- "Breaking Changes" describes any changes that may break existing functionality or cause
|
- "Breaking Changes" describes any changes that may break existing functionality or cause
|
||||||
compatibility issues with previous versions.
|
compatibility issues with previous versions.
|
||||||
|
|
||||||
## SuperSonic [0.7.4] - 2023-09-10
|
## SuperSonic [0.8.2] - 2023-12-18
|
||||||
|
|
||||||
### Added
|
### Added
|
||||||
- add llm parser config
|
- rewrite Python service with Java project, default to Java implementation.
|
||||||
- add datasource agg_time option
|
- support setting the SQL generation method for large models in the interface.
|
||||||
- add function name adaptor in clickhouse
|
- optimization of metric market experience.
|
||||||
- add dimension and metric show in dsl
|
- optimization of semantic modeling canvas experience.
|
||||||
|
- code structure adjustment and abstraction optimization for chat.
|
||||||
|
|
||||||
|
## SuperSonic [0.7.5] - 2023-10-13
|
||||||
|
|
||||||
### Updated
|
### Added
|
||||||
- update user guide doc
|
- add SQL generation improvement optimization, support LLM SQL, Logic SQL, and Physical SQL display.
|
||||||
- update query building of plugin in default model
|
- add showcase functionality to support recommending similar questions.
|
||||||
- update some core API constructs to keep naming consistency
|
- add frontend modification of filtering conditions and re-querying feature.
|
||||||
- update ConfigureDemo config
|
- support nested query functionality in semantic.
|
||||||
- update the association mechanism so that invisible dimensions and metrics will no longer be associated
|
- support switching queries between multiple parsers in the frontend.
|
||||||
|
|
||||||
### Fixed
|
### Updated
|
||||||
- fix hasAggregateFunction logic in SqlParserSelectHelper
|
- optimizing the build and deployment of the project.
|
||||||
|
- overall optimization of the SQL Corrector functionality.
|
||||||
|
|
||||||
|
### Fixed
|
||||||
|
- fix execute error on mysql <=5.7
|
||||||
|
|
||||||
|
## SuperSonic [0.7.4] - 2023-09-10
|
||||||
|
|
||||||
|
### Added
|
||||||
|
- add llm parser config
|
||||||
|
- add datasource agg_time option
|
||||||
|
- add function name adaptor in clickhouse
|
||||||
|
- add dimension and metric show in dsl
|
||||||
|
|
||||||
|
### Updated
|
||||||
|
- update user guide doc
|
||||||
|
- update query building of plugin in default model
|
||||||
|
- update some core API constructs to keep naming consistency
|
||||||
|
- update ConfigureDemo config
|
||||||
|
- update the association mechanism so that invisible dimensions and metrics will no longer be associated
|
||||||
|
|
||||||
|
### Fixed
|
||||||
|
- fix hasAggregateFunction logic in SqlParserSelectHelper
|
||||||
|
|
||||||
## SuperSonic [0.7.3] - 2023-08-29
|
## SuperSonic [0.7.3] - 2023-08-29
|
||||||
|
|
||||||
|
|||||||
26
README.md
26
README.md
@@ -2,25 +2,25 @@
|
|||||||
|
|
||||||
# SuperSonic (超音数)
|
# SuperSonic (超音数)
|
||||||
|
|
||||||
**SuperSonic is an out-of-the-box yet highly extensible framework for building a data chatbot**. SuperSonic provides a chat interface that empowers users to query data using natural language and visualize the results with suitable charts. To enable such experience, the only thing necessary is to build logical semantic models (definition of metrics/dimensions/entities, along with their meaning, context and relationships) on top of physical data models, and no data modification or copying is required. Meanwhile, SuperSonic is designed to be pluggable, allowing new functionalities to be added through plugins and core components to be integrated with other systems.
|
**SuperSonic is the next-generation LLM-powered data analytics platform that integrates ChatBI and HeadlessBI**. SuperSonic provides a chat interface that empowers users to query data using natural language and visualize the results with suitable charts. To enable such experience, the only thing necessary is to build logical semantic models (definition of entities/metrics/dimensions/tags, along with their meaning, context and relationships) on top of physical data models, and **no data modification or copying** is required. Meanwhile, SuperSonic is designed to be **highly extensible**, allowing custom functionalities to be added and configured with Java SPI.
|
||||||
|
|
||||||
<img src="./docs/images/supersonic_demo.gif" height="100%" width="100%" align="center"/>
|
<img src="./docs/images/supersonic_demo.gif" height="100%" width="100%" align="center"/>
|
||||||
|
|
||||||
## Motivation
|
## Motivation
|
||||||
|
|
||||||
The emergence of Large Language Model (LLM) like ChatGPT is reshaping the way information is retrieved. In the field of data analytics, both academia and industry are primarily focused on leveraging LLM to convert natural language queries into SQL queries. While some works show promising results, they are still not applicable to real-world scenarios.
|
The emergence of Large Language Model (LLM) like ChatGPT is reshaping the way information is retrieved. In the field of data analytics, both academia and industry are primarily focused on leveraging LLM to convert natural language into SQL (so called Text2SQL or NL2SQL). While some approaches exhibit promising results, their **reliability** and **efficiency** are insufficient for real-world applications.
|
||||||
|
|
||||||
From our perspective, the key to filling the real-world gap lies in three aspects:
|
From our perspective, the key to filling the real-world gap lies in three aspects:
|
||||||
1. Complement the LLM-based semantic parser with rule-based semantic parsers to improve **efficiency**(in terms of latency and cost).
|
1. Integrate ChatBI with HeadlessBI encapsulating underlying data context (joins, keys, formulas, etc) to **reduce complexity**.
|
||||||
2. Augment semantic parsing with schema mappers(as a kind of preprocessor) and semantic correctors(as a kind of postprocessor) to improve **accuracy** and **stability**.
|
2. Augment the LLM with schema mappers(as a kind of preprocessor) and semantic correctors(as a kind of postprocessor) to **mitigate hallucination**.
|
||||||
3. Introduce a semantic layer encapsulating underlying data context(joins, formulas, etc) to reduce **complexity**.
|
3. Utilize rule-based schema parsers when necessary to **improve efficiency**(in terms of latency and cost).
|
||||||
|
|
||||||
With these ideas in mind, we develop SuperSonic as a practical reference implementation and use it to power our real-world products. Additionally, to facilitate further development of data chatbot, we decide to open source SuperSonic as an extensible framework.
|
With these ideas in mind, we develop SuperSonic as a practical reference implementation and use it to power our real-world products. Additionally, to facilitate further development of ChatBI, we decide to open source SuperSonic as an extensible framework.
|
||||||
|
|
||||||
## Out-of-the-box Features
|
## Out-of-the-box Features
|
||||||
|
|
||||||
- Built-in CUI(Chat User Interface) for *business users* to enter data queries
|
- Built-in ChatBI interface for *business users* to enter natural language queries
|
||||||
- Built-in GUI(Graphical User Interface) for *analytics engineers* to build semantic models
|
- Built-in HeadlessBI interface for *analytics engineers* to build semantic models
|
||||||
- Built-in GUI for *system administrators* to manage chat agents and third-party plugins
|
- Built-in GUI for *system administrators* to manage chat agents and third-party plugins
|
||||||
- Support input auto-completion as well as query recommendation
|
- Support input auto-completion as well as query recommendation
|
||||||
- Support multi-turn conversation and history context management
|
- Support multi-turn conversation and history context management
|
||||||
@@ -40,7 +40,7 @@ The high-level architecture and main process flow is as follows:
|
|||||||
|
|
||||||
- **Semantic Corrector:** checks validity of extracted semantic information and performs correction and optimization if needed.
|
- **Semantic Corrector:** checks validity of extracted semantic information and performs correction and optimization if needed.
|
||||||
|
|
||||||
- **Semantic Layer:** performs execution according to extracted semantic information. It generates SQL queries and executes them against physical data models.
|
- **Semantic Interpreter:** performs execution according to extracted semantic information. It generates SQL statements and executes them against physical data models.
|
||||||
|
|
||||||
- **Chat Plugin:** extends functionality with third-party tools. The LLM is going to select the most suitable one, given all configured plugins with function description and sample questions.
|
- **Chat Plugin:** extends functionality with third-party tools. The LLM is going to select the most suitable one, given all configured plugins with function description and sample questions.
|
||||||
|
|
||||||
@@ -49,15 +49,15 @@ The high-level architecture and main process flow is as follows:
|
|||||||
SuperSonic comes with sample semantic models as well as chat conversations that can be used as a starting point. Please follow the steps:
|
SuperSonic comes with sample semantic models as well as chat conversations that can be used as a starting point. Please follow the steps:
|
||||||
|
|
||||||
- Download the latest prebuilt binary from the [release page](https://github.com/tencentmusic/supersonic/releases)
|
- Download the latest prebuilt binary from the [release page](https://github.com/tencentmusic/supersonic/releases)
|
||||||
- Run script "bin/start-standalone.sh" to start services (one java process and one python process)
|
- Run script "assembly/bin/supersonic-daemon.sh start" to start a standalone Java service
|
||||||
- Visit http://localhost:9080 in the browser to start exploration
|
- Visit http://localhost:9080 in the browser to start exploration
|
||||||
|
|
||||||
## Build and Delopment
|
## Build and Development
|
||||||
|
|
||||||
Please refer to project [wiki](https://github.com/tencentmusic/supersonic/wiki).
|
Please refer to project [wiki](https://github.com/tencentmusic/supersonic/wiki).
|
||||||
|
|
||||||
## WeChat Contact
|
## WeChat Contact
|
||||||
|
|
||||||
Please join the chat group to suggest feedbacks or ideas:
|
Please follow SuperSonic wechat official account:
|
||||||
|
|
||||||
<img src="./docs/images/wechat_contact.jpeg" height="40%" width="40%" align="center"/>
|
<img src="./docs/images/supersonic_wechat_oa.png" height="50%" width="50%" align="center"/>
|
||||||
34
README_CN.md
34
README_CN.md
@@ -1,36 +1,36 @@
|
|||||||
# 超音数(SuperSonic)
|
# SuperSonic (超音数)
|
||||||
|
|
||||||
**超音数是一个开箱即用且易于扩展的数据问答对话框架**。通过超音数的问答对话界面,用户能够使用自然语言查询数据,系统会选择合适的可视化图表呈现结果。超音数不需要修改或复制数据,只需要在物理数据模型之上构建逻辑语义模型(指标/维度/实体的定义,以及他们的业务含义、相互间关系等),即可开启数据问答体验。与此同时,超音数被设计为可插拔式的框架,允许以插件形式来扩展新功能,或者将核心组件与其他系统集成。
|
**SuperSonic融合ChatBI和HeadlessBI打造新一代的数据分析平台**。通过SuperSonic的问答对话界面,用户能够使用自然语言查询数据,系统会选择合适的可视化图表呈现结果。SuperSonic不需要修改或复制数据,只需要在物理数据模型之上构建逻辑语义模型(指标/维度/实体的定义,以及他们的业务含义、相互间关系等),即可开启数据问答体验。与此同时,SuperSonic被设计为可插拔的框架,采用Java SPI机制来扩展定制功能。
|
||||||
|
|
||||||
<img src="./docs/images/supersonic_demo.gif" height="100%" width="100%" align="center"/>
|
<img src="./docs/images/supersonic_demo.gif" height="100%" width="100%" align="center"/>
|
||||||
|
|
||||||
## 项目动机
|
## 项目动机
|
||||||
|
|
||||||
大型语言模型(LLMs)如ChatGPT的出现正在重塑信息检索的方式。在数据分析领域,学术界和工业界主要关注利用深度学习模型将自然语言查询转换为SQL查询。虽然一些工作显示出有前景的结果,但它们还并不适用于实际场景。
|
大型语言模型(LLMs)如ChatGPT的出现正在重塑信息检索的方式。在数据分析领域,学术界和工业界主要关注利用深度学习模型将自然语言查询转换为SQL查询。虽然一些工作显示出有前景的结果,但它们的可靠性还达不到生产可用的要求。
|
||||||
|
|
||||||
在我们看来,为了在实际场景发挥价值,有三个关键点:
|
在我们看来,为了在实际场景发挥价值,有三个关键点:
|
||||||
1. 在基于大模型语义解析器基础上,增加基于规则的解析器,提升语义解析的**效率**。
|
1. 融合HeadlessBI,通过统一语义层封装底层数据细节(关联、键值、公式等),降低SQL生成的**复杂度**。
|
||||||
2. 加入模式映射器和语义修正器,来增强语义解析能力,提升语义解析的**准确性**和**稳定性**。
|
2. 通过一前一后的模式映射器和语义修正器,来缓解LLM常见的**幻觉**现象。
|
||||||
3. 引入语义模型层,封装底层数据的上下文(关联、公式等),降低语义解析的**复杂性**。
|
3. 设计启发式的规则,在一些特定场景提升语义解析的**效率**。
|
||||||
|
|
||||||
为了验证上述想法,我们开发了超音数项目,并将其应用在实际的内部产品中。与此同时,我们将超音数作为一个可扩展的框架开源,希望能够促进数据问答对话领域的进一步发展。
|
为了验证上述想法,我们开发了SuperSonic项目,并将其应用在实际的内部产品中。与此同时,我们将SuperSonic作为一个可扩展的框架开源,希望能够促进数据问答对话领域的进一步发展。
|
||||||
|
|
||||||
## 开箱即用的特性
|
## 开箱即用的特性
|
||||||
|
|
||||||
- 内置对话界面以便*业务用户*输入数据查询。
|
- 内置ChatBI界面以便*业务用户*输入数据查询。
|
||||||
- 内置图形界面以便*分析工程师*构建语义模型。
|
- 内置HeadlessBI界面以便*分析工程师*构建语义模型。
|
||||||
- 内置图形界面以便*系统管理员*管理第三方插件和对话助理。
|
- 内置图形用户界面以便*系统管理员*管理第三方插件和对话助理。
|
||||||
- 支持文本输入的联想和查询问题的推荐。
|
- 支持文本输入的联想和查询问题的推荐。
|
||||||
- 支持多轮对话,根据语境自动切换上下文。
|
- 支持多轮对话,根据语境自动切换上下文。
|
||||||
- 支持四级权限控制:主题域级、模型级、列级、行级。
|
- 支持四级权限控制:主题域级、模型级、列级、行级。
|
||||||
|
|
||||||
## 易于扩展的组件
|
## 易于扩展的组件
|
||||||
|
|
||||||
超音数的整体架构和主流程如下图所示:
|
SuperSonic的整体架构和主流程如下图所示:
|
||||||
|
|
||||||
<img src="./docs/images/supersonic_components.png" height="65%" width="65%" align="center"/>
|
<img src="./docs/images/supersonic_components.png" height="65%" width="65%" align="center"/>
|
||||||
|
|
||||||
- **知识库(Knowledge Base):** 定期从语义模型中提取相关的模式信息,构建词典和索引,以便后续的模式映射。
|
- **模型知识库(Knowledge Base):** 定期从语义模型中提取相关的模式信息,构建词典和索引,以便后续的模式映射。
|
||||||
|
|
||||||
- **模式映射器(Schema Mapper):** 将自然语言文本在知识库中进行匹配,为后续的语义解析提供相关信息。
|
- **模式映射器(Schema Mapper):** 将自然语言文本在知识库中进行匹配,为后续的语义解析提供相关信息。
|
||||||
|
|
||||||
@@ -38,16 +38,16 @@
|
|||||||
|
|
||||||
- **语义修正器(Semantic Corrector):** 检查语义信息的合法性,对不合法的信息做修正和优化处理。
|
- **语义修正器(Semantic Corrector):** 检查语义信息的合法性,对不合法的信息做修正和优化处理。
|
||||||
|
|
||||||
- **语义模型层(Semantic Layer):** 根据语义信息生成物理SQL执行查询。
|
- **语义解释器(Semantic Interpreter):** 根据语义信息生成物理SQL执行查询。
|
||||||
|
|
||||||
- **问答插件(Chat Plugin):** 通过第三方工具扩展功能。给定所有配置的插件及其功能描述和示例问题,大语言模型将选择最合适的插件。
|
- **问答插件(Chat Plugin):** 通过第三方工具扩展功能。给定所有配置的插件及其功能描述和示例问题,大语言模型将选择最合适的插件。
|
||||||
|
|
||||||
## 快速体验
|
## 快速体验
|
||||||
|
|
||||||
超音数自带样例的语义模型和问答对话,只需以下三步即可快速体验:
|
SuperSonic自带样例的语义模型和问答对话,只需以下三步即可快速体验:
|
||||||
|
|
||||||
- 从[release page](https://github.com/tencentmusic/supersonic/releases)下载预先构建好的发行包
|
- 从[release page](https://github.com/tencentmusic/supersonic/releases)下载预先构建好的发行包
|
||||||
- 运行 "bin/start-standalone.sh"启动服务(一个Java进程和一个Python进程)
|
- 运行 "assembly/bin/supersonic-daemon.sh start"启动standalone模式的Java服务
|
||||||
- 在浏览器访问http://localhost:9080 开启探索
|
- 在浏览器访问http://localhost:9080 开启探索
|
||||||
|
|
||||||
## 如何构建和部署
|
## 如何构建和部署
|
||||||
@@ -56,6 +56,6 @@
|
|||||||
|
|
||||||
## 微信联系方式
|
## 微信联系方式
|
||||||
|
|
||||||
欢迎加入微信群反馈建议:
|
欢迎关注微信公众号:
|
||||||
|
|
||||||
<img src="./docs/images/wechat_contact.jpeg" height="40%" width="40%" align="center"/>
|
<img src="./docs/images/supersonic_wechat_oa.png" height="50%" width="50%" align="center"/>
|
||||||
@@ -1,30 +1,72 @@
|
|||||||
@echo off
|
@echo off
|
||||||
setlocal
|
setlocal
|
||||||
|
chcp 65001
|
||||||
set "sbinDir=%~dp0"
|
set "sbinDir=%~dp0"
|
||||||
set "baseDir=%~dp0.."
|
set "baseDir=%~dp0.."
|
||||||
set "buildDir=%baseDir%\build"
|
set "buildDir=%baseDir%\build"
|
||||||
|
set "runtimeDir=%baseDir%\..\runtime"
|
||||||
|
set "pip_path=pip3"
|
||||||
|
set "service=%~1"
|
||||||
|
|
||||||
|
|
||||||
rem 1. build semantic chat service
|
rem 1. build backend java modules
|
||||||
del /q "%buildDir%\*.tar.gz" 2>NUL
|
del /q "%buildDir%\*.tar.gz" 2>NUL
|
||||||
|
|
||||||
call mvn -f "%baseDir%\..\pom.xml" clean package -DskipTests
|
call mvn -f "%baseDir%\..\pom.xml" clean package -DskipTests
|
||||||
|
|
||||||
|
IF ERRORLEVEL 1 (
|
||||||
|
ECHO Failed to build backend Java modules.
|
||||||
|
EXIT /B 1
|
||||||
|
)
|
||||||
|
|
||||||
rem 2. move package to build
|
rem 2. move package to build
|
||||||
echo f|xcopy "%baseDir%\..\launchers\standalone\target\*.tar.gz" "%buildDir%\supersonic-standalone.tar.gz"
|
echo f|xcopy "%baseDir%\..\launchers\standalone\target\*.tar.gz" "%buildDir%\supersonic-standalone.tar.gz"
|
||||||
echo f|xcopy "%baseDir%\..\launchers\semantic\target\*.tar.gz" "%buildDir%\supersonic-semantic.tar.gz"
|
|
||||||
echo f|xcopy "%baseDir%\..\launchers\chat\target\*.tar.gz" "%buildDir%\supersonic-chat.tar.gz"
|
|
||||||
|
|
||||||
rem 3. build webapp
|
rem 3. build frontend webapp
|
||||||
cd "%baseDir%\..\webapp"
|
cd "%baseDir%\..\webapp"
|
||||||
call start-fe-prod.bat
|
call start-fe-prod.bat
|
||||||
copy /y "%baseDir%\..\webapp\supersonic-webapp.tar.gz" "%buildDir%\"
|
copy /y "%baseDir%\..\webapp\supersonic-webapp.tar.gz" "%buildDir%\"
|
||||||
|
|
||||||
|
IF ERRORLEVEL 1 (
|
||||||
|
ECHO Failed to build frontend webapp.
|
||||||
|
EXIT /B 1
|
||||||
|
)
|
||||||
|
|
||||||
|
rem 4. copy webapp to java classpath
|
||||||
cd "%buildDir%"
|
cd "%buildDir%"
|
||||||
tar -zxvf supersonic-webapp.tar.gz
|
tar -zxvf supersonic-webapp.tar.gz
|
||||||
move supersonic-webapp webapp
|
move supersonic-webapp webapp
|
||||||
move webapp ..\..\launchers\standalone\target\classes
|
move webapp ..\..\launchers\standalone\target\classes
|
||||||
|
|
||||||
|
rem 5. build backend python modules
|
||||||
|
if "%service%"=="pyllm" (
|
||||||
|
echo "start installing python modules with pip: ${pip_path}"
|
||||||
|
set requirementPath="%baseDir%/../chat/python/requirements.txt"
|
||||||
|
%pip_path% install -r %requirementPath%
|
||||||
|
echo "install python modules success"
|
||||||
|
)
|
||||||
|
|
||||||
|
call :BUILD_RUNTIME
|
||||||
|
|
||||||
|
:BUILD_RUNTIME
|
||||||
|
rem 6. reset runtime
|
||||||
|
IF EXIST "%runtimeDir%" (
|
||||||
|
echo begin to delete dir : %runtimeDir%
|
||||||
|
rd /s /q "%runtimeDir%"
|
||||||
|
) ELSE (
|
||||||
|
echo %runtimeDir% does not exist, create directly
|
||||||
|
)
|
||||||
|
mkdir "%runtimeDir%"
|
||||||
|
tar -zxvf "%buildDir%\supersonic-standalone.tar.gz" -C "%runtimeDir%"
|
||||||
|
for /d %%f in ("%runtimeDir%\launchers-standalone-*") do (
|
||||||
|
move "%%f" "%runtimeDir%\supersonic-standalone"
|
||||||
|
)
|
||||||
|
|
||||||
|
rem 7. copy webapp to runtime
|
||||||
|
tar -zxvf "%buildDir%\supersonic-webapp.tar.gz" -C "%buildDir%"
|
||||||
|
if not exist "%runtimeDir%\supersonic-standalone\webapp" mkdir "%runtimeDir%\supersonic-standalone\webapp"
|
||||||
|
xcopy /s /e /h /y "%buildDir%\supersonic-webapp\*" "%runtimeDir%\supersonic-standalone\webapp"
|
||||||
|
if not exist "%runtimeDir%\supersonic-standalone\conf\webapp" mkdir "%runtimeDir%\supersonic-standalone\conf\webapp"
|
||||||
|
xcopy /s /e /h /y "%runtimeDir%\supersonic-standalone\webapp\*" "%runtimeDir%\supersonic-standalone\conf\webapp"
|
||||||
|
rd /s /q "%buildDir%\supersonic-webapp"
|
||||||
|
|
||||||
endlocal
|
endlocal
|
||||||
49
assembly/bin/supersonic-build.sh
Normal file → Executable file
49
assembly/bin/supersonic-build.sh
Normal file → Executable file
@@ -1,29 +1,58 @@
|
|||||||
#!/usr/bin/env bash
|
#!/usr/bin/env bash
|
||||||
|
|
||||||
|
set -x
|
||||||
sbinDir=$(cd "$(dirname "$0")"; pwd)
|
sbinDir=$(cd "$(dirname "$0")"; pwd)
|
||||||
baseDir=$(cd "$sbinDir/.." && pwd -P)
|
chmod +x $sbinDir/supersonic-common.sh
|
||||||
runtimeDir=$baseDir/runtime
|
source $sbinDir/supersonic-common.sh
|
||||||
buildDir=$baseDir/build
|
|
||||||
|
|
||||||
cd $baseDir
|
cd $baseDir
|
||||||
|
|
||||||
#1. build semantic chat service
|
service=$1
|
||||||
|
#1. build backend java modules
|
||||||
rm -fr ${buildDir}/*.tar.gz
|
rm -fr ${buildDir}/*.tar.gz
|
||||||
rm -fr dist
|
rm -fr dist
|
||||||
|
set +x
|
||||||
mvn -f $baseDir/../ clean package -DskipTests
|
mvn -f $baseDir/../ clean package -DskipTests
|
||||||
|
# check build result
|
||||||
|
if [ $? -ne 0 ]; then
|
||||||
|
echo "Failed to build backend Java modules."
|
||||||
|
exit 1
|
||||||
|
fi
|
||||||
|
|
||||||
#2. move package to build
|
#2. move package to build
|
||||||
cp $baseDir/../launchers/standalone/target/*.tar.gz ${buildDir}/supersonic.tar.gz
|
cp $baseDir/../launchers/semantic/target/*.tar.gz ${buildDir}/supersonic-semantic.tar.gz
|
||||||
|
cp $baseDir/../launchers/chat/target/*.tar.gz ${buildDir}/supersonic-chat.tar.gz
|
||||||
|
cp $baseDir/../launchers/standalone/target/*.tar.gz ${buildDir}/supersonic-standalone.tar.gz
|
||||||
|
|
||||||
#3. build webapp
|
#3. build frontend webapp
|
||||||
chmod +x $baseDir/../webapp/start-fe-prod.sh
|
chmod +x $baseDir/../webapp/start-fe-prod.sh
|
||||||
cd ../webapp
|
cd ../webapp
|
||||||
sh ./start-fe-prod.sh
|
sh ./start-fe-prod.sh
|
||||||
cp -fr ./supersonic-webapp.tar.gz ${buildDir}/
|
cp -fr ./supersonic-webapp.tar.gz ${buildDir}/
|
||||||
|
|
||||||
|
# check build result
|
||||||
|
if [ $? -ne 0 ]; then
|
||||||
|
echo "Failed to build frontend webapp."
|
||||||
|
exit 1
|
||||||
|
fi
|
||||||
|
#4. copy webapp to java classpath
|
||||||
cd $buildDir
|
cd $buildDir
|
||||||
tar xvf supersonic-webapp.tar.gz
|
tar xvf supersonic-webapp.tar.gz
|
||||||
mv supersonic-webapp webapp
|
mv supersonic-webapp webapp
|
||||||
mv webapp ../../launchers/standalone/target/classes
|
cp -fr webapp ../../launchers/semantic/target/classes
|
||||||
|
cp -fr webapp ../../launchers/chat/target/classes
|
||||||
|
cp -fr webapp ../../launchers/standalone/target/classes
|
||||||
|
rm -fr ${buildDir}/webapp
|
||||||
|
|
||||||
|
#5. build backend python modules
|
||||||
|
if [ "$service" == "pyllm" ]; then
|
||||||
|
echo "start installing python modules with pip: ${pip_path}"
|
||||||
|
requirementPath=$baseDir/../chat/python/requirements.txt
|
||||||
|
${pip_path} install -r ${requirementPath}
|
||||||
|
echo "install python modules success"
|
||||||
|
fi
|
||||||
|
|
||||||
|
#6. reset runtime
|
||||||
|
rm -fr $runtimeDir/supersonic*
|
||||||
|
moveAllToRuntime
|
||||||
|
setEnvToWeb chat
|
||||||
|
setEnvToWeb semantic
|
||||||
|
|||||||
110
assembly/bin/supersonic-common.sh
Executable file
110
assembly/bin/supersonic-common.sh
Executable file
@@ -0,0 +1,110 @@
|
|||||||
|
#!/usr/bin/env bash
|
||||||
|
|
||||||
|
# environment parameters
|
||||||
|
python_path=${PYTHON_PATH:-"python3"}
|
||||||
|
pip_path=${PIP_PATH:-"pip3"}
|
||||||
|
|
||||||
|
sbinDir=$(cd "$(dirname "$0")"; pwd)
|
||||||
|
baseDir=$(cd "$sbinDir/.." && pwd -P)
|
||||||
|
runtimeDir=$baseDir/../runtime
|
||||||
|
buildDir=$baseDir/build
|
||||||
|
|
||||||
|
readonly CHAT_APP_NAME="supersonic_chat"
|
||||||
|
readonly SEMANTIC_APP_NAME="supersonic_semantic"
|
||||||
|
readonly PYLLM_APP_NAME="supersonic_pyllm"
|
||||||
|
readonly STANDALONE_APP_NAME="supersonic_standalone"
|
||||||
|
readonly CHAT_SERVICE="chat"
|
||||||
|
readonly SEMANTIC_SERVICE="semantic"
|
||||||
|
readonly PYLLM_SERVICE="pyllm"
|
||||||
|
readonly STANDALONE_SERVICE="standalone"
|
||||||
|
readonly PYLLM_HOST="127.0.0.1"
|
||||||
|
readonly PYLLM_PORT="9092"
|
||||||
|
|
||||||
|
function setEnvToWeb {
|
||||||
|
model_name=$1
|
||||||
|
json='{"env": "'$model_name'"}'
|
||||||
|
echo $json > ${runtimeDir}/supersonic-${model_name}/webapp/supersonic.config.json
|
||||||
|
echo $json > $baseDir/../launchers/${model_name}/target/classes/webapp/supersonic.config.json
|
||||||
|
}
|
||||||
|
|
||||||
|
function moveToRuntime {
|
||||||
|
model_name=$1
|
||||||
|
file="${buildDir}/supersonic-${model_name}.tar.gz"
|
||||||
|
if [ -f "$file" ]; then
|
||||||
|
tar -zxvf "$file" -C ${runtimeDir}
|
||||||
|
mv ${runtimeDir}/launchers-${model_name}-* ${runtimeDir}/supersonic-${model_name}
|
||||||
|
mkdir -p ${runtimeDir}/supersonic-${model_name}/webapp
|
||||||
|
cp -fr ${buildDir}/webapp/* ${runtimeDir}/supersonic-${model_name}/webapp
|
||||||
|
else
|
||||||
|
echo "File $file does not exist. Skipping the move to runtime."
|
||||||
|
fi
|
||||||
|
}
|
||||||
|
|
||||||
|
function moveAllToRuntime {
|
||||||
|
mkdir -p ${runtimeDir}
|
||||||
|
tar xvf ${buildDir}/supersonic-webapp.tar.gz -C ${buildDir}
|
||||||
|
mv ${buildDir}/supersonic-webapp ${buildDir}/webapp
|
||||||
|
|
||||||
|
moveToRuntime chat
|
||||||
|
moveToRuntime semantic
|
||||||
|
moveToRuntime standalone
|
||||||
|
rm -fr ${buildDir}/webapp
|
||||||
|
}
|
||||||
|
|
||||||
|
# run java service
|
||||||
|
function runJavaService {
|
||||||
|
javaRunDir=${runtimeDir}/supersonic-${model_name}
|
||||||
|
local_app_name=$1
|
||||||
|
libDir=$javaRunDir/lib
|
||||||
|
confDir=$javaRunDir/conf
|
||||||
|
|
||||||
|
CLASSPATH=""
|
||||||
|
CLASSPATH=$CLASSPATH:$confDir
|
||||||
|
|
||||||
|
for jarPath in $libDir/*.jar; do
|
||||||
|
CLASSPATH=$CLASSPATH:$jarPath
|
||||||
|
done
|
||||||
|
|
||||||
|
export CLASSPATH
|
||||||
|
export LANG="zh_CN.UTF-8"
|
||||||
|
|
||||||
|
cd $javaRunDir
|
||||||
|
if [[ "$JAVA_HOME" == "" ]]; then
|
||||||
|
JAVA_HOME=$(ls /usr/jdk64/jdk* -d 2>/dev/null | xargs | awk '{print "'$local_app_name'"}')
|
||||||
|
fi
|
||||||
|
export PATH=$JAVA_HOME/bin:$PATH
|
||||||
|
command="-Dfile.encoding="UTF-8" -Duser.language="Zh" -Duser.region="CN" -Duser.timezone="GMT+08" -Dapp_name=${local_app_name} -Xms1024m -Xmx2048m "$main_class
|
||||||
|
|
||||||
|
mkdir -p $javaRunDir/logs
|
||||||
|
if [[ "$is_test" == "true" ]]; then
|
||||||
|
java -Dspring.profiles.active="dev" $command >/dev/null 2>$javaRunDir/logs/error.log &
|
||||||
|
else
|
||||||
|
java $command $javaRunDir >/dev/null 2>$javaRunDir/logs/error.log &
|
||||||
|
fi
|
||||||
|
}
|
||||||
|
|
||||||
|
# run python service
|
||||||
|
function runPythonService {
|
||||||
|
pythonRunDir=${runtimeDir}/supersonic-${model_name}/pyllm
|
||||||
|
cd $pythonRunDir
|
||||||
|
nohup ${python_path} supersonic_pyllm.py > $pythonRunDir/pyllm.log 2>&1 &
|
||||||
|
# add health check
|
||||||
|
for i in {1..10}
|
||||||
|
do
|
||||||
|
echo "pyllm health check attempt $i..."
|
||||||
|
response=$(curl -s http://${PYLLM_HOST}:${PYLLM_PORT}/health)
|
||||||
|
echo "pyllm health check response: $response"
|
||||||
|
status_ok="Healthy"
|
||||||
|
if [[ $response == *$status_ok* ]] ; then
|
||||||
|
echo "pyllm Health check passed."
|
||||||
|
break
|
||||||
|
else
|
||||||
|
if [ "$i" -eq 10 ]; then
|
||||||
|
echo "pyllm Health check failed after 10 attempts."
|
||||||
|
echo "May still downloading model files. Please check pyllm.log in runtime directory."
|
||||||
|
fi
|
||||||
|
echo "Retrying after 5 seconds..."
|
||||||
|
sleep 5
|
||||||
|
fi
|
||||||
|
done
|
||||||
|
}
|
||||||
@@ -1,120 +1,118 @@
|
|||||||
@echo off
|
@echo off
|
||||||
setlocal
|
setlocal
|
||||||
|
chcp 65001
|
||||||
set "sbinDir=%~dp0"
|
set "sbinDir=%~dp0"
|
||||||
set "baseDir=%~dp0.."
|
set "baseDir=%~dp0.."
|
||||||
set "runtimeDir=%baseDir%\..\runtime"
|
set "runtimeDir=%baseDir%\..\runtime"
|
||||||
set "buildDir=%baseDir%\build"
|
set "buildDir=%baseDir%\build"
|
||||||
|
set "main_class=com.tencent.supersonic.StandaloneLauncher"
|
||||||
|
set "python_path=python"
|
||||||
|
set "pip_path=pip3"
|
||||||
|
set "standalone_service=standalone"
|
||||||
|
set "pyllm_service=pyllm"
|
||||||
|
|
||||||
|
set "javaRunDir=%runtimeDir%\supersonic-standalone"
|
||||||
|
set "pythonRunDir=%runtimeDir%\supersonic-standalone\pyllm"
|
||||||
|
|
||||||
set "command=%~1"
|
set "command=%~1"
|
||||||
set "module=%~2"
|
set "service=%~2"
|
||||||
|
|
||||||
set "APP_NAME=standalone-service"
|
if "%service%"=="" (
|
||||||
set "MAIN_CLASS=com.tencent.supersonic.StandaloneLauncher"
|
set "service=%standalone_service%"
|
||||||
|
|
||||||
set "python_path=python"
|
|
||||||
set "pip_path=pip3.9"
|
|
||||||
set "llm_host=127.0.0.1"
|
|
||||||
set "llm_port=9092"
|
|
||||||
set "start_name=api_service"
|
|
||||||
|
|
||||||
set "llm_path=%baseDir%\..\chat\core\src\main\python"
|
|
||||||
|
|
||||||
if "%module%"=="" (
|
|
||||||
set "module=standalone"
|
|
||||||
) else if "%module%"=="semantic" (
|
|
||||||
set "APP_NAME=semantic-service"
|
|
||||||
set "MAIN_CLASS=com.tencent.supersonic.SemanticLauncher"
|
|
||||||
) else if "%module%"=="chat" (
|
|
||||||
set "APP_NAME=chat-service"
|
|
||||||
set "MAIN_CLASS=com.tencent.supersonic.ChatLauncher"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
if "%command%"=="" (
|
IF "%service%"=="pyllm" (
|
||||||
set "command=restart"
|
SET "llmProxy=PythonLLMProxy"
|
||||||
)
|
)
|
||||||
|
|
||||||
set "libDir=%runtimeDir%\supersonic-%module%\lib"
|
call :BUILD_RUNTIME
|
||||||
set "confDir=%runtimeDir%\supersonic-%module%\conf"
|
|
||||||
set "webDir=%runtimeDir%\supersonic-%module%\webapp"
|
|
||||||
set "CLASSPATH=%confDir%;%webDir%;%libDir%\*"
|
|
||||||
set "java-command=-Dfile.encoding=UTF-8 -Duser.language=Zh -Duser.region=CN -Duser.timezone=GMT+08 -Xms1024m -Xmx2048m -cp %CLASSPATH% %MAIN_CLASS%"
|
|
||||||
|
|
||||||
|
|
||||||
if "%command%"=="stop" (
|
|
||||||
call:STOP
|
|
||||||
goto :EOF
|
|
||||||
)
|
|
||||||
|
|
||||||
if "%command%"=="restart" (
|
if "%command%"=="restart" (
|
||||||
call:STOP
|
call :STOP
|
||||||
)
|
call :START
|
||||||
|
|
||||||
::1. clear file
|
|
||||||
rd /s /q "%runtimeDir%"
|
|
||||||
mkdir "%runtimeDir%"
|
|
||||||
|
|
||||||
if "%module%"=="llmparser" (
|
|
||||||
tar -zxvf "%buildDir%\supersonic-standalone.tar.gz" -C "%runtimeDir%"
|
|
||||||
for /d %%f in ("%runtimeDir%\launchers-standalone-*") do (
|
|
||||||
move "%%f" "%runtimeDir%\supersonic-standalone"
|
|
||||||
)
|
|
||||||
cd "%runtimeDir%"
|
|
||||||
"%pip_path%" install -r "%llm_path%\requirements.txt"
|
|
||||||
"%python_path%" -c "import langchain,fastapi,chromadb,tiktoken,uvicorn" >nul 2>&1
|
|
||||||
cd "%runtimeDir%\supersonic-standalone\llm\llm"
|
|
||||||
start "" /B uvicorn %start_name%:app --port %llm_port% --host %llm_host% > "%runtimeDir%\supersonic-standalone\llm\llm.log" 2>&1
|
|
||||||
echo "llm service started, see logs/error with logs/error command"
|
|
||||||
goto :EOF
|
goto :EOF
|
||||||
)
|
) else if "%command%"=="start" (
|
||||||
|
call :START
|
||||||
tar -zxvf "%buildDir%\supersonic-%module%.tar.gz" -C "%runtimeDir%"
|
|
||||||
for /d %%f in ("%runtimeDir%\launchers-%module%-*") do (
|
|
||||||
move "%%f" "%runtimeDir%\supersonic-%module%"
|
|
||||||
)
|
|
||||||
|
|
||||||
if not exist "%runtimeDir%\supersonic-%module%\logs" mkdir "%runtimeDir%\supersonic-%module%\logs"
|
|
||||||
|
|
||||||
tar -zxvf "%buildDir%\supersonic-webapp.tar.gz" -C "%buildDir%"
|
|
||||||
if not exist "%runtimeDir%\supersonic-%module%\webapp" mkdir "%runtimeDir%\supersonic-%module%\webapp"
|
|
||||||
xcopy /s /e /h /y "%buildDir%\supersonic-webapp\*" "%runtimeDir%\supersonic-%module%\webapp"
|
|
||||||
if not exist "%runtimeDir%\supersonic-%module%\conf\webapp" mkdir "%runtimeDir%\supersonic-%module%\conf\webapp"
|
|
||||||
xcopy /s /e /h /y "%runtimeDir%\supersonic-%module%\webapp\*" "%runtimeDir%\supersonic-%module%\conf\webapp"
|
|
||||||
rd /s /q "%buildDir%\supersonic-webapp"
|
|
||||||
|
|
||||||
::3. start service
|
|
||||||
::start standalone service
|
|
||||||
if "%command%"=="start" (
|
|
||||||
call:START
|
|
||||||
goto :EOF
|
goto :EOF
|
||||||
)
|
) else if "%command%"=="stop" (
|
||||||
|
call :STOP
|
||||||
if "%command%"=="restart" (
|
goto :EOF
|
||||||
call:START
|
) else if "%command%"=="reload" (
|
||||||
|
call :RELOAD_EXAMPLE
|
||||||
|
goto :EOF
|
||||||
|
) else (
|
||||||
|
echo "Use command {start|stop|restart} to run."
|
||||||
goto :EOF
|
goto :EOF
|
||||||
)
|
)
|
||||||
|
|
||||||
:START
|
:START
|
||||||
if "%module%"=="standalone" (
|
if "%service%"=="%pyllm_service%" (
|
||||||
cd "%runtimeDir%"
|
call :START_PYTHON
|
||||||
"%pip_path%" install -r "%llm_path%\requirements.txt"
|
call :START_JAVA
|
||||||
"%python_path%" -c "import langchain,fastapi,chromadb,tiktoken,uvicorn" >nul 2>&1
|
goto :EOF
|
||||||
cd "%runtimeDir%\supersonic-standalone\llm\llm"
|
)
|
||||||
start "" /B uvicorn %start_name%:app --port %llm_port% --host %llm_host% > "%runtimeDir%\supersonic-standalone\llm\llm.log" 2>&1
|
call :START_JAVA
|
||||||
echo "llm service started, see logs/error with logs/error command"
|
|
||||||
)
|
|
||||||
start "supersonic" /B java %java-command%>"%runtimeDir%\supersonic-%module%\logs\info-%module%.log" 2>&1
|
|
||||||
echo "%module% service started, see logs/error with logs/error command"
|
|
||||||
goto :EOF
|
goto :EOF
|
||||||
|
|
||||||
|
|
||||||
:STOP
|
:STOP
|
||||||
|
call :STOP_PYTHON
|
||||||
|
call :STOP_JAVA
|
||||||
|
goto :EOF
|
||||||
|
|
||||||
|
:START_PYTHON
|
||||||
|
echo 'python service starting, see logs in pyllm/pyllm.log'
|
||||||
|
cd "%pythonRunDir%"
|
||||||
|
start /B %python_path% supersonic_pyllm.py > %pythonRunDir%\pyllm.log 2>&1
|
||||||
|
timeout /t 10 >nul
|
||||||
|
echo 'python service started'
|
||||||
|
goto :EOF
|
||||||
|
|
||||||
|
:START_JAVA
|
||||||
|
echo 'java service starting, see logs in logs/'
|
||||||
|
cd "%javaRunDir%"
|
||||||
|
if not exist "%runtimeDir%\supersonic-standalone\logs" mkdir "%runtimeDir%\supersonic-standalone\logs"
|
||||||
|
set "libDir=%runtimeDir%\supersonic-standalone\lib"
|
||||||
|
set "confDir=%runtimeDir%\supersonic-standalone\conf"
|
||||||
|
set "webDir=%runtimeDir%\supersonic-standalone\webapp"
|
||||||
|
set "classpath=%confDir%;%webDir%;%libDir%\*"
|
||||||
|
set "java-command=-Dfile.encoding=UTF-8 -Duser.language=Zh -Duser.region=CN -Duser.timezone=GMT+08 -Xms1024m -Xmx2048m -cp %CLASSPATH% %MAIN_CLASS%"
|
||||||
|
start /B java %java-command% >nul 2>&1
|
||||||
|
timeout /t 10 >nul
|
||||||
|
echo 'java service started'
|
||||||
|
goto :EOF
|
||||||
|
|
||||||
|
:STOP_PYTHON
|
||||||
for /f "tokens=2" %%i in ('tasklist ^| findstr /i "python"') do (
|
for /f "tokens=2" %%i in ('tasklist ^| findstr /i "python"') do (
|
||||||
taskkill /PID %%i /F
|
taskkill /PID %%i /F
|
||||||
echo "llm Process (PID = %%i) is killed."
|
echo "python service (PID = %%i) is killed."
|
||||||
)
|
|
||||||
for /f "tokens=2" %%i in ('tasklist ^| findstr /i "java"') do (
|
|
||||||
taskkill /PID %%i /F
|
|
||||||
echo "%module% Process (PID = %%i) is killed."
|
|
||||||
)
|
)
|
||||||
goto :EOF
|
goto :EOF
|
||||||
|
|
||||||
|
:STOP_JAVA
|
||||||
|
for /f "tokens=2" %%i in ('tasklist ^| findstr /i "java"') do (
|
||||||
|
taskkill /PID %%i /F
|
||||||
|
echo "java service (PID = %%i) is killed."
|
||||||
|
)
|
||||||
|
goto :EOF
|
||||||
|
|
||||||
|
:RELOAD_EXAMPLE
|
||||||
|
cd "%runtimeDir%\supersonic-standalone\pyllm\sql"
|
||||||
|
start %python_path% examples_reload_run.py
|
||||||
|
goto :EOF
|
||||||
|
|
||||||
|
:BUILD_RUNTIME
|
||||||
|
rem 6. reset runtime
|
||||||
|
if exist "%runtimeDir%" goto :EOF
|
||||||
|
mkdir "%runtimeDir%"
|
||||||
|
tar -zxvf "%buildDir%\supersonic-standalone.tar.gz" -C "%runtimeDir%"
|
||||||
|
for /d %%f in ("%runtimeDir%\launchers-standalone-*") do (
|
||||||
|
move "%%f" "%runtimeDir%\supersonic-standalone"
|
||||||
|
)
|
||||||
|
|
||||||
|
rem 7. copy webapp to runtime
|
||||||
|
tar -zxvf "%buildDir%\supersonic-webapp.tar.gz" -C "%buildDir%"
|
||||||
|
if not exist "%runtimeDir%\supersonic-standalone\webapp" mkdir "%runtimeDir%\supersonic-standalone\webapp"
|
||||||
|
xcopy /s /e /h /y "%buildDir%\supersonic-webapp\*" "%runtimeDir%\supersonic-standalone\webapp"
|
||||||
|
if not exist "%runtimeDir%\supersonic-standalone\conf\webapp" mkdir "%runtimeDir%\supersonic-standalone\conf\webapp"
|
||||||
|
xcopy /s /e /h /y "%runtimeDir%\supersonic-standalone\webapp\*" "%runtimeDir%\supersonic-standalone\conf\webapp"
|
||||||
|
rd /s /q "%buildDir%\supersonic-webapp"
|
||||||
193
assembly/bin/supersonic-daemon.sh
Normal file → Executable file
193
assembly/bin/supersonic-daemon.sh
Normal file → Executable file
@@ -1,98 +1,143 @@
|
|||||||
#!/usr/bin/env bash
|
#!/usr/bin/env bash
|
||||||
|
|
||||||
|
set -x
|
||||||
sbinDir=$(cd "$(dirname "$0")"; pwd)
|
sbinDir=$(cd "$(dirname "$0")"; pwd)
|
||||||
baseDir=$(cd "$sbinDir/.." && pwd -P)
|
chmod +x $sbinDir/supersonic-common.sh
|
||||||
runtimeDir=$baseDir/../runtime
|
source $sbinDir/supersonic-common.sh
|
||||||
buildDir=$baseDir/build
|
|
||||||
|
# 1.init environment parameters
|
||||||
|
if [ ! -d "$runtimeDir" ]; then
|
||||||
|
echo "the runtime dir does not exist move all to runtime"
|
||||||
|
moveAllToRuntime
|
||||||
|
fi
|
||||||
|
set +x
|
||||||
|
|
||||||
command=$1
|
command=$1
|
||||||
service=$2
|
service=$2
|
||||||
|
if [ -z "$service" ]; then
|
||||||
|
service=${STANDALONE_SERVICE}
|
||||||
|
fi
|
||||||
|
|
||||||
|
app_name=$STANDALONE_APP_NAME
|
||||||
|
main_class="com.tencent.supersonic.StandaloneLauncher"
|
||||||
|
model_name=$service
|
||||||
|
|
||||||
|
if [ "$service" == "pyllm" ]; then
|
||||||
|
model_name=${STANDALONE_SERVICE}
|
||||||
|
export llmProxy=PythonLLMProxy
|
||||||
|
fi
|
||||||
|
|
||||||
cd $baseDir
|
cd $baseDir
|
||||||
if [[ "$service" == "semantic" || -z "$service" ]] && [ "$command" != "stop" ]; then
|
|
||||||
#1. clear file
|
|
||||||
mkdir -p ${runtimeDir}
|
|
||||||
rm -fr ${runtimeDir}/*
|
|
||||||
|
|
||||||
#2. package lib
|
# 2.set main class
|
||||||
tar -zxvf ${buildDir}/supersonic.tar.gz -C ${runtimeDir}
|
function setMainClass {
|
||||||
mv ${runtimeDir}/launchers-standalone-* ${runtimeDir}/supersonic-standalone
|
if [ "$service" == $CHAT_SERVICE ]; then
|
||||||
tar -zxvf ${buildDir}/supersonic-webapp.tar.gz -C ${buildDir}
|
main_class="com.tencent.supersonic.ChatLauncher"
|
||||||
mkdir -p ${runtimeDir}/supersonic-standalone/webapp
|
elif [ "$service" == $SEMANTIC_SERVICE ]; then
|
||||||
cp -fr ${buildDir}/supersonic-webapp/* ${runtimeDir}/supersonic-standalone/webapp
|
main_class="com.tencent.supersonic.SemanticLauncher"
|
||||||
rm -fr ${buildDir}/supersonic-webapp
|
fi
|
||||||
fi
|
}
|
||||||
if [[ "$service" == "semantic" ]]; then
|
setMainClass
|
||||||
json=$(cat ${runtimeDir}/supersonic-semantic/webapp/supersonic.config.json)
|
# 3.set app name
|
||||||
json=$(echo $json | jq '.env="semantic"')
|
function setAppName {
|
||||||
echo $json > ${runtimeDir}/supersonic-semantic/webapp/supersonic.config.json
|
if [ "$service" == $CHAT_SERVICE ]; then
|
||||||
fi
|
app_name=$CHAT_APP_NAME
|
||||||
|
elif [ "$service" == $SEMANTIC_SERVICE ]; then
|
||||||
|
app_name=$SEMANTIC_APP_NAME
|
||||||
|
elif [ "$service" == $PYLLM_SERVICE ]; then
|
||||||
|
app_name=$PYLLM_APP_NAME
|
||||||
|
fi
|
||||||
|
}
|
||||||
|
setAppName
|
||||||
|
|
||||||
if [[ "$service" == "chat" ]]; then
|
function reloadExamples {
|
||||||
json=$(cat ${runtimeDir}/supersonic-chat/webapp/supersonic.config.json)
|
pythonRunDir=${runtimeDir}/supersonic-${model_name}/pyllm
|
||||||
json=$(echo $json | jq '.env="chat"')
|
cd $pythonRunDir/sql
|
||||||
echo $json > ${runtimeDir}/supersonic-chat/webapp/supersonic.config.json
|
${python_path} examples_reload_run.py
|
||||||
fi
|
}
|
||||||
echo $command
|
|
||||||
echo $service
|
|
||||||
|
function start()
|
||||||
|
{
|
||||||
|
local_app_name=$1
|
||||||
|
pid=$(ps aux |grep ${local_app_name} | grep -v grep | awk '{print $2}')
|
||||||
|
if [[ "$pid" == "" ]]; then
|
||||||
|
if [[ ${local_app_name} == $PYLLM_APP_NAME ]]; then
|
||||||
|
runPythonService ${local_app_name}
|
||||||
|
else
|
||||||
|
runJavaService ${local_app_name}
|
||||||
|
fi
|
||||||
|
else
|
||||||
|
echo "Process (PID = $pid) is running."
|
||||||
|
return 1
|
||||||
|
fi
|
||||||
|
}
|
||||||
|
|
||||||
|
function stop()
|
||||||
|
{
|
||||||
|
pid=$(ps aux | grep $1 | grep -v grep | awk '{print $2}')
|
||||||
|
if [[ "$pid" == "" ]]; then
|
||||||
|
echo "Process $1 is not running !"
|
||||||
|
return 1
|
||||||
|
else
|
||||||
|
kill -9 $pid
|
||||||
|
echo "Process (PID = $pid) is killed !"
|
||||||
|
return 0
|
||||||
|
fi
|
||||||
|
}
|
||||||
|
|
||||||
|
function reload()
|
||||||
|
{
|
||||||
|
if [[ $1 == $PYLLM_APP_NAME ]]; then
|
||||||
|
reloadExamples
|
||||||
|
fi
|
||||||
|
}
|
||||||
|
|
||||||
|
# 4. execute command operation
|
||||||
case "$command" in
|
case "$command" in
|
||||||
start)
|
start)
|
||||||
if [[ "$service" == "semantic" ]];then
|
if [ "$service" == $PYLLM_SERVICE ]; then
|
||||||
echo -e "Starting semantic"
|
echo "Starting $app_name"
|
||||||
sh ${runtimeDir}/supersonic-semantic/bin/service.sh start
|
start $app_name
|
||||||
elif [[ "$service" == "chat" ]];then
|
echo "Starting $STANDALONE_APP_NAME"
|
||||||
echo -e "Starting chat"
|
start $STANDALONE_APP_NAME
|
||||||
sh ${runtimeDir}/supersonic-chat/bin/service.sh start
|
|
||||||
elif [[ "$service" == "llmparser" ]];then
|
|
||||||
echo -e "Starting LLM"
|
|
||||||
sh ${runtimeDir}/supersonic-standalone/llm/bin/service.sh start
|
|
||||||
elif [[ -z "$service" ]]; then
|
|
||||||
echo -e "Starting supersonic"
|
|
||||||
sh ${runtimeDir}/supersonic-standalone/bin/service.sh start
|
|
||||||
sh ${runtimeDir}/supersonic-standalone/llm/bin/service.sh start
|
|
||||||
else
|
else
|
||||||
echo "Use command {semantic|semantic||} to run."
|
echo "Starting $app_name"
|
||||||
|
start $app_name
|
||||||
fi
|
fi
|
||||||
|
echo "Start success"
|
||||||
;;
|
;;
|
||||||
stop)
|
stop)
|
||||||
if [[ "$service" == "semantic" ]];then
|
echo "Stopping $app_name"
|
||||||
echo -e "Stopping semantic"
|
stop $app_name
|
||||||
sh ${runtimeDir}/supersonic-semantic/bin/service.sh stop
|
echo "Stopping $PYLLM_APP_NAME"
|
||||||
elif [[ "$service" == "chat" ]];then
|
stop $PYLLM_APP_NAME
|
||||||
echo -e "Stopping chat"
|
echo "Stop success"
|
||||||
sh ${runtimeDir}/supersonic-chat/bin/service.sh stop
|
;;
|
||||||
elif [[ "$service" == "llmparser" ]];then
|
reload)
|
||||||
echo -e "Stopping LLM"
|
echo "Reloading ${app_name}"
|
||||||
sh ${runtimeDir}/supersonic-standalone/llm/bin/service.sh stop
|
reload ${app_name}
|
||||||
elif [[ -z "$service" ]]; then
|
echo "Reload success"
|
||||||
echo -e "Stopping supersonic"
|
|
||||||
sh ${runtimeDir}/supersonic-standalone/bin/service.sh stop
|
|
||||||
sh ${runtimeDir}/supersonic-standalone/llm/bin/service.sh stop
|
|
||||||
else
|
|
||||||
echo "Use command {semantic|semantic||} to run."
|
|
||||||
fi
|
|
||||||
;;
|
;;
|
||||||
restart)
|
restart)
|
||||||
if [[ "$service" == "semantic" ]];then
|
if [ "$service" == $PYLLM_SERVICE ]; then
|
||||||
echo -e "Restarting semantic"
|
echo "Stopping ${app_name}"
|
||||||
sh ${runtimeDir}/supersonic-semantic/bin/service.sh restart
|
stop ${app_name}
|
||||||
elif [[ "$service" == "chat" ]];then
|
echo "Stopping ${STANDALONE_APP_NAME}"
|
||||||
echo -e "Restarting chat"
|
stop $STANDALONE_APP_NAME
|
||||||
sh ${runtimeDir}/supersonic-chat/bin/service.sh restart
|
echo "Starting ${app_name}"
|
||||||
elif [[ "$service" == "llmparser" ]];then
|
start ${app_name}
|
||||||
echo -e "Restarting LLM"
|
echo "Starting ${STANDALONE_APP_NAME}"
|
||||||
sh ${runtimeDir}/supersonic-standalone/llm/bin/service.sh restart
|
start $STANDALONE_APP_NAME
|
||||||
elif [[ -z "$service" ]]; then
|
|
||||||
echo -e "Restarting supersonic"
|
|
||||||
sh ${runtimeDir}/supersonic-standalone/bin/service.sh restart
|
|
||||||
sh ${runtimeDir}/supersonic-standalone/llm/bin/service.sh restart
|
|
||||||
else
|
else
|
||||||
echo "Use command {semantic|semantic||} to run."
|
echo "Stopping ${app_name}"
|
||||||
|
stop ${app_name}
|
||||||
|
echo "Starting ${app_name}"
|
||||||
|
start ${app_name}
|
||||||
fi
|
fi
|
||||||
|
echo "Restart success"
|
||||||
;;
|
;;
|
||||||
*)
|
*)
|
||||||
echo "Use command {start|stop|status|restart} to run."
|
echo "Use command {start|stop|restart} to run."
|
||||||
exit 1
|
exit 1
|
||||||
esac
|
esac
|
||||||
|
|
||||||
exit 0
|
|
||||||
|
|||||||
@@ -6,14 +6,6 @@
|
|||||||
<format>tar.gz</format>
|
<format>tar.gz</format>
|
||||||
</formats>
|
</formats>
|
||||||
<fileSets>
|
<fileSets>
|
||||||
|
|
||||||
<fileSet>
|
|
||||||
<directory>${project.basedir}/src/main/bin</directory>
|
|
||||||
<outputDirectory>bin</outputDirectory>
|
|
||||||
<fileMode>0777</fileMode>
|
|
||||||
<directoryMode>0755</directoryMode>
|
|
||||||
</fileSet>
|
|
||||||
|
|
||||||
<fileSet>
|
<fileSet>
|
||||||
<directory>${project.basedir}/src/main/resources</directory>
|
<directory>${project.basedir}/src/main/resources</directory>
|
||||||
<outputDirectory>conf</outputDirectory>
|
<outputDirectory>conf</outputDirectory>
|
||||||
@@ -29,8 +21,8 @@
|
|||||||
</includes>
|
</includes>
|
||||||
</fileSet>
|
</fileSet>
|
||||||
<fileSet>
|
<fileSet>
|
||||||
<directory>${project.basedir}/../../chat/core/src/main/python</directory>
|
<directory>${project.basedir}/../../chat/python</directory>
|
||||||
<outputDirectory>llm</outputDirectory>
|
<outputDirectory>pyllm</outputDirectory>
|
||||||
<fileMode>0777</fileMode>
|
<fileMode>0777</fileMode>
|
||||||
<directoryMode>0755</directoryMode>
|
<directoryMode>0755</directoryMode>
|
||||||
</fileSet>
|
</fileSet>
|
||||||
|
|||||||
@@ -7,6 +7,9 @@ import com.tencent.supersonic.auth.api.authentication.request.UserReq;
|
|||||||
import java.util.List;
|
import java.util.List;
|
||||||
import java.util.Set;
|
import java.util.Set;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* UserAdaptor defines some interfaces for obtaining user and organization information
|
||||||
|
*/
|
||||||
public interface UserAdaptor {
|
public interface UserAdaptor {
|
||||||
|
|
||||||
List<String> getUserNames();
|
List<String> getUserNames();
|
||||||
|
|||||||
@@ -12,6 +12,8 @@ public class UserConstants {
|
|||||||
|
|
||||||
public static final String TOKEN_USER_EMAIL = "token_user_email";
|
public static final String TOKEN_USER_EMAIL = "token_user_email";
|
||||||
|
|
||||||
|
public static final String TOKEN_IS_ADMIN = "token_is_admin";
|
||||||
|
|
||||||
public static final String TOKEN_ALGORITHM = "HS512";
|
public static final String TOKEN_ALGORITHM = "HS512";
|
||||||
|
|
||||||
public static final String TOKEN_CREATE_TIME = "token_create_time";
|
public static final String TOKEN_CREATE_TIME = "token_create_time";
|
||||||
|
|||||||
@@ -18,17 +18,22 @@ public class User {
|
|||||||
|
|
||||||
private String email;
|
private String email;
|
||||||
|
|
||||||
public static User get(Long id, String name, String displayName, String email) {
|
private Integer isAdmin;
|
||||||
return new User(id, name, displayName, email);
|
|
||||||
|
public static User get(Long id, String name, String displayName, String email, Integer isAdmin) {
|
||||||
|
return new User(id, name, displayName, email, isAdmin);
|
||||||
}
|
}
|
||||||
|
|
||||||
public static User getFakeUser() {
|
public static User getFakeUser() {
|
||||||
return new User(1L, "admin", "admin", "admin@email");
|
return new User(1L, "admin", "admin", "admin@email", 1);
|
||||||
}
|
}
|
||||||
|
|
||||||
public String getDisplayName() {
|
public String getDisplayName() {
|
||||||
return StringUtils.isBlank(displayName) ? name : displayName;
|
return StringUtils.isBlank(displayName) ? name : displayName;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
public boolean isSuperAdmin() {
|
||||||
|
return isAdmin != null && isAdmin == 1;
|
||||||
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -9,13 +9,14 @@ public class UserWithPassword extends User {
|
|||||||
|
|
||||||
private String password;
|
private String password;
|
||||||
|
|
||||||
public UserWithPassword(Long id, String name, String displayName, String email, String password) {
|
public UserWithPassword(Long id, String name, String displayName, String email, String password, Integer isAdmin) {
|
||||||
super(id, name, displayName, email);
|
super(id, name, displayName, email, isAdmin);
|
||||||
this.password = password;
|
this.password = password;
|
||||||
}
|
}
|
||||||
|
|
||||||
public static UserWithPassword get(Long id, String name, String displayName, String email, String password) {
|
public static UserWithPassword get(Long id, String name, String displayName,
|
||||||
return new UserWithPassword(id, name, displayName, email, password);
|
String email, String password, Integer isAdmin) {
|
||||||
|
return new UserWithPassword(id, name, displayName, email, password, isAdmin);
|
||||||
}
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -6,7 +6,7 @@ import lombok.Data;
|
|||||||
@Data
|
@Data
|
||||||
public class AuthGroup {
|
public class AuthGroup {
|
||||||
|
|
||||||
private String modelId;
|
private Long modelId;
|
||||||
private String name;
|
private String name;
|
||||||
private Integer groupId;
|
private Integer groupId;
|
||||||
private List<AuthRule> authRules;
|
private List<AuthRule> authRules;
|
||||||
|
|||||||
@@ -7,13 +7,13 @@ import lombok.ToString;
|
|||||||
@ToString
|
@ToString
|
||||||
public class AuthRes {
|
public class AuthRes {
|
||||||
|
|
||||||
private String modelId;
|
private Long modelId;
|
||||||
private String name;
|
private String name;
|
||||||
|
|
||||||
public AuthRes() {
|
public AuthRes() {
|
||||||
}
|
}
|
||||||
|
|
||||||
public AuthRes(String modelId, String name) {
|
public AuthRes(Long modelId, String name) {
|
||||||
this.modelId = modelId;
|
this.modelId = modelId;
|
||||||
this.name = name;
|
this.name = name;
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,11 +1,13 @@
|
|||||||
package com.tencent.supersonic.auth.api.authorization.request;
|
package com.tencent.supersonic.auth.api.authorization.request;
|
||||||
|
|
||||||
|
import com.google.common.collect.Lists;
|
||||||
import com.tencent.supersonic.auth.api.authorization.pojo.AuthRes;
|
import com.tencent.supersonic.auth.api.authorization.pojo.AuthRes;
|
||||||
import java.util.ArrayList;
|
|
||||||
import java.util.List;
|
|
||||||
|
|
||||||
import lombok.Data;
|
import lombok.Data;
|
||||||
import lombok.ToString;
|
import lombok.ToString;
|
||||||
|
import org.springframework.util.CollectionUtils;
|
||||||
|
|
||||||
|
import java.util.ArrayList;
|
||||||
|
import java.util.List;
|
||||||
|
|
||||||
@Data
|
@Data
|
||||||
@ToString
|
@ToString
|
||||||
@@ -15,5 +17,17 @@ public class QueryAuthResReq {
|
|||||||
|
|
||||||
private List<AuthRes> resources;
|
private List<AuthRes> resources;
|
||||||
|
|
||||||
private String modelId;
|
private Long modelId;
|
||||||
|
|
||||||
|
private List<Long> modelIds;
|
||||||
|
|
||||||
|
public List<Long> getModelIds() {
|
||||||
|
if (!CollectionUtils.isEmpty(modelIds)) {
|
||||||
|
return modelIds;
|
||||||
|
}
|
||||||
|
if (modelId != null) {
|
||||||
|
return Lists.newArrayList(modelId);
|
||||||
|
}
|
||||||
|
return Lists.newArrayList();
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -10,7 +10,7 @@ public interface AuthService {
|
|||||||
|
|
||||||
List<AuthGroup> queryAuthGroups(String domainId, Integer groupId);
|
List<AuthGroup> queryAuthGroups(String domainId, Integer groupId);
|
||||||
|
|
||||||
void updateAuthGroup(AuthGroup group);
|
void addOrUpdateAuthGroup(AuthGroup group);
|
||||||
|
|
||||||
void removeAuthGroup(AuthGroup group);
|
void removeAuthGroup(AuthGroup group);
|
||||||
|
|
||||||
|
|||||||
@@ -33,12 +33,6 @@
|
|||||||
<artifactId>spring-boot-starter-jdbc</artifactId>
|
<artifactId>spring-boot-starter-jdbc</artifactId>
|
||||||
</dependency>
|
</dependency>
|
||||||
|
|
||||||
<dependency>
|
|
||||||
<groupId>org.mybatis</groupId>
|
|
||||||
<artifactId>mybatis</artifactId>
|
|
||||||
</dependency>
|
|
||||||
|
|
||||||
|
|
||||||
<dependency>
|
<dependency>
|
||||||
<groupId>com.alibaba</groupId>
|
<groupId>com.alibaba</groupId>
|
||||||
<artifactId>druid</artifactId>
|
<artifactId>druid</artifactId>
|
||||||
@@ -52,12 +46,7 @@
|
|||||||
<groupId>org.springframework.boot</groupId>
|
<groupId>org.springframework.boot</groupId>
|
||||||
<artifactId>spring-boot-starter-web</artifactId>
|
<artifactId>spring-boot-starter-web</artifactId>
|
||||||
</dependency>
|
</dependency>
|
||||||
<dependency>
|
|
||||||
<groupId>org.mybatis</groupId>
|
|
||||||
<artifactId>mybatis-spring</artifactId>
|
|
||||||
<version>${mybatis-spring.version}</version>
|
|
||||||
<scope>compile</scope>
|
|
||||||
</dependency>
|
|
||||||
<dependency>
|
<dependency>
|
||||||
<groupId>com.github.pagehelper</groupId>
|
<groupId>com.github.pagehelper</groupId>
|
||||||
<artifactId>pagehelper</artifactId>
|
<artifactId>pagehelper</artifactId>
|
||||||
|
|||||||
@@ -16,6 +16,9 @@ import java.util.List;
|
|||||||
import java.util.Set;
|
import java.util.Set;
|
||||||
import java.util.stream.Collectors;
|
import java.util.stream.Collectors;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* DefaultUserAdaptor provides a default method to obtain user and organization information
|
||||||
|
*/
|
||||||
public class DefaultUserAdaptor implements UserAdaptor {
|
public class DefaultUserAdaptor implements UserAdaptor {
|
||||||
|
|
||||||
private List<UserDO> getUserDOList() {
|
private List<UserDO> getUserDOList() {
|
||||||
@@ -71,7 +74,7 @@ public class DefaultUserAdaptor implements UserAdaptor {
|
|||||||
}
|
}
|
||||||
if (userDO.getPassword().equals(userReq.getPassword())) {
|
if (userDO.getPassword().equals(userReq.getPassword())) {
|
||||||
UserWithPassword user = UserWithPassword.get(userDO.getId(), userDO.getName(), userDO.getDisplayName(),
|
UserWithPassword user = UserWithPassword.get(userDO.getId(), userDO.getName(), userDO.getDisplayName(),
|
||||||
userDO.getEmail(), userDO.getPassword());
|
userDO.getEmail(), userDO.getPassword(), userDO.getIsAdmin());
|
||||||
return userTokenUtils.generateToken(user);
|
return userTokenUtils.generateToken(user);
|
||||||
}
|
}
|
||||||
throw new RuntimeException("password not correct, please try again");
|
throw new RuntimeException("password not correct, please try again");
|
||||||
|
|||||||
@@ -29,7 +29,6 @@ public abstract class AuthenticationInterceptor implements HandlerInterceptor {
|
|||||||
|
|
||||||
protected S2ThreadContext s2ThreadContext;
|
protected S2ThreadContext s2ThreadContext;
|
||||||
|
|
||||||
|
|
||||||
protected boolean isExcludedUri(String uri) {
|
protected boolean isExcludedUri(String uri) {
|
||||||
String excludePathStr = authenticationConfig.getExcludePath();
|
String excludePathStr = authenticationConfig.getExcludePath();
|
||||||
if (Strings.isEmpty(excludePathStr)) {
|
if (Strings.isEmpty(excludePathStr)) {
|
||||||
@@ -59,7 +58,6 @@ public abstract class AuthenticationInterceptor implements HandlerInterceptor {
|
|||||||
return "true".equalsIgnoreCase(internal);
|
return "true".equalsIgnoreCase(internal);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
protected void reflectSetparam(HttpServletRequest request, String key, String value) {
|
protected void reflectSetparam(HttpServletRequest request, String key, String value) {
|
||||||
try {
|
try {
|
||||||
if (request instanceof StandardMultipartHttpServletRequest) {
|
if (request instanceof StandardMultipartHttpServletRequest) {
|
||||||
|
|||||||
@@ -76,5 +76,4 @@ public class DefaultAuthenticationInterceptor extends AuthenticationInterceptor
|
|||||||
s2ThreadContext.set(threadContext);
|
s2ThreadContext.set(threadContext);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,7 +1,6 @@
|
|||||||
package com.tencent.supersonic.auth.authentication.persistence.dataobject;
|
package com.tencent.supersonic.auth.authentication.persistence.dataobject;
|
||||||
|
|
||||||
public class UserDO {
|
public class UserDO {
|
||||||
|
|
||||||
/**
|
/**
|
||||||
*
|
*
|
||||||
*/
|
*/
|
||||||
@@ -28,6 +27,12 @@ public class UserDO {
|
|||||||
private String email;
|
private String email;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
*
|
||||||
|
*/
|
||||||
|
private Integer isAdmin;
|
||||||
|
|
||||||
|
/**
|
||||||
|
*
|
||||||
* @return id
|
* @return id
|
||||||
*/
|
*/
|
||||||
public Long getId() {
|
public Long getId() {
|
||||||
@@ -35,6 +40,7 @@ public class UserDO {
|
|||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
*
|
||||||
* @param id
|
* @param id
|
||||||
*/
|
*/
|
||||||
public void setId(Long id) {
|
public void setId(Long id) {
|
||||||
@@ -42,6 +48,7 @@ public class UserDO {
|
|||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
*
|
||||||
* @return name
|
* @return name
|
||||||
*/
|
*/
|
||||||
public String getName() {
|
public String getName() {
|
||||||
@@ -49,6 +56,7 @@ public class UserDO {
|
|||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
*
|
||||||
* @param name
|
* @param name
|
||||||
*/
|
*/
|
||||||
public void setName(String name) {
|
public void setName(String name) {
|
||||||
@@ -56,6 +64,7 @@ public class UserDO {
|
|||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
*
|
||||||
* @return password
|
* @return password
|
||||||
*/
|
*/
|
||||||
public String getPassword() {
|
public String getPassword() {
|
||||||
@@ -63,6 +72,7 @@ public class UserDO {
|
|||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
*
|
||||||
* @param password
|
* @param password
|
||||||
*/
|
*/
|
||||||
public void setPassword(String password) {
|
public void setPassword(String password) {
|
||||||
@@ -70,6 +80,7 @@ public class UserDO {
|
|||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
*
|
||||||
* @return display_name
|
* @return display_name
|
||||||
*/
|
*/
|
||||||
public String getDisplayName() {
|
public String getDisplayName() {
|
||||||
@@ -77,6 +88,7 @@ public class UserDO {
|
|||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
*
|
||||||
* @param displayName
|
* @param displayName
|
||||||
*/
|
*/
|
||||||
public void setDisplayName(String displayName) {
|
public void setDisplayName(String displayName) {
|
||||||
@@ -84,6 +96,7 @@ public class UserDO {
|
|||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
*
|
||||||
* @return email
|
* @return email
|
||||||
*/
|
*/
|
||||||
public String getEmail() {
|
public String getEmail() {
|
||||||
@@ -91,9 +104,26 @@ public class UserDO {
|
|||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
*
|
||||||
* @param email
|
* @param email
|
||||||
*/
|
*/
|
||||||
public void setEmail(String email) {
|
public void setEmail(String email) {
|
||||||
this.email = email == null ? null : email.trim();
|
this.email = email == null ? null : email.trim();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
*
|
||||||
|
* @return is_admin
|
||||||
|
*/
|
||||||
|
public Integer getIsAdmin() {
|
||||||
|
return isAdmin;
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
*
|
||||||
|
* @param isAdmin
|
||||||
|
*/
|
||||||
|
public void setIsAdmin(Integer isAdmin) {
|
||||||
|
this.isAdmin = isAdmin;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
@@ -4,7 +4,6 @@ import java.util.ArrayList;
|
|||||||
import java.util.List;
|
import java.util.List;
|
||||||
|
|
||||||
public class UserDOExample {
|
public class UserDOExample {
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* s2_user
|
* s2_user
|
||||||
*/
|
*/
|
||||||
@@ -31,6 +30,7 @@ public class UserDOExample {
|
|||||||
protected Integer limitEnd;
|
protected Integer limitEnd;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
*
|
||||||
* @mbg.generated
|
* @mbg.generated
|
||||||
*/
|
*/
|
||||||
public UserDOExample() {
|
public UserDOExample() {
|
||||||
@@ -38,13 +38,7 @@ public class UserDOExample {
|
|||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* @mbg.generated
|
*
|
||||||
*/
|
|
||||||
public String getOrderByClause() {
|
|
||||||
return orderByClause;
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* @mbg.generated
|
* @mbg.generated
|
||||||
*/
|
*/
|
||||||
public void setOrderByClause(String orderByClause) {
|
public void setOrderByClause(String orderByClause) {
|
||||||
@@ -52,13 +46,15 @@ public class UserDOExample {
|
|||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
*
|
||||||
* @mbg.generated
|
* @mbg.generated
|
||||||
*/
|
*/
|
||||||
public boolean isDistinct() {
|
public String getOrderByClause() {
|
||||||
return distinct;
|
return orderByClause;
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
*
|
||||||
* @mbg.generated
|
* @mbg.generated
|
||||||
*/
|
*/
|
||||||
public void setDistinct(boolean distinct) {
|
public void setDistinct(boolean distinct) {
|
||||||
@@ -66,6 +62,15 @@ public class UserDOExample {
|
|||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
*
|
||||||
|
* @mbg.generated
|
||||||
|
*/
|
||||||
|
public boolean isDistinct() {
|
||||||
|
return distinct;
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
*
|
||||||
* @mbg.generated
|
* @mbg.generated
|
||||||
*/
|
*/
|
||||||
public List<Criteria> getOredCriteria() {
|
public List<Criteria> getOredCriteria() {
|
||||||
@@ -73,6 +78,7 @@ public class UserDOExample {
|
|||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
*
|
||||||
* @mbg.generated
|
* @mbg.generated
|
||||||
*/
|
*/
|
||||||
public void or(Criteria criteria) {
|
public void or(Criteria criteria) {
|
||||||
@@ -80,6 +86,7 @@ public class UserDOExample {
|
|||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
*
|
||||||
* @mbg.generated
|
* @mbg.generated
|
||||||
*/
|
*/
|
||||||
public Criteria or() {
|
public Criteria or() {
|
||||||
@@ -89,6 +96,7 @@ public class UserDOExample {
|
|||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
*
|
||||||
* @mbg.generated
|
* @mbg.generated
|
||||||
*/
|
*/
|
||||||
public Criteria createCriteria() {
|
public Criteria createCriteria() {
|
||||||
@@ -100,6 +108,7 @@ public class UserDOExample {
|
|||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
*
|
||||||
* @mbg.generated
|
* @mbg.generated
|
||||||
*/
|
*/
|
||||||
protected Criteria createCriteriaInternal() {
|
protected Criteria createCriteriaInternal() {
|
||||||
@@ -108,6 +117,7 @@ public class UserDOExample {
|
|||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
*
|
||||||
* @mbg.generated
|
* @mbg.generated
|
||||||
*/
|
*/
|
||||||
public void clear() {
|
public void clear() {
|
||||||
@@ -117,6 +127,15 @@ public class UserDOExample {
|
|||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
*
|
||||||
|
* @mbg.generated
|
||||||
|
*/
|
||||||
|
public void setLimitStart(Integer limitStart) {
|
||||||
|
this.limitStart=limitStart;
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
*
|
||||||
* @mbg.generated
|
* @mbg.generated
|
||||||
*/
|
*/
|
||||||
public Integer getLimitStart() {
|
public Integer getLimitStart() {
|
||||||
@@ -124,31 +143,25 @@ public class UserDOExample {
|
|||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
*
|
||||||
* @mbg.generated
|
* @mbg.generated
|
||||||
*/
|
*/
|
||||||
public void setLimitStart(Integer limitStart) {
|
public void setLimitEnd(Integer limitEnd) {
|
||||||
this.limitStart = limitStart;
|
this.limitEnd=limitEnd;
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
*
|
||||||
* @mbg.generated
|
* @mbg.generated
|
||||||
*/
|
*/
|
||||||
public Integer getLimitEnd() {
|
public Integer getLimitEnd() {
|
||||||
return limitEnd;
|
return limitEnd;
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
|
||||||
* @mbg.generated
|
|
||||||
*/
|
|
||||||
public void setLimitEnd(Integer limitEnd) {
|
|
||||||
this.limitEnd = limitEnd;
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* s2_user null
|
* s2_user null
|
||||||
*/
|
*/
|
||||||
protected abstract static class GeneratedCriteria {
|
protected abstract static class GeneratedCriteria {
|
||||||
|
|
||||||
protected List<Criterion> criteria;
|
protected List<Criterion> criteria;
|
||||||
|
|
||||||
protected GeneratedCriteria() {
|
protected GeneratedCriteria() {
|
||||||
@@ -528,6 +541,66 @@ public class UserDOExample {
|
|||||||
addCriterion("email not between", value1, value2, "email");
|
addCriterion("email not between", value1, value2, "email");
|
||||||
return (Criteria) this;
|
return (Criteria) this;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
public Criteria andIsAdminIsNull() {
|
||||||
|
addCriterion("is_admin is null");
|
||||||
|
return (Criteria) this;
|
||||||
|
}
|
||||||
|
|
||||||
|
public Criteria andIsAdminIsNotNull() {
|
||||||
|
addCriterion("is_admin is not null");
|
||||||
|
return (Criteria) this;
|
||||||
|
}
|
||||||
|
|
||||||
|
public Criteria andIsAdminEqualTo(Integer value) {
|
||||||
|
addCriterion("is_admin =", value, "isAdmin");
|
||||||
|
return (Criteria) this;
|
||||||
|
}
|
||||||
|
|
||||||
|
public Criteria andIsAdminNotEqualTo(Integer value) {
|
||||||
|
addCriterion("is_admin <>", value, "isAdmin");
|
||||||
|
return (Criteria) this;
|
||||||
|
}
|
||||||
|
|
||||||
|
public Criteria andIsAdminGreaterThan(Integer value) {
|
||||||
|
addCriterion("is_admin >", value, "isAdmin");
|
||||||
|
return (Criteria) this;
|
||||||
|
}
|
||||||
|
|
||||||
|
public Criteria andIsAdminGreaterThanOrEqualTo(Integer value) {
|
||||||
|
addCriterion("is_admin >=", value, "isAdmin");
|
||||||
|
return (Criteria) this;
|
||||||
|
}
|
||||||
|
|
||||||
|
public Criteria andIsAdminLessThan(Integer value) {
|
||||||
|
addCriterion("is_admin <", value, "isAdmin");
|
||||||
|
return (Criteria) this;
|
||||||
|
}
|
||||||
|
|
||||||
|
public Criteria andIsAdminLessThanOrEqualTo(Integer value) {
|
||||||
|
addCriterion("is_admin <=", value, "isAdmin");
|
||||||
|
return (Criteria) this;
|
||||||
|
}
|
||||||
|
|
||||||
|
public Criteria andIsAdminIn(List<Integer> values) {
|
||||||
|
addCriterion("is_admin in", values, "isAdmin");
|
||||||
|
return (Criteria) this;
|
||||||
|
}
|
||||||
|
|
||||||
|
public Criteria andIsAdminNotIn(List<Integer> values) {
|
||||||
|
addCriterion("is_admin not in", values, "isAdmin");
|
||||||
|
return (Criteria) this;
|
||||||
|
}
|
||||||
|
|
||||||
|
public Criteria andIsAdminBetween(Integer value1, Integer value2) {
|
||||||
|
addCriterion("is_admin between", value1, value2, "isAdmin");
|
||||||
|
return (Criteria) this;
|
||||||
|
}
|
||||||
|
|
||||||
|
public Criteria andIsAdminNotBetween(Integer value1, Integer value2) {
|
||||||
|
addCriterion("is_admin not between", value1, value2, "isAdmin");
|
||||||
|
return (Criteria) this;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
@@ -544,7 +617,6 @@ public class UserDOExample {
|
|||||||
* s2_user null
|
* s2_user null
|
||||||
*/
|
*/
|
||||||
public static class Criterion {
|
public static class Criterion {
|
||||||
|
|
||||||
private String condition;
|
private String condition;
|
||||||
|
|
||||||
private Object value;
|
private Object value;
|
||||||
@@ -561,6 +633,38 @@ public class UserDOExample {
|
|||||||
|
|
||||||
private String typeHandler;
|
private String typeHandler;
|
||||||
|
|
||||||
|
public String getCondition() {
|
||||||
|
return condition;
|
||||||
|
}
|
||||||
|
|
||||||
|
public Object getValue() {
|
||||||
|
return value;
|
||||||
|
}
|
||||||
|
|
||||||
|
public Object getSecondValue() {
|
||||||
|
return secondValue;
|
||||||
|
}
|
||||||
|
|
||||||
|
public boolean isNoValue() {
|
||||||
|
return noValue;
|
||||||
|
}
|
||||||
|
|
||||||
|
public boolean isSingleValue() {
|
||||||
|
return singleValue;
|
||||||
|
}
|
||||||
|
|
||||||
|
public boolean isBetweenValue() {
|
||||||
|
return betweenValue;
|
||||||
|
}
|
||||||
|
|
||||||
|
public boolean isListValue() {
|
||||||
|
return listValue;
|
||||||
|
}
|
||||||
|
|
||||||
|
public String getTypeHandler() {
|
||||||
|
return typeHandler;
|
||||||
|
}
|
||||||
|
|
||||||
protected Criterion(String condition) {
|
protected Criterion(String condition) {
|
||||||
super();
|
super();
|
||||||
this.condition = condition;
|
this.condition = condition;
|
||||||
@@ -596,37 +700,5 @@ public class UserDOExample {
|
|||||||
protected Criterion(String condition, Object value, Object secondValue) {
|
protected Criterion(String condition, Object value, Object secondValue) {
|
||||||
this(condition, value, secondValue, null);
|
this(condition, value, secondValue, null);
|
||||||
}
|
}
|
||||||
|
|
||||||
public String getCondition() {
|
|
||||||
return condition;
|
|
||||||
}
|
|
||||||
|
|
||||||
public Object getValue() {
|
|
||||||
return value;
|
|
||||||
}
|
|
||||||
|
|
||||||
public Object getSecondValue() {
|
|
||||||
return secondValue;
|
|
||||||
}
|
|
||||||
|
|
||||||
public boolean isNoValue() {
|
|
||||||
return noValue;
|
|
||||||
}
|
|
||||||
|
|
||||||
public boolean isSingleValue() {
|
|
||||||
return singleValue;
|
|
||||||
}
|
|
||||||
|
|
||||||
public boolean isBetweenValue() {
|
|
||||||
return betweenValue;
|
|
||||||
}
|
|
||||||
|
|
||||||
public boolean isListValue() {
|
|
||||||
return listValue;
|
|
||||||
}
|
|
||||||
|
|
||||||
public String getTypeHandler() {
|
|
||||||
return typeHandler;
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -20,7 +20,6 @@ public class UserRepositoryImpl implements UserRepository {
|
|||||||
this.userDOMapper = userDOMapper;
|
this.userDOMapper = userDOMapper;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public List<UserDO> getUserList() {
|
public List<UserDO> getUserList() {
|
||||||
return userDOMapper.selectByExample(new UserDOExample());
|
return userDOMapper.selectByExample(new UserDOExample());
|
||||||
@@ -40,5 +39,4 @@ public class UserRepositoryImpl implements UserRepository {
|
|||||||
return userDOOptional.orElse(null);
|
return userDOOptional.orElse(null);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -29,7 +29,6 @@ public class UserController {
|
|||||||
this.userService = userService;
|
this.userService = userService;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@GetMapping("/getCurrentUser")
|
@GetMapping("/getCurrentUser")
|
||||||
public User getCurrentUser(HttpServletRequest httpServletRequest, HttpServletResponse httpServletResponse) {
|
public User getCurrentUser(HttpServletRequest httpServletRequest, HttpServletResponse httpServletResponse) {
|
||||||
return UserHolder.findUser(httpServletRequest, httpServletResponse);
|
return UserHolder.findUser(httpServletRequest, httpServletResponse);
|
||||||
@@ -70,5 +69,4 @@ public class UserController {
|
|||||||
return userService.login(userCmd);
|
return userService.login(userCmd);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -13,7 +13,6 @@ import org.springframework.stereotype.Service;
|
|||||||
@Service
|
@Service
|
||||||
public class UserServiceImpl implements UserService {
|
public class UserServiceImpl implements UserService {
|
||||||
|
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public List<String> getUserNames() {
|
public List<String> getUserNames() {
|
||||||
return ComponentFactory.getUserAdaptor().getUserNames();
|
return ComponentFactory.getUserAdaptor().getUserNames();
|
||||||
|
|||||||
@@ -20,5 +20,4 @@ public class FakeUserStrategy implements UserStrategy {
|
|||||||
return User.getFakeUser();
|
return User.getFakeUser();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -2,6 +2,7 @@ package com.tencent.supersonic.auth.authentication.utils;
|
|||||||
|
|
||||||
import static com.tencent.supersonic.auth.api.authentication.constant.UserConstants.TOKEN_ALGORITHM;
|
import static com.tencent.supersonic.auth.api.authentication.constant.UserConstants.TOKEN_ALGORITHM;
|
||||||
import static com.tencent.supersonic.auth.api.authentication.constant.UserConstants.TOKEN_CREATE_TIME;
|
import static com.tencent.supersonic.auth.api.authentication.constant.UserConstants.TOKEN_CREATE_TIME;
|
||||||
|
import static com.tencent.supersonic.auth.api.authentication.constant.UserConstants.TOKEN_IS_ADMIN;
|
||||||
import static com.tencent.supersonic.auth.api.authentication.constant.UserConstants.TOKEN_PREFIX;
|
import static com.tencent.supersonic.auth.api.authentication.constant.UserConstants.TOKEN_PREFIX;
|
||||||
import static com.tencent.supersonic.auth.api.authentication.constant.UserConstants.TOKEN_TIME_OUT;
|
import static com.tencent.supersonic.auth.api.authentication.constant.UserConstants.TOKEN_TIME_OUT;
|
||||||
import static com.tencent.supersonic.auth.api.authentication.constant.UserConstants.TOKEN_USER_DISPLAY_NAME;
|
import static com.tencent.supersonic.auth.api.authentication.constant.UserConstants.TOKEN_USER_DISPLAY_NAME;
|
||||||
@@ -42,6 +43,7 @@ public class UserTokenUtils {
|
|||||||
claims.put(TOKEN_USER_PASSWORD, StringUtils.isEmpty(user.getPassword()) ? "" : user.getPassword());
|
claims.put(TOKEN_USER_PASSWORD, StringUtils.isEmpty(user.getPassword()) ? "" : user.getPassword());
|
||||||
claims.put(TOKEN_USER_DISPLAY_NAME, user.getDisplayName());
|
claims.put(TOKEN_USER_DISPLAY_NAME, user.getDisplayName());
|
||||||
claims.put(TOKEN_CREATE_TIME, System.currentTimeMillis());
|
claims.put(TOKEN_CREATE_TIME, System.currentTimeMillis());
|
||||||
|
claims.put(TOKEN_IS_ADMIN, user.getIsAdmin());
|
||||||
return generate(claims);
|
return generate(claims);
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -52,10 +54,10 @@ public class UserTokenUtils {
|
|||||||
claims.put(TOKEN_USER_PASSWORD, "admin");
|
claims.put(TOKEN_USER_PASSWORD, "admin");
|
||||||
claims.put(TOKEN_USER_DISPLAY_NAME, "admin");
|
claims.put(TOKEN_USER_DISPLAY_NAME, "admin");
|
||||||
claims.put(TOKEN_CREATE_TIME, System.currentTimeMillis());
|
claims.put(TOKEN_CREATE_TIME, System.currentTimeMillis());
|
||||||
|
claims.put(TOKEN_IS_ADMIN, 1);
|
||||||
return generate(claims);
|
return generate(claims);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
public User getUser(HttpServletRequest request) {
|
public User getUser(HttpServletRequest request) {
|
||||||
String token = request.getHeader(authenticationConfig.getTokenHttpHeaderKey());
|
String token = request.getHeader(authenticationConfig.getTokenHttpHeaderKey());
|
||||||
final Claims claims = getClaims(token);
|
final Claims claims = getClaims(token);
|
||||||
@@ -63,7 +65,9 @@ public class UserTokenUtils {
|
|||||||
String userName = String.valueOf(claims.get(TOKEN_USER_NAME));
|
String userName = String.valueOf(claims.get(TOKEN_USER_NAME));
|
||||||
String email = String.valueOf(claims.get(TOKEN_USER_EMAIL));
|
String email = String.valueOf(claims.get(TOKEN_USER_EMAIL));
|
||||||
String displayName = String.valueOf(claims.get(TOKEN_USER_DISPLAY_NAME));
|
String displayName = String.valueOf(claims.get(TOKEN_USER_DISPLAY_NAME));
|
||||||
return User.get(userId, userName, displayName, email);
|
Integer isAdmin = claims.get(TOKEN_IS_ADMIN) == null
|
||||||
|
? 0 : Integer.parseInt(claims.get(TOKEN_IS_ADMIN).toString());
|
||||||
|
return User.get(userId, userName, displayName, email, isAdmin);
|
||||||
}
|
}
|
||||||
|
|
||||||
public UserWithPassword getUserWithPassword(HttpServletRequest request) {
|
public UserWithPassword getUserWithPassword(HttpServletRequest request) {
|
||||||
@@ -79,7 +83,9 @@ public class UserTokenUtils {
|
|||||||
String email = String.valueOf(claims.get(TOKEN_USER_EMAIL));
|
String email = String.valueOf(claims.get(TOKEN_USER_EMAIL));
|
||||||
String displayName = String.valueOf(claims.get(TOKEN_USER_DISPLAY_NAME));
|
String displayName = String.valueOf(claims.get(TOKEN_USER_DISPLAY_NAME));
|
||||||
String password = String.valueOf(claims.get(TOKEN_USER_PASSWORD));
|
String password = String.valueOf(claims.get(TOKEN_USER_PASSWORD));
|
||||||
return UserWithPassword.get(userId, userName, displayName, email, password);
|
Integer isAdmin = claims.get(TOKEN_IS_ADMIN) == null
|
||||||
|
? 0 : Integer.parseInt(claims.get(TOKEN_IS_ADMIN).toString());
|
||||||
|
return UserWithPassword.get(userId, userName, displayName, email, password, isAdmin);
|
||||||
}
|
}
|
||||||
|
|
||||||
private Claims getClaims(String token) {
|
private Claims getClaims(String token) {
|
||||||
@@ -113,5 +119,4 @@ public class UserTokenUtils {
|
|||||||
.compact();
|
.compact();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -2,11 +2,12 @@
|
|||||||
<!DOCTYPE mapper PUBLIC "-//mybatis.org//DTD Mapper 3.0//EN" "http://mybatis.org/dtd/mybatis-3-mapper.dtd">
|
<!DOCTYPE mapper PUBLIC "-//mybatis.org//DTD Mapper 3.0//EN" "http://mybatis.org/dtd/mybatis-3-mapper.dtd">
|
||||||
<mapper namespace="com.tencent.supersonic.auth.authentication.persistence.mapper.UserDOMapper">
|
<mapper namespace="com.tencent.supersonic.auth.authentication.persistence.mapper.UserDOMapper">
|
||||||
<resultMap id="BaseResultMap" type="com.tencent.supersonic.auth.authentication.persistence.dataobject.UserDO">
|
<resultMap id="BaseResultMap" type="com.tencent.supersonic.auth.authentication.persistence.dataobject.UserDO">
|
||||||
<id column="id" jdbcType="BIGINT" property="id" />
|
<result column="id" jdbcType="BIGINT" property="id" />
|
||||||
<result column="name" jdbcType="VARCHAR" property="name" />
|
<result column="name" jdbcType="VARCHAR" property="name" />
|
||||||
<result column="password" jdbcType="VARCHAR" property="password" />
|
<result column="password" jdbcType="VARCHAR" property="password" />
|
||||||
<result column="display_name" jdbcType="VARCHAR" property="displayName" />
|
<result column="display_name" jdbcType="VARCHAR" property="displayName" />
|
||||||
<result column="email" jdbcType="VARCHAR" property="email" />
|
<result column="email" jdbcType="VARCHAR" property="email" />
|
||||||
|
<result column="is_admin" jdbcType="INTEGER" property="isAdmin" />
|
||||||
</resultMap>
|
</resultMap>
|
||||||
<sql id="Example_Where_Clause">
|
<sql id="Example_Where_Clause">
|
||||||
<where>
|
<where>
|
||||||
@@ -38,7 +39,7 @@
|
|||||||
</where>
|
</where>
|
||||||
</sql>
|
</sql>
|
||||||
<sql id="Base_Column_List">
|
<sql id="Base_Column_List">
|
||||||
id, name, password, display_name, email
|
id, name, password, display_name, email, is_admin
|
||||||
</sql>
|
</sql>
|
||||||
<select id="selectByExample" parameterType="com.tencent.supersonic.auth.authentication.persistence.dataobject.UserDOExample" resultMap="BaseResultMap">
|
<select id="selectByExample" parameterType="com.tencent.supersonic.auth.authentication.persistence.dataobject.UserDOExample" resultMap="BaseResultMap">
|
||||||
select
|
select
|
||||||
@@ -57,21 +58,13 @@
|
|||||||
limit #{limitStart} , #{limitEnd}
|
limit #{limitStart} , #{limitEnd}
|
||||||
</if>
|
</if>
|
||||||
</select>
|
</select>
|
||||||
<select id="selectByPrimaryKey" parameterType="java.lang.Long" resultMap="BaseResultMap">
|
|
||||||
select
|
|
||||||
<include refid="Base_Column_List" />
|
|
||||||
from s2_user
|
|
||||||
where id = #{id,jdbcType=BIGINT}
|
|
||||||
</select>
|
|
||||||
<delete id="deleteByPrimaryKey" parameterType="java.lang.Long">
|
|
||||||
delete from s2_user
|
|
||||||
where id = #{id,jdbcType=BIGINT}
|
|
||||||
</delete>
|
|
||||||
<insert id="insert" parameterType="com.tencent.supersonic.auth.authentication.persistence.dataobject.UserDO">
|
<insert id="insert" parameterType="com.tencent.supersonic.auth.authentication.persistence.dataobject.UserDO">
|
||||||
insert into s2_user (id, name, password,
|
insert into s2_user (id, name, password,
|
||||||
display_name, email)
|
display_name, email, is_admin
|
||||||
|
)
|
||||||
values (#{id,jdbcType=BIGINT}, #{name,jdbcType=VARCHAR}, #{password,jdbcType=VARCHAR},
|
values (#{id,jdbcType=BIGINT}, #{name,jdbcType=VARCHAR}, #{password,jdbcType=VARCHAR},
|
||||||
#{displayName,jdbcType=VARCHAR}, #{email,jdbcType=VARCHAR})
|
#{displayName,jdbcType=VARCHAR}, #{email,jdbcType=VARCHAR}, #{isAdmin,jdbcType=INTEGER}
|
||||||
|
)
|
||||||
</insert>
|
</insert>
|
||||||
<insert id="insertSelective" parameterType="com.tencent.supersonic.auth.authentication.persistence.dataobject.UserDO">
|
<insert id="insertSelective" parameterType="com.tencent.supersonic.auth.authentication.persistence.dataobject.UserDO">
|
||||||
insert into s2_user
|
insert into s2_user
|
||||||
@@ -91,6 +84,9 @@
|
|||||||
<if test="email != null">
|
<if test="email != null">
|
||||||
email,
|
email,
|
||||||
</if>
|
</if>
|
||||||
|
<if test="isAdmin != null">
|
||||||
|
is_admin,
|
||||||
|
</if>
|
||||||
</trim>
|
</trim>
|
||||||
<trim prefix="values (" suffix=")" suffixOverrides=",">
|
<trim prefix="values (" suffix=")" suffixOverrides=",">
|
||||||
<if test="id != null">
|
<if test="id != null">
|
||||||
@@ -108,6 +104,9 @@
|
|||||||
<if test="email != null">
|
<if test="email != null">
|
||||||
#{email,jdbcType=VARCHAR},
|
#{email,jdbcType=VARCHAR},
|
||||||
</if>
|
</if>
|
||||||
|
<if test="isAdmin != null">
|
||||||
|
#{isAdmin,jdbcType=INTEGER},
|
||||||
|
</if>
|
||||||
</trim>
|
</trim>
|
||||||
</insert>
|
</insert>
|
||||||
<select id="countByExample" parameterType="com.tencent.supersonic.auth.authentication.persistence.dataobject.UserDOExample" resultType="java.lang.Long">
|
<select id="countByExample" parameterType="com.tencent.supersonic.auth.authentication.persistence.dataobject.UserDOExample" resultType="java.lang.Long">
|
||||||
@@ -116,30 +115,4 @@
|
|||||||
<include refid="Example_Where_Clause" />
|
<include refid="Example_Where_Clause" />
|
||||||
</if>
|
</if>
|
||||||
</select>
|
</select>
|
||||||
<update id="updateByPrimaryKeySelective" parameterType="com.tencent.supersonic.auth.authentication.persistence.dataobject.UserDO">
|
|
||||||
update s2_user
|
|
||||||
<set>
|
|
||||||
<if test="name != null">
|
|
||||||
name = #{name,jdbcType=VARCHAR},
|
|
||||||
</if>
|
|
||||||
<if test="password != null">
|
|
||||||
password = #{password,jdbcType=VARCHAR},
|
|
||||||
</if>
|
|
||||||
<if test="displayName != null">
|
|
||||||
display_name = #{displayName,jdbcType=VARCHAR},
|
|
||||||
</if>
|
|
||||||
<if test="email != null">
|
|
||||||
email = #{email,jdbcType=VARCHAR},
|
|
||||||
</if>
|
|
||||||
</set>
|
|
||||||
where id = #{id,jdbcType=BIGINT}
|
|
||||||
</update>
|
|
||||||
<update id="updateByPrimaryKey" parameterType="com.tencent.supersonic.auth.authentication.persistence.dataobject.UserDO">
|
|
||||||
update s2_user
|
|
||||||
set name = #{name,jdbcType=VARCHAR},
|
|
||||||
password = #{password,jdbcType=VARCHAR},
|
|
||||||
display_name = #{displayName,jdbcType=VARCHAR},
|
|
||||||
email = #{email,jdbcType=VARCHAR}
|
|
||||||
where id = #{id,jdbcType=BIGINT}
|
|
||||||
</update>
|
|
||||||
</mapper>
|
</mapper>
|
||||||
@@ -40,7 +40,7 @@ public class AuthController {
|
|||||||
@PostMapping("/createGroup")
|
@PostMapping("/createGroup")
|
||||||
public void newAuthGroup(@RequestBody AuthGroup group) {
|
public void newAuthGroup(@RequestBody AuthGroup group) {
|
||||||
group.setGroupId(null);
|
group.setGroupId(null);
|
||||||
authService.updateAuthGroup(group);
|
authService.addOrUpdateAuthGroup(group);
|
||||||
}
|
}
|
||||||
|
|
||||||
@PostMapping("/removeGroup")
|
@PostMapping("/removeGroup")
|
||||||
@@ -58,7 +58,7 @@ public class AuthController {
|
|||||||
if (group.getGroupId() == null || group.getGroupId() == 0) {
|
if (group.getGroupId() == null || group.getGroupId() == 0) {
|
||||||
throw new RuntimeException("groupId is empty");
|
throw new RuntimeException("groupId is empty");
|
||||||
}
|
}
|
||||||
authService.updateAuthGroup(group);
|
authService.addOrUpdateAuthGroup(group);
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
|||||||
@@ -1,6 +1,7 @@
|
|||||||
package com.tencent.supersonic.auth.authorization.application;
|
package com.tencent.supersonic.auth.authorization.service;
|
||||||
|
|
||||||
import com.google.common.base.Strings;
|
import com.google.common.base.Strings;
|
||||||
|
import com.google.common.collect.Lists;
|
||||||
import com.google.gson.Gson;
|
import com.google.gson.Gson;
|
||||||
import com.tencent.supersonic.auth.api.authentication.pojo.User;
|
import com.tencent.supersonic.auth.api.authentication.pojo.User;
|
||||||
import com.tencent.supersonic.auth.api.authentication.service.UserService;
|
import com.tencent.supersonic.auth.api.authentication.service.UserService;
|
||||||
@@ -13,7 +14,6 @@ import com.tencent.supersonic.auth.api.authorization.service.AuthService;
|
|||||||
import com.tencent.supersonic.auth.api.authorization.pojo.AuthGroup;
|
import com.tencent.supersonic.auth.api.authorization.pojo.AuthGroup;
|
||||||
import com.tencent.supersonic.auth.api.authorization.pojo.AuthRule;
|
import com.tencent.supersonic.auth.api.authorization.pojo.AuthRule;
|
||||||
import lombok.extern.slf4j.Slf4j;
|
import lombok.extern.slf4j.Slf4j;
|
||||||
import org.apache.commons.lang3.StringUtils;
|
|
||||||
import org.springframework.jdbc.core.JdbcTemplate;
|
import org.springframework.jdbc.core.JdbcTemplate;
|
||||||
import org.springframework.stereotype.Service;
|
import org.springframework.stereotype.Service;
|
||||||
import org.springframework.util.CollectionUtils;
|
import org.springframework.util.CollectionUtils;
|
||||||
@@ -48,12 +48,12 @@ public class AuthServiceImpl implements AuthService {
|
|||||||
public List<AuthGroup> queryAuthGroups(String modelId, Integer groupId) {
|
public List<AuthGroup> queryAuthGroups(String modelId, Integer groupId) {
|
||||||
return load().stream()
|
return load().stream()
|
||||||
.filter(group -> (Objects.isNull(groupId) || groupId.equals(group.getGroupId()))
|
.filter(group -> (Objects.isNull(groupId) || groupId.equals(group.getGroupId()))
|
||||||
&& modelId.equals(group.getModelId()))
|
&& modelId.equals(group.getModelId().toString()))
|
||||||
.collect(Collectors.toList());
|
.collect(Collectors.toList());
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public void updateAuthGroup(AuthGroup group) {
|
public void addOrUpdateAuthGroup(AuthGroup group) {
|
||||||
Gson g = new Gson();
|
Gson g = new Gson();
|
||||||
if (group.getGroupId() == null) {
|
if (group.getGroupId() == null) {
|
||||||
int nextGroupId = 1;
|
int nextGroupId = 1;
|
||||||
@@ -76,21 +76,17 @@ public class AuthServiceImpl implements AuthService {
|
|||||||
jdbcTemplate.update("delete from s2_auth_groups where group_id = ?", group.getGroupId());
|
jdbcTemplate.update("delete from s2_auth_groups where group_id = ?", group.getGroupId());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public AuthorizedResourceResp queryAuthorizedResources(QueryAuthResReq req, User user) {
|
public AuthorizedResourceResp queryAuthorizedResources(QueryAuthResReq req, User user) {
|
||||||
Set<String> userOrgIds = userService.getUserAllOrgId(user.getName());
|
Set<String> userOrgIds = userService.getUserAllOrgId(user.getName());
|
||||||
if (!CollectionUtils.isEmpty(userOrgIds)) {
|
List<AuthGroup> groups = getAuthGroups(req.getModelIds(), user.getName(), new ArrayList<>(userOrgIds));
|
||||||
req.setDepartmentIds(new ArrayList<>(userOrgIds));
|
|
||||||
}
|
|
||||||
List<AuthGroup> groups = getAuthGroups(req, user.getName());
|
|
||||||
AuthorizedResourceResp resource = new AuthorizedResourceResp();
|
AuthorizedResourceResp resource = new AuthorizedResourceResp();
|
||||||
Map<String, List<AuthGroup>> authGroupsByModelId = groups.stream()
|
Map<Long, List<AuthGroup>> authGroupsByModelId = groups.stream()
|
||||||
.collect(Collectors.groupingBy(AuthGroup::getModelId));
|
.collect(Collectors.groupingBy(AuthGroup::getModelId));
|
||||||
Map<String, List<AuthRes>> reqAuthRes = req.getResources().stream()
|
Map<Long, List<AuthRes>> reqAuthRes = req.getResources().stream()
|
||||||
.collect(Collectors.groupingBy(AuthRes::getModelId));
|
.collect(Collectors.groupingBy(AuthRes::getModelId));
|
||||||
|
|
||||||
for (String modelId : reqAuthRes.keySet()) {
|
for (Long modelId : reqAuthRes.keySet()) {
|
||||||
List<AuthRes> reqResourcesList = reqAuthRes.get(modelId);
|
List<AuthRes> reqResourcesList = reqAuthRes.get(modelId);
|
||||||
AuthResGrp rg = new AuthResGrp();
|
AuthResGrp rg = new AuthResGrp();
|
||||||
if (authGroupsByModelId.containsKey(modelId)) {
|
if (authGroupsByModelId.containsKey(modelId)) {
|
||||||
@@ -113,8 +109,11 @@ public class AuthServiceImpl implements AuthService {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if (StringUtils.isNotEmpty(req.getModelId())) {
|
if (!CollectionUtils.isEmpty(req.getModelIds())) {
|
||||||
List<AuthGroup> authGroups = authGroupsByModelId.get(req.getModelId());
|
List<AuthGroup> authGroups = Lists.newArrayList();
|
||||||
|
for (Long modelId : authGroupsByModelId.keySet()) {
|
||||||
|
authGroups.addAll(authGroupsByModelId.getOrDefault(modelId, Lists.newArrayList()));
|
||||||
|
}
|
||||||
if (!CollectionUtils.isEmpty(authGroups)) {
|
if (!CollectionUtils.isEmpty(authGroups)) {
|
||||||
for (AuthGroup group : authGroups) {
|
for (AuthGroup group : authGroups) {
|
||||||
if (group.getDimensionFilters() != null
|
if (group.getDimensionFilters() != null
|
||||||
@@ -130,17 +129,17 @@ public class AuthServiceImpl implements AuthService {
|
|||||||
return resource;
|
return resource;
|
||||||
}
|
}
|
||||||
|
|
||||||
private List<AuthGroup> getAuthGroups(QueryAuthResReq req, String userName) {
|
private List<AuthGroup> getAuthGroups(List<Long> modelIds, String userName, List<String> departmentIds) {
|
||||||
List<AuthGroup> groups = load().stream()
|
List<AuthGroup> groups = load().stream()
|
||||||
.filter(group -> {
|
.filter(group -> {
|
||||||
if (!Objects.equals(group.getModelId(), req.getModelId())) {
|
if (CollectionUtils.isEmpty(modelIds) || !modelIds.contains(group.getModelId())) {
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
if (!CollectionUtils.isEmpty(group.getAuthorizedUsers()) && group.getAuthorizedUsers()
|
if (!CollectionUtils.isEmpty(group.getAuthorizedUsers()) && group.getAuthorizedUsers()
|
||||||
.contains(userName)) {
|
.contains(userName)) {
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
for (String departmentId : req.getDepartmentIds()) {
|
for (String departmentId : departmentIds) {
|
||||||
if (!CollectionUtils.isEmpty(group.getAuthorizedDepartmentIds())
|
if (!CollectionUtils.isEmpty(group.getAuthorizedDepartmentIds())
|
||||||
&& group.getAuthorizedDepartmentIds().contains(departmentId)) {
|
&& group.getAuthorizedDepartmentIds().contains(departmentId)) {
|
||||||
return true;
|
return true;
|
||||||
@@ -148,7 +147,7 @@ public class AuthServiceImpl implements AuthService {
|
|||||||
}
|
}
|
||||||
return false;
|
return false;
|
||||||
}).collect(Collectors.toList());
|
}).collect(Collectors.toList());
|
||||||
log.info("user:{} department:{} authGroups:{}", userName, req.getDepartmentIds(), groups);
|
log.info("user:{} department:{} authGroups:{}", userName, departmentIds, groups);
|
||||||
return groups;
|
return groups;
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -1,7 +1,7 @@
|
|||||||
package com.tencent.supersonic.chat.api.component;
|
package com.tencent.supersonic.chat.api.component;
|
||||||
|
|
||||||
import com.tencent.supersonic.chat.api.pojo.SemanticCorrectInfo;
|
import com.tencent.supersonic.chat.api.pojo.SemanticParseInfo;
|
||||||
import net.sf.jsqlparser.JSQLParserException;
|
import com.tencent.supersonic.chat.api.pojo.request.QueryReq;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* A semantic corrector checks validity of extracted semantic information and
|
* A semantic corrector checks validity of extracted semantic information and
|
||||||
@@ -9,5 +9,5 @@ import net.sf.jsqlparser.JSQLParserException;
|
|||||||
*/
|
*/
|
||||||
public interface SemanticCorrector {
|
public interface SemanticCorrector {
|
||||||
|
|
||||||
void correct(SemanticCorrectInfo semanticCorrectInfo) throws JSQLParserException;
|
void correct(QueryReq queryReq, SemanticParseInfo semanticParseInfo);
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -8,10 +8,14 @@ import com.tencent.supersonic.semantic.api.model.request.PageDimensionReq;
|
|||||||
import com.tencent.supersonic.semantic.api.model.request.PageMetricReq;
|
import com.tencent.supersonic.semantic.api.model.request.PageMetricReq;
|
||||||
import com.tencent.supersonic.semantic.api.model.response.DomainResp;
|
import com.tencent.supersonic.semantic.api.model.response.DomainResp;
|
||||||
import com.tencent.supersonic.semantic.api.model.response.DimensionResp;
|
import com.tencent.supersonic.semantic.api.model.response.DimensionResp;
|
||||||
import com.tencent.supersonic.semantic.api.model.response.ModelResp;
|
import com.tencent.supersonic.semantic.api.model.response.ExplainResp;
|
||||||
import com.tencent.supersonic.semantic.api.model.response.MetricResp;
|
import com.tencent.supersonic.semantic.api.model.response.MetricResp;
|
||||||
|
import com.tencent.supersonic.semantic.api.model.response.ModelResp;
|
||||||
|
import com.tencent.supersonic.semantic.api.model.response.ModelSchemaResp;
|
||||||
import com.tencent.supersonic.semantic.api.model.response.QueryResultWithSchemaResp;
|
import com.tencent.supersonic.semantic.api.model.response.QueryResultWithSchemaResp;
|
||||||
import com.tencent.supersonic.semantic.api.query.request.QueryDslReq;
|
import com.tencent.supersonic.semantic.api.query.request.ExplainSqlReq;
|
||||||
|
import com.tencent.supersonic.semantic.api.query.request.QueryDimValueReq;
|
||||||
|
import com.tencent.supersonic.semantic.api.query.request.QueryS2SQLReq;
|
||||||
import com.tencent.supersonic.semantic.api.query.request.QueryMultiStructReq;
|
import com.tencent.supersonic.semantic.api.query.request.QueryMultiStructReq;
|
||||||
import com.tencent.supersonic.semantic.api.query.request.QueryStructReq;
|
import com.tencent.supersonic.semantic.api.query.request.QueryStructReq;
|
||||||
|
|
||||||
@@ -28,16 +32,32 @@ import java.util.List;
|
|||||||
* as proxy to a remote semantic service.
|
* as proxy to a remote semantic service.
|
||||||
* </p>
|
* </p>
|
||||||
*/
|
*/
|
||||||
public interface SemanticLayer {
|
public interface SemanticInterpreter {
|
||||||
|
|
||||||
QueryResultWithSchemaResp queryByStruct(QueryStructReq queryStructReq, User user);
|
QueryResultWithSchemaResp queryByStruct(QueryStructReq queryStructReq, User user);
|
||||||
|
|
||||||
QueryResultWithSchemaResp queryByMultiStruct(QueryMultiStructReq queryMultiStructReq, User user);
|
QueryResultWithSchemaResp queryByMultiStruct(QueryMultiStructReq queryMultiStructReq, User user);
|
||||||
QueryResultWithSchemaResp queryByDsl(QueryDslReq queryDslReq, User user);
|
|
||||||
|
QueryResultWithSchemaResp queryByS2SQL(QueryS2SQLReq queryS2SQLReq, User user);
|
||||||
|
|
||||||
|
QueryResultWithSchemaResp queryDimValue(QueryDimValueReq queryDimValueReq, User user);
|
||||||
|
|
||||||
List<ModelSchema> getModelSchema();
|
List<ModelSchema> getModelSchema();
|
||||||
|
|
||||||
List<ModelSchema> getModelSchema(List<Long> ids);
|
List<ModelSchema> getModelSchema(List<Long> ids);
|
||||||
|
|
||||||
ModelSchema getModelSchema(Long model, Boolean cacheEnable);
|
ModelSchema getModelSchema(Long model, Boolean cacheEnable);
|
||||||
PageInfo<DimensionResp> getDimensionPage(PageDimensionReq pageDimensionCmd);
|
|
||||||
PageInfo<MetricResp> getMetricPage(PageMetricReq pageMetricCmd);
|
PageInfo<DimensionResp> getDimensionPage(PageDimensionReq pageDimensionReq);
|
||||||
|
|
||||||
|
PageInfo<MetricResp> getMetricPage(PageMetricReq pageDimensionReq, User user);
|
||||||
|
|
||||||
List<DomainResp> getDomainList(User user);
|
List<DomainResp> getDomainList(User user);
|
||||||
|
|
||||||
List<ModelResp> getModelList(AuthType authType, Long domainId, User user);
|
List<ModelResp> getModelList(AuthType authType, Long domainId, User user);
|
||||||
|
|
||||||
|
<T> ExplainResp explain(ExplainSqlReq<T> explainSqlReq, User user) throws Exception;
|
||||||
|
|
||||||
|
List<ModelSchemaResp> fetchModelSchema(List<Long> ids, Boolean cacheEnable);
|
||||||
|
|
||||||
}
|
}
|
||||||
@@ -14,6 +14,10 @@ public interface SemanticQuery {
|
|||||||
|
|
||||||
QueryResult execute(User user) throws SqlParseException;
|
QueryResult execute(User user) throws SqlParseException;
|
||||||
|
|
||||||
|
void initS2Sql(User user);
|
||||||
|
|
||||||
|
String explain(User user);
|
||||||
|
|
||||||
SemanticParseInfo getParseInfo();
|
SemanticParseInfo getParseInfo();
|
||||||
|
|
||||||
void setParseInfo(SemanticParseInfo parseInfo);
|
void setParseInfo(SemanticParseInfo parseInfo);
|
||||||
|
|||||||
@@ -1,8 +1,13 @@
|
|||||||
package com.tencent.supersonic.chat.api.pojo;
|
package com.tencent.supersonic.chat.api.pojo;
|
||||||
|
|
||||||
|
import com.google.common.collect.Sets;
|
||||||
|
import com.tencent.supersonic.common.pojo.ModelRela;
|
||||||
import lombok.Data;
|
import lombok.Data;
|
||||||
|
import org.springframework.util.CollectionUtils;
|
||||||
|
|
||||||
|
import java.util.ArrayList;
|
||||||
import java.util.HashSet;
|
import java.util.HashSet;
|
||||||
|
import java.util.List;
|
||||||
import java.util.Optional;
|
import java.util.Optional;
|
||||||
import java.util.Set;
|
import java.util.Set;
|
||||||
|
|
||||||
@@ -13,7 +18,9 @@ public class ModelSchema {
|
|||||||
private Set<SchemaElement> metrics = new HashSet<>();
|
private Set<SchemaElement> metrics = new HashSet<>();
|
||||||
private Set<SchemaElement> dimensions = new HashSet<>();
|
private Set<SchemaElement> dimensions = new HashSet<>();
|
||||||
private Set<SchemaElement> dimensionValues = new HashSet<>();
|
private Set<SchemaElement> dimensionValues = new HashSet<>();
|
||||||
|
private Set<SchemaElement> tags = new HashSet<>();
|
||||||
private SchemaElement entity = new SchemaElement();
|
private SchemaElement entity = new SchemaElement();
|
||||||
|
private List<ModelRela> modelRelas = new ArrayList<>();
|
||||||
|
|
||||||
public SchemaElement getElement(SchemaElementType elementType, long elementID) {
|
public SchemaElement getElement(SchemaElementType elementType, long elementID) {
|
||||||
Optional<SchemaElement> element = Optional.empty();
|
Optional<SchemaElement> element = Optional.empty();
|
||||||
@@ -34,6 +41,9 @@ public class ModelSchema {
|
|||||||
case VALUE:
|
case VALUE:
|
||||||
element = dimensionValues.stream().filter(e -> e.getId() == elementID).findFirst();
|
element = dimensionValues.stream().filter(e -> e.getId() == elementID).findFirst();
|
||||||
break;
|
break;
|
||||||
|
case TAG:
|
||||||
|
element = tags.stream().filter(e -> e.getId() == elementID).findFirst();
|
||||||
|
break;
|
||||||
default:
|
default:
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -44,4 +54,45 @@ public class ModelSchema {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
public SchemaElement getElement(SchemaElementType elementType, String name) {
|
||||||
|
Optional<SchemaElement> element = Optional.empty();
|
||||||
|
|
||||||
|
switch (elementType) {
|
||||||
|
case ENTITY:
|
||||||
|
element = Optional.ofNullable(entity);
|
||||||
|
break;
|
||||||
|
case MODEL:
|
||||||
|
element = Optional.of(model);
|
||||||
|
break;
|
||||||
|
case METRIC:
|
||||||
|
element = metrics.stream().filter(e -> name.equals(e.getName())).findFirst();
|
||||||
|
break;
|
||||||
|
case DIMENSION:
|
||||||
|
element = dimensions.stream().filter(e -> name.equals(e.getName())).findFirst();
|
||||||
|
break;
|
||||||
|
case VALUE:
|
||||||
|
element = dimensionValues.stream().filter(e -> name.equals(e.getName())).findFirst();
|
||||||
|
break;
|
||||||
|
default:
|
||||||
|
}
|
||||||
|
|
||||||
|
if (element.isPresent()) {
|
||||||
|
return element.get();
|
||||||
|
} else {
|
||||||
|
return null;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
public Set<Long> getModelClusterSet() {
|
||||||
|
if (CollectionUtils.isEmpty(modelRelas)) {
|
||||||
|
return Sets.newHashSet();
|
||||||
|
}
|
||||||
|
Set<Long> modelClusterSet = new HashSet<>();
|
||||||
|
modelRelas.forEach(modelRela -> {
|
||||||
|
modelClusterSet.add(modelRela.getToModelId());
|
||||||
|
modelClusterSet.add(modelRela.getFromModelId());
|
||||||
|
});
|
||||||
|
return modelClusterSet;
|
||||||
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -15,6 +15,7 @@ public class QueryContext {
|
|||||||
private QueryReq request;
|
private QueryReq request;
|
||||||
private List<SemanticQuery> candidateQueries = new ArrayList<>();
|
private List<SemanticQuery> candidateQueries = new ArrayList<>();
|
||||||
private SchemaMapInfo mapInfo = new SchemaMapInfo();
|
private SchemaMapInfo mapInfo = new SchemaMapInfo();
|
||||||
|
private SchemaModelClusterMapInfo modelClusterMapInfo = new SchemaModelClusterMapInfo();
|
||||||
|
|
||||||
public QueryContext(QueryReq request) {
|
public QueryContext(QueryReq request) {
|
||||||
this.request = request;
|
this.request = request;
|
||||||
|
|||||||
@@ -0,0 +1,19 @@
|
|||||||
|
package com.tencent.supersonic.chat.api.pojo;
|
||||||
|
|
||||||
|
|
||||||
|
import lombok.AllArgsConstructor;
|
||||||
|
import lombok.Builder;
|
||||||
|
import lombok.Data;
|
||||||
|
import lombok.NoArgsConstructor;
|
||||||
|
|
||||||
|
@Data
|
||||||
|
@Builder
|
||||||
|
@NoArgsConstructor
|
||||||
|
@AllArgsConstructor
|
||||||
|
public class RelatedSchemaElement {
|
||||||
|
|
||||||
|
private Long dimensionId;
|
||||||
|
|
||||||
|
private boolean isNecessary;
|
||||||
|
|
||||||
|
}
|
||||||
@@ -1,30 +1,35 @@
|
|||||||
package com.tencent.supersonic.chat.api.pojo;
|
package com.tencent.supersonic.chat.api.pojo;
|
||||||
|
|
||||||
import com.google.common.base.Objects;
|
import com.google.common.base.Objects;
|
||||||
import java.io.Serializable;
|
|
||||||
import java.util.List;
|
|
||||||
import lombok.AllArgsConstructor;
|
import lombok.AllArgsConstructor;
|
||||||
import lombok.Builder;
|
import lombok.Builder;
|
||||||
import lombok.Data;
|
import lombok.Data;
|
||||||
import lombok.Getter;
|
import lombok.Getter;
|
||||||
import lombok.NoArgsConstructor;
|
import lombok.NoArgsConstructor;
|
||||||
|
|
||||||
|
import java.io.Serializable;
|
||||||
|
import java.util.List;
|
||||||
|
|
||||||
@Data
|
@Data
|
||||||
@Getter
|
@Getter
|
||||||
@Builder
|
@Builder
|
||||||
@AllArgsConstructor
|
@AllArgsConstructor
|
||||||
@NoArgsConstructor
|
@NoArgsConstructor
|
||||||
public class SchemaElement implements Serializable {
|
public class SchemaElement implements Serializable {
|
||||||
|
|
||||||
private Long model;
|
private Long model;
|
||||||
private Long id;
|
private Long id;
|
||||||
private String name;
|
private String name;
|
||||||
private String bizName;
|
private String bizName;
|
||||||
private Long useCnt;
|
private Long useCnt;
|
||||||
private SchemaElementType type;
|
private SchemaElementType type;
|
||||||
|
|
||||||
private List<String> alias;
|
private List<String> alias;
|
||||||
|
|
||||||
private List<SchemaValueMap> schemaValueMaps;
|
private List<SchemaValueMap> schemaValueMaps;
|
||||||
|
private List<RelatedSchemaElement> relatedSchemaElements;
|
||||||
|
|
||||||
|
private String defaultAgg;
|
||||||
|
|
||||||
|
private double order;
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public boolean equals(Object o) {
|
public boolean equals(Object o) {
|
||||||
@@ -37,13 +42,13 @@ public class SchemaElement implements Serializable {
|
|||||||
SchemaElement schemaElement = (SchemaElement) o;
|
SchemaElement schemaElement = (SchemaElement) o;
|
||||||
return Objects.equal(model, schemaElement.model) && Objects.equal(id,
|
return Objects.equal(model, schemaElement.model) && Objects.equal(id,
|
||||||
schemaElement.id) && Objects.equal(name, schemaElement.name)
|
schemaElement.id) && Objects.equal(name, schemaElement.name)
|
||||||
&& Objects.equal(bizName, schemaElement.bizName) && Objects.equal(
|
&& Objects.equal(bizName, schemaElement.bizName)
|
||||||
useCnt, schemaElement.useCnt) && Objects.equal(type, schemaElement.type);
|
&& Objects.equal(type, schemaElement.type);
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public int hashCode() {
|
public int hashCode() {
|
||||||
return Objects.hashCode(model, id, name, bizName, useCnt, type);
|
return Objects.hashCode(model, id, name, bizName, type);
|
||||||
}
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -6,6 +6,7 @@ public enum SchemaElementType {
|
|||||||
DIMENSION,
|
DIMENSION,
|
||||||
VALUE,
|
VALUE,
|
||||||
ENTITY,
|
ENTITY,
|
||||||
|
TAG,
|
||||||
ID,
|
ID,
|
||||||
DATE
|
DATE
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -0,0 +1,61 @@
|
|||||||
|
package com.tencent.supersonic.chat.api.pojo;
|
||||||
|
|
||||||
|
import com.clickhouse.client.internal.apache.commons.compress.utils.Lists;
|
||||||
|
import com.tencent.supersonic.common.pojo.ModelCluster;
|
||||||
|
import lombok.Data;
|
||||||
|
import org.springframework.util.CollectionUtils;
|
||||||
|
|
||||||
|
import java.util.HashMap;
|
||||||
|
import java.util.List;
|
||||||
|
import java.util.Map;
|
||||||
|
import java.util.Set;
|
||||||
|
|
||||||
|
@Data
|
||||||
|
public class SchemaModelClusterMapInfo {
|
||||||
|
|
||||||
|
private Map<String, List<SchemaElementMatch>> modelElementMatches = new HashMap<>();
|
||||||
|
|
||||||
|
public Set<String> getMatchedModelClusters() {
|
||||||
|
return modelElementMatches.keySet();
|
||||||
|
}
|
||||||
|
|
||||||
|
public List<SchemaElementMatch> getMatchedElements(Long modelId) {
|
||||||
|
for (String key : modelElementMatches.keySet()) {
|
||||||
|
if (ModelCluster.getModelIdFromKey(key).contains(modelId)) {
|
||||||
|
return modelElementMatches.get(key);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return Lists.newArrayList();
|
||||||
|
}
|
||||||
|
|
||||||
|
public List<SchemaElementMatch> getMatchedElements(String modelCluster) {
|
||||||
|
return modelElementMatches.get(modelCluster);
|
||||||
|
}
|
||||||
|
|
||||||
|
public Map<String, List<SchemaElementMatch>> getModelElementMatches() {
|
||||||
|
return modelElementMatches;
|
||||||
|
}
|
||||||
|
|
||||||
|
public Map<String, List<SchemaElementMatch>> getElementMatchesByModelIds(Set<Long> modelIds) {
|
||||||
|
if (CollectionUtils.isEmpty(modelIds)) {
|
||||||
|
return modelElementMatches;
|
||||||
|
}
|
||||||
|
Map<String, List<SchemaElementMatch>> modelElementMatchesFiltered = new HashMap<>();
|
||||||
|
for (String key : modelElementMatches.keySet()) {
|
||||||
|
for (Long modelId : modelIds) {
|
||||||
|
if (ModelCluster.getModelIdFromKey(key).contains(modelId)) {
|
||||||
|
modelElementMatchesFiltered.put(key, modelElementMatches.get(key));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return modelElementMatchesFiltered;
|
||||||
|
}
|
||||||
|
|
||||||
|
public void setModelElementMatches(Map<String, List<SchemaElementMatch>> modelElementMatches) {
|
||||||
|
this.modelElementMatches = modelElementMatches;
|
||||||
|
}
|
||||||
|
|
||||||
|
public void setMatchedElements(String modelCluster, List<SchemaElementMatch> elementMatches) {
|
||||||
|
modelElementMatches.put(modelCluster, elementMatches);
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -1,53 +1,74 @@
|
|||||||
package com.tencent.supersonic.chat.api.pojo;
|
package com.tencent.supersonic.chat.api.pojo;
|
||||||
|
|
||||||
|
|
||||||
import java.util.List;
|
|
||||||
import java.util.Set;
|
|
||||||
import java.util.TreeSet;
|
|
||||||
import java.util.LinkedHashSet;
|
|
||||||
import java.util.ArrayList;
|
|
||||||
import java.util.Map;
|
|
||||||
import java.util.HashMap;
|
|
||||||
import java.util.Comparator;
|
|
||||||
|
|
||||||
import com.tencent.supersonic.chat.api.pojo.request.QueryFilter;
|
import com.tencent.supersonic.chat.api.pojo.request.QueryFilter;
|
||||||
import com.tencent.supersonic.chat.api.pojo.response.EntityInfo;
|
import com.tencent.supersonic.chat.api.pojo.response.EntityInfo;
|
||||||
|
import com.tencent.supersonic.chat.api.pojo.response.SqlInfo;
|
||||||
import com.tencent.supersonic.common.pojo.DateConf;
|
import com.tencent.supersonic.common.pojo.DateConf;
|
||||||
|
import com.tencent.supersonic.common.pojo.ModelCluster;
|
||||||
import com.tencent.supersonic.common.pojo.Order;
|
import com.tencent.supersonic.common.pojo.Order;
|
||||||
|
import com.tencent.supersonic.common.pojo.enums.QueryType;
|
||||||
import com.tencent.supersonic.common.pojo.enums.AggregateTypeEnum;
|
import com.tencent.supersonic.common.pojo.enums.AggregateTypeEnum;
|
||||||
|
import com.tencent.supersonic.common.pojo.enums.FilterType;
|
||||||
import lombok.Data;
|
import lombok.Data;
|
||||||
|
|
||||||
|
import java.util.ArrayList;
|
||||||
|
import java.util.Comparator;
|
||||||
|
import java.util.HashMap;
|
||||||
|
import java.util.LinkedHashSet;
|
||||||
|
import java.util.List;
|
||||||
|
import java.util.Map;
|
||||||
|
import java.util.Set;
|
||||||
|
import java.util.TreeSet;
|
||||||
|
|
||||||
@Data
|
@Data
|
||||||
public class SemanticParseInfo {
|
public class SemanticParseInfo {
|
||||||
|
|
||||||
private Integer id;
|
private Integer id;
|
||||||
private String queryMode;
|
private String queryMode;
|
||||||
private SchemaElement model;
|
private ModelCluster model = new ModelCluster();
|
||||||
private Set<SchemaElement> metrics = new TreeSet<>(new SchemaNameLengthComparator());
|
private Set<SchemaElement> metrics = new TreeSet<>(new SchemaNameLengthComparator());
|
||||||
private Set<SchemaElement> dimensions = new LinkedHashSet();
|
private Set<SchemaElement> dimensions = new LinkedHashSet();
|
||||||
private SchemaElement entity;
|
private SchemaElement entity;
|
||||||
private AggregateTypeEnum aggType = AggregateTypeEnum.NONE;
|
private AggregateTypeEnum aggType = AggregateTypeEnum.NONE;
|
||||||
|
private FilterType filterType = FilterType.UNION;
|
||||||
private Set<QueryFilter> dimensionFilters = new LinkedHashSet();
|
private Set<QueryFilter> dimensionFilters = new LinkedHashSet();
|
||||||
private Set<QueryFilter> metricFilters = new LinkedHashSet();
|
private Set<QueryFilter> metricFilters = new LinkedHashSet();
|
||||||
private Set<Order> orders = new LinkedHashSet();
|
private Set<Order> orders = new LinkedHashSet();
|
||||||
private DateConf dateInfo;
|
private DateConf dateInfo;
|
||||||
private Long limit;
|
private Long limit;
|
||||||
private Boolean nativeQuery = false;
|
|
||||||
private double score;
|
private double score;
|
||||||
private List<SchemaElementMatch> elementMatches = new ArrayList<>();
|
private List<SchemaElementMatch> elementMatches = new ArrayList<>();
|
||||||
private Map<String, Object> properties = new HashMap<>();
|
private Map<String, Object> properties = new HashMap<>();
|
||||||
private EntityInfo entityInfo;
|
private EntityInfo entityInfo;
|
||||||
public Long getModelId() {
|
private SqlInfo sqlInfo = new SqlInfo();
|
||||||
return model != null ? model.getId() : 0L;
|
private QueryType queryType = QueryType.ID;
|
||||||
|
|
||||||
|
public String getModelClusterKey() {
|
||||||
|
if (model == null) {
|
||||||
|
return "";
|
||||||
|
}
|
||||||
|
return model.getKey();
|
||||||
}
|
}
|
||||||
|
|
||||||
public String getModelName() {
|
public String getModelName() {
|
||||||
return model != null ? model.getName() : "null";
|
if (model == null) {
|
||||||
|
return "";
|
||||||
|
}
|
||||||
|
return model.getName();
|
||||||
}
|
}
|
||||||
|
|
||||||
private static class SchemaNameLengthComparator implements Comparator<SchemaElement> {
|
private static class SchemaNameLengthComparator implements Comparator<SchemaElement> {
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public int compare(SchemaElement o1, SchemaElement o2) {
|
public int compare(SchemaElement o1, SchemaElement o2) {
|
||||||
|
if (o1.getOrder() != o2.getOrder()) {
|
||||||
|
if (o1.getOrder() < o2.getOrder()) {
|
||||||
|
return -1;
|
||||||
|
} else {
|
||||||
|
return 1;
|
||||||
|
}
|
||||||
|
}
|
||||||
int len1 = o1.getName().length();
|
int len1 = o1.getName().length();
|
||||||
int len2 = o2.getName().length();
|
int len2 = o2.getName().length();
|
||||||
if (len1 != len2) {
|
if (len1 != len2) {
|
||||||
@@ -65,4 +86,27 @@ public class SemanticParseInfo {
|
|||||||
return metrics;
|
return metrics;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
private Map<Long, Integer> getModelElementCountMap() {
|
||||||
|
Map<Long, Integer> elementCountMap = new HashMap<>();
|
||||||
|
elementMatches.stream().filter(element -> element.getElement().getModel() != null)
|
||||||
|
.forEach(element -> {
|
||||||
|
int count = elementCountMap.getOrDefault(element.getElement().getModel(), 0);
|
||||||
|
elementCountMap.put(element.getElement().getModel(), count + 1);
|
||||||
|
});
|
||||||
|
return elementCountMap;
|
||||||
|
}
|
||||||
|
|
||||||
|
public Long getModelId() {
|
||||||
|
Map<Long, Integer> elementCountMap = getModelElementCountMap();
|
||||||
|
Long modelId = -1L;
|
||||||
|
int maxCnt = 0;
|
||||||
|
for (Long model : elementCountMap.keySet()) {
|
||||||
|
if (elementCountMap.get(model) > maxCnt) {
|
||||||
|
maxCnt = elementCountMap.get(model);
|
||||||
|
modelId = model;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return modelId;
|
||||||
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,12 +1,18 @@
|
|||||||
package com.tencent.supersonic.chat.api.pojo;
|
package com.tencent.supersonic.chat.api.pojo;
|
||||||
|
|
||||||
|
import org.springframework.util.CollectionUtils;
|
||||||
|
|
||||||
import java.io.Serializable;
|
import java.io.Serializable;
|
||||||
import java.util.ArrayList;
|
import java.util.ArrayList;
|
||||||
|
import java.util.HashMap;
|
||||||
import java.util.List;
|
import java.util.List;
|
||||||
import java.util.Map;
|
import java.util.Map;
|
||||||
|
import java.util.Optional;
|
||||||
|
import java.util.Set;
|
||||||
import java.util.stream.Collectors;
|
import java.util.stream.Collectors;
|
||||||
|
|
||||||
public class SemanticSchema implements Serializable {
|
public class SemanticSchema implements Serializable {
|
||||||
|
|
||||||
private List<ModelSchema> modelSchemaList;
|
private List<ModelSchema> modelSchemaList;
|
||||||
|
|
||||||
public SemanticSchema(List<ModelSchema> modelSchemaList) {
|
public SemanticSchema(List<ModelSchema> modelSchemaList) {
|
||||||
@@ -17,6 +23,64 @@ public class SemanticSchema implements Serializable {
|
|||||||
modelSchemaList.add(schema);
|
modelSchemaList.add(schema);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
public SchemaElement getElement(SchemaElementType elementType, long elementID) {
|
||||||
|
Optional<SchemaElement> element = Optional.empty();
|
||||||
|
|
||||||
|
switch (elementType) {
|
||||||
|
case ENTITY:
|
||||||
|
element = getElementsById(elementID, getEntities());
|
||||||
|
break;
|
||||||
|
case MODEL:
|
||||||
|
element = getElementsById(elementID, getModels());
|
||||||
|
break;
|
||||||
|
case METRIC:
|
||||||
|
element = getElementsById(elementID, getMetrics());
|
||||||
|
break;
|
||||||
|
case DIMENSION:
|
||||||
|
element = getElementsById(elementID, getDimensions());
|
||||||
|
break;
|
||||||
|
case VALUE:
|
||||||
|
element = getElementsById(elementID, getDimensionValues());
|
||||||
|
break;
|
||||||
|
default:
|
||||||
|
}
|
||||||
|
|
||||||
|
if (element.isPresent()) {
|
||||||
|
return element.get();
|
||||||
|
} else {
|
||||||
|
return null;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
public SchemaElement getElementByName(SchemaElementType elementType, String name) {
|
||||||
|
Optional<SchemaElement> element = Optional.empty();
|
||||||
|
|
||||||
|
switch (elementType) {
|
||||||
|
case ENTITY:
|
||||||
|
element = getElementsByNameOrAlias(name, getEntities());
|
||||||
|
break;
|
||||||
|
case MODEL:
|
||||||
|
element = getElementsByNameOrAlias(name, getModels());
|
||||||
|
break;
|
||||||
|
case METRIC:
|
||||||
|
element = getElementsByNameOrAlias(name, getMetrics());
|
||||||
|
break;
|
||||||
|
case DIMENSION:
|
||||||
|
element = getElementsByNameOrAlias(name, getDimensions());
|
||||||
|
break;
|
||||||
|
case VALUE:
|
||||||
|
element = getElementsByNameOrAlias(name, getDimensionValues());
|
||||||
|
break;
|
||||||
|
default:
|
||||||
|
}
|
||||||
|
|
||||||
|
if (element.isPresent()) {
|
||||||
|
return element.get();
|
||||||
|
} else {
|
||||||
|
return null;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
public Map<Long, String> getModelIdToName() {
|
public Map<Long, String> getModelIdToName() {
|
||||||
return modelSchemaList.stream()
|
return modelSchemaList.stream()
|
||||||
.collect(Collectors.toMap(a -> a.getModel().getId(), a -> a.getModel().getName(), (k1, k2) -> k1));
|
.collect(Collectors.toMap(a -> a.getModel().getId(), a -> a.getModel().getName(), (k1, k2) -> k1));
|
||||||
@@ -34,21 +98,85 @@ public class SemanticSchema implements Serializable {
|
|||||||
return dimensions;
|
return dimensions;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
public List<SchemaElement> getDimensions(Set<Long> modelIds) {
|
||||||
|
List<SchemaElement> dimensions = getDimensions();
|
||||||
|
return getElementsByModelId(modelIds, dimensions);
|
||||||
|
}
|
||||||
|
|
||||||
|
public SchemaElement getDimensions(Long id) {
|
||||||
|
List<SchemaElement> dimensions = getDimensions();
|
||||||
|
Optional<SchemaElement> dimension = getElementsById(id, dimensions);
|
||||||
|
return dimension.orElse(null);
|
||||||
|
}
|
||||||
|
|
||||||
|
public List<SchemaElement> getTags() {
|
||||||
|
List<SchemaElement> tags = new ArrayList<>();
|
||||||
|
modelSchemaList.stream().forEach(d -> tags.addAll(d.getTags()));
|
||||||
|
return tags;
|
||||||
|
}
|
||||||
|
|
||||||
|
public List<SchemaElement> getTags(Set<Long> modelIds) {
|
||||||
|
List<SchemaElement> tags = new ArrayList<>();
|
||||||
|
modelSchemaList.stream().filter(schemaElement -> modelIds.contains(schemaElement.getModel()))
|
||||||
|
.forEach(d -> tags.addAll(d.getTags()));
|
||||||
|
return tags;
|
||||||
|
}
|
||||||
|
|
||||||
public List<SchemaElement> getMetrics() {
|
public List<SchemaElement> getMetrics() {
|
||||||
List<SchemaElement> metrics = new ArrayList<>();
|
List<SchemaElement> metrics = new ArrayList<>();
|
||||||
modelSchemaList.stream().forEach(d -> metrics.addAll(d.getMetrics()));
|
modelSchemaList.stream().forEach(d -> metrics.addAll(d.getMetrics()));
|
||||||
return metrics;
|
return metrics;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
public List<SchemaElement> getMetrics(Set<Long> modelIds) {
|
||||||
|
List<SchemaElement> metrics = getMetrics();
|
||||||
|
return getElementsByModelId(modelIds, metrics);
|
||||||
|
}
|
||||||
|
|
||||||
|
public List<SchemaElement> getEntities() {
|
||||||
|
List<SchemaElement> entities = new ArrayList<>();
|
||||||
|
modelSchemaList.stream().forEach(d -> entities.add(d.getEntity()));
|
||||||
|
return entities;
|
||||||
|
}
|
||||||
|
|
||||||
|
private List<SchemaElement> getElementsByModelId(Set<Long> modelIds, List<SchemaElement> elements) {
|
||||||
|
return elements.stream()
|
||||||
|
.filter(schemaElement -> modelIds.contains(schemaElement.getModel()))
|
||||||
|
.collect(Collectors.toList());
|
||||||
|
}
|
||||||
|
|
||||||
|
private Optional<SchemaElement> getElementsById(Long id, List<SchemaElement> elements) {
|
||||||
|
return elements.stream()
|
||||||
|
.filter(schemaElement -> id.equals(schemaElement.getId()))
|
||||||
|
.findFirst();
|
||||||
|
}
|
||||||
|
|
||||||
|
private Optional<SchemaElement> getElementsByNameOrAlias(String name, List<SchemaElement> elements) {
|
||||||
|
return elements.stream()
|
||||||
|
.filter(schemaElement ->
|
||||||
|
name.equals(schemaElement.getName()) || schemaElement.getAlias().contains(name)
|
||||||
|
).findFirst();
|
||||||
|
}
|
||||||
|
|
||||||
public List<SchemaElement> getModels() {
|
public List<SchemaElement> getModels() {
|
||||||
List<SchemaElement> models = new ArrayList<>();
|
List<SchemaElement> models = new ArrayList<>();
|
||||||
modelSchemaList.stream().forEach(d -> models.add(d.getModel()));
|
modelSchemaList.stream().forEach(d -> models.add(d.getModel()));
|
||||||
return models;
|
return models;
|
||||||
}
|
}
|
||||||
|
|
||||||
public List<SchemaElement> getEntities() {
|
public Map<String, String> getBizNameToName(Set<Long> modelIds) {
|
||||||
List<SchemaElement> entities = new ArrayList<>();
|
List<SchemaElement> allElements = new ArrayList<>();
|
||||||
modelSchemaList.stream().forEach(d -> entities.add(d.getEntity()));
|
allElements.addAll(getDimensions(modelIds));
|
||||||
return entities;
|
allElements.addAll(getMetrics(modelIds));
|
||||||
|
return allElements.stream()
|
||||||
|
.collect(Collectors.toMap(SchemaElement::getBizName, SchemaElement::getName, (k1, k2) -> k1));
|
||||||
|
}
|
||||||
|
|
||||||
|
public Map<Long, ModelSchema> getModelSchemaMap() {
|
||||||
|
if (CollectionUtils.isEmpty(modelSchemaList)) {
|
||||||
|
return new HashMap<>();
|
||||||
|
}
|
||||||
|
return modelSchemaList.stream().collect(Collectors.toMap(modelSchema
|
||||||
|
-> modelSchema.getModel().getModel(), modelSchema -> modelSchema));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -32,6 +32,11 @@ public class ChatConfigBaseReq {
|
|||||||
*/
|
*/
|
||||||
private List<RecommendedQuestionReq> recommendedQuestions;
|
private List<RecommendedQuestionReq> recommendedQuestions;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* the llm examples about the model
|
||||||
|
*/
|
||||||
|
private String llmExamples;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* available status
|
* available status
|
||||||
*/
|
*/
|
||||||
|
|||||||
@@ -0,0 +1,23 @@
|
|||||||
|
package com.tencent.supersonic.chat.api.pojo.request;
|
||||||
|
|
||||||
|
import lombok.Data;
|
||||||
|
import lombok.NoArgsConstructor;
|
||||||
|
import lombok.ToString;
|
||||||
|
|
||||||
|
import javax.validation.constraints.NotNull;
|
||||||
|
import java.util.List;
|
||||||
|
|
||||||
|
import static java.time.LocalDate.now;
|
||||||
|
|
||||||
|
@ToString
|
||||||
|
@Data
|
||||||
|
@NoArgsConstructor
|
||||||
|
public class DictLatestTaskReq {
|
||||||
|
|
||||||
|
@NotNull
|
||||||
|
private Long modelId;
|
||||||
|
|
||||||
|
private List<Long> dimIds;
|
||||||
|
|
||||||
|
private String createdAt = now().plusDays(-4).toString();
|
||||||
|
}
|
||||||
@@ -0,0 +1,20 @@
|
|||||||
|
package com.tencent.supersonic.chat.api.pojo.request;
|
||||||
|
|
||||||
|
import com.tencent.supersonic.common.pojo.enums.TaskStatusEnum;
|
||||||
|
import lombok.Data;
|
||||||
|
import lombok.ToString;
|
||||||
|
|
||||||
|
@ToString
|
||||||
|
@Data
|
||||||
|
public class DictTaskFilterReq {
|
||||||
|
|
||||||
|
private Long id;
|
||||||
|
|
||||||
|
private String name;
|
||||||
|
|
||||||
|
private String createdBy;
|
||||||
|
|
||||||
|
private String createdAt;
|
||||||
|
|
||||||
|
private TaskStatusEnum status;
|
||||||
|
}
|
||||||
@@ -1,12 +1,21 @@
|
|||||||
package com.tencent.supersonic.chat.api.pojo.request;
|
package com.tencent.supersonic.chat.api.pojo.request;
|
||||||
|
|
||||||
|
import javax.validation.constraints.NotNull;
|
||||||
import lombok.Data;
|
import lombok.Data;
|
||||||
|
|
||||||
@Data
|
@Data
|
||||||
public class DimensionValueReq {
|
public class DimensionValueReq {
|
||||||
|
|
||||||
|
private Integer agentId;
|
||||||
|
|
||||||
|
@NotNull
|
||||||
|
private Long elementID;
|
||||||
|
|
||||||
|
@NotNull
|
||||||
private Long modelId;
|
private Long modelId;
|
||||||
|
|
||||||
private String bizName;
|
private String bizName;
|
||||||
|
|
||||||
private Object value;
|
@NotNull
|
||||||
|
private String value;
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -3,16 +3,18 @@ package com.tencent.supersonic.chat.api.pojo.request;
|
|||||||
|
|
||||||
import com.tencent.supersonic.auth.api.authentication.pojo.User;
|
import com.tencent.supersonic.auth.api.authentication.pojo.User;
|
||||||
import com.tencent.supersonic.chat.api.pojo.SemanticParseInfo;
|
import com.tencent.supersonic.chat.api.pojo.SemanticParseInfo;
|
||||||
|
import lombok.Builder;
|
||||||
import lombok.Data;
|
import lombok.Data;
|
||||||
|
|
||||||
|
@Builder
|
||||||
@Data
|
@Data
|
||||||
public class ExecuteQueryReq {
|
public class ExecuteQueryReq {
|
||||||
private User user;
|
private User user;
|
||||||
private Integer agentId;
|
private Integer agentId;
|
||||||
private Integer chatId;
|
private Integer chatId;
|
||||||
private String queryText;
|
private String queryText;
|
||||||
private Long queryId = 7L;
|
private Long queryId;
|
||||||
private Integer parseId = 2;
|
private Integer parseId;
|
||||||
private SemanticParseInfo parseInfo;
|
private SemanticParseInfo parseInfo;
|
||||||
private boolean saveAnswer = true;
|
private boolean saveAnswer;
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,6 +1,7 @@
|
|||||||
package com.tencent.supersonic.chat.api.pojo.request;
|
package com.tencent.supersonic.chat.api.pojo.request;
|
||||||
|
|
||||||
import lombok.Data;
|
import lombok.Data;
|
||||||
|
import java.util.List;
|
||||||
|
|
||||||
@Data
|
@Data
|
||||||
public class PageQueryInfoReq {
|
public class PageQueryInfoReq {
|
||||||
@@ -11,27 +12,9 @@ public class PageQueryInfoReq {
|
|||||||
|
|
||||||
private String userName;
|
private String userName;
|
||||||
|
|
||||||
public int getPageSize() {
|
private List<Long> ids;
|
||||||
return pageSize;
|
|
||||||
}
|
|
||||||
|
|
||||||
public void setPageSize(int pageSize) {
|
public Integer getLimitStart() {
|
||||||
this.pageSize = pageSize;
|
return this.pageSize * (this.current - 1);
|
||||||
}
|
|
||||||
|
|
||||||
public int getCurrent() {
|
|
||||||
return current;
|
|
||||||
}
|
|
||||||
|
|
||||||
public void setCurrent(int current) {
|
|
||||||
this.current = current;
|
|
||||||
}
|
|
||||||
|
|
||||||
public String getUserName() {
|
|
||||||
return userName;
|
|
||||||
}
|
|
||||||
|
|
||||||
public void setUserName(String userName) {
|
|
||||||
this.userName = userName;
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,25 +1,21 @@
|
|||||||
package com.tencent.supersonic.chat.api.pojo.request;
|
package com.tencent.supersonic.chat.api.pojo.request;
|
||||||
|
|
||||||
|
|
||||||
|
import com.tencent.supersonic.auth.api.authentication.pojo.User;
|
||||||
import com.tencent.supersonic.chat.api.pojo.SchemaElement;
|
import com.tencent.supersonic.chat.api.pojo.SchemaElement;
|
||||||
import com.tencent.supersonic.common.pojo.DateConf;
|
import com.tencent.supersonic.common.pojo.DateConf;
|
||||||
import com.tencent.supersonic.common.pojo.Order;
|
|
||||||
import com.tencent.supersonic.common.pojo.enums.AggregateTypeEnum;
|
|
||||||
import java.util.HashSet;
|
import java.util.HashSet;
|
||||||
import java.util.Set;
|
import java.util.Set;
|
||||||
import lombok.Data;
|
import lombok.Data;
|
||||||
|
|
||||||
@Data
|
@Data
|
||||||
public class QueryDataReq {
|
public class QueryDataReq {
|
||||||
String queryMode;
|
private User user;
|
||||||
SchemaElement model;
|
private Set<SchemaElement> metrics = new HashSet<>();
|
||||||
Set<SchemaElement> metrics = new HashSet<>();
|
private Set<SchemaElement> dimensions = new HashSet<>();
|
||||||
Set<SchemaElement> dimensions = new HashSet<>();
|
private Set<QueryFilter> dimensionFilters = new HashSet<>();
|
||||||
Set<QueryFilter> dimensionFilters = new HashSet<>();
|
private Set<QueryFilter> metricFilters = new HashSet<>();
|
||||||
Set<QueryFilter> metricFilters = new HashSet<>();
|
|
||||||
private AggregateTypeEnum aggType = AggregateTypeEnum.NONE;
|
|
||||||
private Set<Order> orders = new HashSet<>();
|
|
||||||
private DateConf dateInfo;
|
private DateConf dateInfo;
|
||||||
private Long limit;
|
private Long queryId;
|
||||||
private Boolean nativeQuery = false;
|
private Integer parseId;
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,7 +1,7 @@
|
|||||||
package com.tencent.supersonic.chat.api.pojo.request;
|
package com.tencent.supersonic.chat.api.pojo.request;
|
||||||
|
|
||||||
import com.tencent.supersonic.semantic.api.query.enums.FilterOperatorEnum;
|
import com.google.common.base.Objects;
|
||||||
import java.util.Objects;
|
import com.tencent.supersonic.common.pojo.enums.FilterOperatorEnum;
|
||||||
import lombok.Data;
|
import lombok.Data;
|
||||||
import lombok.ToString;
|
import lombok.ToString;
|
||||||
|
|
||||||
@@ -19,6 +19,8 @@ public class QueryFilter {
|
|||||||
|
|
||||||
private Long elementID;
|
private Long elementID;
|
||||||
|
|
||||||
|
private String function;
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public boolean equals(Object o) {
|
public boolean equals(Object o) {
|
||||||
if (this == o) {
|
if (this == o) {
|
||||||
@@ -27,14 +29,15 @@ public class QueryFilter {
|
|||||||
if (o == null || getClass() != o.getClass()) {
|
if (o == null || getClass() != o.getClass()) {
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
QueryFilter filter = (QueryFilter) o;
|
QueryFilter that = (QueryFilter) o;
|
||||||
return Objects.equals(bizName, filter.bizName) && Objects.equals(name, filter.name)
|
return Objects.equal(bizName, that.bizName) && Objects.equal(name,
|
||||||
&& operator == filter.operator && Objects.equals(value, filter.value) && Objects.equals(
|
that.name) && operator == that.operator && Objects.equal(value, that.value)
|
||||||
elementID, filter.elementID);
|
&& Objects.equal(elementID, that.elementID) && Objects.equal(
|
||||||
|
function, that.function);
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public int hashCode() {
|
public int hashCode() {
|
||||||
return Objects.hash(bizName, name, operator, value, elementID);
|
return Objects.hashCode(bizName, name, operator, value, elementID, function);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -7,7 +7,7 @@ import lombok.Data;
|
|||||||
public class QueryReq {
|
public class QueryReq {
|
||||||
private String queryText;
|
private String queryText;
|
||||||
private Integer chatId;
|
private Integer chatId;
|
||||||
private Long modelId = 0L;
|
private Long modelId;
|
||||||
private User user;
|
private User user;
|
||||||
private QueryFilters queryFilters;
|
private QueryFilters queryFilters;
|
||||||
private boolean saveAnswer = true;
|
private boolean saveAnswer = true;
|
||||||
|
|||||||
@@ -0,0 +1,13 @@
|
|||||||
|
package com.tencent.supersonic.chat.api.pojo.request;
|
||||||
|
|
||||||
|
|
||||||
|
import lombok.Data;
|
||||||
|
|
||||||
|
@Data
|
||||||
|
public class RecommendReq {
|
||||||
|
|
||||||
|
private Long modelId;
|
||||||
|
|
||||||
|
private Long metricId;
|
||||||
|
|
||||||
|
}
|
||||||
@@ -0,0 +1,25 @@
|
|||||||
|
package com.tencent.supersonic.chat.api.pojo.request;
|
||||||
|
|
||||||
|
|
||||||
|
import lombok.AllArgsConstructor;
|
||||||
|
import lombok.Builder;
|
||||||
|
import lombok.Data;
|
||||||
|
import lombok.NoArgsConstructor;
|
||||||
|
|
||||||
|
@Data
|
||||||
|
@NoArgsConstructor
|
||||||
|
@AllArgsConstructor
|
||||||
|
@Builder
|
||||||
|
public class SimilarQueryReq {
|
||||||
|
|
||||||
|
private Long queryId;
|
||||||
|
|
||||||
|
private Integer parseId;
|
||||||
|
|
||||||
|
private String queryText;
|
||||||
|
|
||||||
|
private String modelId;
|
||||||
|
|
||||||
|
private Integer agentId;
|
||||||
|
|
||||||
|
}
|
||||||
@@ -23,6 +23,8 @@ public class ChatConfigResp {
|
|||||||
|
|
||||||
private List<RecommendedQuestionReq> recommendedQuestions;
|
private List<RecommendedQuestionReq> recommendedQuestions;
|
||||||
|
|
||||||
|
private String llmExamples;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* available status
|
* available status
|
||||||
*/
|
*/
|
||||||
|
|||||||
@@ -0,0 +1,30 @@
|
|||||||
|
package com.tencent.supersonic.chat.api.pojo.response;
|
||||||
|
|
||||||
|
import com.tencent.supersonic.common.pojo.enums.TaskStatusEnum;
|
||||||
|
import lombok.Data;
|
||||||
|
import lombok.ToString;
|
||||||
|
|
||||||
|
import java.util.Date;
|
||||||
|
|
||||||
|
@ToString
|
||||||
|
@Data
|
||||||
|
public class DictLatestTaskResp {
|
||||||
|
|
||||||
|
private Long dimId;
|
||||||
|
|
||||||
|
private Long id;
|
||||||
|
|
||||||
|
private String name;
|
||||||
|
|
||||||
|
private String description;
|
||||||
|
|
||||||
|
private String command;
|
||||||
|
|
||||||
|
private TaskStatusEnum status;
|
||||||
|
|
||||||
|
private String createdBy;
|
||||||
|
|
||||||
|
private Date createdAt;
|
||||||
|
|
||||||
|
private Long elapsedMs;
|
||||||
|
}
|
||||||
@@ -1,12 +1,13 @@
|
|||||||
package com.tencent.supersonic.chat.api.pojo.response;
|
package com.tencent.supersonic.chat.api.pojo.response;
|
||||||
|
|
||||||
|
import lombok.Data;
|
||||||
|
|
||||||
import java.io.Serializable;
|
import java.io.Serializable;
|
||||||
import java.util.List;
|
import java.util.List;
|
||||||
import lombok.Data;
|
|
||||||
|
|
||||||
@Data
|
@Data
|
||||||
public class ModelInfo extends DataInfo implements Serializable {
|
public class ModelInfo extends DataInfo implements Serializable {
|
||||||
|
|
||||||
private List<String> words;
|
private List<String> words;
|
||||||
private String primaryEntityBizName;
|
private String primaryKey;
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,30 +1,23 @@
|
|||||||
package com.tencent.supersonic.chat.api.pojo.response;
|
package com.tencent.supersonic.chat.api.pojo.response;
|
||||||
|
|
||||||
|
import com.google.common.collect.Lists;
|
||||||
import com.tencent.supersonic.chat.api.pojo.SemanticParseInfo;
|
import com.tencent.supersonic.chat.api.pojo.SemanticParseInfo;
|
||||||
import lombok.Data;
|
import lombok.Data;
|
||||||
import lombok.Getter;
|
|
||||||
import lombok.Builder;
|
|
||||||
import lombok.NoArgsConstructor;
|
|
||||||
import lombok.AllArgsConstructor;
|
|
||||||
|
|
||||||
import java.util.List;
|
import java.util.List;
|
||||||
|
|
||||||
@Data
|
@Data
|
||||||
@Getter
|
|
||||||
@Builder
|
|
||||||
@NoArgsConstructor
|
|
||||||
@AllArgsConstructor
|
|
||||||
public class ParseResp {
|
public class ParseResp {
|
||||||
private Integer chatId;
|
private Integer chatId;
|
||||||
private String queryText;
|
private String queryText;
|
||||||
private Long queryId;
|
private Long queryId;
|
||||||
private ParseState state;
|
private ParseState state;
|
||||||
private List<SemanticParseInfo> selectedParses;
|
private List<SemanticParseInfo> selectedParses = Lists.newArrayList();
|
||||||
private List<SemanticParseInfo> candidateParses;
|
private ParseTimeCostDO parseTimeCost = new ParseTimeCostDO();
|
||||||
|
|
||||||
public enum ParseState {
|
public enum ParseState {
|
||||||
COMPLETED,
|
COMPLETED,
|
||||||
PENDING,
|
PENDING,
|
||||||
FAILED
|
FAILED
|
||||||
}
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -0,0 +1,15 @@
|
|||||||
|
package com.tencent.supersonic.chat.api.pojo.response;
|
||||||
|
|
||||||
|
import lombok.Data;
|
||||||
|
|
||||||
|
@Data
|
||||||
|
public class ParseTimeCostDO {
|
||||||
|
|
||||||
|
private long parseStartTime;
|
||||||
|
private long parseTime;
|
||||||
|
private long sqlTime;
|
||||||
|
|
||||||
|
public ParseTimeCostDO() {
|
||||||
|
this.parseStartTime = System.currentTimeMillis();
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -0,0 +1,11 @@
|
|||||||
|
package com.tencent.supersonic.chat.api.pojo.response;
|
||||||
|
|
||||||
|
import lombok.Data;
|
||||||
|
|
||||||
|
import java.util.List;
|
||||||
|
|
||||||
|
@Data
|
||||||
|
public class QueryRecallResp {
|
||||||
|
private List<SimilarQueryRecallResp> solvedQueryRecallRespList;
|
||||||
|
private Long queryTimeCost;
|
||||||
|
}
|
||||||
@@ -1,7 +1,9 @@
|
|||||||
package com.tencent.supersonic.chat.api.pojo.response;
|
package com.tencent.supersonic.chat.api.pojo.response;
|
||||||
|
|
||||||
import java.util.Date;
|
import com.tencent.supersonic.chat.api.pojo.SemanticParseInfo;
|
||||||
import lombok.Data;
|
import lombok.Data;
|
||||||
|
import java.util.Date;
|
||||||
|
import java.util.List;
|
||||||
|
|
||||||
@Data
|
@Data
|
||||||
public class QueryResp {
|
public class QueryResp {
|
||||||
@@ -13,4 +15,6 @@ public class QueryResp {
|
|||||||
private String feedback;
|
private String feedback;
|
||||||
private String queryText;
|
private String queryText;
|
||||||
private QueryResult queryResult;
|
private QueryResult queryResult;
|
||||||
|
private List<SemanticParseInfo> parseInfos;
|
||||||
|
private List<SimilarQueryRecallResp> similarQueries;
|
||||||
}
|
}
|
||||||
@@ -1,11 +1,12 @@
|
|||||||
package com.tencent.supersonic.chat.api.pojo.response;
|
package com.tencent.supersonic.chat.api.pojo.response;
|
||||||
|
|
||||||
|
import com.tencent.supersonic.chat.api.pojo.SchemaElement;
|
||||||
import com.tencent.supersonic.chat.api.pojo.SemanticParseInfo;
|
import com.tencent.supersonic.chat.api.pojo.SemanticParseInfo;
|
||||||
import com.tencent.supersonic.common.pojo.QueryAuthorization;
|
import com.tencent.supersonic.common.pojo.QueryAuthorization;
|
||||||
import com.tencent.supersonic.common.pojo.QueryColumn;
|
import com.tencent.supersonic.common.pojo.QueryColumn;
|
||||||
|
import lombok.Data;
|
||||||
import java.util.List;
|
import java.util.List;
|
||||||
import java.util.Map;
|
import java.util.Map;
|
||||||
import lombok.Data;
|
|
||||||
|
|
||||||
@Data
|
@Data
|
||||||
public class QueryResult {
|
public class QueryResult {
|
||||||
@@ -21,4 +22,6 @@ public class QueryResult {
|
|||||||
private SemanticParseInfo chatContext;
|
private SemanticParseInfo chatContext;
|
||||||
private Object response;
|
private Object response;
|
||||||
private List<Map<String, Object>> queryResults;
|
private List<Map<String, Object>> queryResults;
|
||||||
|
private Long queryTimeCost;
|
||||||
|
private List<SchemaElement> recommendedDimensions;
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -0,0 +1,17 @@
|
|||||||
|
package com.tencent.supersonic.chat.api.pojo.response;
|
||||||
|
|
||||||
|
|
||||||
|
import lombok.Builder;
|
||||||
|
import lombok.Data;
|
||||||
|
|
||||||
|
@Data
|
||||||
|
@Builder
|
||||||
|
public class SimilarQueryRecallResp {
|
||||||
|
|
||||||
|
private Long queryId;
|
||||||
|
|
||||||
|
private Integer parseId;
|
||||||
|
|
||||||
|
private String queryText;
|
||||||
|
|
||||||
|
}
|
||||||
@@ -0,0 +1,11 @@
|
|||||||
|
package com.tencent.supersonic.chat.api.pojo.response;
|
||||||
|
|
||||||
|
import lombok.Data;
|
||||||
|
|
||||||
|
@Data
|
||||||
|
public class SqlInfo {
|
||||||
|
|
||||||
|
private String s2SQL;
|
||||||
|
private String correctS2SQL;
|
||||||
|
private String querySQL;
|
||||||
|
}
|
||||||
@@ -40,11 +40,6 @@
|
|||||||
<scope>compile</scope>
|
<scope>compile</scope>
|
||||||
</dependency>
|
</dependency>
|
||||||
|
|
||||||
<dependency>
|
|
||||||
<groupId>org.junit.jupiter</groupId>
|
|
||||||
<artifactId>junit-jupiter</artifactId>
|
|
||||||
<scope>test</scope>
|
|
||||||
</dependency>
|
|
||||||
<dependency>
|
<dependency>
|
||||||
<groupId>org.springframework.boot</groupId>
|
<groupId>org.springframework.boot</groupId>
|
||||||
<artifactId>spring-boot-starter-test</artifactId>
|
<artifactId>spring-boot-starter-test</artifactId>
|
||||||
@@ -59,16 +54,7 @@
|
|||||||
<groupId>org.springframework.boot</groupId>
|
<groupId>org.springframework.boot</groupId>
|
||||||
<artifactId>spring-boot-starter-web</artifactId>
|
<artifactId>spring-boot-starter-web</artifactId>
|
||||||
</dependency>
|
</dependency>
|
||||||
<dependency>
|
|
||||||
<groupId>org.mybatis</groupId>
|
|
||||||
<artifactId>mybatis</artifactId>
|
|
||||||
</dependency>
|
|
||||||
|
|
||||||
<dependency>
|
|
||||||
<groupId>org.mybatis.spring.boot</groupId>
|
|
||||||
<artifactId>mybatis-spring-boot-starter</artifactId>
|
|
||||||
<version>${mybatis-spring.version}</version>
|
|
||||||
</dependency>
|
|
||||||
<dependency>
|
<dependency>
|
||||||
<groupId>com.alibaba</groupId>
|
<groupId>com.alibaba</groupId>
|
||||||
<artifactId>druid</artifactId>
|
<artifactId>druid</artifactId>
|
||||||
@@ -78,24 +64,6 @@
|
|||||||
<groupId>mysql</groupId>
|
<groupId>mysql</groupId>
|
||||||
<artifactId>mysql-connector-java</artifactId>
|
<artifactId>mysql-connector-java</artifactId>
|
||||||
</dependency>
|
</dependency>
|
||||||
<dependency>
|
|
||||||
<groupId>org.mybatis</groupId>
|
|
||||||
<artifactId>mybatis-spring</artifactId>
|
|
||||||
<scope>test</scope>
|
|
||||||
</dependency>
|
|
||||||
|
|
||||||
<dependency>
|
|
||||||
<groupId>org.mybatis.spring.boot</groupId>
|
|
||||||
<artifactId>mybatis-spring-boot-starter-test</artifactId>
|
|
||||||
<version>${mybatis.test.version}</version>
|
|
||||||
<scope>test</scope>
|
|
||||||
</dependency>
|
|
||||||
|
|
||||||
<dependency>
|
|
||||||
<groupId>com.github.pagehelper</groupId>
|
|
||||||
<artifactId>pagehelper-spring-boot-starter</artifactId>
|
|
||||||
<version>${pagehelper.spring.version}</version>
|
|
||||||
</dependency>
|
|
||||||
|
|
||||||
<dependency>
|
<dependency>
|
||||||
<groupId>com.h2database</groupId>
|
<groupId>com.h2database</groupId>
|
||||||
@@ -116,7 +84,6 @@
|
|||||||
<groupId>com.tencent.supersonic</groupId>
|
<groupId>com.tencent.supersonic</groupId>
|
||||||
<artifactId>semantic-query</artifactId>
|
<artifactId>semantic-query</artifactId>
|
||||||
<version>${project.version}</version>
|
<version>${project.version}</version>
|
||||||
<scope>test</scope>
|
|
||||||
</dependency>
|
</dependency>
|
||||||
<dependency>
|
<dependency>
|
||||||
<groupId>com.tencent.supersonic</groupId>
|
<groupId>com.tencent.supersonic</groupId>
|
||||||
@@ -124,12 +91,6 @@
|
|||||||
<version>${project.version}</version>
|
<version>${project.version}</version>
|
||||||
<scope>compile</scope>
|
<scope>compile</scope>
|
||||||
</dependency>
|
</dependency>
|
||||||
<dependency>
|
|
||||||
<groupId>com.tencent.supersonic</groupId>
|
|
||||||
<artifactId>semantic-query</artifactId>
|
|
||||||
<version>${project.version}</version>
|
|
||||||
<scope>compile</scope>
|
|
||||||
</dependency>
|
|
||||||
|
|
||||||
<dependency>
|
<dependency>
|
||||||
<groupId>com.github.xkzhangsan</groupId>
|
<groupId>com.github.xkzhangsan</groupId>
|
||||||
|
|||||||
@@ -3,7 +3,6 @@ package com.tencent.supersonic.chat.agent;
|
|||||||
|
|
||||||
import com.alibaba.fastjson.JSONObject;
|
import com.alibaba.fastjson.JSONObject;
|
||||||
import com.google.common.collect.Lists;
|
import com.google.common.collect.Lists;
|
||||||
import com.tencent.supersonic.chat.agent.tool.AgentToolType;
|
|
||||||
import com.tencent.supersonic.common.pojo.RecordInfo;
|
import com.tencent.supersonic.common.pojo.RecordInfo;
|
||||||
import java.util.Objects;
|
import java.util.Objects;
|
||||||
import lombok.Data;
|
import lombok.Data;
|
||||||
|
|||||||
@@ -1,7 +1,6 @@
|
|||||||
package com.tencent.supersonic.chat.agent;
|
package com.tencent.supersonic.chat.agent;
|
||||||
|
|
||||||
import com.google.common.collect.Lists;
|
import com.google.common.collect.Lists;
|
||||||
import com.tencent.supersonic.chat.agent.tool.AgentTool;
|
|
||||||
import lombok.AllArgsConstructor;
|
import lombok.AllArgsConstructor;
|
||||||
import lombok.Data;
|
import lombok.Data;
|
||||||
import lombok.NoArgsConstructor;
|
import lombok.NoArgsConstructor;
|
||||||
|
|||||||
@@ -1,4 +1,4 @@
|
|||||||
package com.tencent.supersonic.chat.agent.tool;
|
package com.tencent.supersonic.chat.agent;
|
||||||
|
|
||||||
import lombok.AllArgsConstructor;
|
import lombok.AllArgsConstructor;
|
||||||
import lombok.Data;
|
import lombok.Data;
|
||||||
@@ -0,0 +1,8 @@
|
|||||||
|
package com.tencent.supersonic.chat.agent;
|
||||||
|
|
||||||
|
public enum AgentToolType {
|
||||||
|
NL2SQL_RULE,
|
||||||
|
NL2SQL_LLM,
|
||||||
|
PLUGIN,
|
||||||
|
ANALYTICS
|
||||||
|
}
|
||||||
@@ -0,0 +1,11 @@
|
|||||||
|
package com.tencent.supersonic.chat.agent;
|
||||||
|
|
||||||
|
import lombok.Data;
|
||||||
|
|
||||||
|
|
||||||
|
@Data
|
||||||
|
public class DataAnalyticsTool extends AgentTool {
|
||||||
|
|
||||||
|
private Long modelId;
|
||||||
|
|
||||||
|
}
|
||||||
@@ -0,0 +1,12 @@
|
|||||||
|
package com.tencent.supersonic.chat.agent;
|
||||||
|
|
||||||
|
import lombok.Data;
|
||||||
|
|
||||||
|
import java.util.List;
|
||||||
|
|
||||||
|
@Data
|
||||||
|
public class LLMParserTool extends NL2SQLTool {
|
||||||
|
|
||||||
|
private List<String> exampleQuestions;
|
||||||
|
|
||||||
|
}
|
||||||
@@ -0,0 +1,16 @@
|
|||||||
|
package com.tencent.supersonic.chat.agent;
|
||||||
|
|
||||||
|
|
||||||
|
import java.util.List;
|
||||||
|
import lombok.AllArgsConstructor;
|
||||||
|
import lombok.Data;
|
||||||
|
import lombok.NoArgsConstructor;
|
||||||
|
|
||||||
|
@Data
|
||||||
|
@NoArgsConstructor
|
||||||
|
@AllArgsConstructor
|
||||||
|
public class NL2SQLTool extends AgentTool {
|
||||||
|
|
||||||
|
protected List<Long> modelIds;
|
||||||
|
|
||||||
|
}
|
||||||
@@ -1,4 +1,4 @@
|
|||||||
package com.tencent.supersonic.chat.agent.tool;
|
package com.tencent.supersonic.chat.agent;
|
||||||
|
|
||||||
|
|
||||||
import lombok.Data;
|
import lombok.Data;
|
||||||
@@ -1,4 +1,4 @@
|
|||||||
package com.tencent.supersonic.chat.agent.tool;
|
package com.tencent.supersonic.chat.agent;
|
||||||
|
|
||||||
|
|
||||||
import lombok.Data;
|
import lombok.Data;
|
||||||
@@ -7,12 +7,13 @@ import org.apache.commons.collections.CollectionUtils;
|
|||||||
import java.util.List;
|
import java.util.List;
|
||||||
|
|
||||||
@Data
|
@Data
|
||||||
public class RuleQueryTool extends AgentTool {
|
public class RuleParserTool extends NL2SQLTool {
|
||||||
|
|
||||||
private List<Long> modelIds;
|
|
||||||
|
|
||||||
private List<String> queryModes;
|
private List<String> queryModes;
|
||||||
|
|
||||||
|
private List<String> queryTypes;
|
||||||
|
|
||||||
public boolean isContainsAllModel() {
|
public boolean isContainsAllModel() {
|
||||||
return CollectionUtils.isNotEmpty(modelIds) && modelIds.contains(-1L);
|
return CollectionUtils.isNotEmpty(modelIds) && modelIds.contains(-1L);
|
||||||
}
|
}
|
||||||
@@ -1,8 +0,0 @@
|
|||||||
package com.tencent.supersonic.chat.agent.tool;
|
|
||||||
|
|
||||||
public enum AgentToolType {
|
|
||||||
RULE,
|
|
||||||
DSL,
|
|
||||||
PLUGIN,
|
|
||||||
INTERPRET
|
|
||||||
}
|
|
||||||
@@ -1,14 +0,0 @@
|
|||||||
package com.tencent.supersonic.chat.agent.tool;
|
|
||||||
|
|
||||||
import lombok.Data;
|
|
||||||
|
|
||||||
import java.util.List;
|
|
||||||
|
|
||||||
@Data
|
|
||||||
public class DslTool extends AgentTool {
|
|
||||||
|
|
||||||
private List<Long> modelIds;
|
|
||||||
|
|
||||||
private List<String> exampleQuestions;
|
|
||||||
|
|
||||||
}
|
|
||||||
@@ -1,16 +0,0 @@
|
|||||||
package com.tencent.supersonic.chat.agent.tool;
|
|
||||||
|
|
||||||
import com.tencent.supersonic.chat.parser.llm.interpret.MetricOption;
|
|
||||||
import lombok.Data;
|
|
||||||
|
|
||||||
import java.util.List;
|
|
||||||
|
|
||||||
|
|
||||||
@Data
|
|
||||||
public class MetricInterpretTool extends AgentTool {
|
|
||||||
|
|
||||||
private Long modelId;
|
|
||||||
|
|
||||||
private List<MetricOption> metricOptions;
|
|
||||||
|
|
||||||
}
|
|
||||||
@@ -7,13 +7,19 @@ import org.springframework.context.annotation.Configuration;
|
|||||||
|
|
||||||
@Configuration
|
@Configuration
|
||||||
@Data
|
@Data
|
||||||
public class LLMConfig {
|
public class LLMParserConfig {
|
||||||
|
|
||||||
|
|
||||||
@Value("${llm.url:}")
|
@Value("${llm.parser.url:}")
|
||||||
private String url;
|
private String url;
|
||||||
|
|
||||||
@Value("${query2sql.path:/query2sql}")
|
@Value("${query2sql.path:/query2sql}")
|
||||||
private String queryToSqlPath;
|
private String queryToSqlPath;
|
||||||
|
|
||||||
|
@Value("${dimension.topn:5}")
|
||||||
|
private Integer dimensionTopN;
|
||||||
|
|
||||||
|
@Value("${metric.topn:5}")
|
||||||
|
private Integer metricTopN;
|
||||||
|
|
||||||
}
|
}
|
||||||
@@ -1,43 +1,171 @@
|
|||||||
package com.tencent.supersonic.chat.config;
|
package com.tencent.supersonic.chat.config;
|
||||||
|
|
||||||
|
import com.tencent.supersonic.chat.query.llm.s2sql.LLMReq.SqlGenerationMode;
|
||||||
|
import com.tencent.supersonic.common.service.SysParameterService;
|
||||||
import lombok.Data;
|
import lombok.Data;
|
||||||
|
import lombok.extern.slf4j.Slf4j;
|
||||||
|
import org.apache.commons.lang3.StringUtils;
|
||||||
|
import org.springframework.beans.factory.annotation.Autowired;
|
||||||
import org.springframework.beans.factory.annotation.Value;
|
import org.springframework.beans.factory.annotation.Value;
|
||||||
import org.springframework.context.annotation.Configuration;
|
import org.springframework.context.annotation.Configuration;
|
||||||
import org.springframework.context.annotation.PropertySource;
|
|
||||||
|
|
||||||
@Configuration
|
@Configuration
|
||||||
@Data
|
@Data
|
||||||
@PropertySource("classpath:optimization.properties")
|
@Slf4j
|
||||||
//@ComponentScan(basePackages = "com.tencent.supersonic.chat")
|
|
||||||
public class OptimizationConfig {
|
public class OptimizationConfig {
|
||||||
|
|
||||||
@Value("${one.detection.size}")
|
@Value("${one.detection.size:8}")
|
||||||
private Integer oneDetectionSize;
|
private Integer oneDetectionSize;
|
||||||
@Value("${one.detection.max.size}")
|
|
||||||
|
@Value("${one.detection.max.size:20}")
|
||||||
private Integer oneDetectionMaxSize;
|
private Integer oneDetectionMaxSize;
|
||||||
|
|
||||||
@Value("${metric.dimension.min.threshold}")
|
@Value("${metric.dimension.min.threshold:0.3}")
|
||||||
private Double metricDimensionMinThresholdConfig;
|
private Double metricDimensionMinThresholdConfig;
|
||||||
|
|
||||||
@Value("${metric.dimension.threshold}")
|
@Value("${metric.dimension.threshold:0.3}")
|
||||||
private Double metricDimensionThresholdConfig;
|
private Double metricDimensionThresholdConfig;
|
||||||
|
|
||||||
@Value("${dimension.value.threshold}")
|
@Value("${dimension.value.threshold:0.5}")
|
||||||
private Double dimensionValueThresholdConfig;
|
private Double dimensionValueThresholdConfig;
|
||||||
|
|
||||||
@Value("${function.bonus.threshold}")
|
@Value("${long.text.threshold:0.8}")
|
||||||
private Double functionBonusThreshold;
|
|
||||||
|
|
||||||
@Value("${long.text.threshold}")
|
|
||||||
private Double longTextThreshold;
|
private Double longTextThreshold;
|
||||||
|
|
||||||
@Value("${short.text.threshold}")
|
@Value("${short.text.threshold:0.5}")
|
||||||
private Double shortTextThreshold;
|
private Double shortTextThreshold;
|
||||||
|
|
||||||
@Value("${query.text.length.threshold}")
|
@Value("${query.text.length.threshold:10}")
|
||||||
private Integer queryTextLengthThreshold;
|
private Integer queryTextLengthThreshold;
|
||||||
|
@Value("${embedding.mapper.word.min:4}")
|
||||||
|
private int embeddingMapperWordMin;
|
||||||
|
|
||||||
@Value("${candidate.threshold}")
|
@Value("${embedding.mapper.word.max:5}")
|
||||||
private Double candidateThreshold;
|
private int embeddingMapperWordMax;
|
||||||
|
|
||||||
|
@Value("${embedding.mapper.batch:50}")
|
||||||
|
private int embeddingMapperBatch;
|
||||||
|
|
||||||
|
@Value("${embedding.mapper.number:5}")
|
||||||
|
private int embeddingMapperNumber;
|
||||||
|
|
||||||
|
@Value("${embedding.mapper.round.number:10}")
|
||||||
|
private int embeddingMapperRoundNumber;
|
||||||
|
|
||||||
|
@Value("${embedding.mapper.distance.threshold:0.58}")
|
||||||
|
private Double embeddingMapperDistanceThreshold;
|
||||||
|
|
||||||
|
@Value("${s2SQL.linking.value.switch:true}")
|
||||||
|
private boolean useLinkingValueSwitch;
|
||||||
|
|
||||||
|
@Value("${s2SQL.generation:TWO_PASS_AUTO_COT}")
|
||||||
|
private SqlGenerationMode sqlGenerationMode;
|
||||||
|
|
||||||
|
@Value("${s2SQL.use.switch:true}")
|
||||||
|
private boolean useS2SqlSwitch;
|
||||||
|
|
||||||
|
@Value("${text2sql.example.num:10}")
|
||||||
|
private int text2sqlExampleNum;
|
||||||
|
|
||||||
|
@Value("${text2sql.fewShots.num:5}")
|
||||||
|
private int text2sqlFewShotsNum;
|
||||||
|
|
||||||
|
@Value("${text2sql.self.consistency.num:2}")
|
||||||
|
private int text2sqlSelfConsistencyNum;
|
||||||
|
|
||||||
|
@Value("${text2sql.collection.name:text2dsl_agent_collection}")
|
||||||
|
private String text2sqlCollectionName;
|
||||||
|
|
||||||
|
@Autowired
|
||||||
|
private SysParameterService sysParameterService;
|
||||||
|
|
||||||
|
public Integer getOneDetectionSize() {
|
||||||
|
return convertValue("one.detection.size", Integer.class, oneDetectionSize);
|
||||||
|
}
|
||||||
|
|
||||||
|
public Integer getOneDetectionMaxSize() {
|
||||||
|
return convertValue("one.detection.max.size", Integer.class, oneDetectionMaxSize);
|
||||||
|
}
|
||||||
|
|
||||||
|
public Double getMetricDimensionMinThresholdConfig() {
|
||||||
|
return convertValue("metric.dimension.min.threshold", Double.class, metricDimensionMinThresholdConfig);
|
||||||
|
}
|
||||||
|
|
||||||
|
public Double getMetricDimensionThresholdConfig() {
|
||||||
|
return convertValue("metric.dimension.threshold", Double.class, metricDimensionThresholdConfig);
|
||||||
|
}
|
||||||
|
|
||||||
|
public Double getDimensionValueThresholdConfig() {
|
||||||
|
return convertValue("dimension.value.threshold", Double.class, dimensionValueThresholdConfig);
|
||||||
|
}
|
||||||
|
|
||||||
|
public Double getLongTextThreshold() {
|
||||||
|
return convertValue("long.text.threshold", Double.class, longTextThreshold);
|
||||||
|
}
|
||||||
|
|
||||||
|
public Double getShortTextThreshold() {
|
||||||
|
return convertValue("short.text.threshold", Double.class, shortTextThreshold);
|
||||||
|
}
|
||||||
|
|
||||||
|
public Integer getQueryTextLengthThreshold() {
|
||||||
|
return convertValue("query.text.length.threshold", Integer.class, queryTextLengthThreshold);
|
||||||
|
}
|
||||||
|
|
||||||
|
public boolean isUseS2SqlSwitch() {
|
||||||
|
return convertValue("use.s2SQL.switch", Boolean.class, useS2SqlSwitch);
|
||||||
|
}
|
||||||
|
|
||||||
|
public Integer getEmbeddingMapperWordMin() {
|
||||||
|
return convertValue("embedding.mapper.word.min", Integer.class, embeddingMapperWordMin);
|
||||||
|
}
|
||||||
|
|
||||||
|
public Integer getEmbeddingMapperWordMax() {
|
||||||
|
return convertValue("embedding.mapper.word.max", Integer.class, embeddingMapperWordMax);
|
||||||
|
}
|
||||||
|
|
||||||
|
public Integer getEmbeddingMapperBatch() {
|
||||||
|
return convertValue("embedding.mapper.batch", Integer.class, embeddingMapperBatch);
|
||||||
|
}
|
||||||
|
|
||||||
|
public Integer getEmbeddingMapperNumber() {
|
||||||
|
return convertValue("embedding.mapper.number", Integer.class, embeddingMapperNumber);
|
||||||
|
}
|
||||||
|
|
||||||
|
public Integer getEmbeddingMapperRoundNumber() {
|
||||||
|
return convertValue("embedding.mapper.round.number", Integer.class, embeddingMapperRoundNumber);
|
||||||
|
}
|
||||||
|
|
||||||
|
public Double getEmbeddingMapperDistanceThreshold() {
|
||||||
|
return convertValue("embedding.mapper.distance.threshold", Double.class, embeddingMapperDistanceThreshold);
|
||||||
|
}
|
||||||
|
|
||||||
|
public boolean isUseLinkingValueSwitch() {
|
||||||
|
return convertValue("s2SQL.linking.value.switch", Boolean.class, useLinkingValueSwitch);
|
||||||
|
}
|
||||||
|
|
||||||
|
public SqlGenerationMode getSqlGenerationMode() {
|
||||||
|
return convertValue("s2SQL.generation", SqlGenerationMode.class, sqlGenerationMode);
|
||||||
|
}
|
||||||
|
|
||||||
|
public <T> T convertValue(String paramName, Class<T> targetType, T defaultValue) {
|
||||||
|
try {
|
||||||
|
String value = sysParameterService.getSysParameter().getParameterByName(paramName);
|
||||||
|
if (StringUtils.isBlank(value)) {
|
||||||
|
return defaultValue;
|
||||||
|
}
|
||||||
|
if (targetType == Double.class) {
|
||||||
|
return targetType.cast(Double.parseDouble(value));
|
||||||
|
} else if (targetType == Integer.class) {
|
||||||
|
return targetType.cast(Integer.parseInt(value));
|
||||||
|
} else if (targetType == Boolean.class) {
|
||||||
|
return targetType.cast(Boolean.parseBoolean(value));
|
||||||
|
} else if (targetType == SqlGenerationMode.class) {
|
||||||
|
return targetType.cast(SqlGenerationMode.valueOf(value));
|
||||||
|
}
|
||||||
|
} catch (Exception e) {
|
||||||
|
log.error("convertValue", e);
|
||||||
|
}
|
||||||
|
return defaultValue;
|
||||||
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -2,20 +2,51 @@ package com.tencent.supersonic.chat.corrector;
|
|||||||
|
|
||||||
import com.tencent.supersonic.chat.api.component.SemanticCorrector;
|
import com.tencent.supersonic.chat.api.component.SemanticCorrector;
|
||||||
import com.tencent.supersonic.chat.api.pojo.SchemaElement;
|
import com.tencent.supersonic.chat.api.pojo.SchemaElement;
|
||||||
|
import com.tencent.supersonic.chat.api.pojo.SemanticParseInfo;
|
||||||
import com.tencent.supersonic.chat.api.pojo.SemanticSchema;
|
import com.tencent.supersonic.chat.api.pojo.SemanticSchema;
|
||||||
|
import com.tencent.supersonic.chat.api.pojo.request.QueryReq;
|
||||||
|
import com.tencent.supersonic.common.pojo.enums.AggregateTypeEnum;
|
||||||
|
import com.tencent.supersonic.common.pojo.enums.TimeDimensionEnum;
|
||||||
import com.tencent.supersonic.common.util.ContextUtils;
|
import com.tencent.supersonic.common.util.ContextUtils;
|
||||||
|
import com.tencent.supersonic.common.util.jsqlparser.SqlParserAddHelper;
|
||||||
|
import com.tencent.supersonic.common.util.jsqlparser.SqlParserSelectFunctionHelper;
|
||||||
|
import com.tencent.supersonic.common.util.jsqlparser.SqlParserSelectHelper;
|
||||||
import com.tencent.supersonic.knowledge.service.SchemaService;
|
import com.tencent.supersonic.knowledge.service.SchemaService;
|
||||||
import com.tencent.supersonic.semantic.api.model.enums.TimeDimensionEnum;
|
import lombok.extern.slf4j.Slf4j;
|
||||||
|
import org.apache.commons.lang3.StringUtils;
|
||||||
|
import org.apache.commons.lang3.tuple.Pair;
|
||||||
|
import org.springframework.util.CollectionUtils;
|
||||||
|
|
||||||
import java.util.ArrayList;
|
import java.util.ArrayList;
|
||||||
|
import java.util.HashSet;
|
||||||
import java.util.List;
|
import java.util.List;
|
||||||
import java.util.Map;
|
import java.util.Map;
|
||||||
|
import java.util.Objects;
|
||||||
|
import java.util.Set;
|
||||||
import java.util.stream.Collectors;
|
import java.util.stream.Collectors;
|
||||||
import lombok.extern.slf4j.Slf4j;
|
|
||||||
|
|
||||||
|
/**
|
||||||
|
* basic semantic correction functionality, offering common methods and an
|
||||||
|
* abstract method called doCorrect
|
||||||
|
*/
|
||||||
@Slf4j
|
@Slf4j
|
||||||
public abstract class BaseSemanticCorrector implements SemanticCorrector {
|
public abstract class BaseSemanticCorrector implements SemanticCorrector {
|
||||||
public static final String DATE_FIELD = "数据日期";
|
|
||||||
protected Map<String, String> getFieldToBizName(Long modelId) {
|
public void correct(QueryReq queryReq, SemanticParseInfo semanticParseInfo) {
|
||||||
|
try {
|
||||||
|
if (StringUtils.isBlank(semanticParseInfo.getSqlInfo().getCorrectS2SQL())) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
doCorrect(queryReq, semanticParseInfo);
|
||||||
|
log.info("sqlCorrection:{} sql:{}", this.getClass().getSimpleName(), semanticParseInfo.getSqlInfo());
|
||||||
|
} catch (Exception e) {
|
||||||
|
log.error(String.format("correct error,sqlInfo:%s", semanticParseInfo.getSqlInfo()), e);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
public abstract void doCorrect(QueryReq queryReq, SemanticParseInfo semanticParseInfo);
|
||||||
|
|
||||||
|
protected Map<String, String> getFieldNameMap(Set<Long> modelIds) {
|
||||||
|
|
||||||
SemanticSchema semanticSchema = ContextUtils.getBean(SchemaService.class).getSemanticSchema();
|
SemanticSchema semanticSchema = ContextUtils.getBean(SchemaService.class).getSemanticSchema();
|
||||||
|
|
||||||
@@ -23,11 +54,86 @@ public abstract class BaseSemanticCorrector implements SemanticCorrector {
|
|||||||
dbAllFields.addAll(semanticSchema.getMetrics());
|
dbAllFields.addAll(semanticSchema.getMetrics());
|
||||||
dbAllFields.addAll(semanticSchema.getDimensions());
|
dbAllFields.addAll(semanticSchema.getDimensions());
|
||||||
|
|
||||||
|
// support fieldName and field alias
|
||||||
Map<String, String> result = dbAllFields.stream()
|
Map<String, String> result = dbAllFields.stream()
|
||||||
.filter(entry -> entry.getModel().equals(modelId))
|
.filter(entry -> modelIds.contains(entry.getModel()))
|
||||||
.collect(Collectors.toMap(SchemaElement::getName, a -> a.getBizName(), (k1, k2) -> k1));
|
.flatMap(schemaElement -> {
|
||||||
result.put(DATE_FIELD, TimeDimensionEnum.DAY.getName());
|
Set<String> elements = new HashSet<>();
|
||||||
|
elements.add(schemaElement.getName());
|
||||||
|
if (!CollectionUtils.isEmpty(schemaElement.getAlias())) {
|
||||||
|
elements.addAll(schemaElement.getAlias());
|
||||||
|
}
|
||||||
|
return elements.stream();
|
||||||
|
})
|
||||||
|
.collect(Collectors.toMap(a -> a, a -> a, (k1, k2) -> k1));
|
||||||
|
result.put(TimeDimensionEnum.DAY.getChName(), TimeDimensionEnum.DAY.getChName());
|
||||||
|
result.put(TimeDimensionEnum.MONTH.getChName(), TimeDimensionEnum.MONTH.getChName());
|
||||||
|
result.put(TimeDimensionEnum.WEEK.getChName(), TimeDimensionEnum.WEEK.getChName());
|
||||||
|
|
||||||
|
result.put(TimeDimensionEnum.DAY.getName(), TimeDimensionEnum.DAY.getChName());
|
||||||
|
result.put(TimeDimensionEnum.MONTH.getName(), TimeDimensionEnum.MONTH.getChName());
|
||||||
|
result.put(TimeDimensionEnum.WEEK.getName(), TimeDimensionEnum.WEEK.getChName());
|
||||||
|
|
||||||
return result;
|
return result;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
protected void addFieldsToSelect(SemanticParseInfo semanticParseInfo, String correctS2SQL) {
|
||||||
|
Set<String> selectFields = new HashSet<>(SqlParserSelectHelper.getSelectFields(correctS2SQL));
|
||||||
|
Set<String> needAddFields = new HashSet<>(SqlParserSelectHelper.getGroupByFields(correctS2SQL));
|
||||||
|
needAddFields.addAll(SqlParserSelectHelper.getOrderByFields(correctS2SQL));
|
||||||
|
|
||||||
|
// If there is no aggregate function in the S2SQL statement and
|
||||||
|
// there is a data field in 'WHERE' statement, add the field to the 'SELECT' statement.
|
||||||
|
if (!SqlParserSelectFunctionHelper.hasAggregateFunction(correctS2SQL)) {
|
||||||
|
List<String> whereFields = SqlParserSelectHelper.getWhereFields(correctS2SQL);
|
||||||
|
List<String> timeChNameList = TimeDimensionEnum.getChNameList();
|
||||||
|
Set<String> timeFields = whereFields.stream().filter(field -> timeChNameList.contains(field))
|
||||||
|
.collect(Collectors.toSet());
|
||||||
|
needAddFields.addAll(timeFields);
|
||||||
|
}
|
||||||
|
|
||||||
|
if (CollectionUtils.isEmpty(selectFields) || CollectionUtils.isEmpty(needAddFields)) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
needAddFields.removeAll(selectFields);
|
||||||
|
String replaceFields = SqlParserAddHelper.addFieldsToSelect(correctS2SQL, new ArrayList<>(needAddFields));
|
||||||
|
semanticParseInfo.getSqlInfo().setCorrectS2SQL(replaceFields);
|
||||||
|
}
|
||||||
|
|
||||||
|
protected void addAggregateToMetric(SemanticParseInfo semanticParseInfo) {
|
||||||
|
//add aggregate to all metric
|
||||||
|
String correctS2SQL = semanticParseInfo.getSqlInfo().getCorrectS2SQL();
|
||||||
|
Set<Long> modelIds = semanticParseInfo.getModel().getModelIds();
|
||||||
|
|
||||||
|
List<SchemaElement> metrics = getMetricElements(modelIds);
|
||||||
|
|
||||||
|
Map<String, String> metricToAggregate = metrics.stream()
|
||||||
|
.map(schemaElement -> {
|
||||||
|
if (Objects.isNull(schemaElement.getDefaultAgg())) {
|
||||||
|
schemaElement.setDefaultAgg(AggregateTypeEnum.SUM.name());
|
||||||
|
}
|
||||||
|
return schemaElement;
|
||||||
|
}).flatMap(schemaElement -> {
|
||||||
|
Set<String> elements = new HashSet<>();
|
||||||
|
elements.add(schemaElement.getName());
|
||||||
|
if (!CollectionUtils.isEmpty(schemaElement.getAlias())) {
|
||||||
|
elements.addAll(schemaElement.getAlias());
|
||||||
|
}
|
||||||
|
return elements.stream().map(element -> Pair.of(element, schemaElement.getDefaultAgg())
|
||||||
|
);
|
||||||
|
}).collect(Collectors.toMap(Pair::getLeft, Pair::getRight, (k1, k2) -> k1));
|
||||||
|
|
||||||
|
if (CollectionUtils.isEmpty(metricToAggregate)) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
String aggregateSql = SqlParserAddHelper.addAggregateToField(correctS2SQL, metricToAggregate);
|
||||||
|
semanticParseInfo.getSqlInfo().setCorrectS2SQL(aggregateSql);
|
||||||
|
}
|
||||||
|
|
||||||
|
protected List<SchemaElement> getMetricElements(Set<Long> modelIds) {
|
||||||
|
SemanticSchema semanticSchema = ContextUtils.getBean(SchemaService.class).getSemanticSchema();
|
||||||
|
return semanticSchema.getMetrics(modelIds);
|
||||||
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,27 +0,0 @@
|
|||||||
package com.tencent.supersonic.chat.corrector;
|
|
||||||
|
|
||||||
import com.tencent.supersonic.chat.api.pojo.SemanticCorrectInfo;
|
|
||||||
import com.tencent.supersonic.chat.parser.llm.dsl.DSLDateHelper;
|
|
||||||
import com.tencent.supersonic.common.util.jsqlparser.SqlParserSelectHelper;
|
|
||||||
import com.tencent.supersonic.common.util.jsqlparser.SqlParserUpdateHelper;
|
|
||||||
import java.util.List;
|
|
||||||
import lombok.extern.slf4j.Slf4j;
|
|
||||||
import org.springframework.util.CollectionUtils;
|
|
||||||
|
|
||||||
@Slf4j
|
|
||||||
public class DateFieldCorrector extends BaseSemanticCorrector {
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public void correct(SemanticCorrectInfo semanticCorrectInfo) {
|
|
||||||
|
|
||||||
String sql = semanticCorrectInfo.getSql();
|
|
||||||
List<String> whereFields = SqlParserSelectHelper.getWhereFields(sql);
|
|
||||||
if (CollectionUtils.isEmpty(whereFields) || !whereFields.contains(DATE_FIELD)) {
|
|
||||||
String currentDate = DSLDateHelper.getReferenceDate(semanticCorrectInfo.getParseInfo().getModelId());
|
|
||||||
sql = SqlParserUpdateHelper.addWhere(sql, DATE_FIELD, currentDate);
|
|
||||||
}
|
|
||||||
semanticCorrectInfo.setPreSql(semanticCorrectInfo.getSql());
|
|
||||||
semanticCorrectInfo.setSql(sql);
|
|
||||||
}
|
|
||||||
|
|
||||||
}
|
|
||||||
@@ -1,18 +0,0 @@
|
|||||||
package com.tencent.supersonic.chat.corrector;
|
|
||||||
|
|
||||||
import com.tencent.supersonic.chat.api.pojo.SemanticCorrectInfo;
|
|
||||||
import com.tencent.supersonic.common.util.jsqlparser.SqlParserUpdateHelper;
|
|
||||||
import lombok.extern.slf4j.Slf4j;
|
|
||||||
|
|
||||||
@Slf4j
|
|
||||||
public class FieldCorrector extends BaseSemanticCorrector {
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public void correct(SemanticCorrectInfo semanticCorrectInfo) {
|
|
||||||
String preSql = semanticCorrectInfo.getSql();
|
|
||||||
semanticCorrectInfo.setPreSql(preSql);
|
|
||||||
String sql = SqlParserUpdateHelper.replaceFields(preSql,
|
|
||||||
getFieldToBizName(semanticCorrectInfo.getParseInfo().getModelId()));
|
|
||||||
semanticCorrectInfo.setSql(sql);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
@@ -1,49 +0,0 @@
|
|||||||
package com.tencent.supersonic.chat.corrector;
|
|
||||||
|
|
||||||
import com.tencent.supersonic.chat.api.pojo.SemanticCorrectInfo;
|
|
||||||
import com.tencent.supersonic.chat.parser.llm.dsl.DSLParseResult;
|
|
||||||
import com.tencent.supersonic.chat.query.llm.dsl.LLMReq;
|
|
||||||
import com.tencent.supersonic.chat.query.llm.dsl.LLMReq.ElementValue;
|
|
||||||
import com.tencent.supersonic.common.pojo.Constants;
|
|
||||||
import com.tencent.supersonic.common.util.JsonUtil;
|
|
||||||
import com.tencent.supersonic.common.util.jsqlparser.SqlParserUpdateHelper;
|
|
||||||
import java.util.List;
|
|
||||||
import java.util.Map;
|
|
||||||
import java.util.Objects;
|
|
||||||
import java.util.Set;
|
|
||||||
import java.util.stream.Collectors;
|
|
||||||
import lombok.extern.slf4j.Slf4j;
|
|
||||||
import org.springframework.util.CollectionUtils;
|
|
||||||
|
|
||||||
@Slf4j
|
|
||||||
public class FieldNameCorrector extends BaseSemanticCorrector {
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public void correct(SemanticCorrectInfo semanticCorrectInfo) {
|
|
||||||
|
|
||||||
Object context = semanticCorrectInfo.getParseInfo().getProperties().get(Constants.CONTEXT);
|
|
||||||
if (Objects.isNull(context)) {
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
|
|
||||||
DSLParseResult dslParseResult = JsonUtil.toObject(JsonUtil.toString(context), DSLParseResult.class);
|
|
||||||
if (Objects.isNull(dslParseResult) || Objects.isNull(dslParseResult.getLlmReq())) {
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
LLMReq llmReq = dslParseResult.getLlmReq();
|
|
||||||
List<ElementValue> linking = llmReq.getLinking();
|
|
||||||
if (CollectionUtils.isEmpty(linking)) {
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
|
|
||||||
Map<String, Set<String>> fieldValueToFieldNames = linking.stream().collect(
|
|
||||||
Collectors.groupingBy(ElementValue::getFieldValue,
|
|
||||||
Collectors.mapping(ElementValue::getFieldName, Collectors.toSet())));
|
|
||||||
|
|
||||||
String preSql = semanticCorrectInfo.getSql();
|
|
||||||
semanticCorrectInfo.setPreSql(preSql);
|
|
||||||
String sql = SqlParserUpdateHelper.replaceFieldNameByValue(preSql, fieldValueToFieldNames);
|
|
||||||
semanticCorrectInfo.setSql(sql);
|
|
||||||
}
|
|
||||||
|
|
||||||
}
|
|
||||||
@@ -1,81 +0,0 @@
|
|||||||
package com.tencent.supersonic.chat.corrector;
|
|
||||||
|
|
||||||
import com.tencent.supersonic.chat.api.pojo.SemanticCorrectInfo;
|
|
||||||
import com.tencent.supersonic.chat.api.pojo.SchemaElement;
|
|
||||||
import com.tencent.supersonic.chat.api.pojo.SchemaValueMap;
|
|
||||||
import com.tencent.supersonic.chat.api.pojo.SemanticSchema;
|
|
||||||
import com.tencent.supersonic.common.util.ContextUtils;
|
|
||||||
import com.tencent.supersonic.common.util.jsqlparser.SqlParserUpdateHelper;
|
|
||||||
import com.tencent.supersonic.knowledge.service.SchemaService;
|
|
||||||
import java.util.HashMap;
|
|
||||||
import java.util.List;
|
|
||||||
import java.util.Map;
|
|
||||||
import java.util.Objects;
|
|
||||||
import java.util.stream.Collectors;
|
|
||||||
import lombok.extern.slf4j.Slf4j;
|
|
||||||
import org.apache.logging.log4j.util.Strings;
|
|
||||||
import org.springframework.util.CollectionUtils;
|
|
||||||
|
|
||||||
@Slf4j
|
|
||||||
public class FieldValueCorrector extends BaseSemanticCorrector {
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public void correct(SemanticCorrectInfo semanticCorrectInfo) {
|
|
||||||
SemanticSchema semanticSchema = ContextUtils.getBean(SchemaService.class).getSemanticSchema();
|
|
||||||
Long modelId = semanticCorrectInfo.getParseInfo().getModel().getId();
|
|
||||||
List<SchemaElement> dimensions = semanticSchema.getDimensions().stream()
|
|
||||||
.filter(schemaElement -> modelId.equals(schemaElement.getModel()))
|
|
||||||
.collect(Collectors.toList());
|
|
||||||
|
|
||||||
if (CollectionUtils.isEmpty(dimensions)) {
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
|
|
||||||
Map<String, Map<String, String>> aliasAndBizNameToTechName = getAliasAndBizNameToTechName(dimensions);
|
|
||||||
String preSql = semanticCorrectInfo.getSql();
|
|
||||||
semanticCorrectInfo.setPreSql(preSql);
|
|
||||||
String sql = SqlParserUpdateHelper.replaceValue(preSql, aliasAndBizNameToTechName);
|
|
||||||
semanticCorrectInfo.setSql(sql);
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
private Map<String, Map<String, String>> getAliasAndBizNameToTechName(List<SchemaElement> dimensions) {
|
|
||||||
if (CollectionUtils.isEmpty(dimensions)) {
|
|
||||||
return new HashMap<>();
|
|
||||||
}
|
|
||||||
|
|
||||||
Map<String, Map<String, String>> result = new HashMap<>();
|
|
||||||
|
|
||||||
for (SchemaElement dimension : dimensions) {
|
|
||||||
if (Objects.isNull(dimension)
|
|
||||||
|| Strings.isEmpty(dimension.getBizName())
|
|
||||||
|| CollectionUtils.isEmpty(dimension.getSchemaValueMaps())) {
|
|
||||||
continue;
|
|
||||||
}
|
|
||||||
String bizName = dimension.getBizName();
|
|
||||||
|
|
||||||
Map<String, String> aliasAndBizNameToTechName = new HashMap<>();
|
|
||||||
|
|
||||||
for (SchemaValueMap valueMap : dimension.getSchemaValueMaps()) {
|
|
||||||
if (Objects.isNull(valueMap) || Strings.isEmpty(valueMap.getTechName())) {
|
|
||||||
continue;
|
|
||||||
}
|
|
||||||
if (Strings.isNotEmpty(valueMap.getBizName())) {
|
|
||||||
aliasAndBizNameToTechName.put(valueMap.getBizName(), valueMap.getTechName());
|
|
||||||
}
|
|
||||||
if (!CollectionUtils.isEmpty(valueMap.getAlias())) {
|
|
||||||
valueMap.getAlias().stream().forEach(alias -> {
|
|
||||||
if (Strings.isNotEmpty(alias)) {
|
|
||||||
aliasAndBizNameToTechName.put(alias, valueMap.getTechName());
|
|
||||||
}
|
|
||||||
});
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if (!CollectionUtils.isEmpty(aliasAndBizNameToTechName)) {
|
|
||||||
result.put(bizName, aliasAndBizNameToTechName);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return result;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
@@ -0,0 +1,19 @@
|
|||||||
|
package com.tencent.supersonic.chat.corrector;
|
||||||
|
|
||||||
|
import com.tencent.supersonic.chat.api.pojo.SemanticParseInfo;
|
||||||
|
import com.tencent.supersonic.chat.api.pojo.request.QueryReq;
|
||||||
|
import com.tencent.supersonic.common.util.jsqlparser.SqlParserReplaceHelper;
|
||||||
|
import lombok.extern.slf4j.Slf4j;
|
||||||
|
/**
|
||||||
|
* Perform SQL corrections on the "From" section in S2SQL.
|
||||||
|
*/
|
||||||
|
@Slf4j
|
||||||
|
public class FromCorrector extends BaseSemanticCorrector {
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public void doCorrect(QueryReq queryReq, SemanticParseInfo semanticParseInfo) {
|
||||||
|
String modelName = semanticParseInfo.getModel().getName();
|
||||||
|
SqlParserReplaceHelper.replaceTable(semanticParseInfo.getSqlInfo().getCorrectS2SQL(), modelName);
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
||||||
@@ -1,16 +0,0 @@
|
|||||||
package com.tencent.supersonic.chat.corrector;
|
|
||||||
|
|
||||||
import com.tencent.supersonic.chat.api.pojo.SemanticCorrectInfo;
|
|
||||||
import com.tencent.supersonic.common.util.jsqlparser.SqlParserUpdateHelper;
|
|
||||||
import lombok.extern.slf4j.Slf4j;
|
|
||||||
|
|
||||||
@Slf4j
|
|
||||||
public class FunctionAliasCorrector extends BaseSemanticCorrector {
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public void correct(SemanticCorrectInfo semanticCorrectInfo) {
|
|
||||||
String replaceAlias = SqlParserUpdateHelper.replaceAlias(semanticCorrectInfo.getSql());
|
|
||||||
semanticCorrectInfo.setSql(replaceAlias);
|
|
||||||
}
|
|
||||||
|
|
||||||
}
|
|
||||||
@@ -1,17 +0,0 @@
|
|||||||
package com.tencent.supersonic.chat.corrector;
|
|
||||||
|
|
||||||
import com.tencent.supersonic.chat.api.pojo.SemanticCorrectInfo;
|
|
||||||
import com.tencent.supersonic.common.util.jsqlparser.SqlParserUpdateHelper;
|
|
||||||
import lombok.extern.slf4j.Slf4j;
|
|
||||||
|
|
||||||
@Slf4j
|
|
||||||
public class FunctionCorrector extends BaseSemanticCorrector {
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public void correct(SemanticCorrectInfo semanticCorrectInfo) {
|
|
||||||
String preSql = semanticCorrectInfo.getSql();
|
|
||||||
semanticCorrectInfo.setPreSql(preSql);
|
|
||||||
String sql = SqlParserUpdateHelper.replaceFunction(preSql);
|
|
||||||
semanticCorrectInfo.setSql(sql);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
@@ -0,0 +1,91 @@
|
|||||||
|
package com.tencent.supersonic.chat.corrector;
|
||||||
|
|
||||||
|
import com.tencent.supersonic.chat.api.pojo.SemanticParseInfo;
|
||||||
|
import com.tencent.supersonic.chat.api.pojo.SemanticSchema;
|
||||||
|
import com.tencent.supersonic.chat.api.pojo.request.QueryReq;
|
||||||
|
import com.tencent.supersonic.chat.api.pojo.response.SqlInfo;
|
||||||
|
import com.tencent.supersonic.common.pojo.enums.TimeDimensionEnum;
|
||||||
|
import com.tencent.supersonic.common.util.ContextUtils;
|
||||||
|
import com.tencent.supersonic.common.util.jsqlparser.SqlParserAddHelper;
|
||||||
|
import com.tencent.supersonic.common.util.jsqlparser.SqlParserSelectHelper;
|
||||||
|
import com.tencent.supersonic.knowledge.service.SchemaService;
|
||||||
|
import lombok.extern.slf4j.Slf4j;
|
||||||
|
import org.springframework.util.CollectionUtils;
|
||||||
|
|
||||||
|
import java.util.HashSet;
|
||||||
|
import java.util.List;
|
||||||
|
import java.util.Set;
|
||||||
|
import java.util.stream.Collectors;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Perform SQL corrections on the "Group by" section in S2SQL.
|
||||||
|
*/
|
||||||
|
@Slf4j
|
||||||
|
public class GroupByCorrector extends BaseSemanticCorrector {
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public void doCorrect(QueryReq queryReq, SemanticParseInfo semanticParseInfo) {
|
||||||
|
|
||||||
|
addGroupByFields(semanticParseInfo);
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
private void addGroupByFields(SemanticParseInfo semanticParseInfo) {
|
||||||
|
Set<Long> modelIds = semanticParseInfo.getModel().getModelIds();
|
||||||
|
|
||||||
|
//add dimension group by
|
||||||
|
SqlInfo sqlInfo = semanticParseInfo.getSqlInfo();
|
||||||
|
String correctS2SQL = sqlInfo.getCorrectS2SQL();
|
||||||
|
SemanticSchema semanticSchema = ContextUtils.getBean(SchemaService.class).getSemanticSchema();
|
||||||
|
//add alias field name
|
||||||
|
Set<String> dimensions = semanticSchema.getDimensions(modelIds).stream()
|
||||||
|
.flatMap(
|
||||||
|
schemaElement -> {
|
||||||
|
Set<String> elements = new HashSet<>();
|
||||||
|
elements.add(schemaElement.getName());
|
||||||
|
if (!CollectionUtils.isEmpty(schemaElement.getAlias())) {
|
||||||
|
elements.addAll(schemaElement.getAlias());
|
||||||
|
}
|
||||||
|
return elements.stream();
|
||||||
|
}
|
||||||
|
).collect(Collectors.toSet());
|
||||||
|
dimensions.add(TimeDimensionEnum.DAY.getChName());
|
||||||
|
|
||||||
|
List<String> selectFields = SqlParserSelectHelper.getSelectFields(correctS2SQL);
|
||||||
|
|
||||||
|
if (CollectionUtils.isEmpty(selectFields) || CollectionUtils.isEmpty(dimensions)) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
// if only date in select not add group by.
|
||||||
|
if (selectFields.size() == 1 && selectFields.contains(TimeDimensionEnum.DAY.getChName())) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
if (SqlParserSelectHelper.hasGroupBy(correctS2SQL)) {
|
||||||
|
log.info("not add group by ,exist group by in correctS2SQL:{}", correctS2SQL);
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
List<String> aggregateFields = SqlParserSelectHelper.getAggregateFields(correctS2SQL);
|
||||||
|
Set<String> groupByFields = selectFields.stream()
|
||||||
|
.filter(field -> dimensions.contains(field))
|
||||||
|
.filter(field -> {
|
||||||
|
if (!CollectionUtils.isEmpty(aggregateFields) && aggregateFields.contains(field)) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
return true;
|
||||||
|
})
|
||||||
|
.collect(Collectors.toSet());
|
||||||
|
semanticParseInfo.getSqlInfo().setCorrectS2SQL(SqlParserAddHelper.addGroupBy(correctS2SQL, groupByFields));
|
||||||
|
|
||||||
|
addAggregate(semanticParseInfo);
|
||||||
|
}
|
||||||
|
|
||||||
|
private void addAggregate(SemanticParseInfo semanticParseInfo) {
|
||||||
|
List<String> sqlGroupByFields = SqlParserSelectHelper.getGroupByFields(
|
||||||
|
semanticParseInfo.getSqlInfo().getCorrectS2SQL());
|
||||||
|
if (CollectionUtils.isEmpty(sqlGroupByFields)) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
addAggregateToMetric(semanticParseInfo);
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -0,0 +1,65 @@
|
|||||||
|
package com.tencent.supersonic.chat.corrector;
|
||||||
|
|
||||||
|
import com.tencent.supersonic.chat.api.pojo.SemanticParseInfo;
|
||||||
|
import com.tencent.supersonic.chat.api.pojo.SemanticSchema;
|
||||||
|
import com.tencent.supersonic.chat.api.pojo.request.QueryReq;
|
||||||
|
import com.tencent.supersonic.common.util.ContextUtils;
|
||||||
|
import com.tencent.supersonic.common.util.jsqlparser.SqlParserAddHelper;
|
||||||
|
import com.tencent.supersonic.common.util.jsqlparser.SqlParserSelectFunctionHelper;
|
||||||
|
import com.tencent.supersonic.common.util.jsqlparser.SqlParserSelectHelper;
|
||||||
|
import com.tencent.supersonic.knowledge.service.SchemaService;
|
||||||
|
|
||||||
|
import java.util.Set;
|
||||||
|
import java.util.List;
|
||||||
|
import java.util.stream.Collectors;
|
||||||
|
|
||||||
|
import lombok.extern.slf4j.Slf4j;
|
||||||
|
import net.sf.jsqlparser.expression.Expression;
|
||||||
|
import org.springframework.util.CollectionUtils;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Perform SQL corrections on the "Having" section in S2SQL.
|
||||||
|
*/
|
||||||
|
@Slf4j
|
||||||
|
public class HavingCorrector extends BaseSemanticCorrector {
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public void doCorrect(QueryReq queryReq, SemanticParseInfo semanticParseInfo) {
|
||||||
|
|
||||||
|
//add aggregate to all metric
|
||||||
|
addHaving(semanticParseInfo);
|
||||||
|
|
||||||
|
//add having expression filed to select
|
||||||
|
addHavingToSelect(semanticParseInfo);
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
private void addHaving(SemanticParseInfo semanticParseInfo) {
|
||||||
|
Set<Long> modelIds = semanticParseInfo.getModel().getModelIds();
|
||||||
|
|
||||||
|
SemanticSchema semanticSchema = ContextUtils.getBean(SchemaService.class).getSemanticSchema();
|
||||||
|
|
||||||
|
Set<String> metrics = semanticSchema.getMetrics(modelIds).stream()
|
||||||
|
.map(schemaElement -> schemaElement.getName()).collect(Collectors.toSet());
|
||||||
|
|
||||||
|
if (CollectionUtils.isEmpty(metrics)) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
String havingSql = SqlParserAddHelper.addHaving(semanticParseInfo.getSqlInfo().getCorrectS2SQL(), metrics);
|
||||||
|
semanticParseInfo.getSqlInfo().setCorrectS2SQL(havingSql);
|
||||||
|
}
|
||||||
|
|
||||||
|
private void addHavingToSelect(SemanticParseInfo semanticParseInfo) {
|
||||||
|
String correctS2SQL = semanticParseInfo.getSqlInfo().getCorrectS2SQL();
|
||||||
|
if (!SqlParserSelectFunctionHelper.hasAggregateFunction(correctS2SQL)) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
List<Expression> havingExpressionList = SqlParserSelectHelper.getHavingExpression(correctS2SQL);
|
||||||
|
if (!CollectionUtils.isEmpty(havingExpressionList)) {
|
||||||
|
String replaceSql = SqlParserAddHelper.addFunctionToSelect(correctS2SQL, havingExpressionList);
|
||||||
|
semanticParseInfo.getSqlInfo().setCorrectS2SQL(replaceSql);
|
||||||
|
}
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
||||||
@@ -1,48 +0,0 @@
|
|||||||
package com.tencent.supersonic.chat.corrector;
|
|
||||||
|
|
||||||
import com.tencent.supersonic.chat.api.pojo.SemanticCorrectInfo;
|
|
||||||
import com.tencent.supersonic.chat.api.pojo.request.QueryFilters;
|
|
||||||
import com.tencent.supersonic.common.pojo.Constants;
|
|
||||||
import com.tencent.supersonic.common.util.StringUtil;
|
|
||||||
import com.tencent.supersonic.common.util.jsqlparser.SqlParserUpdateHelper;
|
|
||||||
import java.util.Objects;
|
|
||||||
import java.util.stream.Collectors;
|
|
||||||
import lombok.extern.slf4j.Slf4j;
|
|
||||||
import net.sf.jsqlparser.JSQLParserException;
|
|
||||||
import net.sf.jsqlparser.expression.Expression;
|
|
||||||
import net.sf.jsqlparser.parser.CCJSqlParserUtil;
|
|
||||||
import org.apache.commons.collections.CollectionUtils;
|
|
||||||
import org.apache.commons.lang3.StringUtils;
|
|
||||||
|
|
||||||
@Slf4j
|
|
||||||
public class QueryFilterAppend extends BaseSemanticCorrector {
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public void correct(SemanticCorrectInfo semanticCorrectInfo) throws JSQLParserException {
|
|
||||||
String queryFilter = getQueryFilter(semanticCorrectInfo.getQueryFilters());
|
|
||||||
String preSql = semanticCorrectInfo.getSql();
|
|
||||||
|
|
||||||
if (StringUtils.isNotEmpty(queryFilter)) {
|
|
||||||
log.info("add queryFilter to preSql :{}", queryFilter);
|
|
||||||
Expression expression = CCJSqlParserUtil.parseCondExpression(queryFilter);
|
|
||||||
String sql = SqlParserUpdateHelper.addWhere(preSql, expression);
|
|
||||||
semanticCorrectInfo.setPreSql(preSql);
|
|
||||||
semanticCorrectInfo.setSql(sql);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
private String getQueryFilter(QueryFilters queryFilters) {
|
|
||||||
if (Objects.isNull(queryFilters) || CollectionUtils.isEmpty(queryFilters.getFilters())) {
|
|
||||||
return null;
|
|
||||||
}
|
|
||||||
return queryFilters.getFilters().stream()
|
|
||||||
.map(filter -> {
|
|
||||||
String bizNameWrap = StringUtil.getSpaceWrap(filter.getBizName());
|
|
||||||
String operatorWrap = StringUtil.getSpaceWrap(filter.getOperator().getValue());
|
|
||||||
String valueWrap = StringUtil.getCommaWrap(filter.getValue().toString());
|
|
||||||
return bizNameWrap + operatorWrap + valueWrap;
|
|
||||||
})
|
|
||||||
.collect(Collectors.joining(Constants.AND_UPPER));
|
|
||||||
}
|
|
||||||
|
|
||||||
}
|
|
||||||
@@ -0,0 +1,107 @@
|
|||||||
|
package com.tencent.supersonic.chat.corrector;
|
||||||
|
|
||||||
|
import com.tencent.supersonic.chat.api.pojo.SemanticParseInfo;
|
||||||
|
import com.tencent.supersonic.chat.api.pojo.request.QueryReq;
|
||||||
|
import com.tencent.supersonic.chat.api.pojo.response.SqlInfo;
|
||||||
|
import com.tencent.supersonic.chat.parser.sql.llm.ParseResult;
|
||||||
|
import com.tencent.supersonic.chat.query.llm.s2sql.LLMReq.ElementValue;
|
||||||
|
import com.tencent.supersonic.common.pojo.Constants;
|
||||||
|
import com.tencent.supersonic.common.util.JsonUtil;
|
||||||
|
import com.tencent.supersonic.common.util.jsqlparser.AggregateEnum;
|
||||||
|
import com.tencent.supersonic.common.util.jsqlparser.SqlParserReplaceHelper;
|
||||||
|
import java.util.List;
|
||||||
|
import java.util.Map;
|
||||||
|
import java.util.Objects;
|
||||||
|
import java.util.Set;
|
||||||
|
import java.util.stream.Collectors;
|
||||||
|
import lombok.extern.slf4j.Slf4j;
|
||||||
|
import org.springframework.util.CollectionUtils;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Perform schema corrections on the Schema information in S2QL.
|
||||||
|
*/
|
||||||
|
@Slf4j
|
||||||
|
public class SchemaCorrector extends BaseSemanticCorrector {
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public void doCorrect(QueryReq queryReq, SemanticParseInfo semanticParseInfo) {
|
||||||
|
|
||||||
|
correctAggFunction(semanticParseInfo);
|
||||||
|
|
||||||
|
replaceAlias(semanticParseInfo);
|
||||||
|
|
||||||
|
updateFieldNameByLinkingValue(semanticParseInfo);
|
||||||
|
|
||||||
|
updateFieldValueByLinkingValue(semanticParseInfo);
|
||||||
|
|
||||||
|
correctFieldName(semanticParseInfo);
|
||||||
|
}
|
||||||
|
|
||||||
|
private void correctAggFunction(SemanticParseInfo semanticParseInfo) {
|
||||||
|
Map<String, String> aggregateEnum = AggregateEnum.getAggregateEnum();
|
||||||
|
SqlInfo sqlInfo = semanticParseInfo.getSqlInfo();
|
||||||
|
String sql = SqlParserReplaceHelper.replaceFunction(sqlInfo.getCorrectS2SQL(), aggregateEnum);
|
||||||
|
sqlInfo.setCorrectS2SQL(sql);
|
||||||
|
}
|
||||||
|
|
||||||
|
private void replaceAlias(SemanticParseInfo semanticParseInfo) {
|
||||||
|
SqlInfo sqlInfo = semanticParseInfo.getSqlInfo();
|
||||||
|
String replaceAlias = SqlParserReplaceHelper.replaceAlias(sqlInfo.getCorrectS2SQL());
|
||||||
|
sqlInfo.setCorrectS2SQL(replaceAlias);
|
||||||
|
}
|
||||||
|
|
||||||
|
private void correctFieldName(SemanticParseInfo semanticParseInfo) {
|
||||||
|
Map<String, String> fieldNameMap = getFieldNameMap(semanticParseInfo.getModel().getModelIds());
|
||||||
|
SqlInfo sqlInfo = semanticParseInfo.getSqlInfo();
|
||||||
|
String sql = SqlParserReplaceHelper.replaceFields(sqlInfo.getCorrectS2SQL(), fieldNameMap);
|
||||||
|
sqlInfo.setCorrectS2SQL(sql);
|
||||||
|
}
|
||||||
|
|
||||||
|
private void updateFieldNameByLinkingValue(SemanticParseInfo semanticParseInfo) {
|
||||||
|
List<ElementValue> linking = getLinkingValues(semanticParseInfo);
|
||||||
|
if (CollectionUtils.isEmpty(linking)) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
Map<String, Set<String>> fieldValueToFieldNames = linking.stream().collect(
|
||||||
|
Collectors.groupingBy(ElementValue::getFieldValue,
|
||||||
|
Collectors.mapping(ElementValue::getFieldName, Collectors.toSet())));
|
||||||
|
|
||||||
|
SqlInfo sqlInfo = semanticParseInfo.getSqlInfo();
|
||||||
|
|
||||||
|
String sql = SqlParserReplaceHelper.replaceFieldNameByValue(sqlInfo.getCorrectS2SQL(), fieldValueToFieldNames);
|
||||||
|
sqlInfo.setCorrectS2SQL(sql);
|
||||||
|
}
|
||||||
|
|
||||||
|
private List<ElementValue> getLinkingValues(SemanticParseInfo semanticParseInfo) {
|
||||||
|
Object context = semanticParseInfo.getProperties().get(Constants.CONTEXT);
|
||||||
|
if (Objects.isNull(context)) {
|
||||||
|
return null;
|
||||||
|
}
|
||||||
|
|
||||||
|
ParseResult parseResult = JsonUtil.toObject(JsonUtil.toString(context), ParseResult.class);
|
||||||
|
if (Objects.isNull(parseResult) || Objects.isNull(parseResult.getLlmReq())) {
|
||||||
|
return null;
|
||||||
|
}
|
||||||
|
return parseResult.getLinkingValues();
|
||||||
|
}
|
||||||
|
|
||||||
|
private void updateFieldValueByLinkingValue(SemanticParseInfo semanticParseInfo) {
|
||||||
|
List<ElementValue> linking = getLinkingValues(semanticParseInfo);
|
||||||
|
if (CollectionUtils.isEmpty(linking)) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
Map<String, Map<String, String>> filedNameToValueMap = linking.stream().collect(
|
||||||
|
Collectors.groupingBy(ElementValue::getFieldName,
|
||||||
|
Collectors.mapping(ElementValue::getFieldValue, Collectors.toMap(
|
||||||
|
oldValue -> oldValue,
|
||||||
|
newValue -> newValue,
|
||||||
|
(existingValue, newValue) -> newValue)
|
||||||
|
)));
|
||||||
|
|
||||||
|
SqlInfo sqlInfo = semanticParseInfo.getSqlInfo();
|
||||||
|
String sql = SqlParserReplaceHelper.replaceValue(sqlInfo.getCorrectS2SQL(), filedNameToValueMap, false);
|
||||||
|
sqlInfo.setCorrectS2SQL(sql);
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -0,0 +1,29 @@
|
|||||||
|
package com.tencent.supersonic.chat.corrector;
|
||||||
|
|
||||||
|
import com.tencent.supersonic.chat.api.pojo.SemanticParseInfo;
|
||||||
|
import com.tencent.supersonic.chat.api.pojo.request.QueryReq;
|
||||||
|
import com.tencent.supersonic.common.util.jsqlparser.SqlParserSelectHelper;
|
||||||
|
import java.util.List;
|
||||||
|
import lombok.extern.slf4j.Slf4j;
|
||||||
|
import org.springframework.util.CollectionUtils;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Perform SQL corrections on the "Select" section in S2SQL.
|
||||||
|
*/
|
||||||
|
@Slf4j
|
||||||
|
public class SelectCorrector extends BaseSemanticCorrector {
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public void doCorrect(QueryReq queryReq, SemanticParseInfo semanticParseInfo) {
|
||||||
|
String correctS2SQL = semanticParseInfo.getSqlInfo().getCorrectS2SQL();
|
||||||
|
List<String> aggregateFields = SqlParserSelectHelper.getAggregateFields(correctS2SQL);
|
||||||
|
List<String> selectFields = SqlParserSelectHelper.getSelectFields(correctS2SQL);
|
||||||
|
// If the number of aggregated fields is equal to the number of queried fields, do not add fields to select.
|
||||||
|
if (!CollectionUtils.isEmpty(aggregateFields)
|
||||||
|
&& !CollectionUtils.isEmpty(selectFields)
|
||||||
|
&& aggregateFields.size() == selectFields.size()) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
addFieldsToSelect(semanticParseInfo, correctS2SQL);
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -1,38 +0,0 @@
|
|||||||
package com.tencent.supersonic.chat.corrector;
|
|
||||||
|
|
||||||
import com.tencent.supersonic.chat.api.pojo.SemanticCorrectInfo;
|
|
||||||
import com.tencent.supersonic.common.util.jsqlparser.SqlParserSelectHelper;
|
|
||||||
import com.tencent.supersonic.common.util.jsqlparser.SqlParserUpdateHelper;
|
|
||||||
import com.tencent.supersonic.semantic.api.model.enums.TimeDimensionEnum;
|
|
||||||
import java.util.ArrayList;
|
|
||||||
import java.util.HashSet;
|
|
||||||
import java.util.Set;
|
|
||||||
import lombok.extern.slf4j.Slf4j;
|
|
||||||
import org.springframework.util.CollectionUtils;
|
|
||||||
|
|
||||||
@Slf4j
|
|
||||||
public class SelectFieldAppendCorrector extends BaseSemanticCorrector {
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public void correct(SemanticCorrectInfo semanticCorrectInfo) {
|
|
||||||
String preSql = semanticCorrectInfo.getSql();
|
|
||||||
if (SqlParserSelectHelper.hasAggregateFunction(preSql)) {
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
Set<String> selectFields = new HashSet<>(SqlParserSelectHelper.getSelectFields(preSql));
|
|
||||||
Set<String> whereFields = new HashSet<>(SqlParserSelectHelper.getWhereFields(preSql));
|
|
||||||
|
|
||||||
if (CollectionUtils.isEmpty(selectFields) || CollectionUtils.isEmpty(whereFields)) {
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
|
|
||||||
whereFields.addAll(SqlParserSelectHelper.getOrderByFields(preSql));
|
|
||||||
whereFields.removeAll(selectFields);
|
|
||||||
whereFields.remove(TimeDimensionEnum.DAY.getName());
|
|
||||||
whereFields.remove(TimeDimensionEnum.WEEK.getName());
|
|
||||||
whereFields.remove(TimeDimensionEnum.MONTH.getName());
|
|
||||||
String replaceFields = SqlParserUpdateHelper.addFieldsToSelect(preSql, new ArrayList<>(whereFields));
|
|
||||||
semanticCorrectInfo.setPreSql(preSql);
|
|
||||||
semanticCorrectInfo.setSql(replaceFields);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
@@ -1,21 +0,0 @@
|
|||||||
package com.tencent.supersonic.chat.corrector;
|
|
||||||
|
|
||||||
import com.tencent.supersonic.chat.api.pojo.SemanticCorrectInfo;
|
|
||||||
import com.tencent.supersonic.common.util.jsqlparser.SqlParserUpdateHelper;
|
|
||||||
import lombok.extern.slf4j.Slf4j;
|
|
||||||
|
|
||||||
@Slf4j
|
|
||||||
public class TableNameCorrector extends BaseSemanticCorrector {
|
|
||||||
|
|
||||||
public static final String TABLE_PREFIX = "t_";
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public void correct(SemanticCorrectInfo semanticCorrectInfo) {
|
|
||||||
Long modelId = semanticCorrectInfo.getParseInfo().getModelId();
|
|
||||||
String preSql = semanticCorrectInfo.getSql();
|
|
||||||
semanticCorrectInfo.setPreSql(preSql);
|
|
||||||
String sql = SqlParserUpdateHelper.replaceTable(preSql, TABLE_PREFIX + modelId);
|
|
||||||
semanticCorrectInfo.setSql(sql);
|
|
||||||
}
|
|
||||||
|
|
||||||
}
|
|
||||||
@@ -0,0 +1,156 @@
|
|||||||
|
package com.tencent.supersonic.chat.corrector;
|
||||||
|
|
||||||
|
import com.tencent.supersonic.chat.api.pojo.SchemaElement;
|
||||||
|
import com.tencent.supersonic.chat.api.pojo.SchemaValueMap;
|
||||||
|
import com.tencent.supersonic.chat.api.pojo.SemanticParseInfo;
|
||||||
|
import com.tencent.supersonic.chat.api.pojo.SemanticSchema;
|
||||||
|
import com.tencent.supersonic.chat.api.pojo.request.QueryFilters;
|
||||||
|
import com.tencent.supersonic.chat.api.pojo.request.QueryReq;
|
||||||
|
import com.tencent.supersonic.chat.parser.sql.llm.S2SqlDateHelper;
|
||||||
|
import com.tencent.supersonic.common.pojo.Constants;
|
||||||
|
import com.tencent.supersonic.common.pojo.enums.TimeDimensionEnum;
|
||||||
|
import com.tencent.supersonic.common.util.ContextUtils;
|
||||||
|
import com.tencent.supersonic.common.util.StringUtil;
|
||||||
|
import com.tencent.supersonic.common.util.jsqlparser.SqlParserAddHelper;
|
||||||
|
import com.tencent.supersonic.common.util.jsqlparser.SqlParserReplaceHelper;
|
||||||
|
import com.tencent.supersonic.common.util.jsqlparser.SqlParserSelectHelper;
|
||||||
|
import com.tencent.supersonic.knowledge.service.SchemaService;
|
||||||
|
import lombok.extern.slf4j.Slf4j;
|
||||||
|
import net.sf.jsqlparser.JSQLParserException;
|
||||||
|
import net.sf.jsqlparser.expression.Expression;
|
||||||
|
import net.sf.jsqlparser.parser.CCJSqlParserUtil;
|
||||||
|
import org.apache.commons.lang3.StringUtils;
|
||||||
|
import org.apache.logging.log4j.util.Strings;
|
||||||
|
import org.springframework.util.CollectionUtils;
|
||||||
|
|
||||||
|
import java.util.HashMap;
|
||||||
|
import java.util.List;
|
||||||
|
import java.util.Map;
|
||||||
|
import java.util.Objects;
|
||||||
|
import java.util.Set;
|
||||||
|
import java.util.stream.Collectors;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Perform SQL corrections on the "Where" section in S2SQL.
|
||||||
|
*/
|
||||||
|
@Slf4j
|
||||||
|
public class WhereCorrector extends BaseSemanticCorrector {
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public void doCorrect(QueryReq queryReq, SemanticParseInfo semanticParseInfo) {
|
||||||
|
|
||||||
|
addDateIfNotExist(semanticParseInfo);
|
||||||
|
|
||||||
|
parserDateDiffFunction(semanticParseInfo);
|
||||||
|
|
||||||
|
addQueryFilter(queryReq, semanticParseInfo);
|
||||||
|
|
||||||
|
updateFieldValueByTechName(semanticParseInfo);
|
||||||
|
}
|
||||||
|
|
||||||
|
private void addQueryFilter(QueryReq queryReq, SemanticParseInfo semanticParseInfo) {
|
||||||
|
String queryFilter = getQueryFilter(queryReq.getQueryFilters());
|
||||||
|
|
||||||
|
String correctS2SQL = semanticParseInfo.getSqlInfo().getCorrectS2SQL();
|
||||||
|
|
||||||
|
if (StringUtils.isNotEmpty(queryFilter)) {
|
||||||
|
log.info("add queryFilter to correctS2SQL :{}", queryFilter);
|
||||||
|
Expression expression = null;
|
||||||
|
try {
|
||||||
|
expression = CCJSqlParserUtil.parseCondExpression(queryFilter);
|
||||||
|
} catch (JSQLParserException e) {
|
||||||
|
log.error("parseCondExpression", e);
|
||||||
|
}
|
||||||
|
correctS2SQL = SqlParserAddHelper.addWhere(correctS2SQL, expression);
|
||||||
|
semanticParseInfo.getSqlInfo().setCorrectS2SQL(correctS2SQL);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
private void parserDateDiffFunction(SemanticParseInfo semanticParseInfo) {
|
||||||
|
String correctS2SQL = semanticParseInfo.getSqlInfo().getCorrectS2SQL();
|
||||||
|
correctS2SQL = SqlParserReplaceHelper.replaceFunction(correctS2SQL);
|
||||||
|
semanticParseInfo.getSqlInfo().setCorrectS2SQL(correctS2SQL);
|
||||||
|
}
|
||||||
|
|
||||||
|
private void addDateIfNotExist(SemanticParseInfo semanticParseInfo) {
|
||||||
|
String correctS2SQL = semanticParseInfo.getSqlInfo().getCorrectS2SQL();
|
||||||
|
List<String> whereFields = SqlParserSelectHelper.getWhereFields(correctS2SQL);
|
||||||
|
if (CollectionUtils.isEmpty(whereFields) || !TimeDimensionEnum.containsZhTimeDimension(whereFields)) {
|
||||||
|
String currentDate = S2SqlDateHelper.getReferenceDate(semanticParseInfo.getModelId());
|
||||||
|
if (StringUtils.isNotBlank(currentDate)) {
|
||||||
|
correctS2SQL = SqlParserAddHelper.addParenthesisToWhere(correctS2SQL);
|
||||||
|
correctS2SQL = SqlParserAddHelper.addWhere(
|
||||||
|
correctS2SQL, TimeDimensionEnum.DAY.getChName(), currentDate);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
semanticParseInfo.getSqlInfo().setCorrectS2SQL(correctS2SQL);
|
||||||
|
}
|
||||||
|
|
||||||
|
private String getQueryFilter(QueryFilters queryFilters) {
|
||||||
|
if (Objects.isNull(queryFilters) || CollectionUtils.isEmpty(queryFilters.getFilters())) {
|
||||||
|
return null;
|
||||||
|
}
|
||||||
|
return queryFilters.getFilters().stream()
|
||||||
|
.map(filter -> {
|
||||||
|
String bizNameWrap = StringUtil.getSpaceWrap(filter.getName());
|
||||||
|
String operatorWrap = StringUtil.getSpaceWrap(filter.getOperator().getValue());
|
||||||
|
String valueWrap = StringUtil.getCommaWrap(filter.getValue().toString());
|
||||||
|
return bizNameWrap + operatorWrap + valueWrap;
|
||||||
|
})
|
||||||
|
.collect(Collectors.joining(Constants.AND_UPPER));
|
||||||
|
}
|
||||||
|
|
||||||
|
private void updateFieldValueByTechName(SemanticParseInfo semanticParseInfo) {
|
||||||
|
SemanticSchema semanticSchema = ContextUtils.getBean(SchemaService.class).getSemanticSchema();
|
||||||
|
Set<Long> modelIds = semanticParseInfo.getModel().getModelIds();
|
||||||
|
List<SchemaElement> dimensions = semanticSchema.getDimensions(modelIds);
|
||||||
|
|
||||||
|
if (CollectionUtils.isEmpty(dimensions)) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
Map<String, Map<String, String>> aliasAndBizNameToTechName = getAliasAndBizNameToTechName(dimensions);
|
||||||
|
String correctS2SQL = SqlParserReplaceHelper.replaceValue(semanticParseInfo.getSqlInfo().getCorrectS2SQL(),
|
||||||
|
aliasAndBizNameToTechName);
|
||||||
|
semanticParseInfo.getSqlInfo().setCorrectS2SQL(correctS2SQL);
|
||||||
|
}
|
||||||
|
|
||||||
|
private Map<String, Map<String, String>> getAliasAndBizNameToTechName(List<SchemaElement> dimensions) {
|
||||||
|
if (CollectionUtils.isEmpty(dimensions)) {
|
||||||
|
return new HashMap<>();
|
||||||
|
}
|
||||||
|
|
||||||
|
Map<String, Map<String, String>> result = new HashMap<>();
|
||||||
|
|
||||||
|
for (SchemaElement dimension : dimensions) {
|
||||||
|
if (Objects.isNull(dimension)
|
||||||
|
|| Strings.isEmpty(dimension.getName())
|
||||||
|
|| CollectionUtils.isEmpty(dimension.getSchemaValueMaps())) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
String name = dimension.getName();
|
||||||
|
|
||||||
|
Map<String, String> aliasAndBizNameToTechName = new HashMap<>();
|
||||||
|
|
||||||
|
for (SchemaValueMap valueMap : dimension.getSchemaValueMaps()) {
|
||||||
|
if (Objects.isNull(valueMap) || Strings.isEmpty(valueMap.getTechName())) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
if (Strings.isNotEmpty(valueMap.getBizName())) {
|
||||||
|
aliasAndBizNameToTechName.put(valueMap.getBizName(), valueMap.getTechName());
|
||||||
|
}
|
||||||
|
if (!CollectionUtils.isEmpty(valueMap.getAlias())) {
|
||||||
|
valueMap.getAlias().stream().forEach(alias -> {
|
||||||
|
if (Strings.isNotEmpty(alias)) {
|
||||||
|
aliasAndBizNameToTechName.put(alias, valueMap.getTechName());
|
||||||
|
}
|
||||||
|
});
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if (!CollectionUtils.isEmpty(aliasAndBizNameToTechName)) {
|
||||||
|
result.put(name, aliasAndBizNameToTechName);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return result;
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -0,0 +1,100 @@
|
|||||||
|
package com.tencent.supersonic.chat.mapper;
|
||||||
|
|
||||||
|
import com.tencent.supersonic.chat.api.component.SchemaMapper;
|
||||||
|
import com.tencent.supersonic.chat.api.pojo.ModelSchema;
|
||||||
|
import com.tencent.supersonic.chat.api.pojo.QueryContext;
|
||||||
|
import com.tencent.supersonic.chat.api.pojo.SchemaElement;
|
||||||
|
import com.tencent.supersonic.chat.api.pojo.SchemaElementMatch;
|
||||||
|
import com.tencent.supersonic.chat.api.pojo.SchemaElementType;
|
||||||
|
import com.tencent.supersonic.chat.api.pojo.SchemaMapInfo;
|
||||||
|
import com.tencent.supersonic.chat.service.SemanticService;
|
||||||
|
import com.tencent.supersonic.common.util.ContextUtils;
|
||||||
|
import java.util.ArrayList;
|
||||||
|
import java.util.List;
|
||||||
|
import java.util.Map;
|
||||||
|
import java.util.Objects;
|
||||||
|
import java.util.concurrent.atomic.AtomicBoolean;
|
||||||
|
import java.util.stream.Collectors;
|
||||||
|
import lombok.extern.slf4j.Slf4j;
|
||||||
|
import org.apache.commons.lang3.StringUtils;
|
||||||
|
import org.springframework.beans.BeanUtils;
|
||||||
|
|
||||||
|
@Slf4j
|
||||||
|
public abstract class BaseMapper implements SchemaMapper {
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public void map(QueryContext queryContext) {
|
||||||
|
|
||||||
|
String simpleName = this.getClass().getSimpleName();
|
||||||
|
long startTime = System.currentTimeMillis();
|
||||||
|
log.info("before {},mapInfo:{}", simpleName, queryContext.getMapInfo().getModelElementMatches());
|
||||||
|
|
||||||
|
try {
|
||||||
|
doMap(queryContext);
|
||||||
|
} catch (Exception e) {
|
||||||
|
log.error("work error", e);
|
||||||
|
}
|
||||||
|
|
||||||
|
long cost = System.currentTimeMillis() - startTime;
|
||||||
|
log.info("after {},cost:{},mapInfo:{}", simpleName, cost, queryContext.getMapInfo().getModelElementMatches());
|
||||||
|
}
|
||||||
|
|
||||||
|
public abstract void doMap(QueryContext queryContext);
|
||||||
|
|
||||||
|
public void addToSchemaMap(SchemaMapInfo schemaMap, Long modelId, SchemaElementMatch newElementMatch) {
|
||||||
|
Map<Long, List<SchemaElementMatch>> modelElementMatches = schemaMap.getModelElementMatches();
|
||||||
|
List<SchemaElementMatch> schemaElementMatches = modelElementMatches.putIfAbsent(modelId, new ArrayList<>());
|
||||||
|
if (schemaElementMatches == null) {
|
||||||
|
schemaElementMatches = modelElementMatches.get(modelId);
|
||||||
|
}
|
||||||
|
//remove duplication
|
||||||
|
AtomicBoolean needAddNew = new AtomicBoolean(true);
|
||||||
|
schemaElementMatches.removeIf(
|
||||||
|
existElementMatch -> {
|
||||||
|
SchemaElement existElement = existElementMatch.getElement();
|
||||||
|
SchemaElement newElement = newElementMatch.getElement();
|
||||||
|
if (existElement.equals(newElement)) {
|
||||||
|
if (newElementMatch.getSimilarity() > existElementMatch.getSimilarity()) {
|
||||||
|
return true;
|
||||||
|
} else {
|
||||||
|
needAddNew.set(false);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
);
|
||||||
|
if (needAddNew.get()) {
|
||||||
|
schemaElementMatches.add(newElementMatch);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
public SchemaElement getSchemaElement(Long modelId, SchemaElementType elementType, Long elementID) {
|
||||||
|
SchemaElement element = new SchemaElement();
|
||||||
|
SemanticService schemaService = ContextUtils.getBean(SemanticService.class);
|
||||||
|
ModelSchema modelSchema = schemaService.getModelSchema(modelId);
|
||||||
|
if (Objects.isNull(modelSchema)) {
|
||||||
|
return null;
|
||||||
|
}
|
||||||
|
SchemaElement elementDb = modelSchema.getElement(elementType, elementID);
|
||||||
|
if (Objects.isNull(elementDb)) {
|
||||||
|
log.info("element is null, elementType:{},elementID:{}", elementType, elementID);
|
||||||
|
return null;
|
||||||
|
}
|
||||||
|
BeanUtils.copyProperties(elementDb, element);
|
||||||
|
element.setAlias(getAlias(elementDb));
|
||||||
|
return element;
|
||||||
|
}
|
||||||
|
|
||||||
|
public List<String> getAlias(SchemaElement element) {
|
||||||
|
if (!SchemaElementType.VALUE.equals(element.getType())) {
|
||||||
|
return element.getAlias();
|
||||||
|
}
|
||||||
|
if (org.apache.commons.collections.CollectionUtils.isNotEmpty(element.getAlias()) && StringUtils.isNotEmpty(
|
||||||
|
element.getName())) {
|
||||||
|
return element.getAlias().stream()
|
||||||
|
.filter(aliasItem -> aliasItem.contains(element.getName()))
|
||||||
|
.collect(Collectors.toList());
|
||||||
|
}
|
||||||
|
return element.getAlias();
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -0,0 +1,154 @@
|
|||||||
|
package com.tencent.supersonic.chat.mapper;
|
||||||
|
|
||||||
|
import com.hankcs.hanlp.seg.common.Term;
|
||||||
|
import com.tencent.supersonic.chat.api.pojo.QueryContext;
|
||||||
|
import com.tencent.supersonic.knowledge.utils.NatureHelper;
|
||||||
|
import java.util.ArrayList;
|
||||||
|
import java.util.Comparator;
|
||||||
|
import java.util.HashMap;
|
||||||
|
import java.util.HashSet;
|
||||||
|
import java.util.List;
|
||||||
|
import java.util.Map;
|
||||||
|
import java.util.Objects;
|
||||||
|
import java.util.Optional;
|
||||||
|
import java.util.Set;
|
||||||
|
import java.util.stream.Collectors;
|
||||||
|
import lombok.extern.slf4j.Slf4j;
|
||||||
|
import org.apache.commons.collections4.CollectionUtils;
|
||||||
|
import org.apache.commons.lang3.StringUtils;
|
||||||
|
import org.springframework.beans.factory.annotation.Autowired;
|
||||||
|
import org.springframework.stereotype.Service;
|
||||||
|
|
||||||
|
@Service
|
||||||
|
@Slf4j
|
||||||
|
public abstract class BaseMatchStrategy<T> implements MatchStrategy<T> {
|
||||||
|
|
||||||
|
@Autowired
|
||||||
|
private MapperHelper mapperHelper;
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public Map<MatchText, List<T>> match(QueryContext queryContext, List<Term> terms, Set<Long> detectModelIds) {
|
||||||
|
String text = queryContext.getRequest().getQueryText();
|
||||||
|
if (Objects.isNull(terms) || StringUtils.isEmpty(text)) {
|
||||||
|
return null;
|
||||||
|
}
|
||||||
|
|
||||||
|
log.debug("terms:{},,detectModelIds:{}", terms, detectModelIds);
|
||||||
|
|
||||||
|
List<T> detects = detect(queryContext, terms, detectModelIds);
|
||||||
|
Map<MatchText, List<T>> result = new HashMap<>();
|
||||||
|
|
||||||
|
result.put(MatchText.builder().regText(text).detectSegment(text).build(), detects);
|
||||||
|
return result;
|
||||||
|
}
|
||||||
|
|
||||||
|
public List<T> detect(QueryContext queryContext, List<Term> terms, Set<Long> detectModelIds) {
|
||||||
|
Map<Integer, Integer> regOffsetToLength = getRegOffsetToLength(terms);
|
||||||
|
String text = queryContext.getRequest().getQueryText();
|
||||||
|
Set<T> results = new HashSet<>();
|
||||||
|
|
||||||
|
Set<String> detectSegments = new HashSet<>();
|
||||||
|
|
||||||
|
for (Integer startIndex = 0; startIndex <= text.length() - 1; ) {
|
||||||
|
|
||||||
|
for (Integer index = startIndex; index <= text.length(); ) {
|
||||||
|
int offset = mapperHelper.getStepOffset(terms, startIndex);
|
||||||
|
index = mapperHelper.getStepIndex(regOffsetToLength, index);
|
||||||
|
if (index <= text.length()) {
|
||||||
|
String detectSegment = text.substring(startIndex, index);
|
||||||
|
detectSegments.add(detectSegment);
|
||||||
|
detectByStep(queryContext, results, detectModelIds, startIndex, index, offset);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
startIndex = mapperHelper.getStepIndex(regOffsetToLength, startIndex);
|
||||||
|
}
|
||||||
|
detectByBatch(queryContext, results, detectModelIds, detectSegments);
|
||||||
|
return new ArrayList<>(results);
|
||||||
|
}
|
||||||
|
|
||||||
|
protected void detectByBatch(QueryContext queryContext, Set<T> results, Set<Long> detectModelIds,
|
||||||
|
Set<String> detectSegments) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
public Map<Integer, Integer> getRegOffsetToLength(List<Term> terms) {
|
||||||
|
return terms.stream().sorted(Comparator.comparing(Term::length))
|
||||||
|
.collect(Collectors.toMap(Term::getOffset, term -> term.word.length(), (value1, value2) -> value2));
|
||||||
|
}
|
||||||
|
|
||||||
|
public void selectResultInOneRound(Set<T> existResults, List<T> oneRoundResults) {
|
||||||
|
if (CollectionUtils.isEmpty(oneRoundResults)) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
for (T oneRoundResult : oneRoundResults) {
|
||||||
|
if (existResults.contains(oneRoundResult)) {
|
||||||
|
boolean isDeleted = existResults.removeIf(
|
||||||
|
existResult -> {
|
||||||
|
boolean delete = needDelete(oneRoundResult, existResult);
|
||||||
|
if (delete) {
|
||||||
|
log.info("deleted existResult:{}", existResult);
|
||||||
|
}
|
||||||
|
return delete;
|
||||||
|
}
|
||||||
|
);
|
||||||
|
if (isDeleted) {
|
||||||
|
log.info("deleted, add oneRoundResult:{}", oneRoundResult);
|
||||||
|
existResults.add(oneRoundResult);
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
existResults.add(oneRoundResult);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
public List<T> getMatches(QueryContext queryContext, List<Term> terms) {
|
||||||
|
Set<Long> detectModelIds = mapperHelper.getModelIds(queryContext.getRequest());
|
||||||
|
terms = filterByModelIds(terms, detectModelIds);
|
||||||
|
Map<MatchText, List<T>> matchResult = match(queryContext, terms, detectModelIds);
|
||||||
|
List<T> matches = new ArrayList<>();
|
||||||
|
if (Objects.isNull(matchResult)) {
|
||||||
|
return matches;
|
||||||
|
}
|
||||||
|
Optional<List<T>> first = matchResult.entrySet().stream()
|
||||||
|
.filter(entry -> CollectionUtils.isNotEmpty(entry.getValue()))
|
||||||
|
.map(entry -> entry.getValue()).findFirst();
|
||||||
|
|
||||||
|
if (first.isPresent()) {
|
||||||
|
matches = first.get();
|
||||||
|
}
|
||||||
|
return matches;
|
||||||
|
}
|
||||||
|
|
||||||
|
public List<Term> filterByModelIds(List<Term> terms, Set<Long> detectModelIds) {
|
||||||
|
logTerms(terms);
|
||||||
|
if (CollectionUtils.isNotEmpty(detectModelIds)) {
|
||||||
|
terms = terms.stream().filter(term -> {
|
||||||
|
Long modelId = NatureHelper.getModelId(term.getNature().toString());
|
||||||
|
if (Objects.nonNull(modelId)) {
|
||||||
|
return detectModelIds.contains(modelId);
|
||||||
|
}
|
||||||
|
return false;
|
||||||
|
}).collect(Collectors.toList());
|
||||||
|
log.info("terms filter by modelIds:{}", detectModelIds);
|
||||||
|
logTerms(terms);
|
||||||
|
}
|
||||||
|
return terms;
|
||||||
|
}
|
||||||
|
|
||||||
|
public void logTerms(List<Term> terms) {
|
||||||
|
if (CollectionUtils.isEmpty(terms)) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
for (Term term : terms) {
|
||||||
|
log.debug("word:{},nature:{},frequency:{}", term.word, term.nature.toString(), term.getFrequency());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
public abstract boolean needDelete(T oneRoundResult, T existResult);
|
||||||
|
|
||||||
|
public abstract String getMapKey(T a);
|
||||||
|
|
||||||
|
public abstract void detectByStep(QueryContext queryContext, Set<T> results,
|
||||||
|
Set<Long> detectModelIds, Integer startIndex, Integer index, int offset);
|
||||||
|
|
||||||
|
}
|
||||||
@@ -0,0 +1,130 @@
|
|||||||
|
package com.tencent.supersonic.chat.mapper;
|
||||||
|
|
||||||
|
import com.hankcs.hanlp.seg.common.Term;
|
||||||
|
import com.tencent.supersonic.chat.api.pojo.QueryContext;
|
||||||
|
import com.tencent.supersonic.chat.api.pojo.SchemaElement;
|
||||||
|
import com.tencent.supersonic.chat.api.pojo.SchemaElementMatch;
|
||||||
|
import com.tencent.supersonic.chat.config.OptimizationConfig;
|
||||||
|
import com.tencent.supersonic.common.pojo.Constants;
|
||||||
|
import com.tencent.supersonic.knowledge.dictionary.DatabaseMapResult;
|
||||||
|
import com.tencent.supersonic.knowledge.service.SchemaService;
|
||||||
|
import java.util.ArrayList;
|
||||||
|
import java.util.HashSet;
|
||||||
|
import java.util.List;
|
||||||
|
import java.util.Map;
|
||||||
|
import java.util.Map.Entry;
|
||||||
|
import java.util.Set;
|
||||||
|
import java.util.stream.Collectors;
|
||||||
|
import lombok.extern.slf4j.Slf4j;
|
||||||
|
import org.apache.commons.lang3.StringUtils;
|
||||||
|
import org.springframework.beans.factory.annotation.Autowired;
|
||||||
|
import org.springframework.stereotype.Service;
|
||||||
|
import org.springframework.util.CollectionUtils;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* DatabaseMatchStrategy uses SQL LIKE operator to match schema elements.
|
||||||
|
* It currently supports fuzzy matching against names and aliases.
|
||||||
|
*/
|
||||||
|
@Service
|
||||||
|
@Slf4j
|
||||||
|
public class DatabaseMatchStrategy extends BaseMatchStrategy<DatabaseMapResult> {
|
||||||
|
|
||||||
|
@Autowired
|
||||||
|
private OptimizationConfig optimizationConfig;
|
||||||
|
@Autowired
|
||||||
|
private MapperHelper mapperHelper;
|
||||||
|
@Autowired
|
||||||
|
private SchemaService schemaService;
|
||||||
|
private List<SchemaElement> allElements;
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public Map<MatchText, List<DatabaseMapResult>> match(QueryContext queryContext, List<Term> terms,
|
||||||
|
Set<Long> detectModelIds) {
|
||||||
|
this.allElements = getSchemaElements();
|
||||||
|
return super.match(queryContext, terms, detectModelIds);
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public boolean needDelete(DatabaseMapResult oneRoundResult, DatabaseMapResult existResult) {
|
||||||
|
return getMapKey(oneRoundResult).equals(getMapKey(existResult))
|
||||||
|
&& existResult.getDetectWord().length() < oneRoundResult.getDetectWord().length();
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public String getMapKey(DatabaseMapResult a) {
|
||||||
|
return a.getName() + Constants.UNDERLINE + a.getSchemaElement().getId()
|
||||||
|
+ Constants.UNDERLINE + a.getSchemaElement().getName();
|
||||||
|
}
|
||||||
|
|
||||||
|
public void detectByStep(QueryContext queryContext, Set<DatabaseMapResult> existResults, Set<Long> detectModelIds,
|
||||||
|
Integer startIndex, Integer index, int offset) {
|
||||||
|
String detectSegment = queryContext.getRequest().getQueryText().substring(startIndex, index);
|
||||||
|
if (StringUtils.isBlank(detectSegment)) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
Set<Long> modelIds = mapperHelper.getModelIds(queryContext.getRequest());
|
||||||
|
|
||||||
|
Double metricDimensionThresholdConfig = getThreshold(queryContext);
|
||||||
|
|
||||||
|
Map<String, Set<SchemaElement>> nameToItems = getNameToItems(allElements);
|
||||||
|
|
||||||
|
for (Entry<String, Set<SchemaElement>> entry : nameToItems.entrySet()) {
|
||||||
|
String name = entry.getKey();
|
||||||
|
if (!name.contains(detectSegment)
|
||||||
|
|| mapperHelper.getSimilarity(detectSegment, name) < metricDimensionThresholdConfig) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
Set<SchemaElement> schemaElements = entry.getValue();
|
||||||
|
if (!CollectionUtils.isEmpty(modelIds)) {
|
||||||
|
schemaElements = schemaElements.stream()
|
||||||
|
.filter(schemaElement -> modelIds.contains(schemaElement.getModel()))
|
||||||
|
.collect(Collectors.toSet());
|
||||||
|
}
|
||||||
|
for (SchemaElement schemaElement : schemaElements) {
|
||||||
|
DatabaseMapResult databaseMapResult = new DatabaseMapResult();
|
||||||
|
databaseMapResult.setDetectWord(detectSegment);
|
||||||
|
databaseMapResult.setName(schemaElement.getName());
|
||||||
|
databaseMapResult.setSchemaElement(schemaElement);
|
||||||
|
existResults.add(databaseMapResult);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
private List<SchemaElement> getSchemaElements() {
|
||||||
|
List<SchemaElement> allElements = new ArrayList<>();
|
||||||
|
allElements.addAll(schemaService.getSemanticSchema().getDimensions());
|
||||||
|
allElements.addAll(schemaService.getSemanticSchema().getMetrics());
|
||||||
|
return allElements;
|
||||||
|
}
|
||||||
|
|
||||||
|
private Double getThreshold(QueryContext queryContext) {
|
||||||
|
Double metricDimensionThresholdConfig = optimizationConfig.getMetricDimensionThresholdConfig();
|
||||||
|
Double metricDimensionMinThresholdConfig = optimizationConfig.getMetricDimensionMinThresholdConfig();
|
||||||
|
|
||||||
|
Map<Long, List<SchemaElementMatch>> modelElementMatches = queryContext.getMapInfo().getModelElementMatches();
|
||||||
|
|
||||||
|
boolean existElement = modelElementMatches.entrySet().stream().anyMatch(entry -> entry.getValue().size() >= 1);
|
||||||
|
|
||||||
|
if (!existElement) {
|
||||||
|
double halfThreshold = metricDimensionThresholdConfig / 2;
|
||||||
|
|
||||||
|
metricDimensionThresholdConfig = halfThreshold >= metricDimensionMinThresholdConfig ? halfThreshold
|
||||||
|
: metricDimensionMinThresholdConfig;
|
||||||
|
log.info("ModelElementMatches:{} , not exist Element metricDimensionThresholdConfig reduce by half:{}",
|
||||||
|
modelElementMatches, metricDimensionThresholdConfig);
|
||||||
|
}
|
||||||
|
return metricDimensionThresholdConfig;
|
||||||
|
}
|
||||||
|
|
||||||
|
private Map<String, Set<SchemaElement>> getNameToItems(List<SchemaElement> models) {
|
||||||
|
return models.stream().collect(
|
||||||
|
Collectors.toMap(SchemaElement::getName, a -> {
|
||||||
|
Set<SchemaElement> result = new HashSet<>();
|
||||||
|
result.add(a);
|
||||||
|
return result;
|
||||||
|
}, (k1, k2) -> {
|
||||||
|
k1.addAll(k2);
|
||||||
|
return k1;
|
||||||
|
}));
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -0,0 +1,61 @@
|
|||||||
|
package com.tencent.supersonic.chat.mapper;
|
||||||
|
|
||||||
|
import com.alibaba.fastjson.JSONObject;
|
||||||
|
import com.hankcs.hanlp.seg.common.Term;
|
||||||
|
import com.tencent.supersonic.chat.api.pojo.QueryContext;
|
||||||
|
import com.tencent.supersonic.chat.api.pojo.SchemaElement;
|
||||||
|
import com.tencent.supersonic.chat.api.pojo.SchemaElementMatch;
|
||||||
|
import com.tencent.supersonic.common.util.ContextUtils;
|
||||||
|
import com.tencent.supersonic.common.util.embedding.Retrieval;
|
||||||
|
import com.tencent.supersonic.knowledge.dictionary.EmbeddingResult;
|
||||||
|
import com.tencent.supersonic.knowledge.dictionary.builder.BaseWordBuilder;
|
||||||
|
import com.tencent.supersonic.knowledge.utils.HanlpHelper;
|
||||||
|
import java.util.List;
|
||||||
|
import lombok.extern.slf4j.Slf4j;
|
||||||
|
import org.apache.commons.lang3.StringUtils;
|
||||||
|
|
||||||
|
/***
|
||||||
|
* A mapper that recognizes schema elements with vector embedding.
|
||||||
|
*/
|
||||||
|
@Slf4j
|
||||||
|
public class EmbeddingMapper extends BaseMapper {
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public void doMap(QueryContext queryContext) {
|
||||||
|
//1. query from embedding by queryText
|
||||||
|
String queryText = queryContext.getRequest().getQueryText();
|
||||||
|
List<Term> terms = HanlpHelper.getTerms(queryText);
|
||||||
|
|
||||||
|
EmbeddingMatchStrategy matchStrategy = ContextUtils.getBean(EmbeddingMatchStrategy.class);
|
||||||
|
List<EmbeddingResult> matchResults = matchStrategy.getMatches(queryContext, terms);
|
||||||
|
|
||||||
|
HanlpHelper.transLetterOriginal(matchResults);
|
||||||
|
|
||||||
|
//2. build SchemaElementMatch by info
|
||||||
|
for (EmbeddingResult matchResult : matchResults) {
|
||||||
|
Long elementId = Retrieval.getLongId(matchResult.getId());
|
||||||
|
|
||||||
|
SchemaElement schemaElement = JSONObject.parseObject(JSONObject.toJSONString(matchResult.getMetadata()),
|
||||||
|
SchemaElement.class);
|
||||||
|
|
||||||
|
String modelIdStr = matchResult.getMetadata().get("modelId");
|
||||||
|
if (StringUtils.isBlank(modelIdStr)) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
long modelId = Long.parseLong(modelIdStr);
|
||||||
|
schemaElement = getSchemaElement(modelId, schemaElement.getType(), elementId);
|
||||||
|
if (schemaElement == null) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
SchemaElementMatch schemaElementMatch = SchemaElementMatch.builder()
|
||||||
|
.element(schemaElement)
|
||||||
|
.frequency(BaseWordBuilder.DEFAULT_FREQUENCY)
|
||||||
|
.word(matchResult.getName())
|
||||||
|
.similarity(1 - matchResult.getDistance())
|
||||||
|
.detectWord(matchResult.getDetectWord())
|
||||||
|
.build();
|
||||||
|
//3. add to mapInfo
|
||||||
|
addToSchemaMap(queryContext.getMapInfo(), modelId, schemaElementMatch);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -0,0 +1,137 @@
|
|||||||
|
package com.tencent.supersonic.chat.mapper;
|
||||||
|
|
||||||
|
import com.google.common.collect.Lists;
|
||||||
|
import com.tencent.supersonic.chat.api.pojo.QueryContext;
|
||||||
|
import com.tencent.supersonic.chat.config.OptimizationConfig;
|
||||||
|
import com.tencent.supersonic.common.pojo.Constants;
|
||||||
|
import com.tencent.supersonic.common.util.ComponentFactory;
|
||||||
|
import com.tencent.supersonic.common.util.embedding.Retrieval;
|
||||||
|
import com.tencent.supersonic.common.util.embedding.RetrieveQuery;
|
||||||
|
import com.tencent.supersonic.common.util.embedding.RetrieveQueryResult;
|
||||||
|
import com.tencent.supersonic.common.util.embedding.S2EmbeddingStore;
|
||||||
|
import com.tencent.supersonic.knowledge.dictionary.EmbeddingResult;
|
||||||
|
import com.tencent.supersonic.semantic.model.domain.listener.MetaEmbeddingListener;
|
||||||
|
import java.util.Comparator;
|
||||||
|
import java.util.HashMap;
|
||||||
|
import java.util.List;
|
||||||
|
import java.util.Map;
|
||||||
|
import java.util.Set;
|
||||||
|
import java.util.stream.Collectors;
|
||||||
|
import lombok.extern.slf4j.Slf4j;
|
||||||
|
import org.apache.commons.collections.CollectionUtils;
|
||||||
|
import org.apache.commons.lang3.StringUtils;
|
||||||
|
import org.springframework.beans.BeanUtils;
|
||||||
|
import org.springframework.beans.factory.annotation.Autowired;
|
||||||
|
import org.springframework.stereotype.Service;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* EmbeddingMatchStrategy uses vector database to perform
|
||||||
|
* similarity search against the embeddings of schema elements.
|
||||||
|
*/
|
||||||
|
@Service
|
||||||
|
@Slf4j
|
||||||
|
public class EmbeddingMatchStrategy extends BaseMatchStrategy<EmbeddingResult> {
|
||||||
|
|
||||||
|
@Autowired
|
||||||
|
private OptimizationConfig optimizationConfig;
|
||||||
|
|
||||||
|
private S2EmbeddingStore s2EmbeddingStore = ComponentFactory.getS2EmbeddingStore();
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public boolean needDelete(EmbeddingResult oneRoundResult, EmbeddingResult existResult) {
|
||||||
|
return getMapKey(oneRoundResult).equals(getMapKey(existResult))
|
||||||
|
&& existResult.getDistance() > oneRoundResult.getDistance();
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public String getMapKey(EmbeddingResult a) {
|
||||||
|
return a.getName() + Constants.UNDERLINE + a.getId();
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
protected void detectByBatch(QueryContext queryContext, Set<EmbeddingResult> results, Set<Long> detectModelIds,
|
||||||
|
Set<String> detectSegments) {
|
||||||
|
|
||||||
|
List<String> queryTextsList = detectSegments.stream()
|
||||||
|
.map(detectSegment -> detectSegment.trim())
|
||||||
|
.filter(detectSegment -> StringUtils.isNotBlank(detectSegment)
|
||||||
|
&& detectSegment.length() >= optimizationConfig.getEmbeddingMapperWordMin()
|
||||||
|
&& detectSegment.length() <= optimizationConfig.getEmbeddingMapperWordMax())
|
||||||
|
.collect(Collectors.toList());
|
||||||
|
|
||||||
|
List<List<String>> queryTextsSubList = Lists.partition(queryTextsList,
|
||||||
|
optimizationConfig.getEmbeddingMapperBatch());
|
||||||
|
|
||||||
|
for (List<String> queryTextsSub : queryTextsSubList) {
|
||||||
|
detectByQueryTextsSub(results, detectModelIds, queryTextsSub);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
private void detectByQueryTextsSub(Set<EmbeddingResult> results, Set<Long> detectModelIds,
|
||||||
|
List<String> queryTextsSub) {
|
||||||
|
int embeddingNumber = optimizationConfig.getEmbeddingMapperNumber();
|
||||||
|
Double distance = optimizationConfig.getEmbeddingMapperDistanceThreshold();
|
||||||
|
Map<String, String> filterCondition = null;
|
||||||
|
// step1. build query params
|
||||||
|
// if only one modelId, add to filterCondition
|
||||||
|
if (CollectionUtils.isNotEmpty(detectModelIds) && detectModelIds.size() == 1) {
|
||||||
|
filterCondition = new HashMap<>();
|
||||||
|
filterCondition.put("modelId", detectModelIds.stream().findFirst().get().toString());
|
||||||
|
}
|
||||||
|
|
||||||
|
RetrieveQuery retrieveQuery = RetrieveQuery.builder()
|
||||||
|
.queryTextsList(queryTextsSub)
|
||||||
|
.filterCondition(filterCondition)
|
||||||
|
.queryEmbeddings(null)
|
||||||
|
.build();
|
||||||
|
// step2. retrieveQuery by detectSegment
|
||||||
|
List<RetrieveQueryResult> retrieveQueryResults = s2EmbeddingStore.retrieveQuery(
|
||||||
|
MetaEmbeddingListener.COLLECTION_NAME, retrieveQuery, embeddingNumber);
|
||||||
|
|
||||||
|
if (CollectionUtils.isEmpty(retrieveQueryResults)) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
// step3. build EmbeddingResults. filter by modelId
|
||||||
|
List<EmbeddingResult> collect = retrieveQueryResults.stream()
|
||||||
|
.map(retrieveQueryResult -> {
|
||||||
|
List<Retrieval> retrievals = retrieveQueryResult.getRetrieval();
|
||||||
|
if (CollectionUtils.isNotEmpty(retrievals)) {
|
||||||
|
retrievals.removeIf(retrieval -> retrieval.getDistance() > distance.doubleValue());
|
||||||
|
if (CollectionUtils.isNotEmpty(detectModelIds)) {
|
||||||
|
retrievals.removeIf(retrieval -> {
|
||||||
|
String modelIdStr = retrieval.getMetadata().get("modelId").toString();
|
||||||
|
if (StringUtils.isBlank(modelIdStr)) {
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
return detectModelIds.contains(Long.parseLong(modelIdStr));
|
||||||
|
});
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return retrieveQueryResult;
|
||||||
|
})
|
||||||
|
.filter(retrieveQueryResult -> CollectionUtils.isNotEmpty(retrieveQueryResult.getRetrieval()))
|
||||||
|
.flatMap(retrieveQueryResult -> retrieveQueryResult.getRetrieval().stream()
|
||||||
|
.map(retrieval -> {
|
||||||
|
EmbeddingResult embeddingResult = new EmbeddingResult();
|
||||||
|
BeanUtils.copyProperties(retrieval, embeddingResult);
|
||||||
|
embeddingResult.setDetectWord(retrieveQueryResult.getQuery());
|
||||||
|
embeddingResult.setName(retrieval.getQuery());
|
||||||
|
return embeddingResult;
|
||||||
|
}))
|
||||||
|
.collect(Collectors.toList());
|
||||||
|
|
||||||
|
// step4. select mapResul in one round
|
||||||
|
int roundNumber = optimizationConfig.getEmbeddingMapperRoundNumber() * queryTextsSub.size();
|
||||||
|
List<EmbeddingResult> oneRoundResults = collect.stream()
|
||||||
|
.sorted(Comparator.comparingDouble(EmbeddingResult::getDistance))
|
||||||
|
.limit(roundNumber)
|
||||||
|
.collect(Collectors.toList());
|
||||||
|
selectResultInOneRound(results, oneRoundResults);
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public void detectByStep(QueryContext queryContext, Set<EmbeddingResult> existResults, Set<Long> detectModelIds,
|
||||||
|
Integer startIndex, Integer index, int offset) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
}
|
||||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user