mirror of
https://github.com/tencentmusic/supersonic.git
synced 2025-12-11 12:07:42 +00:00
Compare commits
78 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
532a00518c | ||
|
|
895f38b6f7 | ||
|
|
eba3a8ad34 | ||
|
|
6813582ea0 | ||
|
|
6b6e54e95f | ||
|
|
26260c79f1 | ||
|
|
0a42932db2 | ||
|
|
3b4678d682 | ||
|
|
c0458ccf0e | ||
|
|
c7c70208ff | ||
|
|
3a38200448 | ||
|
|
b72e280990 | ||
|
|
0541614dad | ||
|
|
0beb3cefd3 | ||
|
|
afdf18398c | ||
|
|
bdf7df933b | ||
|
|
a909493414 | ||
|
|
aa86fc9275 | ||
|
|
e36060eae4 | ||
|
|
3893e897cb | ||
|
|
617cd87a48 | ||
|
|
042a610231 | ||
|
|
b555beae21 | ||
|
|
ba01cdb9bc | ||
|
|
e610dd8246 | ||
|
|
01bc4dcacf | ||
|
|
c224b81160 | ||
|
|
bfd0e040da | ||
|
|
f50a3157d5 | ||
|
|
61316e939c | ||
|
|
fab1bac50c | ||
|
|
d8043c356f | ||
|
|
e95a528219 | ||
|
|
16643e8d75 | ||
|
|
417a43dee8 | ||
|
|
4a22fdf452 | ||
|
|
16afbc6945 | ||
|
|
fc5ff01eca | ||
|
|
d10801ef38 | ||
|
|
33240cc382 | ||
|
|
3317f1b7ec | ||
|
|
b85778babd | ||
|
|
699a33b1c1 | ||
|
|
fdb69547e6 | ||
|
|
39158d6877 | ||
|
|
329ad327b0 | ||
|
|
9600456bae | ||
|
|
74d0ec2b23 | ||
|
|
8a342eb32a | ||
|
|
e801c448be | ||
|
|
da5e7b9b75 | ||
|
|
75853a8e9e | ||
|
|
2546d1c0e1 | ||
|
|
0c4c6d83ef | ||
|
|
4d4922d269 | ||
|
|
1004f71ba4 | ||
|
|
c13a0e672c | ||
|
|
491c76368c | ||
|
|
2c1c443b3e | ||
|
|
f29b1854ba | ||
|
|
7f15bacca4 | ||
|
|
df975b231d | ||
|
|
24b442baef | ||
|
|
31f8c1df35 | ||
|
|
26aefceb04 | ||
|
|
954c67c947 | ||
|
|
fdfad515dd | ||
|
|
c398ac1a84 | ||
|
|
aae3d6b297 | ||
|
|
923c65b2f9 | ||
|
|
22775343f4 | ||
|
|
d9533c53ea | ||
|
|
841db25198 | ||
|
|
922201c181 | ||
|
|
48fb01f6bc | ||
|
|
9d6f96e6d4 | ||
|
|
42a6f61456 | ||
|
|
163e782f51 |
35
.github/workflows/mac-ci.yml
vendored
Normal file
35
.github/workflows/mac-ci.yml
vendored
Normal file
@@ -0,0 +1,35 @@
|
|||||||
|
name: supersonic mac CI
|
||||||
|
|
||||||
|
on:
|
||||||
|
push:
|
||||||
|
branches:
|
||||||
|
- master
|
||||||
|
pull_request:
|
||||||
|
branches:
|
||||||
|
- master
|
||||||
|
|
||||||
|
jobs:
|
||||||
|
build:
|
||||||
|
runs-on: macos-latest # Specify a macOS runner
|
||||||
|
|
||||||
|
steps:
|
||||||
|
- uses: actions/checkout@v2
|
||||||
|
|
||||||
|
- name: Set up JDK 8
|
||||||
|
uses: actions/setup-java@v2
|
||||||
|
with:
|
||||||
|
java-version: '8'
|
||||||
|
distribution: 'adopt'
|
||||||
|
|
||||||
|
- name: Cache Maven packages
|
||||||
|
uses: actions/cache@v2
|
||||||
|
with:
|
||||||
|
path: ~/Library/Caches/Maven # macOS Maven cache path
|
||||||
|
key: ${{ runner.os }}-m2-${{ hashFiles('**/pom.xml') }}
|
||||||
|
restore-keys: ${{ runner.os }}-m2
|
||||||
|
|
||||||
|
- name: Build with Maven
|
||||||
|
run: mvn -B package --file pom.xml
|
||||||
|
|
||||||
|
- name: Test with Maven
|
||||||
|
run: mvn test
|
||||||
@@ -1,4 +1,4 @@
|
|||||||
name: supersonic CI
|
name: supersonic ubuntu CI
|
||||||
|
|
||||||
on:
|
on:
|
||||||
push:
|
push:
|
||||||
35
.github/workflows/windows-ci.yml
vendored
Normal file
35
.github/workflows/windows-ci.yml
vendored
Normal file
@@ -0,0 +1,35 @@
|
|||||||
|
name: supersonic windows CI
|
||||||
|
|
||||||
|
on:
|
||||||
|
push:
|
||||||
|
branches:
|
||||||
|
- master
|
||||||
|
pull_request:
|
||||||
|
branches:
|
||||||
|
- master
|
||||||
|
|
||||||
|
jobs:
|
||||||
|
build:
|
||||||
|
runs-on: windows-latest # Specify a Windows runner
|
||||||
|
|
||||||
|
steps:
|
||||||
|
- uses: actions/checkout@v2
|
||||||
|
|
||||||
|
- name: Set up JDK 8
|
||||||
|
uses: actions/setup-java@v2
|
||||||
|
with:
|
||||||
|
java-version: '8'
|
||||||
|
distribution: 'adopt'
|
||||||
|
|
||||||
|
- name: Cache Maven packages
|
||||||
|
uses: actions/cache@v2
|
||||||
|
with:
|
||||||
|
path: ~\.m2 # Windows uses a backslash for paths
|
||||||
|
key: ${{ runner.os }}-m2-${{ hashFiles('**/pom.xml') }}
|
||||||
|
restore-keys: ${{ runner.os }}-m2
|
||||||
|
|
||||||
|
- name: Build with Maven
|
||||||
|
run: mvn -B package --file pom.xml
|
||||||
|
|
||||||
|
- name: Test with Maven
|
||||||
|
run: mvn test
|
||||||
@@ -4,6 +4,14 @@
|
|||||||
- "Breaking Changes" describes any changes that may break existing functionality or cause
|
- "Breaking Changes" describes any changes that may break existing functionality or cause
|
||||||
compatibility issues with previous versions.
|
compatibility issues with previous versions.
|
||||||
|
|
||||||
|
## SuperSonic [0.8.6] - 2024-02-23
|
||||||
|
|
||||||
|
### Added
|
||||||
|
- support view abstraction to Headless.
|
||||||
|
- add the Metric API to Headless and optimizing the Headless API.
|
||||||
|
- add integration tests to Headless.
|
||||||
|
- add TimeCorrector to Chat.
|
||||||
|
|
||||||
## SuperSonic [0.8.4] - 2024-01-19
|
## SuperSonic [0.8.4] - 2024-01-19
|
||||||
|
|
||||||
### Added
|
### Added
|
||||||
|
|||||||
@@ -4,7 +4,7 @@
|
|||||||
|
|
||||||
# SuperSonic (超音数)
|
# SuperSonic (超音数)
|
||||||
|
|
||||||
**SuperSonic is the next-generation LLM-powered data analytics platform that integrates ChatBI and HeadlessBI**. SuperSonic provides a chat interface that empowers users to query data using natural language and visualize the results with suitable charts. To enable such experience, the only thing necessary is to build logical semantic models (definition of entities/metrics/dimensions/tags, along with their meaning, context and relationships) on top of physical data models, and **no data modification or copying** is required. Meanwhile, SuperSonic is designed to be **highly extensible**, allowing custom functionalities to be added and configured with Java SPI.
|
**SuperSonic is the next-generation LLM-powered data analytics platform that integrates ChatBI and HeadlessBI**. SuperSonic provides a chat interface that empowers users to query data using natural language and visualize the results with suitable charts. To enable such experience, the only thing necessary is to build logical semantic models (definition of entities/metrics/dimensions/tags, along with their meaning, context and relationships) with semantic layer, and **no data modification or copying** is required. Meanwhile, SuperSonic is designed to be **highly extensible**, allowing custom functionalities to be added and configured with Java SPI.
|
||||||
|
|
||||||
<img src="./docs/images/supersonic_demo.gif" height="100%" width="100%" align="center"/>
|
<img src="./docs/images/supersonic_demo.gif" height="100%" width="100%" align="center"/>
|
||||||
|
|
||||||
@@ -14,6 +14,7 @@ The emergence of Large Language Model (LLM) like ChatGPT is reshaping the way in
|
|||||||
|
|
||||||
From our perspective, the key to filling the real-world gap lies in three aspects:
|
From our perspective, the key to filling the real-world gap lies in three aspects:
|
||||||
1. Integrate ChatBI with HeadlessBI encapsulating underlying data context (joins, keys, formulas, etc) to **reduce complexity**.
|
1. Integrate ChatBI with HeadlessBI encapsulating underlying data context (joins, keys, formulas, etc) to **reduce complexity**.
|
||||||
|
<img src="./docs/images/supersonic_ideas.png" height="65%" width="65%" align="center"/>
|
||||||
2. Augment the LLM with schema mappers(as a kind of preprocessor) and semantic correctors(as a kind of postprocessor) to **mitigate hallucination**.
|
2. Augment the LLM with schema mappers(as a kind of preprocessor) and semantic correctors(as a kind of postprocessor) to **mitigate hallucination**.
|
||||||
3. Utilize rule-based schema parsers when necessary to **improve efficiency**(in terms of latency and cost).
|
3. Utilize rule-based schema parsers when necessary to **improve efficiency**(in terms of latency and cost).
|
||||||
|
|
||||||
|
|||||||
@@ -10,6 +10,8 @@
|
|||||||
|
|
||||||
在我们看来,为了在实际场景发挥价值,有三个关键点:
|
在我们看来,为了在实际场景发挥价值,有三个关键点:
|
||||||
1. 融合HeadlessBI,通过统一语义层封装底层数据细节(关联、键值、公式等),降低SQL生成的**复杂度**。
|
1. 融合HeadlessBI,通过统一语义层封装底层数据细节(关联、键值、公式等),降低SQL生成的**复杂度**。
|
||||||
|
|
||||||
|
<img src="./docs/images/supersonic_ideas.png" height="65%" width="65%" align="center"/>
|
||||||
2. 通过一前一后的模式映射器和语义修正器,来缓解LLM常见的**幻觉**现象。
|
2. 通过一前一后的模式映射器和语义修正器,来缓解LLM常见的**幻觉**现象。
|
||||||
3. 设计启发式的规则,在一些特定场景提升语义解析的**效率**。
|
3. 设计启发式的规则,在一些特定场景提升语义解析的**效率**。
|
||||||
|
|
||||||
|
|||||||
@@ -24,6 +24,10 @@ public class User {
|
|||||||
return new User(id, name, displayName, email, isAdmin);
|
return new User(id, name, displayName, email, isAdmin);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
public static User get(Long id, String name) {
|
||||||
|
return new User(id, name, name, name, 0);
|
||||||
|
}
|
||||||
|
|
||||||
public static User getFakeUser() {
|
public static User getFakeUser() {
|
||||||
return new User(1L, "admin", "admin", "admin@email", 1);
|
return new User(1L, "admin", "admin", "admin@email", 1);
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,5 +1,6 @@
|
|||||||
package com.tencent.supersonic.chat.api.pojo;
|
package com.tencent.supersonic.chat.api.pojo;
|
||||||
|
|
||||||
|
import com.tencent.supersonic.headless.api.pojo.SchemaElement;
|
||||||
import lombok.AllArgsConstructor;
|
import lombok.AllArgsConstructor;
|
||||||
import lombok.Builder;
|
import lombok.Builder;
|
||||||
import lombok.Data;
|
import lombok.Data;
|
||||||
|
|||||||
@@ -1,5 +1,7 @@
|
|||||||
package com.tencent.supersonic.chat.api.pojo;
|
package com.tencent.supersonic.chat.api.pojo;
|
||||||
|
|
||||||
|
import com.google.common.collect.Lists;
|
||||||
|
|
||||||
import java.util.HashMap;
|
import java.util.HashMap;
|
||||||
import java.util.List;
|
import java.util.List;
|
||||||
import java.util.Map;
|
import java.util.Map;
|
||||||
@@ -7,25 +9,25 @@ import java.util.Set;
|
|||||||
|
|
||||||
public class SchemaMapInfo {
|
public class SchemaMapInfo {
|
||||||
|
|
||||||
private Map<Long, List<SchemaElementMatch>> modelElementMatches = new HashMap<>();
|
private Map<Long, List<SchemaElementMatch>> viewElementMatches = new HashMap<>();
|
||||||
|
|
||||||
public Set<Long> getMatchedModels() {
|
public Set<Long> getMatchedViewInfos() {
|
||||||
return modelElementMatches.keySet();
|
return viewElementMatches.keySet();
|
||||||
}
|
}
|
||||||
|
|
||||||
public List<SchemaElementMatch> getMatchedElements(Long model) {
|
public List<SchemaElementMatch> getMatchedElements(Long view) {
|
||||||
return modelElementMatches.get(model);
|
return viewElementMatches.getOrDefault(view, Lists.newArrayList());
|
||||||
}
|
}
|
||||||
|
|
||||||
public Map<Long, List<SchemaElementMatch>> getModelElementMatches() {
|
public Map<Long, List<SchemaElementMatch>> getViewElementMatches() {
|
||||||
return modelElementMatches;
|
return viewElementMatches;
|
||||||
}
|
}
|
||||||
|
|
||||||
public void setModelElementMatches(Map<Long, List<SchemaElementMatch>> modelElementMatches) {
|
public void setViewElementMatches(Map<Long, List<SchemaElementMatch>> viewElementMatches) {
|
||||||
this.modelElementMatches = modelElementMatches;
|
this.viewElementMatches = viewElementMatches;
|
||||||
}
|
}
|
||||||
|
|
||||||
public void setMatchedElements(Long model, List<SchemaElementMatch> elementMatches) {
|
public void setMatchedElements(Long view, List<SchemaElementMatch> elementMatches) {
|
||||||
modelElementMatches.put(model, elementMatches);
|
viewElementMatches.put(view, elementMatches);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,61 +0,0 @@
|
|||||||
package com.tencent.supersonic.chat.api.pojo;
|
|
||||||
|
|
||||||
import com.clickhouse.client.internal.apache.commons.compress.utils.Lists;
|
|
||||||
import com.tencent.supersonic.common.pojo.ModelCluster;
|
|
||||||
import lombok.Data;
|
|
||||||
import org.springframework.util.CollectionUtils;
|
|
||||||
|
|
||||||
import java.util.HashMap;
|
|
||||||
import java.util.List;
|
|
||||||
import java.util.Map;
|
|
||||||
import java.util.Set;
|
|
||||||
|
|
||||||
@Data
|
|
||||||
public class SchemaModelClusterMapInfo {
|
|
||||||
|
|
||||||
private Map<String, List<SchemaElementMatch>> modelElementMatches = new HashMap<>();
|
|
||||||
|
|
||||||
public Set<String> getMatchedModelClusters() {
|
|
||||||
return modelElementMatches.keySet();
|
|
||||||
}
|
|
||||||
|
|
||||||
public List<SchemaElementMatch> getMatchedElements(Long modelId) {
|
|
||||||
for (String key : modelElementMatches.keySet()) {
|
|
||||||
if (ModelCluster.getModelIdFromKey(key).contains(modelId)) {
|
|
||||||
return modelElementMatches.get(key);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return Lists.newArrayList();
|
|
||||||
}
|
|
||||||
|
|
||||||
public List<SchemaElementMatch> getMatchedElements(String modelCluster) {
|
|
||||||
return modelElementMatches.get(modelCluster);
|
|
||||||
}
|
|
||||||
|
|
||||||
public Map<String, List<SchemaElementMatch>> getModelElementMatches() {
|
|
||||||
return modelElementMatches;
|
|
||||||
}
|
|
||||||
|
|
||||||
public Map<String, List<SchemaElementMatch>> getElementMatchesByModelIds(Set<Long> modelIds) {
|
|
||||||
if (CollectionUtils.isEmpty(modelIds)) {
|
|
||||||
return modelElementMatches;
|
|
||||||
}
|
|
||||||
Map<String, List<SchemaElementMatch>> modelElementMatchesFiltered = new HashMap<>();
|
|
||||||
for (String key : modelElementMatches.keySet()) {
|
|
||||||
for (Long modelId : modelIds) {
|
|
||||||
if (ModelCluster.getModelIdFromKey(key).contains(modelId)) {
|
|
||||||
modelElementMatchesFiltered.put(key, modelElementMatches.get(key));
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return modelElementMatchesFiltered;
|
|
||||||
}
|
|
||||||
|
|
||||||
public void setModelElementMatches(Map<String, List<SchemaElementMatch>> modelElementMatches) {
|
|
||||||
this.modelElementMatches = modelElementMatches;
|
|
||||||
}
|
|
||||||
|
|
||||||
public void setMatchedElements(String modelCluster, List<SchemaElementMatch> elementMatches) {
|
|
||||||
modelElementMatches.put(modelCluster, elementMatches);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
@@ -5,11 +5,11 @@ import com.tencent.supersonic.chat.api.pojo.response.SqlInfo;
|
|||||||
import com.tencent.supersonic.chat.api.pojo.request.QueryFilter;
|
import com.tencent.supersonic.chat.api.pojo.request.QueryFilter;
|
||||||
import com.tencent.supersonic.chat.api.pojo.response.EntityInfo;
|
import com.tencent.supersonic.chat.api.pojo.response.EntityInfo;
|
||||||
import com.tencent.supersonic.common.pojo.DateConf;
|
import com.tencent.supersonic.common.pojo.DateConf;
|
||||||
import com.tencent.supersonic.common.pojo.ModelCluster;
|
|
||||||
import com.tencent.supersonic.common.pojo.Order;
|
import com.tencent.supersonic.common.pojo.Order;
|
||||||
import com.tencent.supersonic.common.pojo.enums.QueryType;
|
import com.tencent.supersonic.common.pojo.enums.QueryType;
|
||||||
import com.tencent.supersonic.common.pojo.enums.AggregateTypeEnum;
|
import com.tencent.supersonic.common.pojo.enums.AggregateTypeEnum;
|
||||||
import com.tencent.supersonic.common.pojo.enums.FilterType;
|
import com.tencent.supersonic.common.pojo.enums.FilterType;
|
||||||
|
import com.tencent.supersonic.headless.api.pojo.SchemaElement;
|
||||||
import lombok.Data;
|
import lombok.Data;
|
||||||
|
|
||||||
import java.util.ArrayList;
|
import java.util.ArrayList;
|
||||||
@@ -26,7 +26,7 @@ public class SemanticParseInfo {
|
|||||||
|
|
||||||
private Integer id;
|
private Integer id;
|
||||||
private String queryMode;
|
private String queryMode;
|
||||||
private ModelCluster model = new ModelCluster();
|
private SchemaElement view;
|
||||||
private Set<SchemaElement> metrics = new TreeSet<>(new SchemaNameLengthComparator());
|
private Set<SchemaElement> metrics = new TreeSet<>(new SchemaNameLengthComparator());
|
||||||
private Set<SchemaElement> dimensions = new LinkedHashSet();
|
private Set<SchemaElement> dimensions = new LinkedHashSet();
|
||||||
private SchemaElement entity;
|
private SchemaElement entity;
|
||||||
@@ -44,20 +44,6 @@ public class SemanticParseInfo {
|
|||||||
private SqlInfo sqlInfo = new SqlInfo();
|
private SqlInfo sqlInfo = new SqlInfo();
|
||||||
private QueryType queryType = QueryType.ID;
|
private QueryType queryType = QueryType.ID;
|
||||||
|
|
||||||
public String getModelClusterKey() {
|
|
||||||
if (model == null) {
|
|
||||||
return "";
|
|
||||||
}
|
|
||||||
return model.getKey();
|
|
||||||
}
|
|
||||||
|
|
||||||
public String getModelName() {
|
|
||||||
if (model == null) {
|
|
||||||
return "";
|
|
||||||
}
|
|
||||||
return model.getName();
|
|
||||||
}
|
|
||||||
|
|
||||||
private static class SchemaNameLengthComparator implements Comparator<SchemaElement> {
|
private static class SchemaNameLengthComparator implements Comparator<SchemaElement> {
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
@@ -86,27 +72,15 @@ public class SemanticParseInfo {
|
|||||||
return metrics;
|
return metrics;
|
||||||
}
|
}
|
||||||
|
|
||||||
private Map<Long, Integer> getModelElementCountMap() {
|
public Long getViewId() {
|
||||||
Map<Long, Integer> elementCountMap = new HashMap<>();
|
if (view == null) {
|
||||||
elementMatches.stream().filter(element -> element.getElement().getModel() != null)
|
return null;
|
||||||
.forEach(element -> {
|
}
|
||||||
int count = elementCountMap.getOrDefault(element.getElement().getModel(), 0);
|
return view.getView();
|
||||||
elementCountMap.put(element.getElement().getModel(), count + 1);
|
|
||||||
});
|
|
||||||
return elementCountMap;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
public Long getModelId() {
|
public SchemaElement getModel() {
|
||||||
Map<Long, Integer> elementCountMap = getModelElementCountMap();
|
return view;
|
||||||
Long modelId = -1L;
|
|
||||||
int maxCnt = 0;
|
|
||||||
for (Long model : elementCountMap.keySet()) {
|
|
||||||
if (elementCountMap.get(model) > maxCnt) {
|
|
||||||
maxCnt = elementCountMap.get(model);
|
|
||||||
modelId = model;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return modelId;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,26 +1,26 @@
|
|||||||
package com.tencent.supersonic.chat.api.pojo;
|
package com.tencent.supersonic.chat.api.pojo;
|
||||||
|
|
||||||
|
import com.tencent.supersonic.headless.api.pojo.SchemaElement;
|
||||||
|
import com.tencent.supersonic.headless.api.pojo.SchemaElementType;
|
||||||
import java.io.Serializable;
|
import java.io.Serializable;
|
||||||
import java.util.ArrayList;
|
import java.util.ArrayList;
|
||||||
import java.util.HashMap;
|
import java.util.HashMap;
|
||||||
import java.util.List;
|
import java.util.List;
|
||||||
import java.util.Map;
|
import java.util.Map;
|
||||||
import java.util.Objects;
|
|
||||||
import java.util.Optional;
|
import java.util.Optional;
|
||||||
import java.util.Set;
|
|
||||||
import java.util.stream.Collectors;
|
import java.util.stream.Collectors;
|
||||||
import org.springframework.util.CollectionUtils;
|
import org.springframework.util.CollectionUtils;
|
||||||
|
|
||||||
public class SemanticSchema implements Serializable {
|
public class SemanticSchema implements Serializable {
|
||||||
|
|
||||||
private List<ModelSchema> modelSchemaList;
|
private List<ViewSchema> viewSchemaList;
|
||||||
|
|
||||||
public SemanticSchema(List<ModelSchema> modelSchemaList) {
|
public SemanticSchema(List<ViewSchema> viewSchemaList) {
|
||||||
this.modelSchemaList = modelSchemaList;
|
this.viewSchemaList = viewSchemaList;
|
||||||
}
|
}
|
||||||
|
|
||||||
public void add(ModelSchema schema) {
|
public void add(ViewSchema schema) {
|
||||||
modelSchemaList.add(schema);
|
viewSchemaList.add(schema);
|
||||||
}
|
}
|
||||||
|
|
||||||
public SchemaElement getElement(SchemaElementType elementType, long elementID) {
|
public SchemaElement getElement(SchemaElementType elementType, long elementID) {
|
||||||
@@ -30,8 +30,8 @@ public class SemanticSchema implements Serializable {
|
|||||||
case ENTITY:
|
case ENTITY:
|
||||||
element = getElementsById(elementID, getEntities());
|
element = getElementsById(elementID, getEntities());
|
||||||
break;
|
break;
|
||||||
case MODEL:
|
case VIEW:
|
||||||
element = getElementsById(elementID, getModels());
|
element = getElementsById(elementID, getViews());
|
||||||
break;
|
break;
|
||||||
case METRIC:
|
case METRIC:
|
||||||
element = getElementsById(elementID, getMetrics());
|
element = getElementsById(elementID, getMetrics());
|
||||||
@@ -52,58 +52,29 @@ public class SemanticSchema implements Serializable {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
public SchemaElement getElementByName(SchemaElementType elementType, String name) {
|
public Map<Long, String> getViewIdToName() {
|
||||||
Optional<SchemaElement> element = Optional.empty();
|
return viewSchemaList.stream()
|
||||||
|
.collect(Collectors.toMap(a -> a.getView().getId(), a -> a.getView().getName(), (k1, k2) -> k1));
|
||||||
switch (elementType) {
|
|
||||||
case ENTITY:
|
|
||||||
element = getElementsByNameOrAlias(name, getEntities());
|
|
||||||
break;
|
|
||||||
case MODEL:
|
|
||||||
element = getElementsByNameOrAlias(name, getModels());
|
|
||||||
break;
|
|
||||||
case METRIC:
|
|
||||||
element = getElementsByNameOrAlias(name, getMetrics());
|
|
||||||
break;
|
|
||||||
case DIMENSION:
|
|
||||||
element = getElementsByNameOrAlias(name, getDimensions());
|
|
||||||
break;
|
|
||||||
case VALUE:
|
|
||||||
element = getElementsByNameOrAlias(name, getDimensionValues());
|
|
||||||
break;
|
|
||||||
default:
|
|
||||||
}
|
|
||||||
|
|
||||||
if (element.isPresent()) {
|
|
||||||
return element.get();
|
|
||||||
} else {
|
|
||||||
return null;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
public Map<Long, String> getModelIdToName() {
|
|
||||||
return modelSchemaList.stream()
|
|
||||||
.collect(Collectors.toMap(a -> a.getModel().getId(), a -> a.getModel().getName(), (k1, k2) -> k1));
|
|
||||||
}
|
}
|
||||||
|
|
||||||
public List<SchemaElement> getDimensionValues() {
|
public List<SchemaElement> getDimensionValues() {
|
||||||
List<SchemaElement> dimensionValues = new ArrayList<>();
|
List<SchemaElement> dimensionValues = new ArrayList<>();
|
||||||
modelSchemaList.stream().forEach(d -> dimensionValues.addAll(d.getDimensionValues()));
|
viewSchemaList.stream().forEach(d -> dimensionValues.addAll(d.getDimensionValues()));
|
||||||
return dimensionValues;
|
return dimensionValues;
|
||||||
}
|
}
|
||||||
|
|
||||||
public List<SchemaElement> getDimensions() {
|
public List<SchemaElement> getDimensions() {
|
||||||
List<SchemaElement> dimensions = new ArrayList<>();
|
List<SchemaElement> dimensions = new ArrayList<>();
|
||||||
modelSchemaList.stream().forEach(d -> dimensions.addAll(d.getDimensions()));
|
viewSchemaList.stream().forEach(d -> dimensions.addAll(d.getDimensions()));
|
||||||
return dimensions;
|
return dimensions;
|
||||||
}
|
}
|
||||||
|
|
||||||
public List<SchemaElement> getDimensions(Set<Long> modelIds) {
|
public List<SchemaElement> getDimensions(Long viewId) {
|
||||||
List<SchemaElement> dimensions = getDimensions();
|
List<SchemaElement> dimensions = getDimensions();
|
||||||
return getElementsByModelId(modelIds, dimensions);
|
return getElementsByViewId(viewId, dimensions);
|
||||||
}
|
}
|
||||||
|
|
||||||
public SchemaElement getDimensions(Long id) {
|
public SchemaElement getDimension(Long id) {
|
||||||
List<SchemaElement> dimensions = getDimensions();
|
List<SchemaElement> dimensions = getDimensions();
|
||||||
Optional<SchemaElement> dimension = getElementsById(id, dimensions);
|
Optional<SchemaElement> dimension = getElementsById(id, dimensions);
|
||||||
return dimension.orElse(null);
|
return dimension.orElse(null);
|
||||||
@@ -111,43 +82,43 @@ public class SemanticSchema implements Serializable {
|
|||||||
|
|
||||||
public List<SchemaElement> getTags() {
|
public List<SchemaElement> getTags() {
|
||||||
List<SchemaElement> tags = new ArrayList<>();
|
List<SchemaElement> tags = new ArrayList<>();
|
||||||
modelSchemaList.stream().forEach(d -> tags.addAll(d.getTags()));
|
viewSchemaList.stream().forEach(d -> tags.addAll(d.getTags()));
|
||||||
return tags;
|
return tags;
|
||||||
}
|
}
|
||||||
|
|
||||||
public List<SchemaElement> getTags(Set<Long> modelIds) {
|
public List<SchemaElement> getTags(Long viewId) {
|
||||||
List<SchemaElement> tags = new ArrayList<>();
|
List<SchemaElement> tags = new ArrayList<>();
|
||||||
modelSchemaList.stream().filter(schemaElement ->
|
viewSchemaList.stream().filter(schemaElement ->
|
||||||
modelIds.contains(schemaElement.getModel().getModel()))
|
viewId.equals(schemaElement.getView().getView()))
|
||||||
.forEach(d -> tags.addAll(d.getTags()));
|
.forEach(d -> tags.addAll(d.getTags()));
|
||||||
return tags;
|
return tags;
|
||||||
}
|
}
|
||||||
|
|
||||||
public List<SchemaElement> getMetrics() {
|
public List<SchemaElement> getMetrics() {
|
||||||
List<SchemaElement> metrics = new ArrayList<>();
|
List<SchemaElement> metrics = new ArrayList<>();
|
||||||
modelSchemaList.stream().forEach(d -> metrics.addAll(d.getMetrics()));
|
viewSchemaList.stream().forEach(d -> metrics.addAll(d.getMetrics()));
|
||||||
return metrics;
|
return metrics;
|
||||||
}
|
}
|
||||||
|
|
||||||
public List<SchemaElement> getMetrics(Set<Long> modelIds) {
|
public List<SchemaElement> getMetrics(Long viewId) {
|
||||||
List<SchemaElement> metrics = getMetrics();
|
List<SchemaElement> metrics = getMetrics();
|
||||||
return getElementsByModelId(modelIds, metrics);
|
return getElementsByViewId(viewId, metrics);
|
||||||
}
|
}
|
||||||
|
|
||||||
public List<SchemaElement> getEntities() {
|
public List<SchemaElement> getEntities() {
|
||||||
List<SchemaElement> entities = new ArrayList<>();
|
List<SchemaElement> entities = new ArrayList<>();
|
||||||
modelSchemaList.stream().forEach(d -> entities.add(d.getEntity()));
|
viewSchemaList.stream().forEach(d -> entities.add(d.getEntity()));
|
||||||
return entities;
|
return entities;
|
||||||
}
|
}
|
||||||
|
|
||||||
public List<SchemaElement> getEntities(Set<Long> modelIds) {
|
public List<SchemaElement> getEntities(Long viewId) {
|
||||||
List<SchemaElement> entities = getEntities();
|
List<SchemaElement> entities = getEntities();
|
||||||
return getElementsByModelId(modelIds, entities);
|
return getElementsByViewId(viewId, entities);
|
||||||
}
|
}
|
||||||
|
|
||||||
private List<SchemaElement> getElementsByModelId(Set<Long> modelIds, List<SchemaElement> elements) {
|
private List<SchemaElement> getElementsByViewId(Long viewId, List<SchemaElement> elements) {
|
||||||
return elements.stream()
|
return elements.stream()
|
||||||
.filter(schemaElement -> modelIds.contains(schemaElement.getModel()))
|
.filter(schemaElement -> viewId.equals(schemaElement.getView()))
|
||||||
.collect(Collectors.toList());
|
.collect(Collectors.toList());
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -157,33 +128,30 @@ public class SemanticSchema implements Serializable {
|
|||||||
.findFirst();
|
.findFirst();
|
||||||
}
|
}
|
||||||
|
|
||||||
private Optional<SchemaElement> getElementsByNameOrAlias(String name, List<SchemaElement> elements) {
|
public SchemaElement getView(Long viewId) {
|
||||||
return elements.stream()
|
List<SchemaElement> views = getViews();
|
||||||
.filter(schemaElement ->
|
return getElementsById(viewId, views).orElse(null);
|
||||||
name.equals(schemaElement.getName()) || (Objects.nonNull(schemaElement.getAlias())
|
|
||||||
&& schemaElement.getAlias().contains(name))
|
|
||||||
).findFirst();
|
|
||||||
}
|
}
|
||||||
|
|
||||||
public List<SchemaElement> getModels() {
|
public List<SchemaElement> getViews() {
|
||||||
List<SchemaElement> models = new ArrayList<>();
|
List<SchemaElement> views = new ArrayList<>();
|
||||||
modelSchemaList.stream().forEach(d -> models.add(d.getModel()));
|
viewSchemaList.stream().forEach(d -> views.add(d.getView()));
|
||||||
return models;
|
return views;
|
||||||
}
|
}
|
||||||
|
|
||||||
public Map<String, String> getBizNameToName(Set<Long> modelIds) {
|
public Map<String, String> getBizNameToName(Long viewId) {
|
||||||
List<SchemaElement> allElements = new ArrayList<>();
|
List<SchemaElement> allElements = new ArrayList<>();
|
||||||
allElements.addAll(getDimensions(modelIds));
|
allElements.addAll(getDimensions(viewId));
|
||||||
allElements.addAll(getMetrics(modelIds));
|
allElements.addAll(getMetrics(viewId));
|
||||||
return allElements.stream()
|
return allElements.stream()
|
||||||
.collect(Collectors.toMap(SchemaElement::getBizName, SchemaElement::getName, (k1, k2) -> k1));
|
.collect(Collectors.toMap(SchemaElement::getBizName, SchemaElement::getName, (k1, k2) -> k1));
|
||||||
}
|
}
|
||||||
|
|
||||||
public Map<Long, ModelSchema> getModelSchemaMap() {
|
public Map<Long, ViewSchema> getViewSchemaMap() {
|
||||||
if (CollectionUtils.isEmpty(modelSchemaList)) {
|
if (CollectionUtils.isEmpty(viewSchemaList)) {
|
||||||
return new HashMap<>();
|
return new HashMap<>();
|
||||||
}
|
}
|
||||||
return modelSchemaList.stream().collect(Collectors.toMap(modelSchema
|
return viewSchemaList.stream().collect(Collectors.toMap(viewSchema
|
||||||
-> modelSchema.getModel().getModel(), modelSchema -> modelSchema));
|
-> viewSchema.getView().getView(), viewSchema -> viewSchema));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,26 +1,26 @@
|
|||||||
package com.tencent.supersonic.chat.api.pojo;
|
package com.tencent.supersonic.chat.api.pojo;
|
||||||
|
|
||||||
import com.google.common.collect.Sets;
|
import com.tencent.supersonic.headless.api.pojo.QueryConfig;
|
||||||
import com.tencent.supersonic.common.pojo.ModelRela;
|
import com.tencent.supersonic.headless.api.pojo.SchemaElement;
|
||||||
|
import com.tencent.supersonic.headless.api.pojo.SchemaElementType;
|
||||||
|
import com.tencent.supersonic.headless.api.pojo.TagTypeDefaultConfig;
|
||||||
|
import com.tencent.supersonic.headless.api.pojo.TimeDefaultConfig;
|
||||||
import lombok.Data;
|
import lombok.Data;
|
||||||
import org.springframework.util.CollectionUtils;
|
|
||||||
|
|
||||||
import java.util.ArrayList;
|
|
||||||
import java.util.HashSet;
|
import java.util.HashSet;
|
||||||
import java.util.List;
|
|
||||||
import java.util.Optional;
|
import java.util.Optional;
|
||||||
import java.util.Set;
|
import java.util.Set;
|
||||||
|
|
||||||
@Data
|
@Data
|
||||||
public class ModelSchema {
|
public class ViewSchema {
|
||||||
|
|
||||||
private SchemaElement model;
|
private SchemaElement view;
|
||||||
private Set<SchemaElement> metrics = new HashSet<>();
|
private Set<SchemaElement> metrics = new HashSet<>();
|
||||||
private Set<SchemaElement> dimensions = new HashSet<>();
|
private Set<SchemaElement> dimensions = new HashSet<>();
|
||||||
private Set<SchemaElement> dimensionValues = new HashSet<>();
|
private Set<SchemaElement> dimensionValues = new HashSet<>();
|
||||||
private Set<SchemaElement> tags = new HashSet<>();
|
private Set<SchemaElement> tags = new HashSet<>();
|
||||||
private SchemaElement entity = new SchemaElement();
|
private SchemaElement entity = new SchemaElement();
|
||||||
private List<ModelRela> modelRelas = new ArrayList<>();
|
private QueryConfig queryConfig;
|
||||||
|
|
||||||
public SchemaElement getElement(SchemaElementType elementType, long elementID) {
|
public SchemaElement getElement(SchemaElementType elementType, long elementID) {
|
||||||
Optional<SchemaElement> element = Optional.empty();
|
Optional<SchemaElement> element = Optional.empty();
|
||||||
@@ -29,8 +29,8 @@ public class ModelSchema {
|
|||||||
case ENTITY:
|
case ENTITY:
|
||||||
element = Optional.ofNullable(entity);
|
element = Optional.ofNullable(entity);
|
||||||
break;
|
break;
|
||||||
case MODEL:
|
case VIEW:
|
||||||
element = Optional.of(model);
|
element = Optional.of(view);
|
||||||
break;
|
break;
|
||||||
case METRIC:
|
case METRIC:
|
||||||
element = metrics.stream().filter(e -> e.getId() == elementID).findFirst();
|
element = metrics.stream().filter(e -> e.getId() == elementID).findFirst();
|
||||||
@@ -61,8 +61,8 @@ public class ModelSchema {
|
|||||||
case ENTITY:
|
case ENTITY:
|
||||||
element = Optional.ofNullable(entity);
|
element = Optional.ofNullable(entity);
|
||||||
break;
|
break;
|
||||||
case MODEL:
|
case VIEW:
|
||||||
element = Optional.of(model);
|
element = Optional.of(view);
|
||||||
break;
|
break;
|
||||||
case METRIC:
|
case METRIC:
|
||||||
element = metrics.stream().filter(e -> name.equals(e.getName())).findFirst();
|
element = metrics.stream().filter(e -> name.equals(e.getName())).findFirst();
|
||||||
@@ -83,16 +83,31 @@ public class ModelSchema {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
public Set<Long> getModelClusterSet() {
|
public TimeDefaultConfig getTagTypeTimeDefaultConfig() {
|
||||||
if (CollectionUtils.isEmpty(modelRelas)) {
|
if (queryConfig == null) {
|
||||||
return Sets.newHashSet();
|
return null;
|
||||||
}
|
}
|
||||||
Set<Long> modelClusterSet = new HashSet<>();
|
if (queryConfig.getTagTypeDefaultConfig() == null) {
|
||||||
modelRelas.forEach(modelRela -> {
|
return null;
|
||||||
modelClusterSet.add(modelRela.getToModelId());
|
}
|
||||||
modelClusterSet.add(modelRela.getFromModelId());
|
return queryConfig.getTagTypeDefaultConfig().getTimeDefaultConfig();
|
||||||
});
|
}
|
||||||
return modelClusterSet;
|
|
||||||
|
public TimeDefaultConfig getMetricTypeTimeDefaultConfig() {
|
||||||
|
if (queryConfig == null) {
|
||||||
|
return null;
|
||||||
|
}
|
||||||
|
if (queryConfig.getMetricTypeDefaultConfig() == null) {
|
||||||
|
return null;
|
||||||
|
}
|
||||||
|
return queryConfig.getMetricTypeDefaultConfig().getTimeDefaultConfig();
|
||||||
|
}
|
||||||
|
|
||||||
|
public TagTypeDefaultConfig getTagTypeDefaultConfig() {
|
||||||
|
if (queryConfig == null) {
|
||||||
|
return null;
|
||||||
|
}
|
||||||
|
return queryConfig.getTagTypeDefaultConfig();
|
||||||
}
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
@@ -1,7 +1,6 @@
|
|||||||
package com.tencent.supersonic.chat.api.pojo.request;
|
package com.tencent.supersonic.chat.api.pojo.request;
|
||||||
|
|
||||||
|
|
||||||
import com.tencent.supersonic.common.pojo.Constants;
|
|
||||||
import lombok.Data;
|
import lombok.Data;
|
||||||
|
|
||||||
import java.util.ArrayList;
|
import java.util.ArrayList;
|
||||||
@@ -13,26 +12,5 @@ public class ChatDefaultConfigReq {
|
|||||||
private List<Long> dimensionIds = new ArrayList<>();
|
private List<Long> dimensionIds = new ArrayList<>();
|
||||||
private List<Long> metricIds = new ArrayList<>();
|
private List<Long> metricIds = new ArrayList<>();
|
||||||
|
|
||||||
/**
|
|
||||||
* default time span unit
|
|
||||||
*/
|
|
||||||
private Integer unit = 1;
|
|
||||||
|
|
||||||
/**
|
|
||||||
* default time type: day
|
|
||||||
* DAY, WEEK, MONTH, YEAR
|
|
||||||
*/
|
|
||||||
private String period = Constants.DAY;
|
|
||||||
|
|
||||||
private TimeMode timeMode = TimeMode.LAST;
|
|
||||||
|
|
||||||
public enum TimeMode {
|
|
||||||
/**
|
|
||||||
* date mode
|
|
||||||
* LAST - a certain time
|
|
||||||
* RECENT - a period time
|
|
||||||
*/
|
|
||||||
LAST, RECENT
|
|
||||||
}
|
|
||||||
|
|
||||||
}
|
}
|
||||||
@@ -1,23 +0,0 @@
|
|||||||
package com.tencent.supersonic.chat.api.pojo.request;
|
|
||||||
|
|
||||||
import lombok.Data;
|
|
||||||
import lombok.NoArgsConstructor;
|
|
||||||
import lombok.ToString;
|
|
||||||
|
|
||||||
import javax.validation.constraints.NotNull;
|
|
||||||
import java.util.List;
|
|
||||||
|
|
||||||
import static java.time.LocalDate.now;
|
|
||||||
|
|
||||||
@ToString
|
|
||||||
@Data
|
|
||||||
@NoArgsConstructor
|
|
||||||
public class DictLatestTaskReq {
|
|
||||||
|
|
||||||
@NotNull
|
|
||||||
private Long modelId;
|
|
||||||
|
|
||||||
private List<Long> dimIds;
|
|
||||||
|
|
||||||
private String createdAt = now().plusDays(-4).toString();
|
|
||||||
}
|
|
||||||
@@ -1,20 +0,0 @@
|
|||||||
package com.tencent.supersonic.chat.api.pojo.request;
|
|
||||||
|
|
||||||
import com.tencent.supersonic.common.pojo.enums.TaskStatusEnum;
|
|
||||||
import lombok.Data;
|
|
||||||
import lombok.ToString;
|
|
||||||
|
|
||||||
@ToString
|
|
||||||
@Data
|
|
||||||
public class DictTaskFilterReq {
|
|
||||||
|
|
||||||
private Long id;
|
|
||||||
|
|
||||||
private String name;
|
|
||||||
|
|
||||||
private String createdBy;
|
|
||||||
|
|
||||||
private String createdAt;
|
|
||||||
|
|
||||||
private TaskStatusEnum status;
|
|
||||||
}
|
|
||||||
@@ -13,7 +13,7 @@ public class PluginQueryReq {
|
|||||||
|
|
||||||
private String type;
|
private String type;
|
||||||
|
|
||||||
private String model;
|
private String view;
|
||||||
|
|
||||||
private String pattern;
|
private String pattern;
|
||||||
|
|
||||||
|
|||||||
@@ -2,7 +2,7 @@ package com.tencent.supersonic.chat.api.pojo.request;
|
|||||||
|
|
||||||
|
|
||||||
import com.tencent.supersonic.auth.api.authentication.pojo.User;
|
import com.tencent.supersonic.auth.api.authentication.pojo.User;
|
||||||
import com.tencent.supersonic.chat.api.pojo.SchemaElement;
|
import com.tencent.supersonic.headless.api.pojo.SchemaElement;
|
||||||
import com.tencent.supersonic.common.pojo.DateConf;
|
import com.tencent.supersonic.common.pojo.DateConf;
|
||||||
import java.util.HashSet;
|
import java.util.HashSet;
|
||||||
import java.util.Set;
|
import java.util.Set;
|
||||||
|
|||||||
@@ -7,7 +7,7 @@ import lombok.Data;
|
|||||||
public class QueryReq {
|
public class QueryReq {
|
||||||
private String queryText;
|
private String queryText;
|
||||||
private Integer chatId;
|
private Integer chatId;
|
||||||
private Long modelId;
|
private Long viewId;
|
||||||
private User user;
|
private User user;
|
||||||
private QueryFilters queryFilters;
|
private QueryFilters queryFilters;
|
||||||
private boolean saveAnswer = true;
|
private boolean saveAnswer = true;
|
||||||
|
|||||||
@@ -18,7 +18,7 @@ public class SimilarQueryReq {
|
|||||||
|
|
||||||
private String queryText;
|
private String queryText;
|
||||||
|
|
||||||
private String modelId;
|
private Long viewId;
|
||||||
|
|
||||||
private Integer agentId;
|
private Integer agentId;
|
||||||
|
|
||||||
|
|||||||
@@ -1,9 +1,9 @@
|
|||||||
package com.tencent.supersonic.chat.api.pojo.response;
|
package com.tencent.supersonic.chat.api.pojo.response;
|
||||||
|
|
||||||
|
|
||||||
import com.tencent.supersonic.chat.api.pojo.SchemaElement;
|
import com.tencent.supersonic.headless.api.pojo.SchemaElement;
|
||||||
import com.tencent.supersonic.chat.api.pojo.request.ChatDefaultConfigReq.TimeMode;
|
|
||||||
import com.tencent.supersonic.common.pojo.Constants;
|
import com.tencent.supersonic.common.pojo.Constants;
|
||||||
|
import com.tencent.supersonic.common.pojo.enums.TimeMode;
|
||||||
import lombok.Data;
|
import lombok.Data;
|
||||||
|
|
||||||
import java.util.List;
|
import java.util.List;
|
||||||
@@ -21,7 +21,7 @@ public class ChatDefaultRichConfigResp {
|
|||||||
private Integer unit = 1;
|
private Integer unit = 1;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* default time type: day
|
* default time type:
|
||||||
* DAY, WEEK, MONTH, YEAR
|
* DAY, WEEK, MONTH, YEAR
|
||||||
*/
|
*/
|
||||||
private String period = Constants.DAY;
|
private String period = Constants.DAY;
|
||||||
|
|||||||
@@ -1,8 +1,12 @@
|
|||||||
package com.tencent.supersonic.chat.api.pojo.response;
|
package com.tencent.supersonic.chat.api.pojo.response;
|
||||||
|
|
||||||
|
import lombok.AllArgsConstructor;
|
||||||
import lombok.Data;
|
import lombok.Data;
|
||||||
|
import lombok.NoArgsConstructor;
|
||||||
|
|
||||||
@Data
|
@Data
|
||||||
|
@AllArgsConstructor
|
||||||
|
@NoArgsConstructor
|
||||||
public class DataInfo {
|
public class DataInfo {
|
||||||
|
|
||||||
private Integer itemId;
|
private Integer itemId;
|
||||||
|
|||||||
@@ -1,14 +1,16 @@
|
|||||||
package com.tencent.supersonic.chat.api.pojo.response;
|
package com.tencent.supersonic.chat.api.pojo.response;
|
||||||
|
|
||||||
|
import lombok.Data;
|
||||||
|
|
||||||
import java.util.ArrayList;
|
import java.util.ArrayList;
|
||||||
import java.util.List;
|
import java.util.List;
|
||||||
import lombok.Data;
|
|
||||||
|
|
||||||
@Data
|
@Data
|
||||||
public class EntityInfo {
|
public class EntityInfo {
|
||||||
|
|
||||||
private ModelInfo modelInfo = new ModelInfo();
|
private ViewInfo viewInfo = new ViewInfo();
|
||||||
private List<DataInfo> dimensions = new ArrayList<>();
|
private List<DataInfo> dimensions = new ArrayList<>();
|
||||||
private List<DataInfo> metrics = new ArrayList<>();
|
private List<DataInfo> metrics = new ArrayList<>();
|
||||||
private String entityId;
|
private String entityId;
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,6 +1,6 @@
|
|||||||
package com.tencent.supersonic.chat.api.pojo.response;
|
package com.tencent.supersonic.chat.api.pojo.response;
|
||||||
|
|
||||||
import com.tencent.supersonic.chat.api.pojo.SchemaElement;
|
import com.tencent.supersonic.headless.api.pojo.SchemaElement;
|
||||||
|
|
||||||
import java.util.List;
|
import java.util.List;
|
||||||
import lombok.Data;
|
import lombok.Data;
|
||||||
|
|||||||
@@ -1,6 +1,6 @@
|
|||||||
package com.tencent.supersonic.chat.api.pojo.response;
|
package com.tencent.supersonic.chat.api.pojo.response;
|
||||||
|
|
||||||
import com.tencent.supersonic.chat.api.pojo.SchemaElement;
|
import com.tencent.supersonic.headless.api.pojo.SchemaElement;
|
||||||
import com.tencent.supersonic.chat.api.pojo.SemanticParseInfo;
|
import com.tencent.supersonic.chat.api.pojo.SemanticParseInfo;
|
||||||
import com.tencent.supersonic.common.pojo.QueryAuthorization;
|
import com.tencent.supersonic.common.pojo.QueryAuthorization;
|
||||||
import com.tencent.supersonic.common.pojo.QueryColumn;
|
import com.tencent.supersonic.common.pojo.QueryColumn;
|
||||||
|
|||||||
@@ -1,6 +1,6 @@
|
|||||||
package com.tencent.supersonic.chat.api.pojo.response;
|
package com.tencent.supersonic.chat.api.pojo.response;
|
||||||
|
|
||||||
import com.tencent.supersonic.chat.api.pojo.SchemaElement;
|
import com.tencent.supersonic.headless.api.pojo.SchemaElement;
|
||||||
import lombok.Data;
|
import lombok.Data;
|
||||||
|
|
||||||
import java.util.List;
|
import java.util.List;
|
||||||
|
|||||||
@@ -1,6 +1,6 @@
|
|||||||
package com.tencent.supersonic.chat.api.pojo.response;
|
package com.tencent.supersonic.chat.api.pojo.response;
|
||||||
|
|
||||||
import com.tencent.supersonic.chat.api.pojo.SchemaElementType;
|
import com.tencent.supersonic.headless.api.pojo.SchemaElementType;
|
||||||
import java.util.Objects;
|
import java.util.Objects;
|
||||||
import lombok.Builder;
|
import lombok.Builder;
|
||||||
import lombok.Data;
|
import lombok.Data;
|
||||||
|
|||||||
@@ -6,7 +6,7 @@ import java.io.Serializable;
|
|||||||
import java.util.List;
|
import java.util.List;
|
||||||
|
|
||||||
@Data
|
@Data
|
||||||
public class ModelInfo extends DataInfo implements Serializable {
|
public class ViewInfo extends DataInfo implements Serializable {
|
||||||
|
|
||||||
private List<String> words;
|
private List<String> words;
|
||||||
private String primaryKey;
|
private String primaryKey;
|
||||||
@@ -21,70 +21,6 @@
|
|||||||
<groupId>org.springframework</groupId>
|
<groupId>org.springframework</groupId>
|
||||||
<artifactId>spring-context</artifactId>
|
<artifactId>spring-context</artifactId>
|
||||||
</dependency>
|
</dependency>
|
||||||
|
|
||||||
<dependency>
|
|
||||||
<groupId>com.hankcs</groupId>
|
|
||||||
<artifactId>hanlp</artifactId>
|
|
||||||
<version>${hanlp.version}</version>
|
|
||||||
</dependency>
|
|
||||||
|
|
||||||
<dependency>
|
|
||||||
<groupId>org.apache.hadoop</groupId>
|
|
||||||
<artifactId>hadoop-client</artifactId>
|
|
||||||
<version>${hadoop.version}</version>
|
|
||||||
<exclusions>
|
|
||||||
<exclusion>
|
|
||||||
<groupId>org.apache.zookeeper</groupId>
|
|
||||||
<artifactId>zookeeper</artifactId>
|
|
||||||
</exclusion>
|
|
||||||
<exclusion>
|
|
||||||
<groupId>javax.servlet</groupId>
|
|
||||||
<artifactId>servlet-api</artifactId>
|
|
||||||
</exclusion>
|
|
||||||
</exclusions>
|
|
||||||
</dependency>
|
|
||||||
<dependency>
|
|
||||||
<groupId>org.apache.hadoop</groupId>
|
|
||||||
<artifactId>hadoop-hdfs</artifactId>
|
|
||||||
<version>${hadoop.version}</version>
|
|
||||||
<exclusions>
|
|
||||||
<exclusion>
|
|
||||||
<groupId>org.apache.zookeeper</groupId>
|
|
||||||
<artifactId>zookeeper</artifactId>
|
|
||||||
</exclusion>
|
|
||||||
<exclusion>
|
|
||||||
<groupId>javax.servlet</groupId>
|
|
||||||
<artifactId>servlet-api</artifactId>
|
|
||||||
</exclusion>
|
|
||||||
</exclusions>
|
|
||||||
</dependency>
|
|
||||||
<dependency>
|
|
||||||
<groupId>org.apache.hadoop</groupId>
|
|
||||||
<artifactId>hadoop-common</artifactId>
|
|
||||||
<version>${hadoop.version}</version>
|
|
||||||
<exclusions>
|
|
||||||
<exclusion>
|
|
||||||
<groupId>org.slf4j</groupId>
|
|
||||||
<artifactId>slf4j-log4j12</artifactId>
|
|
||||||
</exclusion>
|
|
||||||
<exclusion>
|
|
||||||
<groupId>log4j</groupId>
|
|
||||||
<artifactId>log4j</artifactId>
|
|
||||||
</exclusion>
|
|
||||||
<exclusion>
|
|
||||||
<groupId>org.apache.zookeeper</groupId>
|
|
||||||
<artifactId>zookeeper</artifactId>
|
|
||||||
</exclusion>
|
|
||||||
<exclusion>
|
|
||||||
<groupId>org.apache.curator</groupId>
|
|
||||||
<artifactId>*</artifactId>
|
|
||||||
</exclusion>
|
|
||||||
<exclusion>
|
|
||||||
<groupId>javax.servlet</groupId>
|
|
||||||
<artifactId>servlet-api</artifactId>
|
|
||||||
</exclusion>
|
|
||||||
</exclusions>
|
|
||||||
</dependency>
|
|
||||||
<dependency>
|
<dependency>
|
||||||
<groupId>org.testng</groupId>
|
<groupId>org.testng</groupId>
|
||||||
<artifactId>testng</artifactId>
|
<artifactId>testng</artifactId>
|
||||||
|
|||||||
@@ -4,6 +4,9 @@ package com.tencent.supersonic.chat.core.agent;
|
|||||||
import com.alibaba.fastjson.JSONObject;
|
import com.alibaba.fastjson.JSONObject;
|
||||||
import com.google.common.collect.Lists;
|
import com.google.common.collect.Lists;
|
||||||
import com.tencent.supersonic.common.pojo.RecordInfo;
|
import com.tencent.supersonic.common.pojo.RecordInfo;
|
||||||
|
import lombok.Data;
|
||||||
|
import org.springframework.util.CollectionUtils;
|
||||||
|
|
||||||
import java.util.Collection;
|
import java.util.Collection;
|
||||||
import java.util.HashSet;
|
import java.util.HashSet;
|
||||||
import java.util.List;
|
import java.util.List;
|
||||||
@@ -11,8 +14,6 @@ import java.util.Map;
|
|||||||
import java.util.Objects;
|
import java.util.Objects;
|
||||||
import java.util.Set;
|
import java.util.Set;
|
||||||
import java.util.stream.Collectors;
|
import java.util.stream.Collectors;
|
||||||
import lombok.Data;
|
|
||||||
import org.springframework.util.CollectionUtils;
|
|
||||||
|
|
||||||
@Data
|
@Data
|
||||||
public class Agent extends RecordInfo {
|
public class Agent extends RecordInfo {
|
||||||
@@ -51,8 +52,8 @@ public class Agent extends RecordInfo {
|
|||||||
return enableSearch != null && enableSearch == 1;
|
return enableSearch != null && enableSearch == 1;
|
||||||
}
|
}
|
||||||
|
|
||||||
public static boolean containsAllModel(Set<Long> detectModelIds) {
|
public static boolean containsAllModel(Set<Long> detectViewIds) {
|
||||||
return !CollectionUtils.isEmpty(detectModelIds) && detectModelIds.contains(-1L);
|
return !CollectionUtils.isEmpty(detectViewIds) && detectViewIds.contains(-1L);
|
||||||
}
|
}
|
||||||
|
|
||||||
public List<NL2SQLTool> getParserTools(AgentToolType agentToolType) {
|
public List<NL2SQLTool> getParserTools(AgentToolType agentToolType) {
|
||||||
@@ -64,12 +65,12 @@ public class Agent extends RecordInfo {
|
|||||||
.collect(Collectors.toList());
|
.collect(Collectors.toList());
|
||||||
}
|
}
|
||||||
|
|
||||||
public Set<Long> getModelIds(AgentToolType agentToolType) {
|
public Set<Long> getViewIds(AgentToolType agentToolType) {
|
||||||
List<NL2SQLTool> commonAgentTools = getParserTools(agentToolType);
|
List<NL2SQLTool> commonAgentTools = getParserTools(agentToolType);
|
||||||
if (CollectionUtils.isEmpty(commonAgentTools)) {
|
if (CollectionUtils.isEmpty(commonAgentTools)) {
|
||||||
return new HashSet<>();
|
return new HashSet<>();
|
||||||
}
|
}
|
||||||
return commonAgentTools.stream().map(NL2SQLTool::getModelIds)
|
return commonAgentTools.stream().map(NL2SQLTool::getViewIds)
|
||||||
.filter(modelIds -> !CollectionUtils.isEmpty(modelIds))
|
.filter(modelIds -> !CollectionUtils.isEmpty(modelIds))
|
||||||
.flatMap(Collection::stream)
|
.flatMap(Collection::stream)
|
||||||
.collect(Collectors.toSet());
|
.collect(Collectors.toSet());
|
||||||
|
|||||||
@@ -1,8 +1,25 @@
|
|||||||
package com.tencent.supersonic.chat.core.agent;
|
package com.tencent.supersonic.chat.core.agent;
|
||||||
|
|
||||||
|
import java.util.HashMap;
|
||||||
|
import java.util.Map;
|
||||||
|
|
||||||
public enum AgentToolType {
|
public enum AgentToolType {
|
||||||
NL2SQL_RULE,
|
NL2SQL_RULE("基于规则Text-to-SQL"),
|
||||||
NL2SQL_LLM,
|
NL2SQL_LLM("基于大模型Text-to-SQL"),
|
||||||
PLUGIN,
|
PLUGIN("第三方插件");
|
||||||
ANALYTICS
|
|
||||||
|
private String title;
|
||||||
|
|
||||||
|
AgentToolType(String title) {
|
||||||
|
this.title = title;
|
||||||
|
}
|
||||||
|
|
||||||
|
public static Map<AgentToolType, String> getToolTypes() {
|
||||||
|
Map<AgentToolType, String> map = new HashMap<>();
|
||||||
|
map.put(NL2SQL_RULE, NL2SQL_RULE.title);
|
||||||
|
map.put(NL2SQL_LLM, NL2SQL_LLM.title);
|
||||||
|
map.put(PLUGIN, PLUGIN.title);
|
||||||
|
return map;
|
||||||
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,11 +0,0 @@
|
|||||||
package com.tencent.supersonic.chat.core.agent;
|
|
||||||
|
|
||||||
import lombok.Data;
|
|
||||||
|
|
||||||
|
|
||||||
@Data
|
|
||||||
public class DataAnalyticsTool extends AgentTool {
|
|
||||||
|
|
||||||
private Long modelId;
|
|
||||||
|
|
||||||
}
|
|
||||||
@@ -1,16 +1,17 @@
|
|||||||
package com.tencent.supersonic.chat.core.agent;
|
package com.tencent.supersonic.chat.core.agent;
|
||||||
|
|
||||||
|
|
||||||
import java.util.List;
|
|
||||||
import lombok.AllArgsConstructor;
|
import lombok.AllArgsConstructor;
|
||||||
import lombok.Data;
|
import lombok.Data;
|
||||||
import lombok.NoArgsConstructor;
|
import lombok.NoArgsConstructor;
|
||||||
|
|
||||||
|
import java.util.List;
|
||||||
|
|
||||||
@Data
|
@Data
|
||||||
@NoArgsConstructor
|
@NoArgsConstructor
|
||||||
@AllArgsConstructor
|
@AllArgsConstructor
|
||||||
public class NL2SQLTool extends AgentTool {
|
public class NL2SQLTool extends AgentTool {
|
||||||
|
|
||||||
protected List<Long> modelIds;
|
protected List<Long> viewIds;
|
||||||
|
|
||||||
}
|
}
|
||||||
@@ -15,7 +15,7 @@ public class RuleParserTool extends NL2SQLTool {
|
|||||||
private List<String> queryTypes;
|
private List<String> queryTypes;
|
||||||
|
|
||||||
public boolean isContainsAllModel() {
|
public boolean isContainsAllModel() {
|
||||||
return CollectionUtils.isNotEmpty(modelIds) && modelIds.contains(-1L);
|
return CollectionUtils.isNotEmpty(viewIds) && viewIds.contains(-1L);
|
||||||
}
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,7 +1,9 @@
|
|||||||
package com.tencent.supersonic.chat.core.config;
|
package com.tencent.supersonic.chat.core.config;
|
||||||
|
|
||||||
import com.tencent.supersonic.chat.core.utils.HanlpHelper;
|
import com.tencent.supersonic.headless.core.knowledge.helper.HanlpHelper;
|
||||||
|
|
||||||
import java.io.FileNotFoundException;
|
import java.io.FileNotFoundException;
|
||||||
|
|
||||||
import lombok.Data;
|
import lombok.Data;
|
||||||
import lombok.extern.slf4j.Slf4j;
|
import lombok.extern.slf4j.Slf4j;
|
||||||
import org.springframework.beans.factory.annotation.Value;
|
import org.springframework.beans.factory.annotation.Value;
|
||||||
@@ -11,7 +13,7 @@ import org.springframework.context.annotation.Configuration;
|
|||||||
@Data
|
@Data
|
||||||
@Configuration
|
@Configuration
|
||||||
@Slf4j
|
@Slf4j
|
||||||
public class LocalFileConfig {
|
public class ChatLocalFileConfig {
|
||||||
|
|
||||||
|
|
||||||
@Value("${dict.directory.latest:/data/dictionary/custom}")
|
@Value("${dict.directory.latest:/data/dictionary/custom}")
|
||||||
@@ -1,8 +1,8 @@
|
|||||||
package com.tencent.supersonic.chat.core.config;
|
package com.tencent.supersonic.chat.core.config;
|
||||||
|
|
||||||
|
|
||||||
import com.tencent.supersonic.headless.api.response.DimSchemaResp;
|
import com.tencent.supersonic.headless.api.pojo.response.DimSchemaResp;
|
||||||
import com.tencent.supersonic.headless.api.response.MetricSchemaResp;
|
import com.tencent.supersonic.headless.api.pojo.response.MetricSchemaResp;
|
||||||
import java.util.List;
|
import java.util.List;
|
||||||
import lombok.Data;
|
import lombok.Data;
|
||||||
|
|
||||||
|
|||||||
@@ -52,7 +52,7 @@ public class OptimizationConfig {
|
|||||||
@Value("${embedding.mapper.round.number:10}")
|
@Value("${embedding.mapper.round.number:10}")
|
||||||
private int embeddingMapperRoundNumber;
|
private int embeddingMapperRoundNumber;
|
||||||
|
|
||||||
@Value("${embedding.mapper.distance.threshold:0.58}")
|
@Value("${embedding.mapper.distance.threshold:0.01}")
|
||||||
private Double embeddingMapperDistanceThreshold;
|
private Double embeddingMapperDistanceThreshold;
|
||||||
|
|
||||||
@Value("${s2SQL.linking.value.switch:true}")
|
@Value("${s2SQL.linking.value.switch:true}")
|
||||||
@@ -73,9 +73,6 @@ public class OptimizationConfig {
|
|||||||
@Value("${text2sql.self.consistency.num:5}")
|
@Value("${text2sql.self.consistency.num:5}")
|
||||||
private int text2sqlSelfConsistencyNum;
|
private int text2sqlSelfConsistencyNum;
|
||||||
|
|
||||||
@Value("${text2sql.collection.name:text2dsl_agent_collection}")
|
|
||||||
private String text2sqlCollectionName;
|
|
||||||
|
|
||||||
@Value("${parse.show.count:3}")
|
@Value("${parse.show.count:3}")
|
||||||
private Integer parseShowCount;
|
private Integer parseShowCount;
|
||||||
|
|
||||||
|
|||||||
@@ -1,14 +1,21 @@
|
|||||||
package com.tencent.supersonic.chat.core.corrector;
|
package com.tencent.supersonic.chat.core.corrector;
|
||||||
|
|
||||||
import com.tencent.supersonic.chat.core.pojo.QueryContext;
|
import com.tencent.supersonic.common.util.ContextUtils;
|
||||||
import com.tencent.supersonic.chat.api.pojo.SchemaElement;
|
import com.tencent.supersonic.headless.api.pojo.SchemaElement;
|
||||||
import com.tencent.supersonic.chat.api.pojo.SemanticParseInfo;
|
import com.tencent.supersonic.chat.api.pojo.SemanticParseInfo;
|
||||||
import com.tencent.supersonic.chat.api.pojo.SemanticSchema;
|
import com.tencent.supersonic.chat.api.pojo.SemanticSchema;
|
||||||
|
import com.tencent.supersonic.chat.core.pojo.QueryContext;
|
||||||
import com.tencent.supersonic.common.pojo.enums.AggregateTypeEnum;
|
import com.tencent.supersonic.common.pojo.enums.AggregateTypeEnum;
|
||||||
import com.tencent.supersonic.common.pojo.enums.TimeDimensionEnum;
|
import com.tencent.supersonic.common.pojo.enums.TimeDimensionEnum;
|
||||||
import com.tencent.supersonic.common.util.jsqlparser.SqlParserAddHelper;
|
import com.tencent.supersonic.common.util.jsqlparser.SqlAddHelper;
|
||||||
import com.tencent.supersonic.common.util.jsqlparser.SqlParserSelectFunctionHelper;
|
import com.tencent.supersonic.common.util.jsqlparser.SqlSelectFunctionHelper;
|
||||||
import com.tencent.supersonic.common.util.jsqlparser.SqlParserSelectHelper;
|
import com.tencent.supersonic.common.util.jsqlparser.SqlSelectHelper;
|
||||||
|
import lombok.extern.slf4j.Slf4j;
|
||||||
|
import org.apache.commons.lang3.StringUtils;
|
||||||
|
import org.apache.commons.lang3.tuple.Pair;
|
||||||
|
import org.springframework.core.env.Environment;
|
||||||
|
import org.springframework.util.CollectionUtils;
|
||||||
|
|
||||||
import java.util.ArrayList;
|
import java.util.ArrayList;
|
||||||
import java.util.HashSet;
|
import java.util.HashSet;
|
||||||
import java.util.List;
|
import java.util.List;
|
||||||
@@ -16,10 +23,6 @@ import java.util.Map;
|
|||||||
import java.util.Objects;
|
import java.util.Objects;
|
||||||
import java.util.Set;
|
import java.util.Set;
|
||||||
import java.util.stream.Collectors;
|
import java.util.stream.Collectors;
|
||||||
import lombok.extern.slf4j.Slf4j;
|
|
||||||
import org.apache.commons.lang3.StringUtils;
|
|
||||||
import org.apache.commons.lang3.tuple.Pair;
|
|
||||||
import org.springframework.util.CollectionUtils;
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* basic semantic correction functionality, offering common methods and an
|
* basic semantic correction functionality, offering common methods and an
|
||||||
@@ -42,7 +45,7 @@ public abstract class BaseSemanticCorrector implements SemanticCorrector {
|
|||||||
|
|
||||||
public abstract void doCorrect(QueryContext queryContext, SemanticParseInfo semanticParseInfo);
|
public abstract void doCorrect(QueryContext queryContext, SemanticParseInfo semanticParseInfo);
|
||||||
|
|
||||||
protected Map<String, String> getFieldNameMap(QueryContext queryContext, Set<Long> modelIds) {
|
protected Map<String, String> getFieldNameMap(QueryContext queryContext, Long viewId) {
|
||||||
|
|
||||||
SemanticSchema semanticSchema = queryContext.getSemanticSchema();
|
SemanticSchema semanticSchema = queryContext.getSemanticSchema();
|
||||||
|
|
||||||
@@ -52,7 +55,7 @@ public abstract class BaseSemanticCorrector implements SemanticCorrector {
|
|||||||
|
|
||||||
// support fieldName and field alias
|
// support fieldName and field alias
|
||||||
Map<String, String> result = dbAllFields.stream()
|
Map<String, String> result = dbAllFields.stream()
|
||||||
.filter(entry -> modelIds.contains(entry.getModel()))
|
.filter(entry -> viewId.equals(entry.getView()))
|
||||||
.flatMap(schemaElement -> {
|
.flatMap(schemaElement -> {
|
||||||
Set<String> elements = new HashSet<>();
|
Set<String> elements = new HashSet<>();
|
||||||
elements.add(schemaElement.getName());
|
elements.add(schemaElement.getName());
|
||||||
@@ -74,14 +77,20 @@ public abstract class BaseSemanticCorrector implements SemanticCorrector {
|
|||||||
}
|
}
|
||||||
|
|
||||||
protected void addFieldsToSelect(SemanticParseInfo semanticParseInfo, String correctS2SQL) {
|
protected void addFieldsToSelect(SemanticParseInfo semanticParseInfo, String correctS2SQL) {
|
||||||
Set<String> selectFields = new HashSet<>(SqlParserSelectHelper.getSelectFields(correctS2SQL));
|
Set<String> selectFields = new HashSet<>(SqlSelectHelper.getSelectFields(correctS2SQL));
|
||||||
Set<String> needAddFields = new HashSet<>(SqlParserSelectHelper.getGroupByFields(correctS2SQL));
|
Set<String> needAddFields = new HashSet<>(SqlSelectHelper.getGroupByFields(correctS2SQL));
|
||||||
needAddFields.addAll(SqlParserSelectHelper.getOrderByFields(correctS2SQL));
|
|
||||||
|
//decide whether add order by expression field to select
|
||||||
|
Environment environment = ContextUtils.getBean(Environment.class);
|
||||||
|
String correctorAdditionalInfo = environment.getProperty("corrector.additional.information");
|
||||||
|
if (StringUtils.isNotBlank(correctorAdditionalInfo) && Boolean.parseBoolean(correctorAdditionalInfo)) {
|
||||||
|
needAddFields.addAll(SqlSelectHelper.getOrderByFields(correctS2SQL));
|
||||||
|
}
|
||||||
|
|
||||||
// If there is no aggregate function in the S2SQL statement and
|
// If there is no aggregate function in the S2SQL statement and
|
||||||
// there is a data field in 'WHERE' statement, add the field to the 'SELECT' statement.
|
// there is a data field in 'WHERE' statement, add the field to the 'SELECT' statement.
|
||||||
if (!SqlParserSelectFunctionHelper.hasAggregateFunction(correctS2SQL)) {
|
if (!SqlSelectFunctionHelper.hasAggregateFunction(correctS2SQL)) {
|
||||||
List<String> whereFields = SqlParserSelectHelper.getWhereFields(correctS2SQL);
|
List<String> whereFields = SqlSelectHelper.getWhereFields(correctS2SQL);
|
||||||
List<String> timeChNameList = TimeDimensionEnum.getChNameList();
|
List<String> timeChNameList = TimeDimensionEnum.getChNameList();
|
||||||
Set<String> timeFields = whereFields.stream().filter(field -> timeChNameList.contains(field))
|
Set<String> timeFields = whereFields.stream().filter(field -> timeChNameList.contains(field))
|
||||||
.collect(Collectors.toSet());
|
.collect(Collectors.toSet());
|
||||||
@@ -93,16 +102,15 @@ public abstract class BaseSemanticCorrector implements SemanticCorrector {
|
|||||||
}
|
}
|
||||||
|
|
||||||
needAddFields.removeAll(selectFields);
|
needAddFields.removeAll(selectFields);
|
||||||
String replaceFields = SqlParserAddHelper.addFieldsToSelect(correctS2SQL, new ArrayList<>(needAddFields));
|
String replaceFields = SqlAddHelper.addFieldsToSelect(correctS2SQL, new ArrayList<>(needAddFields));
|
||||||
semanticParseInfo.getSqlInfo().setCorrectS2SQL(replaceFields);
|
semanticParseInfo.getSqlInfo().setCorrectS2SQL(replaceFields);
|
||||||
}
|
}
|
||||||
|
|
||||||
protected void addAggregateToMetric(QueryContext queryContext, SemanticParseInfo semanticParseInfo) {
|
protected void addAggregateToMetric(QueryContext queryContext, SemanticParseInfo semanticParseInfo) {
|
||||||
//add aggregate to all metric
|
//add aggregate to all metric
|
||||||
String correctS2SQL = semanticParseInfo.getSqlInfo().getCorrectS2SQL();
|
String correctS2SQL = semanticParseInfo.getSqlInfo().getCorrectS2SQL();
|
||||||
Set<Long> modelIds = semanticParseInfo.getModel().getModelIds();
|
Long viewId = semanticParseInfo.getView().getView();
|
||||||
|
List<SchemaElement> metrics = getMetricElements(queryContext, viewId);
|
||||||
List<SchemaElement> metrics = getMetricElements(queryContext, modelIds);
|
|
||||||
|
|
||||||
Map<String, String> metricToAggregate = metrics.stream()
|
Map<String, String> metricToAggregate = metrics.stream()
|
||||||
.map(schemaElement -> {
|
.map(schemaElement -> {
|
||||||
@@ -123,13 +131,28 @@ public abstract class BaseSemanticCorrector implements SemanticCorrector {
|
|||||||
if (CollectionUtils.isEmpty(metricToAggregate)) {
|
if (CollectionUtils.isEmpty(metricToAggregate)) {
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
String aggregateSql = SqlParserAddHelper.addAggregateToField(correctS2SQL, metricToAggregate);
|
String aggregateSql = SqlAddHelper.addAggregateToField(correctS2SQL, metricToAggregate);
|
||||||
semanticParseInfo.getSqlInfo().setCorrectS2SQL(aggregateSql);
|
semanticParseInfo.getSqlInfo().setCorrectS2SQL(aggregateSql);
|
||||||
}
|
}
|
||||||
|
|
||||||
protected List<SchemaElement> getMetricElements(QueryContext queryContext, Set<Long> modelIds) {
|
protected List<SchemaElement> getMetricElements(QueryContext queryContext, Long viewId) {
|
||||||
SemanticSchema semanticSchema = queryContext.getSemanticSchema();
|
SemanticSchema semanticSchema = queryContext.getSemanticSchema();
|
||||||
return semanticSchema.getMetrics(modelIds);
|
return semanticSchema.getMetrics(viewId);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
protected Set<String> getDimensions(Long viewId, SemanticSchema semanticSchema) {
|
||||||
|
Set<String> dimensions = semanticSchema.getDimensions(viewId).stream()
|
||||||
|
.flatMap(
|
||||||
|
schemaElement -> {
|
||||||
|
Set<String> elements = new HashSet<>();
|
||||||
|
elements.add(schemaElement.getName());
|
||||||
|
if (!CollectionUtils.isEmpty(schemaElement.getAlias())) {
|
||||||
|
elements.addAll(schemaElement.getAlias());
|
||||||
|
}
|
||||||
|
return elements.stream();
|
||||||
|
}
|
||||||
|
).collect(Collectors.toSet());
|
||||||
|
dimensions.add(TimeDimensionEnum.DAY.getChName());
|
||||||
|
return dimensions;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,21 +0,0 @@
|
|||||||
package com.tencent.supersonic.chat.core.corrector;
|
|
||||||
|
|
||||||
import com.tencent.supersonic.chat.core.pojo.QueryContext;
|
|
||||||
import com.tencent.supersonic.chat.api.pojo.SemanticParseInfo;
|
|
||||||
import com.tencent.supersonic.common.util.jsqlparser.SqlParserReplaceHelper;
|
|
||||||
import lombok.extern.slf4j.Slf4j;
|
|
||||||
/**
|
|
||||||
* Perform SQL corrections on the "From" section in S2SQL.
|
|
||||||
*/
|
|
||||||
@Slf4j
|
|
||||||
public class FromCorrector extends BaseSemanticCorrector {
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public void doCorrect(QueryContext queryContext, SemanticParseInfo semanticParseInfo) {
|
|
||||||
String modelName = semanticParseInfo.getModel().getName();
|
|
||||||
String correctSql = SqlParserReplaceHelper
|
|
||||||
.replaceTable(semanticParseInfo.getSqlInfo().getCorrectS2SQL(), modelName);
|
|
||||||
semanticParseInfo.getSqlInfo().setCorrectS2SQL(correctSql);
|
|
||||||
}
|
|
||||||
|
|
||||||
}
|
|
||||||
@@ -0,0 +1,31 @@
|
|||||||
|
package com.tencent.supersonic.chat.core.corrector;
|
||||||
|
|
||||||
|
import com.tencent.supersonic.chat.api.pojo.SemanticParseInfo;
|
||||||
|
import com.tencent.supersonic.chat.core.pojo.QueryContext;
|
||||||
|
import java.util.ArrayList;
|
||||||
|
import java.util.List;
|
||||||
|
import lombok.extern.slf4j.Slf4j;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Correcting SQL syntax, primarily including fixes to select, where, groupBy, and Having clauses
|
||||||
|
*/
|
||||||
|
@Slf4j
|
||||||
|
public class GrammarCorrector extends BaseSemanticCorrector {
|
||||||
|
|
||||||
|
private List<BaseSemanticCorrector> correctors;
|
||||||
|
|
||||||
|
public GrammarCorrector() {
|
||||||
|
correctors = new ArrayList<>();
|
||||||
|
correctors.add(new SelectCorrector());
|
||||||
|
correctors.add(new WhereCorrector());
|
||||||
|
correctors.add(new GroupByCorrector());
|
||||||
|
correctors.add(new HavingCorrector());
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public void doCorrect(QueryContext queryContext, SemanticParseInfo semanticParseInfo) {
|
||||||
|
for (BaseSemanticCorrector corrector : correctors) {
|
||||||
|
corrector.correct(queryContext, semanticParseInfo);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -1,14 +1,21 @@
|
|||||||
package com.tencent.supersonic.chat.core.corrector;
|
package com.tencent.supersonic.chat.core.corrector;
|
||||||
|
|
||||||
import com.tencent.supersonic.chat.core.pojo.QueryContext;
|
|
||||||
import com.tencent.supersonic.chat.api.pojo.SemanticParseInfo;
|
import com.tencent.supersonic.chat.api.pojo.SemanticParseInfo;
|
||||||
import com.tencent.supersonic.chat.api.pojo.SemanticSchema;
|
import com.tencent.supersonic.chat.api.pojo.SemanticSchema;
|
||||||
import com.tencent.supersonic.chat.api.pojo.response.SqlInfo;
|
import com.tencent.supersonic.chat.api.pojo.response.SqlInfo;
|
||||||
|
import com.tencent.supersonic.chat.core.pojo.QueryContext;
|
||||||
import com.tencent.supersonic.common.pojo.enums.TimeDimensionEnum;
|
import com.tencent.supersonic.common.pojo.enums.TimeDimensionEnum;
|
||||||
import com.tencent.supersonic.common.util.jsqlparser.SqlParserAddHelper;
|
import com.tencent.supersonic.common.util.ContextUtils;
|
||||||
import com.tencent.supersonic.common.util.jsqlparser.SqlParserSelectHelper;
|
import com.tencent.supersonic.common.util.jsqlparser.SqlAddHelper;
|
||||||
import java.util.HashSet;
|
import com.tencent.supersonic.common.util.jsqlparser.SqlSelectHelper;
|
||||||
|
import com.tencent.supersonic.headless.api.pojo.Dim;
|
||||||
|
import com.tencent.supersonic.headless.api.pojo.response.ModelResp;
|
||||||
|
import com.tencent.supersonic.headless.api.pojo.response.ViewResp;
|
||||||
|
import com.tencent.supersonic.headless.server.pojo.MetaFilter;
|
||||||
|
import com.tencent.supersonic.headless.server.service.ModelService;
|
||||||
|
import com.tencent.supersonic.headless.server.service.ViewService;
|
||||||
import java.util.List;
|
import java.util.List;
|
||||||
|
import java.util.Objects;
|
||||||
import java.util.Set;
|
import java.util.Set;
|
||||||
import java.util.stream.Collectors;
|
import java.util.stream.Collectors;
|
||||||
import lombok.extern.slf4j.Slf4j;
|
import lombok.extern.slf4j.Slf4j;
|
||||||
@@ -22,47 +29,67 @@ public class GroupByCorrector extends BaseSemanticCorrector {
|
|||||||
|
|
||||||
@Override
|
@Override
|
||||||
public void doCorrect(QueryContext queryContext, SemanticParseInfo semanticParseInfo) {
|
public void doCorrect(QueryContext queryContext, SemanticParseInfo semanticParseInfo) {
|
||||||
|
Boolean needAddGroupBy = needAddGroupBy(queryContext, semanticParseInfo);
|
||||||
|
if (!needAddGroupBy) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
addGroupByFields(queryContext, semanticParseInfo);
|
addGroupByFields(queryContext, semanticParseInfo);
|
||||||
|
}
|
||||||
|
|
||||||
|
private Boolean needAddGroupBy(QueryContext queryContext, SemanticParseInfo semanticParseInfo) {
|
||||||
|
Long viewId = semanticParseInfo.getViewId();
|
||||||
|
ViewService viewService = ContextUtils.getBean(ViewService.class);
|
||||||
|
ModelService modelService = ContextUtils.getBean(ModelService.class);
|
||||||
|
ViewResp viewResp = viewService.getView(viewId);
|
||||||
|
List<Long> modelIds = viewResp.getViewDetail().getViewModelConfigs().stream().map(config -> config.getId())
|
||||||
|
.collect(Collectors.toList());
|
||||||
|
MetaFilter metaFilter = new MetaFilter();
|
||||||
|
metaFilter.setIds(modelIds);
|
||||||
|
List<ModelResp> modelRespList = modelService.getModelList(metaFilter);
|
||||||
|
for (ModelResp modelResp : modelRespList) {
|
||||||
|
List<Dim> dimList = modelResp.getModelDetail().getDimensions();
|
||||||
|
for (Dim dim : dimList) {
|
||||||
|
if (Objects.nonNull(dim.getTypeParams()) && dim.getTypeParams().getTimeGranularity().equals("none")) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
//add dimension group by
|
||||||
|
SqlInfo sqlInfo = semanticParseInfo.getSqlInfo();
|
||||||
|
String correctS2SQL = sqlInfo.getCorrectS2SQL();
|
||||||
|
SemanticSchema semanticSchema = queryContext.getSemanticSchema();
|
||||||
|
// check has distinct
|
||||||
|
if (SqlSelectHelper.hasDistinct(correctS2SQL)) {
|
||||||
|
log.info("not add group by ,exist distinct in correctS2SQL:{}", correctS2SQL);
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
//add alias field name
|
||||||
|
Set<String> dimensions = getDimensions(viewId, semanticSchema);
|
||||||
|
List<String> selectFields = SqlSelectHelper.getSelectFields(correctS2SQL);
|
||||||
|
if (CollectionUtils.isEmpty(selectFields) || CollectionUtils.isEmpty(dimensions)) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
// if only date in select not add group by.
|
||||||
|
if (selectFields.size() == 1 && selectFields.contains(TimeDimensionEnum.DAY.getChName())) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
if (SqlSelectHelper.hasGroupBy(correctS2SQL)) {
|
||||||
|
log.info("not add group by ,exist group by in correctS2SQL:{}", correctS2SQL);
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
return true;
|
||||||
}
|
}
|
||||||
|
|
||||||
private void addGroupByFields(QueryContext queryContext, SemanticParseInfo semanticParseInfo) {
|
private void addGroupByFields(QueryContext queryContext, SemanticParseInfo semanticParseInfo) {
|
||||||
Set<Long> modelIds = semanticParseInfo.getModel().getModelIds();
|
Long viewId = semanticParseInfo.getViewId();
|
||||||
|
|
||||||
//add dimension group by
|
//add dimension group by
|
||||||
SqlInfo sqlInfo = semanticParseInfo.getSqlInfo();
|
SqlInfo sqlInfo = semanticParseInfo.getSqlInfo();
|
||||||
String correctS2SQL = sqlInfo.getCorrectS2SQL();
|
String correctS2SQL = sqlInfo.getCorrectS2SQL();
|
||||||
SemanticSchema semanticSchema = queryContext.getSemanticSchema();
|
SemanticSchema semanticSchema = queryContext.getSemanticSchema();
|
||||||
//add alias field name
|
//add alias field name
|
||||||
Set<String> dimensions = semanticSchema.getDimensions(modelIds).stream()
|
Set<String> dimensions = getDimensions(viewId, semanticSchema);
|
||||||
.flatMap(
|
List<String> selectFields = SqlSelectHelper.getSelectFields(correctS2SQL);
|
||||||
schemaElement -> {
|
List<String> aggregateFields = SqlSelectHelper.getAggregateFields(correctS2SQL);
|
||||||
Set<String> elements = new HashSet<>();
|
|
||||||
elements.add(schemaElement.getName());
|
|
||||||
if (!CollectionUtils.isEmpty(schemaElement.getAlias())) {
|
|
||||||
elements.addAll(schemaElement.getAlias());
|
|
||||||
}
|
|
||||||
return elements.stream();
|
|
||||||
}
|
|
||||||
).collect(Collectors.toSet());
|
|
||||||
dimensions.add(TimeDimensionEnum.DAY.getChName());
|
|
||||||
|
|
||||||
List<String> selectFields = SqlParserSelectHelper.getSelectFields(correctS2SQL);
|
|
||||||
|
|
||||||
if (CollectionUtils.isEmpty(selectFields) || CollectionUtils.isEmpty(dimensions)) {
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
// if only date in select not add group by.
|
|
||||||
if (selectFields.size() == 1 && selectFields.contains(TimeDimensionEnum.DAY.getChName())) {
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
if (SqlParserSelectHelper.hasGroupBy(correctS2SQL)) {
|
|
||||||
log.info("not add group by ,exist group by in correctS2SQL:{}", correctS2SQL);
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
|
|
||||||
List<String> aggregateFields = SqlParserSelectHelper.getAggregateFields(correctS2SQL);
|
|
||||||
Set<String> groupByFields = selectFields.stream()
|
Set<String> groupByFields = selectFields.stream()
|
||||||
.filter(field -> dimensions.contains(field))
|
.filter(field -> dimensions.contains(field))
|
||||||
.filter(field -> {
|
.filter(field -> {
|
||||||
@@ -72,13 +99,12 @@ public class GroupByCorrector extends BaseSemanticCorrector {
|
|||||||
return true;
|
return true;
|
||||||
})
|
})
|
||||||
.collect(Collectors.toSet());
|
.collect(Collectors.toSet());
|
||||||
semanticParseInfo.getSqlInfo().setCorrectS2SQL(SqlParserAddHelper.addGroupBy(correctS2SQL, groupByFields));
|
semanticParseInfo.getSqlInfo().setCorrectS2SQL(SqlAddHelper.addGroupBy(correctS2SQL, groupByFields));
|
||||||
|
|
||||||
addAggregate(queryContext, semanticParseInfo);
|
addAggregate(queryContext, semanticParseInfo);
|
||||||
}
|
}
|
||||||
|
|
||||||
private void addAggregate(QueryContext queryContext, SemanticParseInfo semanticParseInfo) {
|
private void addAggregate(QueryContext queryContext, SemanticParseInfo semanticParseInfo) {
|
||||||
List<String> sqlGroupByFields = SqlParserSelectHelper.getGroupByFields(
|
List<String> sqlGroupByFields = SqlSelectHelper.getGroupByFields(
|
||||||
semanticParseInfo.getSqlInfo().getCorrectS2SQL());
|
semanticParseInfo.getSqlInfo().getCorrectS2SQL());
|
||||||
if (CollectionUtils.isEmpty(sqlGroupByFields)) {
|
if (CollectionUtils.isEmpty(sqlGroupByFields)) {
|
||||||
return;
|
return;
|
||||||
|
|||||||
@@ -1,17 +1,21 @@
|
|||||||
package com.tencent.supersonic.chat.core.corrector;
|
package com.tencent.supersonic.chat.core.corrector;
|
||||||
|
|
||||||
import com.tencent.supersonic.chat.core.pojo.QueryContext;
|
|
||||||
import com.tencent.supersonic.chat.api.pojo.SemanticParseInfo;
|
import com.tencent.supersonic.chat.api.pojo.SemanticParseInfo;
|
||||||
import com.tencent.supersonic.chat.api.pojo.SemanticSchema;
|
import com.tencent.supersonic.chat.api.pojo.SemanticSchema;
|
||||||
import com.tencent.supersonic.common.util.jsqlparser.SqlParserAddHelper;
|
import com.tencent.supersonic.chat.core.pojo.QueryContext;
|
||||||
import com.tencent.supersonic.common.util.jsqlparser.SqlParserSelectFunctionHelper;
|
import com.tencent.supersonic.common.util.ContextUtils;
|
||||||
import com.tencent.supersonic.common.util.jsqlparser.SqlParserSelectHelper;
|
import com.tencent.supersonic.common.util.jsqlparser.SqlAddHelper;
|
||||||
|
import com.tencent.supersonic.common.util.jsqlparser.SqlSelectFunctionHelper;
|
||||||
|
import com.tencent.supersonic.common.util.jsqlparser.SqlSelectHelper;
|
||||||
|
import lombok.extern.slf4j.Slf4j;
|
||||||
|
import net.sf.jsqlparser.expression.Expression;
|
||||||
|
import org.apache.commons.lang3.StringUtils;
|
||||||
|
import org.springframework.core.env.Environment;
|
||||||
|
import org.springframework.util.CollectionUtils;
|
||||||
|
|
||||||
import java.util.List;
|
import java.util.List;
|
||||||
import java.util.Set;
|
import java.util.Set;
|
||||||
import java.util.stream.Collectors;
|
import java.util.stream.Collectors;
|
||||||
import lombok.extern.slf4j.Slf4j;
|
|
||||||
import net.sf.jsqlparser.expression.Expression;
|
|
||||||
import org.springframework.util.CollectionUtils;
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Perform SQL corrections on the "Having" section in S2SQL.
|
* Perform SQL corrections on the "Having" section in S2SQL.
|
||||||
@@ -25,34 +29,38 @@ public class HavingCorrector extends BaseSemanticCorrector {
|
|||||||
//add aggregate to all metric
|
//add aggregate to all metric
|
||||||
addHaving(queryContext, semanticParseInfo);
|
addHaving(queryContext, semanticParseInfo);
|
||||||
|
|
||||||
//add having expression filed to select
|
//decide whether add having expression field to select
|
||||||
|
Environment environment = ContextUtils.getBean(Environment.class);
|
||||||
|
String correctorAdditionalInfo = environment.getProperty("corrector.additional.information");
|
||||||
|
if (StringUtils.isNotBlank(correctorAdditionalInfo) && Boolean.parseBoolean(correctorAdditionalInfo)) {
|
||||||
addHavingToSelect(semanticParseInfo);
|
addHavingToSelect(semanticParseInfo);
|
||||||
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
private void addHaving(QueryContext queryContext, SemanticParseInfo semanticParseInfo) {
|
private void addHaving(QueryContext queryContext, SemanticParseInfo semanticParseInfo) {
|
||||||
Set<Long> modelIds = semanticParseInfo.getModel().getModelIds();
|
Long viewId = semanticParseInfo.getView().getView();
|
||||||
|
|
||||||
SemanticSchema semanticSchema = queryContext.getSemanticSchema();
|
SemanticSchema semanticSchema = queryContext.getSemanticSchema();
|
||||||
|
|
||||||
Set<String> metrics = semanticSchema.getMetrics(modelIds).stream()
|
Set<String> metrics = semanticSchema.getMetrics(viewId).stream()
|
||||||
.map(schemaElement -> schemaElement.getName()).collect(Collectors.toSet());
|
.map(schemaElement -> schemaElement.getName()).collect(Collectors.toSet());
|
||||||
|
|
||||||
if (CollectionUtils.isEmpty(metrics)) {
|
if (CollectionUtils.isEmpty(metrics)) {
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
String havingSql = SqlParserAddHelper.addHaving(semanticParseInfo.getSqlInfo().getCorrectS2SQL(), metrics);
|
String havingSql = SqlAddHelper.addHaving(semanticParseInfo.getSqlInfo().getCorrectS2SQL(), metrics);
|
||||||
semanticParseInfo.getSqlInfo().setCorrectS2SQL(havingSql);
|
semanticParseInfo.getSqlInfo().setCorrectS2SQL(havingSql);
|
||||||
}
|
}
|
||||||
|
|
||||||
private void addHavingToSelect(SemanticParseInfo semanticParseInfo) {
|
private void addHavingToSelect(SemanticParseInfo semanticParseInfo) {
|
||||||
String correctS2SQL = semanticParseInfo.getSqlInfo().getCorrectS2SQL();
|
String correctS2SQL = semanticParseInfo.getSqlInfo().getCorrectS2SQL();
|
||||||
if (!SqlParserSelectFunctionHelper.hasAggregateFunction(correctS2SQL)) {
|
if (!SqlSelectFunctionHelper.hasAggregateFunction(correctS2SQL)) {
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
List<Expression> havingExpressionList = SqlParserSelectHelper.getHavingExpression(correctS2SQL);
|
List<Expression> havingExpressionList = SqlSelectHelper.getHavingExpression(correctS2SQL);
|
||||||
if (!CollectionUtils.isEmpty(havingExpressionList)) {
|
if (!CollectionUtils.isEmpty(havingExpressionList)) {
|
||||||
String replaceSql = SqlParserAddHelper.addFunctionToSelect(correctS2SQL, havingExpressionList);
|
String replaceSql = SqlAddHelper.addFunctionToSelect(correctS2SQL, havingExpressionList);
|
||||||
semanticParseInfo.getSqlInfo().setCorrectS2SQL(replaceSql);
|
semanticParseInfo.getSqlInfo().setCorrectS2SQL(replaceSql);
|
||||||
}
|
}
|
||||||
return;
|
return;
|
||||||
|
|||||||
@@ -1,24 +1,33 @@
|
|||||||
package com.tencent.supersonic.chat.core.corrector;
|
package com.tencent.supersonic.chat.core.corrector;
|
||||||
|
|
||||||
import com.tencent.supersonic.chat.core.pojo.QueryContext;
|
|
||||||
import com.tencent.supersonic.chat.api.pojo.SemanticParseInfo;
|
import com.tencent.supersonic.chat.api.pojo.SemanticParseInfo;
|
||||||
|
import com.tencent.supersonic.chat.api.pojo.SemanticSchema;
|
||||||
import com.tencent.supersonic.chat.api.pojo.response.SqlInfo;
|
import com.tencent.supersonic.chat.api.pojo.response.SqlInfo;
|
||||||
import com.tencent.supersonic.chat.core.parser.sql.llm.ParseResult;
|
import com.tencent.supersonic.chat.core.parser.sql.llm.ParseResult;
|
||||||
|
import com.tencent.supersonic.chat.core.pojo.QueryContext;
|
||||||
import com.tencent.supersonic.chat.core.query.llm.s2sql.LLMReq.ElementValue;
|
import com.tencent.supersonic.chat.core.query.llm.s2sql.LLMReq.ElementValue;
|
||||||
import com.tencent.supersonic.common.pojo.Constants;
|
import com.tencent.supersonic.common.pojo.Constants;
|
||||||
|
import com.tencent.supersonic.common.pojo.enums.FilterOperatorEnum;
|
||||||
|
import com.tencent.supersonic.common.pojo.enums.TimeDimensionEnum;
|
||||||
|
import com.tencent.supersonic.common.util.DateUtils;
|
||||||
import com.tencent.supersonic.common.util.JsonUtil;
|
import com.tencent.supersonic.common.util.JsonUtil;
|
||||||
import com.tencent.supersonic.common.util.jsqlparser.AggregateEnum;
|
import com.tencent.supersonic.common.util.jsqlparser.AggregateEnum;
|
||||||
import com.tencent.supersonic.common.util.jsqlparser.SqlParserReplaceHelper;
|
import com.tencent.supersonic.common.util.jsqlparser.FieldExpression;
|
||||||
|
import com.tencent.supersonic.common.util.jsqlparser.SqlRemoveHelper;
|
||||||
|
import com.tencent.supersonic.common.util.jsqlparser.SqlReplaceHelper;
|
||||||
|
import com.tencent.supersonic.common.util.jsqlparser.SqlSelectHelper;
|
||||||
|
import java.util.ArrayList;
|
||||||
import java.util.List;
|
import java.util.List;
|
||||||
import java.util.Map;
|
import java.util.Map;
|
||||||
import java.util.Objects;
|
import java.util.Objects;
|
||||||
import java.util.Set;
|
import java.util.Set;
|
||||||
import java.util.stream.Collectors;
|
import java.util.stream.Collectors;
|
||||||
import lombok.extern.slf4j.Slf4j;
|
import lombok.extern.slf4j.Slf4j;
|
||||||
|
import org.apache.commons.lang3.StringUtils;
|
||||||
import org.springframework.util.CollectionUtils;
|
import org.springframework.util.CollectionUtils;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Perform schema corrections on the Schema information in S2QL.
|
* Perform schema corrections on the Schema information in S2SQL.
|
||||||
*/
|
*/
|
||||||
@Slf4j
|
@Slf4j
|
||||||
public class SchemaCorrector extends BaseSemanticCorrector {
|
public class SchemaCorrector extends BaseSemanticCorrector {
|
||||||
@@ -26,6 +35,8 @@ public class SchemaCorrector extends BaseSemanticCorrector {
|
|||||||
@Override
|
@Override
|
||||||
public void doCorrect(QueryContext queryContext, SemanticParseInfo semanticParseInfo) {
|
public void doCorrect(QueryContext queryContext, SemanticParseInfo semanticParseInfo) {
|
||||||
|
|
||||||
|
removeFilterIfNotInLinkingValue(queryContext, semanticParseInfo);
|
||||||
|
|
||||||
correctAggFunction(semanticParseInfo);
|
correctAggFunction(semanticParseInfo);
|
||||||
|
|
||||||
replaceAlias(semanticParseInfo);
|
replaceAlias(semanticParseInfo);
|
||||||
@@ -40,20 +51,20 @@ public class SchemaCorrector extends BaseSemanticCorrector {
|
|||||||
private void correctAggFunction(SemanticParseInfo semanticParseInfo) {
|
private void correctAggFunction(SemanticParseInfo semanticParseInfo) {
|
||||||
Map<String, String> aggregateEnum = AggregateEnum.getAggregateEnum();
|
Map<String, String> aggregateEnum = AggregateEnum.getAggregateEnum();
|
||||||
SqlInfo sqlInfo = semanticParseInfo.getSqlInfo();
|
SqlInfo sqlInfo = semanticParseInfo.getSqlInfo();
|
||||||
String sql = SqlParserReplaceHelper.replaceFunction(sqlInfo.getCorrectS2SQL(), aggregateEnum);
|
String sql = SqlReplaceHelper.replaceFunction(sqlInfo.getCorrectS2SQL(), aggregateEnum);
|
||||||
sqlInfo.setCorrectS2SQL(sql);
|
sqlInfo.setCorrectS2SQL(sql);
|
||||||
}
|
}
|
||||||
|
|
||||||
private void replaceAlias(SemanticParseInfo semanticParseInfo) {
|
private void replaceAlias(SemanticParseInfo semanticParseInfo) {
|
||||||
SqlInfo sqlInfo = semanticParseInfo.getSqlInfo();
|
SqlInfo sqlInfo = semanticParseInfo.getSqlInfo();
|
||||||
String replaceAlias = SqlParserReplaceHelper.replaceAlias(sqlInfo.getCorrectS2SQL());
|
String replaceAlias = SqlReplaceHelper.replaceAlias(sqlInfo.getCorrectS2SQL());
|
||||||
sqlInfo.setCorrectS2SQL(replaceAlias);
|
sqlInfo.setCorrectS2SQL(replaceAlias);
|
||||||
}
|
}
|
||||||
|
|
||||||
private void correctFieldName(QueryContext queryContext, SemanticParseInfo semanticParseInfo) {
|
private void correctFieldName(QueryContext queryContext, SemanticParseInfo semanticParseInfo) {
|
||||||
Map<String, String> fieldNameMap = getFieldNameMap(queryContext, semanticParseInfo.getModel().getModelIds());
|
Map<String, String> fieldNameMap = getFieldNameMap(queryContext, semanticParseInfo.getViewId());
|
||||||
SqlInfo sqlInfo = semanticParseInfo.getSqlInfo();
|
SqlInfo sqlInfo = semanticParseInfo.getSqlInfo();
|
||||||
String sql = SqlParserReplaceHelper.replaceFields(sqlInfo.getCorrectS2SQL(), fieldNameMap);
|
String sql = SqlReplaceHelper.replaceFields(sqlInfo.getCorrectS2SQL(), fieldNameMap);
|
||||||
sqlInfo.setCorrectS2SQL(sql);
|
sqlInfo.setCorrectS2SQL(sql);
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -69,7 +80,7 @@ public class SchemaCorrector extends BaseSemanticCorrector {
|
|||||||
|
|
||||||
SqlInfo sqlInfo = semanticParseInfo.getSqlInfo();
|
SqlInfo sqlInfo = semanticParseInfo.getSqlInfo();
|
||||||
|
|
||||||
String sql = SqlParserReplaceHelper.replaceFieldNameByValue(sqlInfo.getCorrectS2SQL(), fieldValueToFieldNames);
|
String sql = SqlReplaceHelper.replaceFieldNameByValue(sqlInfo.getCorrectS2SQL(), fieldValueToFieldNames);
|
||||||
sqlInfo.setCorrectS2SQL(sql);
|
sqlInfo.setCorrectS2SQL(sql);
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -101,7 +112,38 @@ public class SchemaCorrector extends BaseSemanticCorrector {
|
|||||||
)));
|
)));
|
||||||
|
|
||||||
SqlInfo sqlInfo = semanticParseInfo.getSqlInfo();
|
SqlInfo sqlInfo = semanticParseInfo.getSqlInfo();
|
||||||
String sql = SqlParserReplaceHelper.replaceValue(sqlInfo.getCorrectS2SQL(), filedNameToValueMap, false);
|
String sql = SqlReplaceHelper.replaceValue(sqlInfo.getCorrectS2SQL(), filedNameToValueMap, false);
|
||||||
sqlInfo.setCorrectS2SQL(sql);
|
sqlInfo.setCorrectS2SQL(sql);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
public void removeFilterIfNotInLinkingValue(QueryContext queryContext, SemanticParseInfo semanticParseInfo) {
|
||||||
|
SqlInfo sqlInfo = semanticParseInfo.getSqlInfo();
|
||||||
|
String correctS2SQL = sqlInfo.getCorrectS2SQL();
|
||||||
|
List<FieldExpression> whereExpressionList = SqlSelectHelper.getWhereExpressions(correctS2SQL);
|
||||||
|
if (CollectionUtils.isEmpty(whereExpressionList)) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
List<ElementValue> linkingValues = getLinkingValues(semanticParseInfo);
|
||||||
|
SemanticSchema semanticSchema = queryContext.getSemanticSchema();
|
||||||
|
Set<String> dimensions = getDimensions(semanticParseInfo.getViewId(), semanticSchema);
|
||||||
|
|
||||||
|
if (CollectionUtils.isEmpty(linkingValues)) {
|
||||||
|
linkingValues = new ArrayList<>();
|
||||||
|
}
|
||||||
|
Set<String> linkingFieldNames = linkingValues.stream().map(linking -> linking.getFieldName())
|
||||||
|
.collect(Collectors.toSet());
|
||||||
|
|
||||||
|
Set<String> removeFieldNames = whereExpressionList.stream()
|
||||||
|
.filter(fieldExpression -> StringUtils.isBlank(fieldExpression.getFunction()))
|
||||||
|
.filter(fieldExpression -> !TimeDimensionEnum.containsTimeDimension(fieldExpression.getFieldName()))
|
||||||
|
.filter(fieldExpression -> FilterOperatorEnum.EQUALS.getValue().equals(fieldExpression.getOperator()))
|
||||||
|
.filter(fieldExpression -> dimensions.contains(fieldExpression.getFieldName()))
|
||||||
|
.filter(fieldExpression -> !DateUtils.isAnyDateString(fieldExpression.getFieldValue().toString()))
|
||||||
|
.filter(fieldExpression -> !linkingFieldNames.contains(fieldExpression.getFieldName()))
|
||||||
|
.map(fieldExpression -> fieldExpression.getFieldName()).collect(Collectors.toSet());
|
||||||
|
|
||||||
|
String sql = SqlRemoveHelper.removeWhereCondition(correctS2SQL, removeFieldNames);
|
||||||
|
sqlInfo.setCorrectS2SQL(sql);
|
||||||
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -2,7 +2,7 @@ package com.tencent.supersonic.chat.core.corrector;
|
|||||||
|
|
||||||
import com.tencent.supersonic.chat.core.pojo.QueryContext;
|
import com.tencent.supersonic.chat.core.pojo.QueryContext;
|
||||||
import com.tencent.supersonic.chat.api.pojo.SemanticParseInfo;
|
import com.tencent.supersonic.chat.api.pojo.SemanticParseInfo;
|
||||||
import com.tencent.supersonic.common.util.jsqlparser.SqlParserSelectHelper;
|
import com.tencent.supersonic.common.util.jsqlparser.SqlSelectHelper;
|
||||||
import java.util.List;
|
import java.util.List;
|
||||||
import lombok.extern.slf4j.Slf4j;
|
import lombok.extern.slf4j.Slf4j;
|
||||||
import org.springframework.util.CollectionUtils;
|
import org.springframework.util.CollectionUtils;
|
||||||
@@ -16,8 +16,8 @@ public class SelectCorrector extends BaseSemanticCorrector {
|
|||||||
@Override
|
@Override
|
||||||
public void doCorrect(QueryContext queryContext, SemanticParseInfo semanticParseInfo) {
|
public void doCorrect(QueryContext queryContext, SemanticParseInfo semanticParseInfo) {
|
||||||
String correctS2SQL = semanticParseInfo.getSqlInfo().getCorrectS2SQL();
|
String correctS2SQL = semanticParseInfo.getSqlInfo().getCorrectS2SQL();
|
||||||
List<String> aggregateFields = SqlParserSelectHelper.getAggregateFields(correctS2SQL);
|
List<String> aggregateFields = SqlSelectHelper.getAggregateFields(correctS2SQL);
|
||||||
List<String> selectFields = SqlParserSelectHelper.getSelectFields(correctS2SQL);
|
List<String> selectFields = SqlSelectHelper.getSelectFields(correctS2SQL);
|
||||||
// If the number of aggregated fields is equal to the number of queried fields, do not add fields to select.
|
// If the number of aggregated fields is equal to the number of queried fields, do not add fields to select.
|
||||||
if (!CollectionUtils.isEmpty(aggregateFields)
|
if (!CollectionUtils.isEmpty(aggregateFields)
|
||||||
&& !CollectionUtils.isEmpty(selectFields)
|
&& !CollectionUtils.isEmpty(selectFields)
|
||||||
|
|||||||
@@ -0,0 +1,57 @@
|
|||||||
|
package com.tencent.supersonic.chat.core.corrector;
|
||||||
|
|
||||||
|
import com.tencent.supersonic.chat.api.pojo.SemanticParseInfo;
|
||||||
|
import com.tencent.supersonic.chat.core.pojo.QueryContext;
|
||||||
|
import com.tencent.supersonic.common.util.jsqlparser.DateVisitor.DateBoundInfo;
|
||||||
|
import com.tencent.supersonic.common.util.jsqlparser.SqlAddHelper;
|
||||||
|
import com.tencent.supersonic.common.util.jsqlparser.SqlDateSelectHelper;
|
||||||
|
import com.tencent.supersonic.common.util.jsqlparser.SqlReplaceHelper;
|
||||||
|
import java.util.Objects;
|
||||||
|
import lombok.extern.slf4j.Slf4j;
|
||||||
|
import net.sf.jsqlparser.JSQLParserException;
|
||||||
|
import net.sf.jsqlparser.parser.CCJSqlParserUtil;
|
||||||
|
import org.apache.commons.lang3.StringUtils;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Perform SQL corrections on the time in S2SQL.
|
||||||
|
*/
|
||||||
|
@Slf4j
|
||||||
|
public class TimeCorrector extends BaseSemanticCorrector {
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public void doCorrect(QueryContext queryContext, SemanticParseInfo semanticParseInfo) {
|
||||||
|
|
||||||
|
parserDateDiffFunction(semanticParseInfo);
|
||||||
|
|
||||||
|
addLowerBoundDate(semanticParseInfo);
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
private void addLowerBoundDate(SemanticParseInfo semanticParseInfo) {
|
||||||
|
String correctS2SQL = semanticParseInfo.getSqlInfo().getCorrectS2SQL();
|
||||||
|
DateBoundInfo dateBoundInfo = SqlDateSelectHelper.getDateBoundInfo(correctS2SQL);
|
||||||
|
if (Objects.isNull(dateBoundInfo)) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
if (StringUtils.isBlank(dateBoundInfo.getLowerBound())
|
||||||
|
&& StringUtils.isNotBlank(dateBoundInfo.getUpperBound())
|
||||||
|
&& StringUtils.isNotBlank(dateBoundInfo.getUpperDate())) {
|
||||||
|
String upperDate = dateBoundInfo.getUpperDate();
|
||||||
|
try {
|
||||||
|
correctS2SQL = SqlAddHelper.addParenthesisToWhere(correctS2SQL);
|
||||||
|
String condExpr = dateBoundInfo.getColumName() + " >= '" + upperDate + "'";
|
||||||
|
correctS2SQL = SqlAddHelper.addWhere(correctS2SQL, CCJSqlParserUtil.parseCondExpression(condExpr));
|
||||||
|
} catch (JSQLParserException e) {
|
||||||
|
log.error("parseCondExpression", e);
|
||||||
|
}
|
||||||
|
semanticParseInfo.getSqlInfo().setCorrectS2SQL(correctS2SQL);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
private void parserDateDiffFunction(SemanticParseInfo semanticParseInfo) {
|
||||||
|
String correctS2SQL = semanticParseInfo.getSqlInfo().getCorrectS2SQL();
|
||||||
|
correctS2SQL = SqlReplaceHelper.replaceFunction(correctS2SQL);
|
||||||
|
semanticParseInfo.getSqlInfo().setCorrectS2SQL(correctS2SQL);
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
||||||
@@ -1,32 +1,33 @@
|
|||||||
package com.tencent.supersonic.chat.core.corrector;
|
package com.tencent.supersonic.chat.core.corrector;
|
||||||
|
|
||||||
import com.tencent.supersonic.chat.core.pojo.QueryContext;
|
import com.tencent.supersonic.headless.api.pojo.SchemaElement;
|
||||||
import com.tencent.supersonic.chat.api.pojo.SchemaElement;
|
import com.tencent.supersonic.headless.api.pojo.SchemaValueMap;
|
||||||
import com.tencent.supersonic.chat.api.pojo.SchemaValueMap;
|
|
||||||
import com.tencent.supersonic.chat.api.pojo.SemanticParseInfo;
|
import com.tencent.supersonic.chat.api.pojo.SemanticParseInfo;
|
||||||
import com.tencent.supersonic.chat.api.pojo.SemanticSchema;
|
import com.tencent.supersonic.chat.api.pojo.SemanticSchema;
|
||||||
import com.tencent.supersonic.chat.api.pojo.request.QueryFilters;
|
import com.tencent.supersonic.chat.api.pojo.request.QueryFilters;
|
||||||
import com.tencent.supersonic.chat.core.parser.sql.llm.S2SqlDateHelper;
|
import com.tencent.supersonic.chat.core.utils.S2SqlDateHelper;
|
||||||
|
import com.tencent.supersonic.chat.core.pojo.QueryContext;
|
||||||
import com.tencent.supersonic.common.pojo.Constants;
|
import com.tencent.supersonic.common.pojo.Constants;
|
||||||
import com.tencent.supersonic.common.pojo.enums.TimeDimensionEnum;
|
import com.tencent.supersonic.common.pojo.enums.TimeDimensionEnum;
|
||||||
import com.tencent.supersonic.common.util.StringUtil;
|
import com.tencent.supersonic.common.util.StringUtil;
|
||||||
import com.tencent.supersonic.common.util.jsqlparser.SqlParserAddHelper;
|
import com.tencent.supersonic.common.util.jsqlparser.SqlAddHelper;
|
||||||
import com.tencent.supersonic.common.util.jsqlparser.SqlParserReplaceHelper;
|
import com.tencent.supersonic.common.util.jsqlparser.SqlReplaceHelper;
|
||||||
import com.tencent.supersonic.common.util.jsqlparser.SqlParserSelectHelper;
|
import com.tencent.supersonic.common.util.jsqlparser.SqlSelectHelper;
|
||||||
import java.util.HashMap;
|
|
||||||
import java.util.List;
|
|
||||||
import java.util.Map;
|
|
||||||
import java.util.Objects;
|
|
||||||
import java.util.Set;
|
|
||||||
import java.util.stream.Collectors;
|
|
||||||
import lombok.extern.slf4j.Slf4j;
|
import lombok.extern.slf4j.Slf4j;
|
||||||
import net.sf.jsqlparser.JSQLParserException;
|
import net.sf.jsqlparser.JSQLParserException;
|
||||||
import net.sf.jsqlparser.expression.Expression;
|
import net.sf.jsqlparser.expression.Expression;
|
||||||
import net.sf.jsqlparser.parser.CCJSqlParserUtil;
|
import net.sf.jsqlparser.parser.CCJSqlParserUtil;
|
||||||
import org.apache.commons.lang3.StringUtils;
|
import org.apache.commons.lang3.StringUtils;
|
||||||
|
import org.apache.commons.lang3.tuple.Pair;
|
||||||
import org.apache.logging.log4j.util.Strings;
|
import org.apache.logging.log4j.util.Strings;
|
||||||
import org.springframework.util.CollectionUtils;
|
import org.springframework.util.CollectionUtils;
|
||||||
|
|
||||||
|
import java.util.HashMap;
|
||||||
|
import java.util.List;
|
||||||
|
import java.util.Map;
|
||||||
|
import java.util.Objects;
|
||||||
|
import java.util.stream.Collectors;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Perform SQL corrections on the "Where" section in S2SQL.
|
* Perform SQL corrections on the "Where" section in S2SQL.
|
||||||
*/
|
*/
|
||||||
@@ -38,8 +39,6 @@ public class WhereCorrector extends BaseSemanticCorrector {
|
|||||||
|
|
||||||
addDateIfNotExist(queryContext, semanticParseInfo);
|
addDateIfNotExist(queryContext, semanticParseInfo);
|
||||||
|
|
||||||
parserDateDiffFunction(semanticParseInfo);
|
|
||||||
|
|
||||||
addQueryFilter(queryContext, semanticParseInfo);
|
addQueryFilter(queryContext, semanticParseInfo);
|
||||||
|
|
||||||
updateFieldValueByTechName(queryContext, semanticParseInfo);
|
updateFieldValueByTechName(queryContext, semanticParseInfo);
|
||||||
@@ -58,26 +57,29 @@ public class WhereCorrector extends BaseSemanticCorrector {
|
|||||||
} catch (JSQLParserException e) {
|
} catch (JSQLParserException e) {
|
||||||
log.error("parseCondExpression", e);
|
log.error("parseCondExpression", e);
|
||||||
}
|
}
|
||||||
correctS2SQL = SqlParserAddHelper.addWhere(correctS2SQL, expression);
|
correctS2SQL = SqlAddHelper.addWhere(correctS2SQL, expression);
|
||||||
semanticParseInfo.getSqlInfo().setCorrectS2SQL(correctS2SQL);
|
semanticParseInfo.getSqlInfo().setCorrectS2SQL(correctS2SQL);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
private void parserDateDiffFunction(SemanticParseInfo semanticParseInfo) {
|
|
||||||
String correctS2SQL = semanticParseInfo.getSqlInfo().getCorrectS2SQL();
|
|
||||||
correctS2SQL = SqlParserReplaceHelper.replaceFunction(correctS2SQL);
|
|
||||||
semanticParseInfo.getSqlInfo().setCorrectS2SQL(correctS2SQL);
|
|
||||||
}
|
|
||||||
|
|
||||||
private void addDateIfNotExist(QueryContext queryContext, SemanticParseInfo semanticParseInfo) {
|
private void addDateIfNotExist(QueryContext queryContext, SemanticParseInfo semanticParseInfo) {
|
||||||
String correctS2SQL = semanticParseInfo.getSqlInfo().getCorrectS2SQL();
|
String correctS2SQL = semanticParseInfo.getSqlInfo().getCorrectS2SQL();
|
||||||
List<String> whereFields = SqlParserSelectHelper.getWhereFields(correctS2SQL);
|
List<String> whereFields = SqlSelectHelper.getWhereFields(correctS2SQL);
|
||||||
if (CollectionUtils.isEmpty(whereFields) || !TimeDimensionEnum.containsZhTimeDimension(whereFields)) {
|
if (CollectionUtils.isEmpty(whereFields) || !TimeDimensionEnum.containsZhTimeDimension(whereFields)) {
|
||||||
String currentDate = S2SqlDateHelper.getReferenceDate(queryContext, semanticParseInfo.getModelId());
|
Pair<String, String> startEndDate = S2SqlDateHelper.getStartEndDate(queryContext,
|
||||||
if (StringUtils.isNotBlank(currentDate)) {
|
semanticParseInfo.getViewId(), semanticParseInfo.getQueryType());
|
||||||
correctS2SQL = SqlParserAddHelper.addParenthesisToWhere(correctS2SQL);
|
if (StringUtils.isNotBlank(startEndDate.getLeft())
|
||||||
correctS2SQL = SqlParserAddHelper.addWhere(
|
&& StringUtils.isNotBlank(startEndDate.getRight())) {
|
||||||
correctS2SQL, TimeDimensionEnum.DAY.getChName(), currentDate);
|
correctS2SQL = SqlAddHelper.addParenthesisToWhere(correctS2SQL);
|
||||||
|
String dateChName = TimeDimensionEnum.DAY.getChName();
|
||||||
|
String condExpr = String.format(" ( %s >= '%s' and %s <= '%s' )", dateChName,
|
||||||
|
startEndDate.getLeft(), dateChName, startEndDate.getRight());
|
||||||
|
try {
|
||||||
|
Expression expression = CCJSqlParserUtil.parseCondExpression(condExpr);
|
||||||
|
correctS2SQL = SqlAddHelper.addWhere(correctS2SQL, expression);
|
||||||
|
} catch (JSQLParserException e) {
|
||||||
|
log.error("parseCondExpression:{}", e);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
semanticParseInfo.getSqlInfo().setCorrectS2SQL(correctS2SQL);
|
semanticParseInfo.getSqlInfo().setCorrectS2SQL(correctS2SQL);
|
||||||
@@ -99,15 +101,15 @@ public class WhereCorrector extends BaseSemanticCorrector {
|
|||||||
|
|
||||||
private void updateFieldValueByTechName(QueryContext queryContext, SemanticParseInfo semanticParseInfo) {
|
private void updateFieldValueByTechName(QueryContext queryContext, SemanticParseInfo semanticParseInfo) {
|
||||||
SemanticSchema semanticSchema = queryContext.getSemanticSchema();
|
SemanticSchema semanticSchema = queryContext.getSemanticSchema();
|
||||||
Set<Long> modelIds = semanticParseInfo.getModel().getModelIds();
|
Long viewId = semanticParseInfo.getViewId();
|
||||||
List<SchemaElement> dimensions = semanticSchema.getDimensions(modelIds);
|
List<SchemaElement> dimensions = semanticSchema.getDimensions(viewId);
|
||||||
|
|
||||||
if (CollectionUtils.isEmpty(dimensions)) {
|
if (CollectionUtils.isEmpty(dimensions)) {
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
Map<String, Map<String, String>> aliasAndBizNameToTechName = getAliasAndBizNameToTechName(dimensions);
|
Map<String, Map<String, String>> aliasAndBizNameToTechName = getAliasAndBizNameToTechName(dimensions);
|
||||||
String correctS2SQL = SqlParserReplaceHelper.replaceValue(semanticParseInfo.getSqlInfo().getCorrectS2SQL(),
|
String correctS2SQL = SqlReplaceHelper.replaceValue(semanticParseInfo.getSqlInfo().getCorrectS2SQL(),
|
||||||
aliasAndBizNameToTechName);
|
aliasAndBizNameToTechName);
|
||||||
semanticParseInfo.getSqlInfo().setCorrectS2SQL(correctS2SQL);
|
semanticParseInfo.getSqlInfo().setCorrectS2SQL(correctS2SQL);
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,13 +0,0 @@
|
|||||||
package com.tencent.supersonic.chat.core.knowledge;
|
|
||||||
|
|
||||||
import java.util.List;
|
|
||||||
import lombok.Data;
|
|
||||||
|
|
||||||
|
|
||||||
@Data
|
|
||||||
public class DictConfig {
|
|
||||||
|
|
||||||
private Long modelId;
|
|
||||||
|
|
||||||
private List<DimValueInfo> dimValueInfoList;
|
|
||||||
}
|
|
||||||
@@ -1,18 +0,0 @@
|
|||||||
package com.tencent.supersonic.chat.core.knowledge;
|
|
||||||
|
|
||||||
|
|
||||||
import java.util.HashMap;
|
|
||||||
import java.util.List;
|
|
||||||
import java.util.Map;
|
|
||||||
|
|
||||||
import lombok.Data;
|
|
||||||
|
|
||||||
@Data
|
|
||||||
public class DimValue2DictCommand {
|
|
||||||
|
|
||||||
private DictUpdateMode updateMode;
|
|
||||||
|
|
||||||
private List<Long> modelIds;
|
|
||||||
|
|
||||||
private Map<Long, List<Long>> modelAndDimPair = new HashMap<>();
|
|
||||||
}
|
|
||||||
@@ -1,31 +0,0 @@
|
|||||||
package com.tencent.supersonic.chat.core.knowledge;
|
|
||||||
|
|
||||||
|
|
||||||
import com.tencent.supersonic.common.pojo.enums.TaskStatusEnum;
|
|
||||||
|
|
||||||
import java.util.Date;
|
|
||||||
import java.util.Set;
|
|
||||||
|
|
||||||
import lombok.Data;
|
|
||||||
|
|
||||||
@Data
|
|
||||||
public class DimValueDictInfo {
|
|
||||||
|
|
||||||
private Long id;
|
|
||||||
|
|
||||||
private String name;
|
|
||||||
|
|
||||||
private String description;
|
|
||||||
|
|
||||||
private String command;
|
|
||||||
|
|
||||||
private TaskStatusEnum status;
|
|
||||||
|
|
||||||
private String createdBy;
|
|
||||||
|
|
||||||
private Date createdAt;
|
|
||||||
|
|
||||||
private Long elapsedMs;
|
|
||||||
|
|
||||||
private Set<Long> dimIds;
|
|
||||||
}
|
|
||||||
@@ -1,26 +0,0 @@
|
|||||||
package com.tencent.supersonic.chat.core.knowledge;
|
|
||||||
|
|
||||||
|
|
||||||
import com.tencent.supersonic.common.pojo.enums.TypeEnums;
|
|
||||||
import java.util.List;
|
|
||||||
import javax.validation.constraints.NotNull;
|
|
||||||
|
|
||||||
public class DimValueInfo {
|
|
||||||
|
|
||||||
/**
|
|
||||||
* metricId、DimensionId、domainId
|
|
||||||
*/
|
|
||||||
private Long itemId;
|
|
||||||
|
|
||||||
/**
|
|
||||||
* type: IntentionTypeEnum
|
|
||||||
* temporarily only supports dimension-related information
|
|
||||||
*/
|
|
||||||
@NotNull
|
|
||||||
private TypeEnums type = TypeEnums.DIMENSION;
|
|
||||||
|
|
||||||
private List<String> blackList;
|
|
||||||
private List<String> whiteList;
|
|
||||||
private List<String> ruleList;
|
|
||||||
private Boolean isDictInfo;
|
|
||||||
}
|
|
||||||
@@ -1,21 +0,0 @@
|
|||||||
package com.tencent.supersonic.chat.core.knowledge;
|
|
||||||
|
|
||||||
import java.io.Serializable;
|
|
||||||
import lombok.Builder;
|
|
||||||
import lombok.Data;
|
|
||||||
import lombok.ToString;
|
|
||||||
|
|
||||||
@Data
|
|
||||||
@ToString
|
|
||||||
@Builder
|
|
||||||
public class ModelInfoStat implements Serializable {
|
|
||||||
|
|
||||||
private long modelCount;
|
|
||||||
|
|
||||||
private long metricModelCount;
|
|
||||||
|
|
||||||
private long dimensionModelCount;
|
|
||||||
|
|
||||||
private long dimensionValueModelCount;
|
|
||||||
|
|
||||||
}
|
|
||||||
@@ -1,67 +0,0 @@
|
|||||||
package com.tencent.supersonic.chat.core.knowledge.semantic;
|
|
||||||
|
|
||||||
import com.google.common.cache.Cache;
|
|
||||||
import com.google.common.cache.CacheBuilder;
|
|
||||||
import com.tencent.supersonic.chat.api.pojo.ModelSchema;
|
|
||||||
import com.tencent.supersonic.headless.api.response.ModelSchemaResp;
|
|
||||||
import java.util.ArrayList;
|
|
||||||
import java.util.List;
|
|
||||||
import java.util.Optional;
|
|
||||||
import java.util.concurrent.TimeUnit;
|
|
||||||
import lombok.SneakyThrows;
|
|
||||||
import lombok.extern.slf4j.Slf4j;
|
|
||||||
import org.springframework.util.CollectionUtils;
|
|
||||||
|
|
||||||
@Slf4j
|
|
||||||
public abstract class BaseSemanticInterpreter implements SemanticInterpreter {
|
|
||||||
|
|
||||||
protected final Cache<String, List<ModelSchemaResp>> modelSchemaCache =
|
|
||||||
CacheBuilder.newBuilder().expireAfterWrite(10, TimeUnit.SECONDS).build();
|
|
||||||
|
|
||||||
@SneakyThrows
|
|
||||||
public List<ModelSchemaResp> fetchModelSchema(List<Long> ids, Boolean cacheEnable) {
|
|
||||||
if (cacheEnable) {
|
|
||||||
return modelSchemaCache.get(String.valueOf(ids), () -> {
|
|
||||||
List<ModelSchemaResp> data = doFetchModelSchema(ids);
|
|
||||||
modelSchemaCache.put(String.valueOf(ids), data);
|
|
||||||
return data;
|
|
||||||
});
|
|
||||||
}
|
|
||||||
List<ModelSchemaResp> data = doFetchModelSchema(ids);
|
|
||||||
return data;
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public ModelSchema getModelSchema(Long model, Boolean cacheEnable) {
|
|
||||||
List<Long> ids = new ArrayList<>();
|
|
||||||
ids.add(model);
|
|
||||||
List<ModelSchemaResp> modelSchemaResps = fetchModelSchema(ids, cacheEnable);
|
|
||||||
if (!CollectionUtils.isEmpty(modelSchemaResps)) {
|
|
||||||
Optional<ModelSchemaResp> modelSchemaResp = modelSchemaResps.stream()
|
|
||||||
.filter(d -> d.getId().equals(model)).findFirst();
|
|
||||||
if (modelSchemaResp.isPresent()) {
|
|
||||||
ModelSchemaResp modelSchema = modelSchemaResp.get();
|
|
||||||
return ModelSchemaBuilder.build(modelSchema);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return null;
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public List<ModelSchema> getModelSchema() {
|
|
||||||
return getModelSchema(new ArrayList<>());
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public List<ModelSchema> getModelSchema(List<Long> ids) {
|
|
||||||
List<ModelSchema> domainSchemaList = new ArrayList<>();
|
|
||||||
|
|
||||||
for (ModelSchemaResp resp : fetchModelSchema(ids, true)) {
|
|
||||||
domainSchemaList.add(ModelSchemaBuilder.build(resp));
|
|
||||||
}
|
|
||||||
|
|
||||||
return domainSchemaList;
|
|
||||||
}
|
|
||||||
|
|
||||||
protected abstract List<ModelSchemaResp> doFetchModelSchema(List<Long> ids);
|
|
||||||
}
|
|
||||||
@@ -1,63 +0,0 @@
|
|||||||
package com.tencent.supersonic.chat.core.knowledge.semantic;
|
|
||||||
|
|
||||||
import com.github.pagehelper.PageInfo;
|
|
||||||
import com.tencent.supersonic.auth.api.authentication.pojo.User;
|
|
||||||
import com.tencent.supersonic.chat.api.pojo.ModelSchema;
|
|
||||||
import com.tencent.supersonic.common.pojo.enums.AuthType;
|
|
||||||
import com.tencent.supersonic.headless.api.request.PageDimensionReq;
|
|
||||||
import com.tencent.supersonic.headless.api.request.PageMetricReq;
|
|
||||||
import com.tencent.supersonic.headless.api.response.DomainResp;
|
|
||||||
import com.tencent.supersonic.headless.api.response.DimensionResp;
|
|
||||||
import com.tencent.supersonic.headless.api.response.ExplainResp;
|
|
||||||
import com.tencent.supersonic.headless.api.response.MetricResp;
|
|
||||||
import com.tencent.supersonic.headless.api.response.ModelResp;
|
|
||||||
import com.tencent.supersonic.headless.api.response.ModelSchemaResp;
|
|
||||||
import com.tencent.supersonic.headless.api.response.SemanticQueryResp;
|
|
||||||
import com.tencent.supersonic.headless.api.request.ExplainSqlReq;
|
|
||||||
import com.tencent.supersonic.headless.api.request.QueryDimValueReq;
|
|
||||||
import com.tencent.supersonic.headless.api.request.QuerySqlReq;
|
|
||||||
import com.tencent.supersonic.headless.api.request.QueryMultiStructReq;
|
|
||||||
import com.tencent.supersonic.headless.api.request.QueryStructReq;
|
|
||||||
|
|
||||||
import java.util.List;
|
|
||||||
|
|
||||||
/**
|
|
||||||
* A semantic layer provides a simplified and consistent view of data from multiple sources.
|
|
||||||
* It abstracts away the complexity of the underlying data sources and provides a unified view
|
|
||||||
* of the data that is easier to understand and use.
|
|
||||||
* <p>
|
|
||||||
* The interface defines methods for getting metadata as well as querying data in the semantic layer.
|
|
||||||
* Implementations of this interface should provide concrete implementations that interact with the
|
|
||||||
* underlying data sources and return results in a consistent format. Or it can be implemented
|
|
||||||
* as proxy to a remote semantic service.
|
|
||||||
* </p>
|
|
||||||
*/
|
|
||||||
public interface SemanticInterpreter {
|
|
||||||
|
|
||||||
SemanticQueryResp queryByStruct(QueryStructReq queryStructReq, User user);
|
|
||||||
|
|
||||||
SemanticQueryResp queryByMultiStruct(QueryMultiStructReq queryMultiStructReq, User user);
|
|
||||||
|
|
||||||
SemanticQueryResp queryByS2SQL(QuerySqlReq querySQLReq, User user);
|
|
||||||
|
|
||||||
SemanticQueryResp queryDimValue(QueryDimValueReq queryDimValueReq, User user);
|
|
||||||
|
|
||||||
List<ModelSchema> getModelSchema();
|
|
||||||
|
|
||||||
List<ModelSchema> getModelSchema(List<Long> ids);
|
|
||||||
|
|
||||||
ModelSchema getModelSchema(Long model, Boolean cacheEnable);
|
|
||||||
|
|
||||||
PageInfo<DimensionResp> getDimensionPage(PageDimensionReq pageDimensionReq);
|
|
||||||
|
|
||||||
PageInfo<MetricResp> getMetricPage(PageMetricReq pageDimensionReq, User user);
|
|
||||||
|
|
||||||
List<DomainResp> getDomainList(User user);
|
|
||||||
|
|
||||||
List<ModelResp> getModelList(AuthType authType, Long domainId, User user);
|
|
||||||
|
|
||||||
<T> ExplainResp explain(ExplainSqlReq<T> explainSqlReq, User user) throws Exception;
|
|
||||||
|
|
||||||
List<ModelSchemaResp> fetchModelSchema(List<Long> ids, Boolean cacheEnable);
|
|
||||||
|
|
||||||
}
|
|
||||||
@@ -1,21 +1,22 @@
|
|||||||
package com.tencent.supersonic.chat.core.mapper;
|
package com.tencent.supersonic.chat.core.mapper;
|
||||||
|
|
||||||
import com.tencent.supersonic.chat.api.pojo.ModelSchema;
|
import com.tencent.supersonic.headless.api.pojo.SchemaElement;
|
||||||
import com.tencent.supersonic.chat.core.pojo.QueryContext;
|
|
||||||
import com.tencent.supersonic.chat.api.pojo.SchemaElement;
|
|
||||||
import com.tencent.supersonic.chat.api.pojo.SchemaElementMatch;
|
import com.tencent.supersonic.chat.api.pojo.SchemaElementMatch;
|
||||||
import com.tencent.supersonic.chat.api.pojo.SchemaElementType;
|
import com.tencent.supersonic.headless.api.pojo.SchemaElementType;
|
||||||
import com.tencent.supersonic.chat.api.pojo.SchemaMapInfo;
|
import com.tencent.supersonic.chat.api.pojo.SchemaMapInfo;
|
||||||
import com.tencent.supersonic.chat.api.pojo.SemanticSchema;
|
import com.tencent.supersonic.chat.api.pojo.SemanticSchema;
|
||||||
|
import com.tencent.supersonic.chat.api.pojo.ViewSchema;
|
||||||
|
import com.tencent.supersonic.chat.core.pojo.QueryContext;
|
||||||
|
import lombok.extern.slf4j.Slf4j;
|
||||||
|
import org.apache.commons.lang3.StringUtils;
|
||||||
|
import org.springframework.beans.BeanUtils;
|
||||||
|
|
||||||
import java.util.ArrayList;
|
import java.util.ArrayList;
|
||||||
import java.util.List;
|
import java.util.List;
|
||||||
import java.util.Map;
|
import java.util.Map;
|
||||||
import java.util.Objects;
|
import java.util.Objects;
|
||||||
import java.util.concurrent.atomic.AtomicBoolean;
|
import java.util.concurrent.atomic.AtomicBoolean;
|
||||||
import java.util.stream.Collectors;
|
import java.util.stream.Collectors;
|
||||||
import lombok.extern.slf4j.Slf4j;
|
|
||||||
import org.apache.commons.lang3.StringUtils;
|
|
||||||
import org.springframework.beans.BeanUtils;
|
|
||||||
|
|
||||||
@Slf4j
|
@Slf4j
|
||||||
public abstract class BaseMapper implements SchemaMapper {
|
public abstract class BaseMapper implements SchemaMapper {
|
||||||
@@ -25,7 +26,7 @@ public abstract class BaseMapper implements SchemaMapper {
|
|||||||
|
|
||||||
String simpleName = this.getClass().getSimpleName();
|
String simpleName = this.getClass().getSimpleName();
|
||||||
long startTime = System.currentTimeMillis();
|
long startTime = System.currentTimeMillis();
|
||||||
log.info("before {},mapInfo:{}", simpleName, queryContext.getMapInfo().getModelElementMatches());
|
log.debug("before {},mapInfo:{}", simpleName, queryContext.getMapInfo().getViewElementMatches());
|
||||||
|
|
||||||
try {
|
try {
|
||||||
doMap(queryContext);
|
doMap(queryContext);
|
||||||
@@ -34,13 +35,13 @@ public abstract class BaseMapper implements SchemaMapper {
|
|||||||
}
|
}
|
||||||
|
|
||||||
long cost = System.currentTimeMillis() - startTime;
|
long cost = System.currentTimeMillis() - startTime;
|
||||||
log.info("after {},cost:{},mapInfo:{}", simpleName, cost, queryContext.getMapInfo().getModelElementMatches());
|
log.debug("after {},cost:{},mapInfo:{}", simpleName, cost, queryContext.getMapInfo().getViewElementMatches());
|
||||||
}
|
}
|
||||||
|
|
||||||
public abstract void doMap(QueryContext queryContext);
|
public abstract void doMap(QueryContext queryContext);
|
||||||
|
|
||||||
public void addToSchemaMap(SchemaMapInfo schemaMap, Long modelId, SchemaElementMatch newElementMatch) {
|
public void addToSchemaMap(SchemaMapInfo schemaMap, Long modelId, SchemaElementMatch newElementMatch) {
|
||||||
Map<Long, List<SchemaElementMatch>> modelElementMatches = schemaMap.getModelElementMatches();
|
Map<Long, List<SchemaElementMatch>> modelElementMatches = schemaMap.getViewElementMatches();
|
||||||
List<SchemaElementMatch> schemaElementMatches = modelElementMatches.putIfAbsent(modelId, new ArrayList<>());
|
List<SchemaElementMatch> schemaElementMatches = modelElementMatches.putIfAbsent(modelId, new ArrayList<>());
|
||||||
if (schemaElementMatches == null) {
|
if (schemaElementMatches == null) {
|
||||||
schemaElementMatches = modelElementMatches.get(modelId);
|
schemaElementMatches = modelElementMatches.get(modelId);
|
||||||
@@ -66,14 +67,14 @@ public abstract class BaseMapper implements SchemaMapper {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
public SchemaElement getSchemaElement(Long modelId, SchemaElementType elementType, Long elementID,
|
public SchemaElement getSchemaElement(Long viewId, SchemaElementType elementType, Long elementID,
|
||||||
SemanticSchema semanticSchema) {
|
SemanticSchema semanticSchema) {
|
||||||
SchemaElement element = new SchemaElement();
|
SchemaElement element = new SchemaElement();
|
||||||
ModelSchema modelSchema = semanticSchema.getModelSchemaMap().get(modelId);
|
ViewSchema viewSchema = semanticSchema.getViewSchemaMap().get(viewId);
|
||||||
if (Objects.isNull(modelSchema)) {
|
if (Objects.isNull(viewSchema)) {
|
||||||
return null;
|
return null;
|
||||||
}
|
}
|
||||||
SchemaElement elementDb = modelSchema.getElement(elementType, elementID);
|
SchemaElement elementDb = viewSchema.getElement(elementType, elementID);
|
||||||
if (Objects.isNull(elementDb)) {
|
if (Objects.isNull(elementDb)) {
|
||||||
log.info("element is null, elementType:{},elementID:{}", elementType, elementID);
|
log.info("element is null, elementType:{},elementID:{}", elementType, elementID);
|
||||||
return null;
|
return null;
|
||||||
|
|||||||
@@ -1,8 +1,8 @@
|
|||||||
package com.tencent.supersonic.chat.core.mapper;
|
package com.tencent.supersonic.chat.core.mapper;
|
||||||
|
|
||||||
import com.hankcs.hanlp.seg.common.Term;
|
|
||||||
import com.tencent.supersonic.chat.core.pojo.QueryContext;
|
import com.tencent.supersonic.chat.core.pojo.QueryContext;
|
||||||
import com.tencent.supersonic.chat.core.utils.NatureHelper;
|
import com.tencent.supersonic.headless.api.pojo.response.S2Term;
|
||||||
|
import com.tencent.supersonic.headless.core.knowledge.helper.NatureHelper;
|
||||||
import java.util.ArrayList;
|
import java.util.ArrayList;
|
||||||
import java.util.Comparator;
|
import java.util.Comparator;
|
||||||
import java.util.HashMap;
|
import java.util.HashMap;
|
||||||
@@ -27,22 +27,23 @@ public abstract class BaseMatchStrategy<T> implements MatchStrategy<T> {
|
|||||||
private MapperHelper mapperHelper;
|
private MapperHelper mapperHelper;
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public Map<MatchText, List<T>> match(QueryContext queryContext, List<Term> terms, Set<Long> detectModelIds) {
|
public Map<MatchText, List<T>> match(QueryContext queryContext, List<S2Term> terms,
|
||||||
|
Set<Long> detectViewIds) {
|
||||||
String text = queryContext.getQueryText();
|
String text = queryContext.getQueryText();
|
||||||
if (Objects.isNull(terms) || StringUtils.isEmpty(text)) {
|
if (Objects.isNull(terms) || StringUtils.isEmpty(text)) {
|
||||||
return null;
|
return null;
|
||||||
}
|
}
|
||||||
|
|
||||||
log.debug("terms:{},,detectModelIds:{}", terms, detectModelIds);
|
log.debug("terms:{},,detectViewIds:{}", terms, detectViewIds);
|
||||||
|
|
||||||
List<T> detects = detect(queryContext, terms, detectModelIds);
|
List<T> detects = detect(queryContext, terms, detectViewIds);
|
||||||
Map<MatchText, List<T>> result = new HashMap<>();
|
Map<MatchText, List<T>> result = new HashMap<>();
|
||||||
|
|
||||||
result.put(MatchText.builder().regText(text).detectSegment(text).build(), detects);
|
result.put(MatchText.builder().regText(text).detectSegment(text).build(), detects);
|
||||||
return result;
|
return result;
|
||||||
}
|
}
|
||||||
|
|
||||||
public List<T> detect(QueryContext queryContext, List<Term> terms, Set<Long> detectModelIds) {
|
public List<T> detect(QueryContext queryContext, List<S2Term> terms, Set<Long> detectViewIds) {
|
||||||
Map<Integer, Integer> regOffsetToLength = getRegOffsetToLength(terms);
|
Map<Integer, Integer> regOffsetToLength = getRegOffsetToLength(terms);
|
||||||
String text = queryContext.getQueryText();
|
String text = queryContext.getQueryText();
|
||||||
Set<T> results = new HashSet<>();
|
Set<T> results = new HashSet<>();
|
||||||
@@ -55,25 +56,26 @@ public abstract class BaseMatchStrategy<T> implements MatchStrategy<T> {
|
|||||||
int offset = mapperHelper.getStepOffset(terms, startIndex);
|
int offset = mapperHelper.getStepOffset(terms, startIndex);
|
||||||
index = mapperHelper.getStepIndex(regOffsetToLength, index);
|
index = mapperHelper.getStepIndex(regOffsetToLength, index);
|
||||||
if (index <= text.length()) {
|
if (index <= text.length()) {
|
||||||
String detectSegment = text.substring(startIndex, index);
|
String detectSegment = text.substring(startIndex, index).trim();
|
||||||
detectSegments.add(detectSegment);
|
detectSegments.add(detectSegment);
|
||||||
detectByStep(queryContext, results, detectModelIds, startIndex, index, offset);
|
detectByStep(queryContext, results, detectViewIds, detectSegment, offset);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
startIndex = mapperHelper.getStepIndex(regOffsetToLength, startIndex);
|
startIndex = mapperHelper.getStepIndex(regOffsetToLength, startIndex);
|
||||||
}
|
}
|
||||||
detectByBatch(queryContext, results, detectModelIds, detectSegments);
|
detectByBatch(queryContext, results, detectViewIds, detectSegments);
|
||||||
return new ArrayList<>(results);
|
return new ArrayList<>(results);
|
||||||
}
|
}
|
||||||
|
|
||||||
protected void detectByBatch(QueryContext queryContext, Set<T> results, Set<Long> detectModelIds,
|
protected void detectByBatch(QueryContext queryContext, Set<T> results, Set<Long> detectViewIds,
|
||||||
Set<String> detectSegments) {
|
Set<String> detectSegments) {
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
public Map<Integer, Integer> getRegOffsetToLength(List<Term> terms) {
|
public Map<Integer, Integer> getRegOffsetToLength(List<S2Term> terms) {
|
||||||
return terms.stream().sorted(Comparator.comparing(Term::length))
|
return terms.stream().sorted(Comparator.comparing(S2Term::length))
|
||||||
.collect(Collectors.toMap(Term::getOffset, term -> term.word.length(), (value1, value2) -> value2));
|
.collect(Collectors.toMap(S2Term::getOffset, term -> term.word.length(),
|
||||||
|
(value1, value2) -> value2));
|
||||||
}
|
}
|
||||||
|
|
||||||
public void selectResultInOneRound(Set<T> existResults, List<T> oneRoundResults) {
|
public void selectResultInOneRound(Set<T> existResults, List<T> oneRoundResults) {
|
||||||
@@ -101,10 +103,10 @@ public abstract class BaseMatchStrategy<T> implements MatchStrategy<T> {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
public List<T> getMatches(QueryContext queryContext, List<Term> terms) {
|
public List<T> getMatches(QueryContext queryContext, List<S2Term> terms) {
|
||||||
Set<Long> detectModelIds = mapperHelper.getModelIds(queryContext.getModelId(), queryContext.getAgent());
|
Set<Long> viewIds = mapperHelper.getViewIds(queryContext.getViewId(), queryContext.getAgent());
|
||||||
terms = filterByModelIds(terms, detectModelIds);
|
terms = filterByViewId(terms, viewIds);
|
||||||
Map<MatchText, List<T>> matchResult = match(queryContext, terms, detectModelIds);
|
Map<MatchText, List<T>> matchResult = match(queryContext, terms, viewIds);
|
||||||
List<T> matches = new ArrayList<>();
|
List<T> matches = new ArrayList<>();
|
||||||
if (Objects.isNull(matchResult)) {
|
if (Objects.isNull(matchResult)) {
|
||||||
return matches;
|
return matches;
|
||||||
@@ -119,27 +121,27 @@ public abstract class BaseMatchStrategy<T> implements MatchStrategy<T> {
|
|||||||
return matches;
|
return matches;
|
||||||
}
|
}
|
||||||
|
|
||||||
public List<Term> filterByModelIds(List<Term> terms, Set<Long> detectModelIds) {
|
public List<S2Term> filterByViewId(List<S2Term> terms, Set<Long> viewIds) {
|
||||||
logTerms(terms);
|
logTerms(terms);
|
||||||
if (CollectionUtils.isNotEmpty(detectModelIds)) {
|
if (CollectionUtils.isNotEmpty(viewIds)) {
|
||||||
terms = terms.stream().filter(term -> {
|
terms = terms.stream().filter(term -> {
|
||||||
Long modelId = NatureHelper.getModelId(term.getNature().toString());
|
Long viewId = NatureHelper.getViewId(term.getNature().toString());
|
||||||
if (Objects.nonNull(modelId)) {
|
if (Objects.nonNull(viewId)) {
|
||||||
return detectModelIds.contains(modelId);
|
return viewIds.contains(viewId);
|
||||||
}
|
}
|
||||||
return false;
|
return false;
|
||||||
}).collect(Collectors.toList());
|
}).collect(Collectors.toList());
|
||||||
log.info("terms filter by modelIds:{}", detectModelIds);
|
log.info("terms filter by viewId:{}", viewIds);
|
||||||
logTerms(terms);
|
logTerms(terms);
|
||||||
}
|
}
|
||||||
return terms;
|
return terms;
|
||||||
}
|
}
|
||||||
|
|
||||||
public void logTerms(List<Term> terms) {
|
public void logTerms(List<S2Term> terms) {
|
||||||
if (CollectionUtils.isEmpty(terms)) {
|
if (CollectionUtils.isEmpty(terms)) {
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
for (Term term : terms) {
|
for (S2Term term : terms) {
|
||||||
log.debug("word:{},nature:{},frequency:{}", term.word, term.nature.toString(), term.getFrequency());
|
log.debug("word:{},nature:{},frequency:{}", term.word, term.nature.toString(), term.getFrequency());
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -148,7 +150,7 @@ public abstract class BaseMatchStrategy<T> implements MatchStrategy<T> {
|
|||||||
|
|
||||||
public abstract String getMapKey(T a);
|
public abstract String getMapKey(T a);
|
||||||
|
|
||||||
public abstract void detectByStep(QueryContext queryContext, Set<T> results,
|
public abstract void detectByStep(QueryContext queryContext, Set<T> existResults, Set<Long> detectViewIds,
|
||||||
Set<Long> detectModelIds, Integer startIndex, Integer index, int offset);
|
String detectSegment, int offset);
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,12 +1,18 @@
|
|||||||
package com.tencent.supersonic.chat.core.mapper;
|
package com.tencent.supersonic.chat.core.mapper;
|
||||||
|
|
||||||
import com.hankcs.hanlp.seg.common.Term;
|
import com.tencent.supersonic.headless.api.pojo.SchemaElement;
|
||||||
import com.tencent.supersonic.chat.core.pojo.QueryContext;
|
|
||||||
import com.tencent.supersonic.chat.api.pojo.SchemaElement;
|
|
||||||
import com.tencent.supersonic.chat.api.pojo.SchemaElementMatch;
|
import com.tencent.supersonic.chat.api.pojo.SchemaElementMatch;
|
||||||
import com.tencent.supersonic.chat.core.config.OptimizationConfig;
|
import com.tencent.supersonic.chat.core.config.OptimizationConfig;
|
||||||
import com.tencent.supersonic.chat.core.knowledge.DatabaseMapResult;
|
import com.tencent.supersonic.headless.api.pojo.response.S2Term;
|
||||||
|
import com.tencent.supersonic.headless.core.knowledge.DatabaseMapResult;
|
||||||
|
import com.tencent.supersonic.chat.core.pojo.QueryContext;
|
||||||
import com.tencent.supersonic.common.pojo.Constants;
|
import com.tencent.supersonic.common.pojo.Constants;
|
||||||
|
import lombok.extern.slf4j.Slf4j;
|
||||||
|
import org.apache.commons.lang3.StringUtils;
|
||||||
|
import org.springframework.beans.factory.annotation.Autowired;
|
||||||
|
import org.springframework.stereotype.Service;
|
||||||
|
import org.springframework.util.CollectionUtils;
|
||||||
|
|
||||||
import java.util.ArrayList;
|
import java.util.ArrayList;
|
||||||
import java.util.HashSet;
|
import java.util.HashSet;
|
||||||
import java.util.List;
|
import java.util.List;
|
||||||
@@ -14,11 +20,6 @@ import java.util.Map;
|
|||||||
import java.util.Map.Entry;
|
import java.util.Map.Entry;
|
||||||
import java.util.Set;
|
import java.util.Set;
|
||||||
import java.util.stream.Collectors;
|
import java.util.stream.Collectors;
|
||||||
import lombok.extern.slf4j.Slf4j;
|
|
||||||
import org.apache.commons.lang3.StringUtils;
|
|
||||||
import org.springframework.beans.factory.annotation.Autowired;
|
|
||||||
import org.springframework.stereotype.Service;
|
|
||||||
import org.springframework.util.CollectionUtils;
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* DatabaseMatchStrategy uses SQL LIKE operator to match schema elements.
|
* DatabaseMatchStrategy uses SQL LIKE operator to match schema elements.
|
||||||
@@ -35,10 +36,10 @@ public class DatabaseMatchStrategy extends BaseMatchStrategy<DatabaseMapResult>
|
|||||||
private List<SchemaElement> allElements;
|
private List<SchemaElement> allElements;
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public Map<MatchText, List<DatabaseMapResult>> match(QueryContext queryContext, List<Term> terms,
|
public Map<MatchText, List<DatabaseMapResult>> match(QueryContext queryContext, List<S2Term> terms,
|
||||||
Set<Long> detectModelIds) {
|
Set<Long> detectViewIds) {
|
||||||
this.allElements = getSchemaElements(queryContext);
|
this.allElements = getSchemaElements(queryContext);
|
||||||
return super.match(queryContext, terms, detectModelIds);
|
return super.match(queryContext, terms, detectViewIds);
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
@@ -53,16 +54,13 @@ public class DatabaseMatchStrategy extends BaseMatchStrategy<DatabaseMapResult>
|
|||||||
+ Constants.UNDERLINE + a.getSchemaElement().getName();
|
+ Constants.UNDERLINE + a.getSchemaElement().getName();
|
||||||
}
|
}
|
||||||
|
|
||||||
public void detectByStep(QueryContext queryContext, Set<DatabaseMapResult> existResults, Set<Long> detectModelIds,
|
public void detectByStep(QueryContext queryContext, Set<DatabaseMapResult> existResults, Set<Long> detectViewIds,
|
||||||
Integer startIndex, Integer index, int offset) {
|
String detectSegment, int offset) {
|
||||||
String detectSegment = queryContext.getQueryText().substring(startIndex, index);
|
|
||||||
if (StringUtils.isBlank(detectSegment)) {
|
if (StringUtils.isBlank(detectSegment)) {
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
Set<Long> modelIds = mapperHelper.getModelIds(queryContext.getModelId(), queryContext.getAgent());
|
|
||||||
|
|
||||||
Double metricDimensionThresholdConfig = getThreshold(queryContext);
|
Double metricDimensionThresholdConfig = getThreshold(queryContext);
|
||||||
|
|
||||||
Map<String, Set<SchemaElement>> nameToItems = getNameToItems(allElements);
|
Map<String, Set<SchemaElement>> nameToItems = getNameToItems(allElements);
|
||||||
|
|
||||||
for (Entry<String, Set<SchemaElement>> entry : nameToItems.entrySet()) {
|
for (Entry<String, Set<SchemaElement>> entry : nameToItems.entrySet()) {
|
||||||
@@ -72,9 +70,9 @@ public class DatabaseMatchStrategy extends BaseMatchStrategy<DatabaseMapResult>
|
|||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
Set<SchemaElement> schemaElements = entry.getValue();
|
Set<SchemaElement> schemaElements = entry.getValue();
|
||||||
if (!CollectionUtils.isEmpty(modelIds)) {
|
if (!CollectionUtils.isEmpty(detectViewIds)) {
|
||||||
schemaElements = schemaElements.stream()
|
schemaElements = schemaElements.stream()
|
||||||
.filter(schemaElement -> modelIds.contains(schemaElement.getModel()))
|
.filter(schemaElement -> detectViewIds.contains(schemaElement.getView()))
|
||||||
.collect(Collectors.toSet());
|
.collect(Collectors.toSet());
|
||||||
}
|
}
|
||||||
for (SchemaElement schemaElement : schemaElements) {
|
for (SchemaElement schemaElement : schemaElements) {
|
||||||
@@ -98,7 +96,7 @@ public class DatabaseMatchStrategy extends BaseMatchStrategy<DatabaseMapResult>
|
|||||||
Double metricDimensionThresholdConfig = optimizationConfig.getMetricDimensionThresholdConfig();
|
Double metricDimensionThresholdConfig = optimizationConfig.getMetricDimensionThresholdConfig();
|
||||||
Double metricDimensionMinThresholdConfig = optimizationConfig.getMetricDimensionMinThresholdConfig();
|
Double metricDimensionMinThresholdConfig = optimizationConfig.getMetricDimensionMinThresholdConfig();
|
||||||
|
|
||||||
Map<Long, List<SchemaElementMatch>> modelElementMatches = queryContext.getMapInfo().getModelElementMatches();
|
Map<Long, List<SchemaElementMatch>> modelElementMatches = queryContext.getMapInfo().getViewElementMatches();
|
||||||
|
|
||||||
boolean existElement = modelElementMatches.entrySet().stream().anyMatch(entry -> entry.getValue().size() >= 1);
|
boolean existElement = modelElementMatches.entrySet().stream().anyMatch(entry -> entry.getValue().size() >= 1);
|
||||||
|
|
||||||
|
|||||||
@@ -1,18 +1,19 @@
|
|||||||
package com.tencent.supersonic.chat.core.mapper;
|
package com.tencent.supersonic.chat.core.mapper;
|
||||||
|
|
||||||
import com.alibaba.fastjson.JSONObject;
|
|
||||||
import com.hankcs.hanlp.seg.common.Term;
|
|
||||||
import com.tencent.supersonic.chat.core.pojo.QueryContext;
|
|
||||||
import com.tencent.supersonic.chat.api.pojo.SchemaElement;
|
|
||||||
import com.tencent.supersonic.chat.api.pojo.SchemaElementMatch;
|
import com.tencent.supersonic.chat.api.pojo.SchemaElementMatch;
|
||||||
import com.tencent.supersonic.chat.core.knowledge.EmbeddingResult;
|
import com.tencent.supersonic.chat.core.pojo.QueryContext;
|
||||||
import com.tencent.supersonic.chat.core.knowledge.builder.BaseWordBuilder;
|
|
||||||
import com.tencent.supersonic.chat.core.utils.HanlpHelper;
|
|
||||||
import com.tencent.supersonic.common.util.ContextUtils;
|
import com.tencent.supersonic.common.util.ContextUtils;
|
||||||
import com.tencent.supersonic.common.util.embedding.Retrieval;
|
import com.tencent.supersonic.common.util.embedding.Retrieval;
|
||||||
import java.util.List;
|
import com.tencent.supersonic.headless.api.pojo.SchemaElement;
|
||||||
|
import com.tencent.supersonic.headless.api.pojo.SchemaElementType;
|
||||||
|
import com.tencent.supersonic.headless.api.pojo.response.S2Term;
|
||||||
|
import com.tencent.supersonic.headless.core.knowledge.EmbeddingResult;
|
||||||
|
import com.tencent.supersonic.headless.core.knowledge.builder.BaseWordBuilder;
|
||||||
|
import com.tencent.supersonic.headless.core.knowledge.helper.HanlpHelper;
|
||||||
|
import com.tencent.supersonic.headless.server.service.KnowledgeService;
|
||||||
import lombok.extern.slf4j.Slf4j;
|
import lombok.extern.slf4j.Slf4j;
|
||||||
import org.apache.commons.lang3.StringUtils;
|
import java.util.List;
|
||||||
|
import java.util.Objects;
|
||||||
|
|
||||||
/***
|
/***
|
||||||
* A mapper that recognizes schema elements with vector embedding.
|
* A mapper that recognizes schema elements with vector embedding.
|
||||||
@@ -24,7 +25,8 @@ public class EmbeddingMapper extends BaseMapper {
|
|||||||
public void doMap(QueryContext queryContext) {
|
public void doMap(QueryContext queryContext) {
|
||||||
//1. query from embedding by queryText
|
//1. query from embedding by queryText
|
||||||
String queryText = queryContext.getQueryText();
|
String queryText = queryContext.getQueryText();
|
||||||
List<Term> terms = HanlpHelper.getTerms(queryText);
|
KnowledgeService knowledgeService = ContextUtils.getBean(KnowledgeService.class);
|
||||||
|
List<S2Term> terms = knowledgeService.getTerms(queryText);
|
||||||
|
|
||||||
EmbeddingMatchStrategy matchStrategy = ContextUtils.getBean(EmbeddingMatchStrategy.class);
|
EmbeddingMatchStrategy matchStrategy = ContextUtils.getBean(EmbeddingMatchStrategy.class);
|
||||||
List<EmbeddingResult> matchResults = matchStrategy.getMatches(queryContext, terms);
|
List<EmbeddingResult> matchResults = matchStrategy.getMatches(queryContext, terms);
|
||||||
@@ -34,16 +36,12 @@ public class EmbeddingMapper extends BaseMapper {
|
|||||||
//2. build SchemaElementMatch by info
|
//2. build SchemaElementMatch by info
|
||||||
for (EmbeddingResult matchResult : matchResults) {
|
for (EmbeddingResult matchResult : matchResults) {
|
||||||
Long elementId = Retrieval.getLongId(matchResult.getId());
|
Long elementId = Retrieval.getLongId(matchResult.getId());
|
||||||
|
Long viewId = Retrieval.getLongId(matchResult.getMetadata().get("viewId"));
|
||||||
SchemaElement schemaElement = JSONObject.parseObject(JSONObject.toJSONString(matchResult.getMetadata()),
|
if (Objects.isNull(viewId)) {
|
||||||
SchemaElement.class);
|
|
||||||
|
|
||||||
String modelIdStr = matchResult.getMetadata().get("modelId");
|
|
||||||
if (StringUtils.isBlank(modelIdStr)) {
|
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
long modelId = Long.parseLong(modelIdStr);
|
SchemaElementType elementType = SchemaElementType.valueOf(matchResult.getMetadata().get("type"));
|
||||||
schemaElement = getSchemaElement(modelId, schemaElement.getType(), elementId,
|
SchemaElement schemaElement = getSchemaElement(viewId, elementType, elementId,
|
||||||
queryContext.getSemanticSchema());
|
queryContext.getSemanticSchema());
|
||||||
if (schemaElement == null) {
|
if (schemaElement == null) {
|
||||||
continue;
|
continue;
|
||||||
@@ -56,7 +54,7 @@ public class EmbeddingMapper extends BaseMapper {
|
|||||||
.detectWord(matchResult.getDetectWord())
|
.detectWord(matchResult.getDetectWord())
|
||||||
.build();
|
.build();
|
||||||
//3. add to mapInfo
|
//3. add to mapInfo
|
||||||
addToSchemaMap(queryContext.getMapInfo(), modelId, schemaElementMatch);
|
addToSchemaMap(queryContext.getMapInfo(), viewId, schemaElementMatch);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -2,17 +2,15 @@ package com.tencent.supersonic.chat.core.mapper;
|
|||||||
|
|
||||||
import com.google.common.collect.Lists;
|
import com.google.common.collect.Lists;
|
||||||
import com.tencent.supersonic.chat.core.config.OptimizationConfig;
|
import com.tencent.supersonic.chat.core.config.OptimizationConfig;
|
||||||
import com.tencent.supersonic.chat.core.knowledge.EmbeddingResult;
|
|
||||||
import com.tencent.supersonic.chat.core.pojo.QueryContext;
|
import com.tencent.supersonic.chat.core.pojo.QueryContext;
|
||||||
import com.tencent.supersonic.common.config.EmbeddingConfig;
|
|
||||||
import com.tencent.supersonic.common.pojo.Constants;
|
import com.tencent.supersonic.common.pojo.Constants;
|
||||||
import com.tencent.supersonic.common.util.ComponentFactory;
|
|
||||||
import com.tencent.supersonic.common.util.embedding.Retrieval;
|
import com.tencent.supersonic.common.util.embedding.Retrieval;
|
||||||
import com.tencent.supersonic.common.util.embedding.RetrieveQuery;
|
import com.tencent.supersonic.common.util.embedding.RetrieveQuery;
|
||||||
import com.tencent.supersonic.common.util.embedding.RetrieveQueryResult;
|
import com.tencent.supersonic.common.util.embedding.RetrieveQueryResult;
|
||||||
import com.tencent.supersonic.common.util.embedding.S2EmbeddingStore;
|
import com.tencent.supersonic.headless.core.knowledge.EmbeddingResult;
|
||||||
|
import com.tencent.supersonic.headless.server.service.MetaEmbeddingService;
|
||||||
|
import java.util.ArrayList;
|
||||||
import java.util.Comparator;
|
import java.util.Comparator;
|
||||||
import java.util.HashMap;
|
|
||||||
import java.util.List;
|
import java.util.List;
|
||||||
import java.util.Map;
|
import java.util.Map;
|
||||||
import java.util.Set;
|
import java.util.Set;
|
||||||
@@ -36,9 +34,7 @@ public class EmbeddingMatchStrategy extends BaseMatchStrategy<EmbeddingResult> {
|
|||||||
private OptimizationConfig optimizationConfig;
|
private OptimizationConfig optimizationConfig;
|
||||||
|
|
||||||
@Autowired
|
@Autowired
|
||||||
private EmbeddingConfig embeddingConfig;
|
private MetaEmbeddingService metaEmbeddingService;
|
||||||
|
|
||||||
private S2EmbeddingStore s2EmbeddingStore = ComponentFactory.getS2EmbeddingStore();
|
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public boolean needDelete(EmbeddingResult oneRoundResult, EmbeddingResult existResult) {
|
public boolean needDelete(EmbeddingResult oneRoundResult, EmbeddingResult existResult) {
|
||||||
@@ -52,7 +48,13 @@ public class EmbeddingMatchStrategy extends BaseMatchStrategy<EmbeddingResult> {
|
|||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
protected void detectByBatch(QueryContext queryContext, Set<EmbeddingResult> results, Set<Long> detectModelIds,
|
public void detectByStep(QueryContext queryContext, Set<EmbeddingResult> existResults, Set<Long> detectViewIds,
|
||||||
|
String detectSegment, int offset) {
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
protected void detectByBatch(QueryContext queryContext, Set<EmbeddingResult> results, Set<Long> detectViewIds,
|
||||||
Set<String> detectSegments) {
|
Set<String> detectSegments) {
|
||||||
|
|
||||||
List<String> queryTextsList = detectSegments.stream()
|
List<String> queryTextsList = detectSegments.stream()
|
||||||
@@ -66,49 +68,29 @@ public class EmbeddingMatchStrategy extends BaseMatchStrategy<EmbeddingResult> {
|
|||||||
optimizationConfig.getEmbeddingMapperBatch());
|
optimizationConfig.getEmbeddingMapperBatch());
|
||||||
|
|
||||||
for (List<String> queryTextsSub : queryTextsSubList) {
|
for (List<String> queryTextsSub : queryTextsSubList) {
|
||||||
detectByQueryTextsSub(results, detectModelIds, queryTextsSub);
|
detectByQueryTextsSub(results, detectViewIds, queryTextsSub);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
private void detectByQueryTextsSub(Set<EmbeddingResult> results, Set<Long> detectModelIds,
|
private void detectByQueryTextsSub(Set<EmbeddingResult> results, Set<Long> detectViewIds,
|
||||||
List<String> queryTextsSub) {
|
List<String> queryTextsSub) {
|
||||||
int embeddingNumber = optimizationConfig.getEmbeddingMapperNumber();
|
int embeddingNumber = optimizationConfig.getEmbeddingMapperNumber();
|
||||||
Double distance = optimizationConfig.getEmbeddingMapperDistanceThreshold();
|
Double distance = optimizationConfig.getEmbeddingMapperDistanceThreshold();
|
||||||
Map<String, String> filterCondition = null;
|
|
||||||
// step1. build query params
|
// step1. build query params
|
||||||
// if only one modelId, add to filterCondition
|
RetrieveQuery retrieveQuery = RetrieveQuery.builder().queryTextsList(queryTextsSub).build();
|
||||||
if (CollectionUtils.isNotEmpty(detectModelIds) && detectModelIds.size() == 1) {
|
|
||||||
filterCondition = new HashMap<>();
|
|
||||||
filterCondition.put("modelId", detectModelIds.stream().findFirst().get().toString());
|
|
||||||
}
|
|
||||||
|
|
||||||
RetrieveQuery retrieveQuery = RetrieveQuery.builder()
|
|
||||||
.queryTextsList(queryTextsSub)
|
|
||||||
.filterCondition(filterCondition)
|
|
||||||
.queryEmbeddings(null)
|
|
||||||
.build();
|
|
||||||
// step2. retrieveQuery by detectSegment
|
// step2. retrieveQuery by detectSegment
|
||||||
List<RetrieveQueryResult> retrieveQueryResults = s2EmbeddingStore.retrieveQuery(
|
List<RetrieveQueryResult> retrieveQueryResults = metaEmbeddingService.retrieveQuery(
|
||||||
embeddingConfig.getMetaCollectionName(), retrieveQuery, embeddingNumber);
|
new ArrayList<>(detectViewIds), retrieveQuery, embeddingNumber);
|
||||||
|
|
||||||
if (CollectionUtils.isEmpty(retrieveQueryResults)) {
|
if (CollectionUtils.isEmpty(retrieveQueryResults)) {
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
// step3. build EmbeddingResults. filter by modelId
|
// step3. build EmbeddingResults
|
||||||
List<EmbeddingResult> collect = retrieveQueryResults.stream()
|
List<EmbeddingResult> collect = retrieveQueryResults.stream()
|
||||||
.map(retrieveQueryResult -> {
|
.map(retrieveQueryResult -> {
|
||||||
List<Retrieval> retrievals = retrieveQueryResult.getRetrieval();
|
List<Retrieval> retrievals = retrieveQueryResult.getRetrieval();
|
||||||
if (CollectionUtils.isNotEmpty(retrievals)) {
|
if (CollectionUtils.isNotEmpty(retrievals)) {
|
||||||
retrievals.removeIf(retrieval -> retrieval.getDistance() > distance.doubleValue());
|
retrievals.removeIf(retrieval -> retrieval.getDistance() > distance.doubleValue());
|
||||||
if (CollectionUtils.isNotEmpty(detectModelIds)) {
|
|
||||||
retrievals.removeIf(retrieval -> {
|
|
||||||
String modelIdStr = retrieval.getMetadata().get("modelId").toString();
|
|
||||||
if (StringUtils.isBlank(modelIdStr)) {
|
|
||||||
return true;
|
|
||||||
}
|
|
||||||
return detectModelIds.contains(Long.parseLong(modelIdStr));
|
|
||||||
});
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
return retrieveQueryResult;
|
return retrieveQueryResult;
|
||||||
})
|
})
|
||||||
@@ -119,6 +101,9 @@ public class EmbeddingMatchStrategy extends BaseMatchStrategy<EmbeddingResult> {
|
|||||||
BeanUtils.copyProperties(retrieval, embeddingResult);
|
BeanUtils.copyProperties(retrieval, embeddingResult);
|
||||||
embeddingResult.setDetectWord(retrieveQueryResult.getQuery());
|
embeddingResult.setDetectWord(retrieveQueryResult.getQuery());
|
||||||
embeddingResult.setName(retrieval.getQuery());
|
embeddingResult.setName(retrieval.getQuery());
|
||||||
|
Map<String, String> convertedMap = retrieval.getMetadata().entrySet().stream()
|
||||||
|
.collect(Collectors.toMap(Map.Entry::getKey, entry -> entry.getValue().toString()));
|
||||||
|
embeddingResult.setMetadata(convertedMap);
|
||||||
return embeddingResult;
|
return embeddingResult;
|
||||||
}))
|
}))
|
||||||
.collect(Collectors.toList());
|
.collect(Collectors.toList());
|
||||||
@@ -132,9 +117,4 @@ public class EmbeddingMatchStrategy extends BaseMatchStrategy<EmbeddingResult> {
|
|||||||
selectResultInOneRound(results, oneRoundResults);
|
selectResultInOneRound(results, oneRoundResults);
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
|
||||||
public void detectByStep(QueryContext queryContext, Set<EmbeddingResult> existResults, Set<Long> detectModelIds,
|
|
||||||
Integer startIndex, Integer index, int offset) {
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,19 +1,20 @@
|
|||||||
package com.tencent.supersonic.chat.core.mapper;
|
package com.tencent.supersonic.chat.core.mapper;
|
||||||
|
|
||||||
|
|
||||||
import com.tencent.supersonic.chat.api.pojo.ModelSchema;
|
import com.tencent.supersonic.headless.api.pojo.SchemaElement;
|
||||||
import com.tencent.supersonic.chat.core.pojo.QueryContext;
|
|
||||||
import com.tencent.supersonic.chat.api.pojo.SchemaElement;
|
|
||||||
import com.tencent.supersonic.chat.api.pojo.SchemaElementMatch;
|
import com.tencent.supersonic.chat.api.pojo.SchemaElementMatch;
|
||||||
import com.tencent.supersonic.chat.api.pojo.SchemaElementType;
|
import com.tencent.supersonic.headless.api.pojo.SchemaElementType;
|
||||||
import com.tencent.supersonic.chat.api.pojo.SchemaMapInfo;
|
import com.tencent.supersonic.chat.api.pojo.SchemaMapInfo;
|
||||||
import com.tencent.supersonic.chat.api.pojo.SemanticSchema;
|
import com.tencent.supersonic.chat.api.pojo.SemanticSchema;
|
||||||
import java.util.List;
|
import com.tencent.supersonic.chat.api.pojo.ViewSchema;
|
||||||
import java.util.stream.Collectors;
|
import com.tencent.supersonic.chat.core.pojo.QueryContext;
|
||||||
import lombok.extern.slf4j.Slf4j;
|
import lombok.extern.slf4j.Slf4j;
|
||||||
import org.springframework.beans.BeanUtils;
|
import org.springframework.beans.BeanUtils;
|
||||||
import org.springframework.util.CollectionUtils;
|
import org.springframework.util.CollectionUtils;
|
||||||
|
|
||||||
|
import java.util.List;
|
||||||
|
import java.util.stream.Collectors;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* A mapper capable of converting the VALUE of entity dimension values into ID types.
|
* A mapper capable of converting the VALUE of entity dimension values into ID types.
|
||||||
*/
|
*/
|
||||||
@@ -23,12 +24,12 @@ public class EntityMapper extends BaseMapper {
|
|||||||
@Override
|
@Override
|
||||||
public void doMap(QueryContext queryContext) {
|
public void doMap(QueryContext queryContext) {
|
||||||
SchemaMapInfo schemaMapInfo = queryContext.getMapInfo();
|
SchemaMapInfo schemaMapInfo = queryContext.getMapInfo();
|
||||||
for (Long modelId : schemaMapInfo.getMatchedModels()) {
|
for (Long viewId : schemaMapInfo.getMatchedViewInfos()) {
|
||||||
List<SchemaElementMatch> schemaElementMatchList = schemaMapInfo.getMatchedElements(modelId);
|
List<SchemaElementMatch> schemaElementMatchList = schemaMapInfo.getMatchedElements(viewId);
|
||||||
if (CollectionUtils.isEmpty(schemaElementMatchList)) {
|
if (CollectionUtils.isEmpty(schemaElementMatchList)) {
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
SchemaElement entity = getEntity(modelId, queryContext);
|
SchemaElement entity = getEntity(viewId, queryContext);
|
||||||
if (entity == null || entity.getId() == null) {
|
if (entity == null || entity.getId() == null) {
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
@@ -64,9 +65,9 @@ public class EntityMapper extends BaseMapper {
|
|||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
|
||||||
private SchemaElement getEntity(Long modelId, QueryContext queryContext) {
|
private SchemaElement getEntity(Long viewId, QueryContext queryContext) {
|
||||||
SemanticSchema semanticSchema = queryContext.getSemanticSchema();
|
SemanticSchema semanticSchema = queryContext.getSemanticSchema();
|
||||||
ModelSchema modelSchema = semanticSchema.getModelSchemaMap().get(modelId);
|
ViewSchema modelSchema = semanticSchema.getViewSchemaMap().get(viewId);
|
||||||
if (modelSchema != null && modelSchema.getEntity() != null) {
|
if (modelSchema != null && modelSchema.getEntity() != null) {
|
||||||
return modelSchema.getEntity();
|
return modelSchema.getEntity();
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,11 +1,11 @@
|
|||||||
package com.tencent.supersonic.chat.core.mapper;
|
package com.tencent.supersonic.chat.core.mapper;
|
||||||
|
|
||||||
import com.hankcs.hanlp.seg.common.Term;
|
|
||||||
import com.tencent.supersonic.chat.core.config.OptimizationConfig;
|
import com.tencent.supersonic.chat.core.config.OptimizationConfig;
|
||||||
import com.tencent.supersonic.chat.core.knowledge.HanlpMapResult;
|
|
||||||
import com.tencent.supersonic.chat.core.knowledge.SearchService;
|
|
||||||
import com.tencent.supersonic.chat.core.pojo.QueryContext;
|
import com.tencent.supersonic.chat.core.pojo.QueryContext;
|
||||||
import com.tencent.supersonic.common.pojo.Constants;
|
import com.tencent.supersonic.common.pojo.Constants;
|
||||||
|
import com.tencent.supersonic.headless.api.pojo.response.S2Term;
|
||||||
|
import com.tencent.supersonic.headless.core.knowledge.HanlpMapResult;
|
||||||
|
import com.tencent.supersonic.headless.server.service.KnowledgeService;
|
||||||
import java.util.HashMap;
|
import java.util.HashMap;
|
||||||
import java.util.LinkedHashSet;
|
import java.util.LinkedHashSet;
|
||||||
import java.util.List;
|
import java.util.List;
|
||||||
@@ -34,17 +34,20 @@ public class HanlpDictMatchStrategy extends BaseMatchStrategy<HanlpMapResult> {
|
|||||||
@Autowired
|
@Autowired
|
||||||
private OptimizationConfig optimizationConfig;
|
private OptimizationConfig optimizationConfig;
|
||||||
|
|
||||||
|
@Autowired
|
||||||
|
private KnowledgeService knowledgeService;
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public Map<MatchText, List<HanlpMapResult>> match(QueryContext queryContext, List<Term> terms,
|
public Map<MatchText, List<HanlpMapResult>> match(QueryContext queryContext, List<S2Term> terms,
|
||||||
Set<Long> detectModelIds) {
|
Set<Long> detectViewIds) {
|
||||||
String text = queryContext.getQueryText();
|
String text = queryContext.getQueryText();
|
||||||
if (Objects.isNull(terms) || StringUtils.isEmpty(text)) {
|
if (Objects.isNull(terms) || StringUtils.isEmpty(text)) {
|
||||||
return null;
|
return null;
|
||||||
}
|
}
|
||||||
|
|
||||||
log.debug("retryCount:{},terms:{},,detectModelIds:{}", terms, detectModelIds);
|
log.debug("retryCount:{},terms:{},,detectModelIds:{}", terms, detectViewIds);
|
||||||
|
|
||||||
List<HanlpMapResult> detects = detect(queryContext, terms, detectModelIds);
|
List<HanlpMapResult> detects = detect(queryContext, terms, detectViewIds);
|
||||||
Map<MatchText, List<HanlpMapResult>> result = new HashMap<>();
|
Map<MatchText, List<HanlpMapResult>> result = new HashMap<>();
|
||||||
|
|
||||||
result.put(MatchText.builder().regText(text).detectSegment(text).build(), detects);
|
result.put(MatchText.builder().regText(text).detectSegment(text).build(), detects);
|
||||||
@@ -57,20 +60,15 @@ public class HanlpDictMatchStrategy extends BaseMatchStrategy<HanlpMapResult> {
|
|||||||
&& existResult.getDetectWord().length() < oneRoundResult.getDetectWord().length();
|
&& existResult.getDetectWord().length() < oneRoundResult.getDetectWord().length();
|
||||||
}
|
}
|
||||||
|
|
||||||
public void detectByStep(QueryContext queryContext, Set<HanlpMapResult> existResults, Set<Long> detectModelIds,
|
public void detectByStep(QueryContext queryContext, Set<HanlpMapResult> existResults, Set<Long> detectViewIds,
|
||||||
Integer startIndex, Integer index, int offset) {
|
String detectSegment, int offset) {
|
||||||
String text = queryContext.getQueryText();
|
|
||||||
Integer agentId = queryContext.getAgentId();
|
|
||||||
String detectSegment = text.substring(startIndex, index);
|
|
||||||
|
|
||||||
// step1. pre search
|
// step1. pre search
|
||||||
Integer oneDetectionMaxSize = optimizationConfig.getOneDetectionMaxSize();
|
Integer oneDetectionMaxSize = optimizationConfig.getOneDetectionMaxSize();
|
||||||
LinkedHashSet<HanlpMapResult> hanlpMapResults = SearchService.prefixSearch(detectSegment, oneDetectionMaxSize,
|
LinkedHashSet<HanlpMapResult> hanlpMapResults = knowledgeService.prefixSearch(detectSegment,
|
||||||
agentId, detectModelIds).stream().collect(Collectors.toCollection(LinkedHashSet::new));
|
oneDetectionMaxSize, detectViewIds).stream().collect(Collectors.toCollection(LinkedHashSet::new));
|
||||||
// step2. suffix search
|
// step2. suffix search
|
||||||
LinkedHashSet<HanlpMapResult> suffixHanlpMapResults = SearchService.suffixSearch(detectSegment,
|
LinkedHashSet<HanlpMapResult> suffixHanlpMapResults = knowledgeService.suffixSearch(detectSegment,
|
||||||
oneDetectionMaxSize, agentId, detectModelIds).stream()
|
oneDetectionMaxSize, detectViewIds).stream().collect(Collectors.toCollection(LinkedHashSet::new));
|
||||||
.collect(Collectors.toCollection(LinkedHashSet::new));
|
|
||||||
|
|
||||||
hanlpMapResults.addAll(suffixHanlpMapResults);
|
hanlpMapResults.addAll(suffixHanlpMapResults);
|
||||||
|
|
||||||
|
|||||||
@@ -1,24 +1,26 @@
|
|||||||
package com.tencent.supersonic.chat.core.mapper;
|
package com.tencent.supersonic.chat.core.mapper;
|
||||||
|
|
||||||
import com.hankcs.hanlp.seg.common.Term;
|
import com.tencent.supersonic.headless.api.pojo.SchemaElement;
|
||||||
import com.tencent.supersonic.chat.core.pojo.QueryContext;
|
|
||||||
import com.tencent.supersonic.chat.api.pojo.SchemaElement;
|
|
||||||
import com.tencent.supersonic.chat.api.pojo.SchemaElementMatch;
|
import com.tencent.supersonic.chat.api.pojo.SchemaElementMatch;
|
||||||
import com.tencent.supersonic.chat.api.pojo.SchemaElementType;
|
import com.tencent.supersonic.headless.api.pojo.SchemaElementType;
|
||||||
import com.tencent.supersonic.chat.api.pojo.SchemaMapInfo;
|
import com.tencent.supersonic.chat.api.pojo.SchemaMapInfo;
|
||||||
import com.tencent.supersonic.chat.core.knowledge.DatabaseMapResult;
|
import com.tencent.supersonic.headless.api.pojo.response.S2Term;
|
||||||
import com.tencent.supersonic.chat.core.knowledge.HanlpMapResult;
|
import com.tencent.supersonic.headless.core.knowledge.DatabaseMapResult;
|
||||||
import com.tencent.supersonic.chat.core.utils.HanlpHelper;
|
import com.tencent.supersonic.headless.core.knowledge.HanlpMapResult;
|
||||||
import com.tencent.supersonic.chat.core.utils.NatureHelper;
|
import com.tencent.supersonic.chat.core.pojo.QueryContext;
|
||||||
|
import com.tencent.supersonic.headless.server.service.KnowledgeService;
|
||||||
|
import com.tencent.supersonic.headless.core.knowledge.helper.HanlpHelper;
|
||||||
|
import com.tencent.supersonic.headless.core.knowledge.helper.NatureHelper;
|
||||||
import com.tencent.supersonic.common.util.ContextUtils;
|
import com.tencent.supersonic.common.util.ContextUtils;
|
||||||
|
import lombok.extern.slf4j.Slf4j;
|
||||||
|
import org.springframework.util.CollectionUtils;
|
||||||
|
|
||||||
import java.util.HashSet;
|
import java.util.HashSet;
|
||||||
import java.util.List;
|
import java.util.List;
|
||||||
import java.util.Map;
|
import java.util.Map;
|
||||||
import java.util.Objects;
|
import java.util.Objects;
|
||||||
import java.util.Set;
|
import java.util.Set;
|
||||||
import java.util.stream.Collectors;
|
import java.util.stream.Collectors;
|
||||||
import lombok.extern.slf4j.Slf4j;
|
|
||||||
import org.springframework.util.CollectionUtils;
|
|
||||||
|
|
||||||
/***
|
/***
|
||||||
* A mapper that recognizes schema elements with keyword.
|
* A mapper that recognizes schema elements with keyword.
|
||||||
@@ -31,7 +33,8 @@ public class KeywordMapper extends BaseMapper {
|
|||||||
public void doMap(QueryContext queryContext) {
|
public void doMap(QueryContext queryContext) {
|
||||||
String queryText = queryContext.getQueryText();
|
String queryText = queryContext.getQueryText();
|
||||||
//1.hanlpDict Match
|
//1.hanlpDict Match
|
||||||
List<Term> terms = HanlpHelper.getTerms(queryText);
|
KnowledgeService knowledgeService = ContextUtils.getBean(KnowledgeService.class);
|
||||||
|
List<S2Term> terms = knowledgeService.getTerms(queryText);
|
||||||
HanlpDictMatchStrategy hanlpMatchStrategy = ContextUtils.getBean(HanlpDictMatchStrategy.class);
|
HanlpDictMatchStrategy hanlpMatchStrategy = ContextUtils.getBean(HanlpDictMatchStrategy.class);
|
||||||
|
|
||||||
List<HanlpMapResult> hanlpMapResults = hanlpMatchStrategy.getMatches(queryContext, terms);
|
List<HanlpMapResult> hanlpMapResults = hanlpMatchStrategy.getMatches(queryContext, terms);
|
||||||
@@ -45,7 +48,7 @@ public class KeywordMapper extends BaseMapper {
|
|||||||
}
|
}
|
||||||
|
|
||||||
private void convertHanlpMapResultToMapInfo(List<HanlpMapResult> mapResults, QueryContext queryContext,
|
private void convertHanlpMapResultToMapInfo(List<HanlpMapResult> mapResults, QueryContext queryContext,
|
||||||
List<Term> terms) {
|
List<S2Term> terms) {
|
||||||
if (CollectionUtils.isEmpty(mapResults)) {
|
if (CollectionUtils.isEmpty(mapResults)) {
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
@@ -56,8 +59,8 @@ public class KeywordMapper extends BaseMapper {
|
|||||||
|
|
||||||
for (HanlpMapResult hanlpMapResult : mapResults) {
|
for (HanlpMapResult hanlpMapResult : mapResults) {
|
||||||
for (String nature : hanlpMapResult.getNatures()) {
|
for (String nature : hanlpMapResult.getNatures()) {
|
||||||
Long modelId = NatureHelper.getModelId(nature);
|
Long viewId = NatureHelper.getViewId(nature);
|
||||||
if (Objects.isNull(modelId)) {
|
if (Objects.isNull(viewId)) {
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
SchemaElementType elementType = NatureHelper.convertToElementType(nature);
|
SchemaElementType elementType = NatureHelper.convertToElementType(nature);
|
||||||
@@ -65,8 +68,8 @@ public class KeywordMapper extends BaseMapper {
|
|||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
Long elementID = NatureHelper.getElementID(nature);
|
Long elementID = NatureHelper.getElementID(nature);
|
||||||
SchemaElement element = getSchemaElement(modelId, elementType, elementID,
|
SchemaElement element = getSchemaElement(viewId, elementType,
|
||||||
queryContext.getSemanticSchema());
|
elementID, queryContext.getSemanticSchema());
|
||||||
if (element == null) {
|
if (element == null) {
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
@@ -82,7 +85,7 @@ public class KeywordMapper extends BaseMapper {
|
|||||||
.detectWord(hanlpMapResult.getDetectWord())
|
.detectWord(hanlpMapResult.getDetectWord())
|
||||||
.build();
|
.build();
|
||||||
|
|
||||||
addToSchemaMap(queryContext.getMapInfo(), modelId, schemaElementMatch);
|
addToSchemaMap(queryContext.getMapInfo(), viewId, schemaElementMatch);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -103,12 +106,12 @@ public class KeywordMapper extends BaseMapper {
|
|||||||
.similarity(mapperHelper.getSimilarity(match.getDetectWord(), schemaElement.getName()))
|
.similarity(mapperHelper.getSimilarity(match.getDetectWord(), schemaElement.getName()))
|
||||||
.build();
|
.build();
|
||||||
log.info("add to schema, elementMatch {}", schemaElementMatch);
|
log.info("add to schema, elementMatch {}", schemaElementMatch);
|
||||||
addToSchemaMap(queryContext.getMapInfo(), schemaElement.getModel(), schemaElementMatch);
|
addToSchemaMap(queryContext.getMapInfo(), schemaElement.getView(), schemaElementMatch);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
private Set<Long> getRegElementSet(SchemaMapInfo schemaMap, SchemaElement schemaElement) {
|
private Set<Long> getRegElementSet(SchemaMapInfo schemaMap, SchemaElement schemaElement) {
|
||||||
List<SchemaElementMatch> elements = schemaMap.getMatchedElements(schemaElement.getModel());
|
List<SchemaElementMatch> elements = schemaMap.getMatchedElements(schemaElement.getView());
|
||||||
if (CollectionUtils.isEmpty(elements)) {
|
if (CollectionUtils.isEmpty(elements)) {
|
||||||
return new HashSet<>();
|
return new HashSet<>();
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,10 +1,15 @@
|
|||||||
package com.tencent.supersonic.chat.core.mapper;
|
package com.tencent.supersonic.chat.core.mapper;
|
||||||
|
|
||||||
import com.hankcs.hanlp.algorithm.EditDistance;
|
import com.hankcs.hanlp.algorithm.EditDistance;
|
||||||
import com.hankcs.hanlp.seg.common.Term;
|
|
||||||
import com.tencent.supersonic.chat.core.agent.Agent;
|
import com.tencent.supersonic.chat.core.agent.Agent;
|
||||||
import com.tencent.supersonic.chat.core.config.OptimizationConfig;
|
import com.tencent.supersonic.chat.core.config.OptimizationConfig;
|
||||||
import com.tencent.supersonic.chat.core.utils.NatureHelper;
|
import com.tencent.supersonic.headless.api.pojo.response.S2Term;
|
||||||
|
import com.tencent.supersonic.headless.core.knowledge.helper.NatureHelper;
|
||||||
|
import lombok.Data;
|
||||||
|
import lombok.extern.slf4j.Slf4j;
|
||||||
|
import org.springframework.beans.factory.annotation.Autowired;
|
||||||
|
import org.springframework.stereotype.Service;
|
||||||
|
|
||||||
import java.util.Comparator;
|
import java.util.Comparator;
|
||||||
import java.util.HashSet;
|
import java.util.HashSet;
|
||||||
import java.util.List;
|
import java.util.List;
|
||||||
@@ -12,10 +17,6 @@ import java.util.Map;
|
|||||||
import java.util.Objects;
|
import java.util.Objects;
|
||||||
import java.util.Set;
|
import java.util.Set;
|
||||||
import java.util.stream.Collectors;
|
import java.util.stream.Collectors;
|
||||||
import lombok.Data;
|
|
||||||
import lombok.extern.slf4j.Slf4j;
|
|
||||||
import org.springframework.beans.factory.annotation.Autowired;
|
|
||||||
import org.springframework.stereotype.Service;
|
|
||||||
|
|
||||||
@Data
|
@Data
|
||||||
@Service
|
@Service
|
||||||
@@ -35,8 +36,8 @@ public class MapperHelper {
|
|||||||
return index;
|
return index;
|
||||||
}
|
}
|
||||||
|
|
||||||
public Integer getStepOffset(List<Term> termList, Integer index) {
|
public Integer getStepOffset(List<S2Term> termList, Integer index) {
|
||||||
List<Integer> offsetList = termList.stream().sorted(Comparator.comparing(Term::getOffset))
|
List<Integer> offsetList = termList.stream().sorted(Comparator.comparing(S2Term::getOffset))
|
||||||
.map(term -> term.getOffset()).collect(Collectors.toList());
|
.map(term -> term.getOffset()).collect(Collectors.toList());
|
||||||
|
|
||||||
for (int j = 0; j < termList.size() - 1; j++) {
|
for (int j = 0; j < termList.size() - 1; j++) {
|
||||||
@@ -61,7 +62,7 @@ public class MapperHelper {
|
|||||||
*/
|
*/
|
||||||
public boolean existDimensionValues(List<String> natures) {
|
public boolean existDimensionValues(List<String> natures) {
|
||||||
for (String nature : natures) {
|
for (String nature : natures) {
|
||||||
if (NatureHelper.isDimensionValueModelId(nature)) {
|
if (NatureHelper.isDimensionValueViewId(nature)) {
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -81,33 +82,33 @@ public class MapperHelper {
|
|||||||
detectSegment.length());
|
detectSegment.length());
|
||||||
}
|
}
|
||||||
|
|
||||||
public Set<Long> getModelIds(Long modelId, Agent agent) {
|
public Set<Long> getViewIds(Long viewId, Agent agent) {
|
||||||
|
|
||||||
Set<Long> detectModelIds = new HashSet<>();
|
Set<Long> detectViewIds = new HashSet<>();
|
||||||
if (Objects.nonNull(agent)) {
|
if (Objects.nonNull(agent)) {
|
||||||
detectModelIds = agent.getModelIds(null);
|
detectViewIds = agent.getViewIds(null);
|
||||||
}
|
}
|
||||||
//contains all
|
//contains all
|
||||||
if (Agent.containsAllModel(detectModelIds)) {
|
if (Agent.containsAllModel(detectViewIds)) {
|
||||||
if (Objects.nonNull(modelId) && modelId > 0) {
|
if (Objects.nonNull(viewId) && viewId > 0) {
|
||||||
Set<Long> result = new HashSet<>();
|
Set<Long> result = new HashSet<>();
|
||||||
result.add(modelId);
|
result.add(viewId);
|
||||||
return result;
|
return result;
|
||||||
}
|
}
|
||||||
return new HashSet<>();
|
return new HashSet<>();
|
||||||
}
|
}
|
||||||
|
|
||||||
if (Objects.nonNull(detectModelIds)) {
|
if (Objects.nonNull(detectViewIds)) {
|
||||||
detectModelIds = detectModelIds.stream().filter(entry -> entry > 0).collect(Collectors.toSet());
|
detectViewIds = detectViewIds.stream().filter(entry -> entry > 0).collect(Collectors.toSet());
|
||||||
}
|
}
|
||||||
|
|
||||||
if (Objects.nonNull(modelId) && modelId > 0 && Objects.nonNull(detectModelIds)) {
|
if (Objects.nonNull(viewId) && viewId > 0 && Objects.nonNull(detectViewIds)) {
|
||||||
if (detectModelIds.contains(modelId)) {
|
if (detectViewIds.contains(viewId)) {
|
||||||
Set<Long> result = new HashSet<>();
|
Set<Long> result = new HashSet<>();
|
||||||
result.add(modelId);
|
result.add(viewId);
|
||||||
return result;
|
return result;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return detectModelIds;
|
return detectViewIds;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,7 +1,8 @@
|
|||||||
package com.tencent.supersonic.chat.core.mapper;
|
package com.tencent.supersonic.chat.core.mapper;
|
||||||
|
|
||||||
import com.hankcs.hanlp.seg.common.Term;
|
|
||||||
import com.tencent.supersonic.chat.core.pojo.QueryContext;
|
import com.tencent.supersonic.chat.core.pojo.QueryContext;
|
||||||
|
import com.tencent.supersonic.headless.api.pojo.response.S2Term;
|
||||||
|
|
||||||
import java.util.List;
|
import java.util.List;
|
||||||
import java.util.Map;
|
import java.util.Map;
|
||||||
import java.util.Set;
|
import java.util.Set;
|
||||||
@@ -12,6 +13,6 @@ import java.util.Set;
|
|||||||
*/
|
*/
|
||||||
public interface MatchStrategy<T> {
|
public interface MatchStrategy<T> {
|
||||||
|
|
||||||
Map<MatchText, List<T>> match(QueryContext queryContext, List<Term> terms, Set<Long> detectModelId);
|
Map<MatchText, List<T>> match(QueryContext queryContext, List<S2Term> terms, Set<Long> detectViewIds);
|
||||||
|
|
||||||
}
|
}
|
||||||
@@ -1,52 +0,0 @@
|
|||||||
package com.tencent.supersonic.chat.core.mapper;
|
|
||||||
|
|
||||||
import com.tencent.supersonic.chat.api.pojo.SchemaElementMatch;
|
|
||||||
import com.tencent.supersonic.chat.api.pojo.SchemaMapInfo;
|
|
||||||
import com.tencent.supersonic.chat.api.pojo.SchemaModelClusterMapInfo;
|
|
||||||
import com.tencent.supersonic.chat.api.pojo.SemanticSchema;
|
|
||||||
import com.tencent.supersonic.chat.core.pojo.QueryContext;
|
|
||||||
import com.tencent.supersonic.chat.core.utils.ModelClusterBuilder;
|
|
||||||
import com.tencent.supersonic.common.pojo.ModelCluster;
|
|
||||||
import java.util.ArrayList;
|
|
||||||
import java.util.HashMap;
|
|
||||||
import java.util.List;
|
|
||||||
import java.util.Map;
|
|
||||||
import java.util.Set;
|
|
||||||
import java.util.stream.Collectors;
|
|
||||||
|
|
||||||
/***
|
|
||||||
* ModelClusterMapper build a cluster from
|
|
||||||
* connectable data models based on model-rela configuration
|
|
||||||
* and generate SchemaModelClusterMapInfo
|
|
||||||
*/
|
|
||||||
public class ModelClusterMapper implements SchemaMapper {
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public void map(QueryContext queryContext) {
|
|
||||||
SemanticSchema semanticSchema = queryContext.getSemanticSchema();
|
|
||||||
SchemaMapInfo schemaMapInfo = queryContext.getMapInfo();
|
|
||||||
List<ModelCluster> modelClusters = buildModelClusterMatched(schemaMapInfo, semanticSchema);
|
|
||||||
Map<String, List<SchemaElementMatch>> modelClusterElementMatches = new HashMap<>();
|
|
||||||
for (ModelCluster modelCluster : modelClusters) {
|
|
||||||
for (Long modelId : schemaMapInfo.getMatchedModels()) {
|
|
||||||
if (modelCluster.getModelIds().contains(modelId)) {
|
|
||||||
modelClusterElementMatches.computeIfAbsent(modelCluster.getKey(), k -> new ArrayList<>())
|
|
||||||
.addAll(schemaMapInfo.getMatchedElements(modelId));
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
SchemaModelClusterMapInfo modelClusterMapInfo = new SchemaModelClusterMapInfo();
|
|
||||||
modelClusterMapInfo.setModelElementMatches(modelClusterElementMatches);
|
|
||||||
queryContext.setModelClusterMapInfo(modelClusterMapInfo);
|
|
||||||
}
|
|
||||||
|
|
||||||
private List<ModelCluster> buildModelClusterMatched(SchemaMapInfo schemaMapInfo,
|
|
||||||
SemanticSchema semanticSchema) {
|
|
||||||
Set<Long> matchedModels = schemaMapInfo.getMatchedModels();
|
|
||||||
List<ModelCluster> modelClusters = ModelClusterBuilder.buildModelClusters(semanticSchema);
|
|
||||||
return modelClusters.stream().map(ModelCluster::getModelIds).peek(modelCluster -> {
|
|
||||||
modelCluster.removeIf(model -> !matchedModels.contains(model));
|
|
||||||
}).filter(modelCluster -> modelCluster.size() > 0).map(ModelCluster::build).collect(Collectors.toList());
|
|
||||||
}
|
|
||||||
|
|
||||||
}
|
|
||||||
@@ -1,6 +1,6 @@
|
|||||||
package com.tencent.supersonic.chat.core.mapper;
|
package com.tencent.supersonic.chat.core.mapper;
|
||||||
|
|
||||||
import com.tencent.supersonic.chat.api.pojo.SchemaElementType;
|
import com.tencent.supersonic.headless.api.pojo.SchemaElementType;
|
||||||
import java.io.Serializable;
|
import java.io.Serializable;
|
||||||
import lombok.Data;
|
import lombok.Data;
|
||||||
import lombok.ToString;
|
import lombok.ToString;
|
||||||
|
|||||||
@@ -1,20 +1,21 @@
|
|||||||
package com.tencent.supersonic.chat.core.mapper;
|
package com.tencent.supersonic.chat.core.mapper;
|
||||||
|
|
||||||
import com.google.common.collect.Lists;
|
import com.google.common.collect.Lists;
|
||||||
import com.tencent.supersonic.chat.api.pojo.SchemaElement;
|
import com.tencent.supersonic.headless.api.pojo.SchemaElement;
|
||||||
import com.tencent.supersonic.chat.api.pojo.SchemaElementMatch;
|
import com.tencent.supersonic.chat.api.pojo.SchemaElementMatch;
|
||||||
import com.tencent.supersonic.chat.api.pojo.SchemaElementType;
|
import com.tencent.supersonic.headless.api.pojo.SchemaElementType;
|
||||||
import com.tencent.supersonic.chat.api.pojo.SchemaMapInfo;
|
import com.tencent.supersonic.chat.api.pojo.SchemaMapInfo;
|
||||||
import com.tencent.supersonic.chat.api.pojo.request.QueryFilter;
|
import com.tencent.supersonic.chat.api.pojo.request.QueryFilter;
|
||||||
import com.tencent.supersonic.chat.api.pojo.request.QueryFilters;
|
import com.tencent.supersonic.chat.api.pojo.request.QueryFilters;
|
||||||
import com.tencent.supersonic.chat.core.knowledge.builder.BaseWordBuilder;
|
import com.tencent.supersonic.headless.core.knowledge.builder.BaseWordBuilder;
|
||||||
import com.tencent.supersonic.chat.core.pojo.QueryContext;
|
import com.tencent.supersonic.chat.core.pojo.QueryContext;
|
||||||
import com.tencent.supersonic.common.pojo.Constants;
|
import com.tencent.supersonic.common.pojo.Constants;
|
||||||
|
import lombok.extern.slf4j.Slf4j;
|
||||||
|
import org.springframework.util.CollectionUtils;
|
||||||
|
|
||||||
import java.util.List;
|
import java.util.List;
|
||||||
import java.util.Map;
|
import java.util.Map;
|
||||||
import java.util.stream.Collectors;
|
import java.util.stream.Collectors;
|
||||||
import lombok.extern.slf4j.Slf4j;
|
|
||||||
import org.springframework.util.CollectionUtils;
|
|
||||||
|
|
||||||
@Slf4j
|
@Slf4j
|
||||||
public class QueryFilterMapper implements SchemaMapper {
|
public class QueryFilterMapper implements SchemaMapper {
|
||||||
@@ -23,22 +24,22 @@ public class QueryFilterMapper implements SchemaMapper {
|
|||||||
|
|
||||||
@Override
|
@Override
|
||||||
public void map(QueryContext queryContext) {
|
public void map(QueryContext queryContext) {
|
||||||
Long modelId = queryContext.getModelId();
|
Long viewId = queryContext.getViewId();
|
||||||
if (modelId == null || modelId <= 0) {
|
if (viewId == null || viewId <= 0) {
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
SchemaMapInfo schemaMapInfo = queryContext.getMapInfo();
|
SchemaMapInfo schemaMapInfo = queryContext.getMapInfo();
|
||||||
clearOtherSchemaElementMatch(modelId, schemaMapInfo);
|
clearOtherSchemaElementMatch(viewId, schemaMapInfo);
|
||||||
List<SchemaElementMatch> schemaElementMatches = schemaMapInfo.getMatchedElements(modelId);
|
List<SchemaElementMatch> schemaElementMatches = schemaMapInfo.getMatchedElements(viewId);
|
||||||
if (schemaElementMatches == null) {
|
if (schemaElementMatches == null) {
|
||||||
schemaElementMatches = Lists.newArrayList();
|
schemaElementMatches = Lists.newArrayList();
|
||||||
schemaMapInfo.setMatchedElements(modelId, schemaElementMatches);
|
schemaMapInfo.setMatchedElements(viewId, schemaElementMatches);
|
||||||
}
|
}
|
||||||
addValueSchemaElementMatch(queryContext, schemaElementMatches);
|
addValueSchemaElementMatch(queryContext, schemaElementMatches);
|
||||||
}
|
}
|
||||||
|
|
||||||
private void clearOtherSchemaElementMatch(Long modelId, SchemaMapInfo schemaMapInfo) {
|
private void clearOtherSchemaElementMatch(Long modelId, SchemaMapInfo schemaMapInfo) {
|
||||||
for (Map.Entry<Long, List<SchemaElementMatch>> entry : schemaMapInfo.getModelElementMatches().entrySet()) {
|
for (Map.Entry<Long, List<SchemaElementMatch>> entry : schemaMapInfo.getViewElementMatches().entrySet()) {
|
||||||
if (!entry.getKey().equals(modelId)) {
|
if (!entry.getKey().equals(modelId)) {
|
||||||
entry.getValue().clear();
|
entry.getValue().clear();
|
||||||
}
|
}
|
||||||
@@ -60,7 +61,7 @@ public class QueryFilterMapper implements SchemaMapper {
|
|||||||
.name(String.valueOf(filter.getValue()))
|
.name(String.valueOf(filter.getValue()))
|
||||||
.type(SchemaElementType.VALUE)
|
.type(SchemaElementType.VALUE)
|
||||||
.bizName(filter.getBizName())
|
.bizName(filter.getBizName())
|
||||||
.model(queryContext.getModelId())
|
.view(queryContext.getViewId())
|
||||||
.build();
|
.build();
|
||||||
SchemaElementMatch schemaElementMatch = SchemaElementMatch.builder()
|
SchemaElementMatch schemaElementMatch = SchemaElementMatch.builder()
|
||||||
.element(element)
|
.element(element)
|
||||||
|
|||||||
@@ -1,11 +1,12 @@
|
|||||||
package com.tencent.supersonic.chat.core.mapper;
|
package com.tencent.supersonic.chat.core.mapper;
|
||||||
|
|
||||||
import com.google.common.collect.Lists;
|
import com.google.common.collect.Lists;
|
||||||
import com.hankcs.hanlp.seg.common.Term;
|
import com.tencent.supersonic.headless.api.pojo.response.S2Term;
|
||||||
import com.tencent.supersonic.chat.core.knowledge.HanlpMapResult;
|
import com.tencent.supersonic.headless.core.knowledge.HanlpMapResult;
|
||||||
import com.tencent.supersonic.chat.core.knowledge.SearchService;
|
import com.tencent.supersonic.headless.core.knowledge.SearchService;
|
||||||
import com.tencent.supersonic.chat.core.pojo.QueryContext;
|
import com.tencent.supersonic.chat.core.pojo.QueryContext;
|
||||||
import com.tencent.supersonic.common.pojo.enums.DictWordType;
|
import com.tencent.supersonic.common.pojo.enums.DictWordType;
|
||||||
|
import com.tencent.supersonic.headless.server.service.KnowledgeService;
|
||||||
import java.util.List;
|
import java.util.List;
|
||||||
import java.util.Map;
|
import java.util.Map;
|
||||||
import java.util.Objects;
|
import java.util.Objects;
|
||||||
@@ -14,6 +15,7 @@ import java.util.concurrent.ConcurrentHashMap;
|
|||||||
import java.util.stream.Collectors;
|
import java.util.stream.Collectors;
|
||||||
import org.apache.commons.collections.CollectionUtils;
|
import org.apache.commons.collections.CollectionUtils;
|
||||||
import org.apache.commons.lang3.StringUtils;
|
import org.apache.commons.lang3.StringUtils;
|
||||||
|
import org.springframework.beans.factory.annotation.Autowired;
|
||||||
import org.springframework.stereotype.Service;
|
import org.springframework.stereotype.Service;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
@@ -25,9 +27,12 @@ public class SearchMatchStrategy extends BaseMatchStrategy<HanlpMapResult> {
|
|||||||
|
|
||||||
private static final int SEARCH_SIZE = 3;
|
private static final int SEARCH_SIZE = 3;
|
||||||
|
|
||||||
|
@Autowired
|
||||||
|
private KnowledgeService knowledgeService;
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public Map<MatchText, List<HanlpMapResult>> match(QueryContext queryContext, List<Term> originals,
|
public Map<MatchText, List<HanlpMapResult>> match(QueryContext queryContext, List<S2Term> originals,
|
||||||
Set<Long> detectModelIds) {
|
Set<Long> detectViewIds) {
|
||||||
String text = queryContext.getQueryText();
|
String text = queryContext.getQueryText();
|
||||||
Map<Integer, Integer> regOffsetToLength = getRegOffsetToLength(originals);
|
Map<Integer, Integer> regOffsetToLength = getRegOffsetToLength(originals);
|
||||||
|
|
||||||
@@ -51,10 +56,10 @@ public class SearchMatchStrategy extends BaseMatchStrategy<HanlpMapResult> {
|
|||||||
String detectSegment = text.substring(detectIndex);
|
String detectSegment = text.substring(detectIndex);
|
||||||
|
|
||||||
if (StringUtils.isNotEmpty(detectSegment)) {
|
if (StringUtils.isNotEmpty(detectSegment)) {
|
||||||
List<HanlpMapResult> hanlpMapResults = SearchService.prefixSearch(detectSegment,
|
List<HanlpMapResult> hanlpMapResults = knowledgeService.prefixSearch(detectSegment,
|
||||||
SearchService.SEARCH_SIZE, queryContext.getAgentId(), detectModelIds);
|
SearchService.SEARCH_SIZE, detectViewIds);
|
||||||
List<HanlpMapResult> suffixHanlpMapResults = SearchService.suffixSearch(
|
List<HanlpMapResult> suffixHanlpMapResults = knowledgeService.suffixSearch(
|
||||||
detectSegment, SEARCH_SIZE, queryContext.getAgentId(), detectModelIds);
|
detectSegment, SEARCH_SIZE, detectViewIds);
|
||||||
hanlpMapResults.addAll(suffixHanlpMapResults);
|
hanlpMapResults.addAll(suffixHanlpMapResults);
|
||||||
// remove entity name where search
|
// remove entity name where search
|
||||||
hanlpMapResults = hanlpMapResults.stream().filter(entry -> {
|
hanlpMapResults = hanlpMapResults.stream().filter(entry -> {
|
||||||
@@ -88,9 +93,9 @@ public class SearchMatchStrategy extends BaseMatchStrategy<HanlpMapResult> {
|
|||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public void detectByStep(QueryContext queryContext, Set<HanlpMapResult> results, Set<Long> detectModelIds,
|
public void detectByStep(QueryContext queryContext, Set<HanlpMapResult> existResults, Set<Long> detectViewIds,
|
||||||
Integer startIndex,
|
String detectSegment, int offset) {
|
||||||
Integer i, int offset) {
|
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,23 +1,24 @@
|
|||||||
package com.tencent.supersonic.chat.core.parser;
|
package com.tencent.supersonic.chat.core.parser;
|
||||||
|
|
||||||
import com.tencent.supersonic.chat.core.pojo.QueryContext;
|
|
||||||
import com.tencent.supersonic.chat.core.parser.plugin.function.FunctionPromptGenerator;
|
import com.tencent.supersonic.chat.core.parser.plugin.function.FunctionPromptGenerator;
|
||||||
|
import com.tencent.supersonic.chat.core.parser.plugin.function.FunctionReq;
|
||||||
|
import com.tencent.supersonic.chat.core.parser.plugin.function.FunctionResp;
|
||||||
import com.tencent.supersonic.chat.core.parser.sql.llm.OutputFormat;
|
import com.tencent.supersonic.chat.core.parser.sql.llm.OutputFormat;
|
||||||
import com.tencent.supersonic.chat.core.parser.sql.llm.SqlGeneration;
|
import com.tencent.supersonic.chat.core.parser.sql.llm.SqlGeneration;
|
||||||
import com.tencent.supersonic.chat.core.parser.sql.llm.SqlGenerationFactory;
|
import com.tencent.supersonic.chat.core.parser.sql.llm.SqlGenerationFactory;
|
||||||
import com.tencent.supersonic.chat.core.parser.plugin.function.FunctionReq;
|
import com.tencent.supersonic.chat.core.pojo.QueryContext;
|
||||||
import com.tencent.supersonic.chat.core.parser.plugin.function.FunctionResp;
|
|
||||||
import com.tencent.supersonic.chat.core.query.llm.s2sql.LLMReq;
|
import com.tencent.supersonic.chat.core.query.llm.s2sql.LLMReq;
|
||||||
import com.tencent.supersonic.chat.core.query.llm.s2sql.LLMReq.SqlGenerationMode;
|
import com.tencent.supersonic.chat.core.query.llm.s2sql.LLMReq.SqlGenerationMode;
|
||||||
import com.tencent.supersonic.chat.core.query.llm.s2sql.LLMResp;
|
import com.tencent.supersonic.chat.core.query.llm.s2sql.LLMResp;
|
||||||
import com.tencent.supersonic.common.util.ContextUtils;
|
import com.tencent.supersonic.common.util.ContextUtils;
|
||||||
import dev.langchain4j.model.chat.ChatLanguageModel;
|
import dev.langchain4j.model.chat.ChatLanguageModel;
|
||||||
import java.util.Objects;
|
|
||||||
import lombok.extern.slf4j.Slf4j;
|
import lombok.extern.slf4j.Slf4j;
|
||||||
import org.slf4j.Logger;
|
import org.slf4j.Logger;
|
||||||
import org.slf4j.LoggerFactory;
|
import org.slf4j.LoggerFactory;
|
||||||
import org.springframework.stereotype.Component;
|
import org.springframework.stereotype.Component;
|
||||||
|
|
||||||
|
import java.util.Objects;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* LLMProxy based on langchain4j Java version.
|
* LLMProxy based on langchain4j Java version.
|
||||||
*/
|
*/
|
||||||
@@ -37,12 +38,12 @@ public class JavaLLMProxy implements LLMProxy {
|
|||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
|
||||||
public LLMResp query2sql(LLMReq llmReq, String modelClusterKey) {
|
public LLMResp query2sql(LLMReq llmReq, Long viewId) {
|
||||||
|
|
||||||
SqlGeneration sqlGeneration = SqlGenerationFactory.get(
|
SqlGeneration sqlGeneration = SqlGenerationFactory.get(
|
||||||
SqlGenerationMode.getMode(llmReq.getSqlGenerationMode()));
|
SqlGenerationMode.getMode(llmReq.getSqlGenerationMode()));
|
||||||
String modelName = llmReq.getSchema().getModelName();
|
String modelName = llmReq.getSchema().getViewName();
|
||||||
LLMResp result = sqlGeneration.generation(llmReq, modelClusterKey);
|
LLMResp result = sqlGeneration.generation(llmReq, viewId);
|
||||||
result.setQuery(llmReq.getQueryText());
|
result.setQuery(llmReq.getQueryText());
|
||||||
result.setModelName(modelName);
|
result.setModelName(modelName);
|
||||||
return result;
|
return result;
|
||||||
|
|||||||
@@ -15,7 +15,7 @@ public interface LLMProxy {
|
|||||||
|
|
||||||
boolean isSkip(QueryContext queryContext);
|
boolean isSkip(QueryContext queryContext);
|
||||||
|
|
||||||
LLMResp query2sql(LLMReq llmReq, String modelClusterKey);
|
LLMResp query2sql(LLMReq llmReq, Long viewId);
|
||||||
|
|
||||||
FunctionResp requestFunction(FunctionReq functionReq);
|
FunctionResp requestFunction(FunctionReq functionReq);
|
||||||
|
|
||||||
|
|||||||
@@ -1,19 +1,16 @@
|
|||||||
package com.tencent.supersonic.chat.core.parser;
|
package com.tencent.supersonic.chat.core.parser;
|
||||||
|
|
||||||
import com.alibaba.fastjson.JSON;
|
import com.alibaba.fastjson.JSON;
|
||||||
import com.tencent.supersonic.chat.core.pojo.QueryContext;
|
|
||||||
import com.tencent.supersonic.chat.core.config.LLMParserConfig;
|
import com.tencent.supersonic.chat.core.config.LLMParserConfig;
|
||||||
import com.tencent.supersonic.chat.core.parser.sql.llm.OutputFormat;
|
|
||||||
import com.tencent.supersonic.chat.core.parser.plugin.function.FunctionCallConfig;
|
import com.tencent.supersonic.chat.core.parser.plugin.function.FunctionCallConfig;
|
||||||
import com.tencent.supersonic.chat.core.parser.plugin.function.FunctionReq;
|
import com.tencent.supersonic.chat.core.parser.plugin.function.FunctionReq;
|
||||||
import com.tencent.supersonic.chat.core.parser.plugin.function.FunctionResp;
|
import com.tencent.supersonic.chat.core.parser.plugin.function.FunctionResp;
|
||||||
|
import com.tencent.supersonic.chat.core.parser.sql.llm.OutputFormat;
|
||||||
|
import com.tencent.supersonic.chat.core.pojo.QueryContext;
|
||||||
import com.tencent.supersonic.chat.core.query.llm.s2sql.LLMReq;
|
import com.tencent.supersonic.chat.core.query.llm.s2sql.LLMReq;
|
||||||
import com.tencent.supersonic.chat.core.query.llm.s2sql.LLMResp;
|
import com.tencent.supersonic.chat.core.query.llm.s2sql.LLMResp;
|
||||||
import com.tencent.supersonic.common.util.ContextUtils;
|
import com.tencent.supersonic.common.util.ContextUtils;
|
||||||
import com.tencent.supersonic.common.util.JsonUtil;
|
import com.tencent.supersonic.common.util.JsonUtil;
|
||||||
import java.net.URI;
|
|
||||||
import java.net.URL;
|
|
||||||
import java.util.ArrayList;
|
|
||||||
import lombok.extern.slf4j.Slf4j;
|
import lombok.extern.slf4j.Slf4j;
|
||||||
import org.apache.commons.collections4.MapUtils;
|
import org.apache.commons.collections4.MapUtils;
|
||||||
import org.apache.commons.lang3.StringUtils;
|
import org.apache.commons.lang3.StringUtils;
|
||||||
@@ -28,6 +25,10 @@ import org.springframework.stereotype.Component;
|
|||||||
import org.springframework.web.client.RestTemplate;
|
import org.springframework.web.client.RestTemplate;
|
||||||
import org.springframework.web.util.UriComponentsBuilder;
|
import org.springframework.web.util.UriComponentsBuilder;
|
||||||
|
|
||||||
|
import java.net.URI;
|
||||||
|
import java.net.URL;
|
||||||
|
import java.util.ArrayList;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* PythonLLMProxy sends requests to LangChain-based python service.
|
* PythonLLMProxy sends requests to LangChain-based python service.
|
||||||
*/
|
*/
|
||||||
@@ -47,10 +48,10 @@ public class PythonLLMProxy implements LLMProxy {
|
|||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
|
||||||
public LLMResp query2sql(LLMReq llmReq, String modelClusterKey) {
|
public LLMResp query2sql(LLMReq llmReq, Long viewId) {
|
||||||
long startTime = System.currentTimeMillis();
|
long startTime = System.currentTimeMillis();
|
||||||
log.info("requestLLM request, modelId:{},llmReq:{}", modelClusterKey, llmReq);
|
log.info("requestLLM request, viewId:{},llmReq:{}", viewId, llmReq);
|
||||||
keyPipelineLog.info("modelClusterKey:{},llmReq:{}", modelClusterKey, llmReq);
|
keyPipelineLog.info("viewId:{},llmReq:{}", viewId, llmReq);
|
||||||
try {
|
try {
|
||||||
LLMParserConfig llmParserConfig = ContextUtils.getBean(LLMParserConfig.class);
|
LLMParserConfig llmParserConfig = ContextUtils.getBean(LLMParserConfig.class);
|
||||||
|
|
||||||
|
|||||||
@@ -1,7 +1,7 @@
|
|||||||
package com.tencent.supersonic.chat.core.parser;
|
package com.tencent.supersonic.chat.core.parser;
|
||||||
|
|
||||||
import com.tencent.supersonic.auth.api.authentication.pojo.User;
|
import com.tencent.supersonic.auth.api.authentication.pojo.User;
|
||||||
import com.tencent.supersonic.chat.api.pojo.SchemaElement;
|
import com.tencent.supersonic.headless.api.pojo.SchemaElement;
|
||||||
import com.tencent.supersonic.chat.api.pojo.SemanticParseInfo;
|
import com.tencent.supersonic.chat.api.pojo.SemanticParseInfo;
|
||||||
import com.tencent.supersonic.chat.api.pojo.SemanticSchema;
|
import com.tencent.supersonic.chat.api.pojo.SemanticSchema;
|
||||||
import com.tencent.supersonic.chat.api.pojo.response.SqlInfo;
|
import com.tencent.supersonic.chat.api.pojo.response.SqlInfo;
|
||||||
@@ -12,14 +12,15 @@ import com.tencent.supersonic.chat.core.query.llm.s2sql.LLMSqlQuery;
|
|||||||
import com.tencent.supersonic.chat.core.query.rule.RuleSemanticQuery;
|
import com.tencent.supersonic.chat.core.query.rule.RuleSemanticQuery;
|
||||||
import com.tencent.supersonic.common.pojo.enums.QueryType;
|
import com.tencent.supersonic.common.pojo.enums.QueryType;
|
||||||
import com.tencent.supersonic.common.pojo.enums.TimeDimensionEnum;
|
import com.tencent.supersonic.common.pojo.enums.TimeDimensionEnum;
|
||||||
import com.tencent.supersonic.common.util.jsqlparser.SqlParserSelectHelper;
|
import com.tencent.supersonic.common.util.jsqlparser.SqlSelectHelper;
|
||||||
|
import lombok.extern.slf4j.Slf4j;
|
||||||
|
import org.apache.commons.collections4.CollectionUtils;
|
||||||
|
import org.apache.commons.lang3.StringUtils;
|
||||||
|
|
||||||
import java.util.List;
|
import java.util.List;
|
||||||
import java.util.Objects;
|
import java.util.Objects;
|
||||||
import java.util.Set;
|
import java.util.Set;
|
||||||
import java.util.stream.Collectors;
|
import java.util.stream.Collectors;
|
||||||
import lombok.extern.slf4j.Slf4j;
|
|
||||||
import org.apache.commons.collections4.CollectionUtils;
|
|
||||||
import org.apache.commons.lang3.StringUtils;
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* QueryTypeParser resolves query type as either METRIC or TAG, or ID.
|
* QueryTypeParser resolves query type as either METRIC or TAG, or ID.
|
||||||
@@ -49,21 +50,21 @@ public class QueryTypeParser implements SemanticParser {
|
|||||||
return QueryType.ID;
|
return QueryType.ID;
|
||||||
}
|
}
|
||||||
//1. entity queryType
|
//1. entity queryType
|
||||||
Set<Long> modelIds = parseInfo.getModel().getModelIds();
|
Long viewId = parseInfo.getViewId();
|
||||||
SemanticSchema semanticSchema = queryContext.getSemanticSchema();
|
SemanticSchema semanticSchema = queryContext.getSemanticSchema();
|
||||||
if (semanticQuery instanceof RuleSemanticQuery || semanticQuery instanceof LLMSqlQuery) {
|
if (semanticQuery instanceof RuleSemanticQuery || semanticQuery instanceof LLMSqlQuery) {
|
||||||
//If all the fields in the SELECT statement are of tag type.
|
//If all the fields in the SELECT statement are of tag type.
|
||||||
List<String> whereFields = SqlParserSelectHelper.getWhereFields(sqlInfo.getS2SQL())
|
List<String> whereFields = SqlSelectHelper.getWhereFields(sqlInfo.getS2SQL())
|
||||||
.stream().filter(field -> !TimeDimensionEnum.containsTimeDimension(field))
|
.stream().filter(field -> !TimeDimensionEnum.containsTimeDimension(field))
|
||||||
.collect(Collectors.toList());
|
.collect(Collectors.toList());
|
||||||
|
|
||||||
if (CollectionUtils.isNotEmpty(whereFields)) {
|
if (CollectionUtils.isNotEmpty(whereFields)) {
|
||||||
Set<String> ids = semanticSchema.getEntities(modelIds).stream().map(SchemaElement::getName)
|
Set<String> ids = semanticSchema.getEntities(viewId).stream().map(SchemaElement::getName)
|
||||||
.collect(Collectors.toSet());
|
.collect(Collectors.toSet());
|
||||||
if (CollectionUtils.isNotEmpty(ids) && ids.stream().anyMatch(whereFields::contains)) {
|
if (CollectionUtils.isNotEmpty(ids) && ids.stream().anyMatch(whereFields::contains)) {
|
||||||
return QueryType.ID;
|
return QueryType.ID;
|
||||||
}
|
}
|
||||||
Set<String> tags = semanticSchema.getTags(modelIds).stream().map(SchemaElement::getName)
|
Set<String> tags = semanticSchema.getTags(viewId).stream().map(SchemaElement::getName)
|
||||||
.collect(Collectors.toSet());
|
.collect(Collectors.toSet());
|
||||||
if (CollectionUtils.isNotEmpty(tags) && tags.containsAll(whereFields)) {
|
if (CollectionUtils.isNotEmpty(tags) && tags.containsAll(whereFields)) {
|
||||||
return QueryType.TAG;
|
return QueryType.TAG;
|
||||||
@@ -71,8 +72,8 @@ public class QueryTypeParser implements SemanticParser {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
//2. metric queryType
|
//2. metric queryType
|
||||||
List<String> selectFields = SqlParserSelectHelper.getSelectFields(sqlInfo.getS2SQL());
|
List<String> selectFields = SqlSelectHelper.getSelectFields(sqlInfo.getS2SQL());
|
||||||
List<SchemaElement> metrics = semanticSchema.getMetrics(modelIds);
|
List<SchemaElement> metrics = semanticSchema.getMetrics(viewId);
|
||||||
if (CollectionUtils.isNotEmpty(metrics)) {
|
if (CollectionUtils.isNotEmpty(metrics)) {
|
||||||
Set<String> metricNameSet = metrics.stream().map(SchemaElement::getName).collect(Collectors.toSet());
|
Set<String> metricNameSet = metrics.stream().map(SchemaElement::getName).collect(Collectors.toSet());
|
||||||
boolean containMetric = selectFields.stream().anyMatch(metricNameSet::contains);
|
boolean containMetric = selectFields.stream().anyMatch(metricNameSet::contains);
|
||||||
|
|||||||
@@ -3,7 +3,7 @@ package com.tencent.supersonic.chat.core.parser.plugin;
|
|||||||
import com.google.common.collect.Lists;
|
import com.google.common.collect.Lists;
|
||||||
import com.google.common.collect.Sets;
|
import com.google.common.collect.Sets;
|
||||||
import com.tencent.supersonic.chat.api.pojo.SchemaElementMatch;
|
import com.tencent.supersonic.chat.api.pojo.SchemaElementMatch;
|
||||||
import com.tencent.supersonic.chat.api.pojo.SchemaElementType;
|
import com.tencent.supersonic.headless.api.pojo.SchemaElementType;
|
||||||
import com.tencent.supersonic.chat.api.pojo.SemanticParseInfo;
|
import com.tencent.supersonic.chat.api.pojo.SemanticParseInfo;
|
||||||
import com.tencent.supersonic.chat.api.pojo.request.QueryFilter;
|
import com.tencent.supersonic.chat.api.pojo.request.QueryFilter;
|
||||||
import com.tencent.supersonic.chat.api.pojo.request.QueryFilters;
|
import com.tencent.supersonic.chat.api.pojo.request.QueryFilters;
|
||||||
@@ -18,7 +18,6 @@ import com.tencent.supersonic.chat.core.query.QueryManager;
|
|||||||
import com.tencent.supersonic.chat.core.query.SemanticQuery;
|
import com.tencent.supersonic.chat.core.query.SemanticQuery;
|
||||||
import com.tencent.supersonic.chat.core.query.plugin.PluginSemanticQuery;
|
import com.tencent.supersonic.chat.core.query.plugin.PluginSemanticQuery;
|
||||||
import com.tencent.supersonic.common.pojo.Constants;
|
import com.tencent.supersonic.common.pojo.Constants;
|
||||||
import com.tencent.supersonic.common.pojo.ModelCluster;
|
|
||||||
import com.tencent.supersonic.common.pojo.enums.FilterOperatorEnum;
|
import com.tencent.supersonic.common.pojo.enums.FilterOperatorEnum;
|
||||||
import org.springframework.util.CollectionUtils;
|
import org.springframework.util.CollectionUtils;
|
||||||
import java.util.HashMap;
|
import java.util.HashMap;
|
||||||
@@ -56,13 +55,13 @@ public abstract class PluginParser implements SemanticParser {
|
|||||||
|
|
||||||
public void buildQuery(QueryContext queryContext, PluginRecallResult pluginRecallResult) {
|
public void buildQuery(QueryContext queryContext, PluginRecallResult pluginRecallResult) {
|
||||||
Plugin plugin = pluginRecallResult.getPlugin();
|
Plugin plugin = pluginRecallResult.getPlugin();
|
||||||
Set<Long> modelIds = pluginRecallResult.getModelIds();
|
Set<Long> viewIds = pluginRecallResult.getViewIds();
|
||||||
if (plugin.isContainsAllModel()) {
|
if (plugin.isContainsAllModel()) {
|
||||||
modelIds = Sets.newHashSet(-1L);
|
viewIds = Sets.newHashSet(-1L);
|
||||||
}
|
}
|
||||||
for (Long modelId : modelIds) {
|
for (Long viewId : viewIds) {
|
||||||
PluginSemanticQuery pluginQuery = QueryManager.createPluginQuery(plugin.getType());
|
PluginSemanticQuery pluginQuery = QueryManager.createPluginQuery(plugin.getType());
|
||||||
SemanticParseInfo semanticParseInfo = buildSemanticParseInfo(modelId, plugin,
|
SemanticParseInfo semanticParseInfo = buildSemanticParseInfo(viewId, plugin,
|
||||||
queryContext, pluginRecallResult.getDistance());
|
queryContext, pluginRecallResult.getDistance());
|
||||||
semanticParseInfo.setQueryMode(pluginQuery.getQueryMode());
|
semanticParseInfo.setQueryMode(pluginQuery.getQueryMode());
|
||||||
semanticParseInfo.setScore(pluginRecallResult.getScore());
|
semanticParseInfo.setScore(pluginRecallResult.getScore());
|
||||||
@@ -75,20 +74,19 @@ public abstract class PluginParser implements SemanticParser {
|
|||||||
return PluginManager.getPluginAgentCanSupport(queryContext);
|
return PluginManager.getPluginAgentCanSupport(queryContext);
|
||||||
}
|
}
|
||||||
|
|
||||||
protected SemanticParseInfo buildSemanticParseInfo(Long modelId, Plugin plugin,
|
protected SemanticParseInfo buildSemanticParseInfo(Long viewId, Plugin plugin,
|
||||||
QueryContext queryContext, double distance) {
|
QueryContext queryContext, double distance) {
|
||||||
List<SchemaElementMatch> schemaElementMatches =
|
List<SchemaElementMatch> schemaElementMatches = queryContext.getMapInfo().getMatchedElements(viewId);
|
||||||
queryContext.getModelClusterMapInfo().getMatchedElements(modelId);
|
|
||||||
QueryFilters queryFilters = queryContext.getQueryFilters();
|
QueryFilters queryFilters = queryContext.getQueryFilters();
|
||||||
if (modelId == null && !CollectionUtils.isEmpty(plugin.getModelList())) {
|
if (viewId == null && !CollectionUtils.isEmpty(plugin.getViewList())) {
|
||||||
modelId = plugin.getModelList().get(0);
|
viewId = plugin.getViewList().get(0);
|
||||||
}
|
}
|
||||||
if (schemaElementMatches == null) {
|
if (schemaElementMatches == null) {
|
||||||
schemaElementMatches = Lists.newArrayList();
|
schemaElementMatches = Lists.newArrayList();
|
||||||
}
|
}
|
||||||
SemanticParseInfo semanticParseInfo = new SemanticParseInfo();
|
SemanticParseInfo semanticParseInfo = new SemanticParseInfo();
|
||||||
semanticParseInfo.setElementMatches(schemaElementMatches);
|
semanticParseInfo.setElementMatches(schemaElementMatches);
|
||||||
semanticParseInfo.setModel(ModelCluster.build(Sets.newHashSet(modelId)));
|
semanticParseInfo.setView(queryContext.getSemanticSchema().getView(viewId));
|
||||||
Map<String, Object> properties = new HashMap<>();
|
Map<String, Object> properties = new HashMap<>();
|
||||||
PluginParseResult pluginParseResult = new PluginParseResult();
|
PluginParseResult pluginParseResult = new PluginParseResult();
|
||||||
pluginParseResult.setPlugin(plugin);
|
pluginParseResult.setPlugin(plugin);
|
||||||
|
|||||||
@@ -57,15 +57,15 @@ public class EmbeddingRecallParser extends PluginParser {
|
|||||||
Pair<Boolean, Set<Long>> pair = PluginManager.resolve(plugin, queryContext);
|
Pair<Boolean, Set<Long>> pair = PluginManager.resolve(plugin, queryContext);
|
||||||
log.info("embedding plugin resolve: {}", pair);
|
log.info("embedding plugin resolve: {}", pair);
|
||||||
if (pair.getLeft()) {
|
if (pair.getLeft()) {
|
||||||
Set<Long> modelList = pair.getRight();
|
Set<Long> viewList = pair.getRight();
|
||||||
if (CollectionUtils.isEmpty(modelList)) {
|
if (CollectionUtils.isEmpty(viewList)) {
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
plugin.setParseMode(ParseMode.EMBEDDING_RECALL);
|
plugin.setParseMode(ParseMode.EMBEDDING_RECALL);
|
||||||
double distance = embeddingRetrieval.getDistance();
|
double distance = embeddingRetrieval.getDistance();
|
||||||
double score = queryContext.getQueryText().length() * (1 - distance);
|
double score = queryContext.getQueryText().length() * (1 - distance);
|
||||||
return PluginRecallResult.builder()
|
return PluginRecallResult.builder()
|
||||||
.plugin(plugin).modelIds(modelList).score(score).distance(distance).build();
|
.plugin(plugin).viewIds(viewList).score(score).distance(distance).build();
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return null;
|
return null;
|
||||||
|
|||||||
@@ -12,15 +12,16 @@ import com.tencent.supersonic.chat.core.query.llm.s2sql.LLMSqlQuery;
|
|||||||
import com.tencent.supersonic.chat.core.utils.ComponentFactory;
|
import com.tencent.supersonic.chat.core.utils.ComponentFactory;
|
||||||
import com.tencent.supersonic.common.util.ContextUtils;
|
import com.tencent.supersonic.common.util.ContextUtils;
|
||||||
import com.tencent.supersonic.common.util.JsonUtil;
|
import com.tencent.supersonic.common.util.JsonUtil;
|
||||||
import java.util.List;
|
|
||||||
import java.util.Objects;
|
|
||||||
import java.util.Set;
|
|
||||||
import java.util.stream.Collectors;
|
|
||||||
import lombok.extern.slf4j.Slf4j;
|
import lombok.extern.slf4j.Slf4j;
|
||||||
import org.apache.commons.lang3.StringUtils;
|
import org.apache.commons.lang3.StringUtils;
|
||||||
import org.apache.commons.lang3.tuple.Pair;
|
import org.apache.commons.lang3.tuple.Pair;
|
||||||
import org.springframework.util.CollectionUtils;
|
import org.springframework.util.CollectionUtils;
|
||||||
|
|
||||||
|
import java.util.List;
|
||||||
|
import java.util.Objects;
|
||||||
|
import java.util.Set;
|
||||||
|
import java.util.stream.Collectors;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* FunctionCallParser is an implementation of a recall plugin based on FunctionCall
|
* FunctionCallParser is an implementation of a recall plugin based on FunctionCall
|
||||||
*/
|
*/
|
||||||
@@ -56,19 +57,19 @@ public class FunctionCallParser extends PluginParser {
|
|||||||
plugin.setParseMode(ParseMode.FUNCTION_CALL);
|
plugin.setParseMode(ParseMode.FUNCTION_CALL);
|
||||||
Pair<Boolean, Set<Long>> pluginResolveResult = PluginManager.resolve(plugin, queryContext);
|
Pair<Boolean, Set<Long>> pluginResolveResult = PluginManager.resolve(plugin, queryContext);
|
||||||
if (pluginResolveResult.getLeft()) {
|
if (pluginResolveResult.getLeft()) {
|
||||||
Set<Long> modelList = pluginResolveResult.getRight();
|
Set<Long> viewList = pluginResolveResult.getRight();
|
||||||
if (CollectionUtils.isEmpty(modelList)) {
|
if (CollectionUtils.isEmpty(viewList)) {
|
||||||
return null;
|
return null;
|
||||||
}
|
}
|
||||||
double score = queryContext.getQueryText().length();
|
double score = queryContext.getQueryText().length();
|
||||||
return PluginRecallResult.builder().plugin(plugin).modelIds(modelList).score(score).build();
|
return PluginRecallResult.builder().plugin(plugin).viewIds(viewList).score(score).build();
|
||||||
}
|
}
|
||||||
return null;
|
return null;
|
||||||
}
|
}
|
||||||
|
|
||||||
public FunctionResp functionCall(QueryContext queryContext) {
|
public FunctionResp functionCall(QueryContext queryContext) {
|
||||||
List<PluginParseConfig> pluginToFunctionCall =
|
List<PluginParseConfig> pluginToFunctionCall =
|
||||||
getPluginToFunctionCall(queryContext.getModelId(), queryContext);
|
getPluginToFunctionCall(queryContext.getViewId(), queryContext);
|
||||||
if (CollectionUtils.isEmpty(pluginToFunctionCall)) {
|
if (CollectionUtils.isEmpty(pluginToFunctionCall)) {
|
||||||
log.info("function call parser, plugin is empty, skip");
|
log.info("function call parser, plugin is empty, skip");
|
||||||
return null;
|
return null;
|
||||||
|
|||||||
@@ -1,146 +0,0 @@
|
|||||||
package com.tencent.supersonic.chat.core.parser.sql.llm;
|
|
||||||
|
|
||||||
import com.tencent.supersonic.chat.core.query.SemanticQuery;
|
|
||||||
import com.tencent.supersonic.chat.core.pojo.ChatContext;
|
|
||||||
import com.tencent.supersonic.chat.core.pojo.QueryContext;
|
|
||||||
import com.tencent.supersonic.chat.api.pojo.SchemaElementMatch;
|
|
||||||
import com.tencent.supersonic.chat.api.pojo.SchemaElementType;
|
|
||||||
import com.tencent.supersonic.chat.api.pojo.SchemaModelClusterMapInfo;
|
|
||||||
import com.tencent.supersonic.common.pojo.ModelCluster;
|
|
||||||
import lombok.extern.slf4j.Slf4j;
|
|
||||||
import org.apache.commons.collections.CollectionUtils;
|
|
||||||
|
|
||||||
import java.util.Comparator;
|
|
||||||
import java.util.HashMap;
|
|
||||||
import java.util.HashSet;
|
|
||||||
import java.util.List;
|
|
||||||
import java.util.Map;
|
|
||||||
import java.util.Map.Entry;
|
|
||||||
import java.util.Objects;
|
|
||||||
import java.util.Set;
|
|
||||||
import java.util.stream.Collectors;
|
|
||||||
|
|
||||||
@Slf4j
|
|
||||||
public class HeuristicModelResolver implements ModelResolver {
|
|
||||||
|
|
||||||
protected static String selectModelBySchemaElementMatchScore(Map<String, SemanticQuery> modelQueryModes,
|
|
||||||
SchemaModelClusterMapInfo schemaMap) {
|
|
||||||
//model count priority
|
|
||||||
String modelIdByModelCount = getModelIdByMatchModelScore(schemaMap);
|
|
||||||
if (Objects.nonNull(modelIdByModelCount)) {
|
|
||||||
log.info("selectModel by model count:{}", modelIdByModelCount);
|
|
||||||
return modelIdByModelCount;
|
|
||||||
}
|
|
||||||
|
|
||||||
Map<String, ModelMatchResult> modelTypeMap = getModelTypeMap(schemaMap);
|
|
||||||
if (modelTypeMap.size() == 1) {
|
|
||||||
String modelSelect = modelTypeMap.entrySet().stream().collect(Collectors.toList()).get(0).getKey();
|
|
||||||
if (modelQueryModes.containsKey(modelSelect)) {
|
|
||||||
log.info("selectModel with only one Model [{}]", modelSelect);
|
|
||||||
return modelSelect;
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
Map.Entry<String, ModelMatchResult> maxModel = modelTypeMap.entrySet().stream()
|
|
||||||
.filter(entry -> modelQueryModes.containsKey(entry.getKey()))
|
|
||||||
.sorted((o1, o2) -> {
|
|
||||||
int difference = o2.getValue().getCount() - o1.getValue().getCount();
|
|
||||||
if (difference == 0) {
|
|
||||||
return (int) ((o2.getValue().getMaxSimilarity()
|
|
||||||
- o1.getValue().getMaxSimilarity()) * 100);
|
|
||||||
}
|
|
||||||
return difference;
|
|
||||||
}).findFirst().orElse(null);
|
|
||||||
if (maxModel != null) {
|
|
||||||
log.info("selectModel with multiple Models [{}]", maxModel.getKey());
|
|
||||||
return maxModel.getKey();
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return null;
|
|
||||||
}
|
|
||||||
|
|
||||||
private static String getModelIdByMatchModelScore(SchemaModelClusterMapInfo schemaMap) {
|
|
||||||
Map<String, List<SchemaElementMatch>> modelElementMatches = schemaMap.getModelElementMatches();
|
|
||||||
// calculate model match score, matched element gets 1.0 point, and inherit element gets 0.5 point
|
|
||||||
Map<String, Double> modelIdToModelScore = new HashMap<>();
|
|
||||||
if (Objects.nonNull(modelElementMatches)) {
|
|
||||||
for (Entry<String, List<SchemaElementMatch>> modelElementMatch : modelElementMatches.entrySet()) {
|
|
||||||
String modelId = modelElementMatch.getKey();
|
|
||||||
List<Double> modelMatchesScore = modelElementMatch.getValue().stream()
|
|
||||||
.filter(elementMatch -> elementMatch.getSimilarity() >= 1)
|
|
||||||
.filter(elementMatch -> SchemaElementType.MODEL.equals(elementMatch.getElement().getType()))
|
|
||||||
.map(elementMatch -> elementMatch.isInherited() ? 0.5 : 1.0).collect(Collectors.toList());
|
|
||||||
|
|
||||||
if (!CollectionUtils.isEmpty(modelMatchesScore)) {
|
|
||||||
// get sum of model match score
|
|
||||||
double score = modelMatchesScore.stream().mapToDouble(Double::doubleValue).sum();
|
|
||||||
modelIdToModelScore.put(modelId, score);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
Entry<String, Double> maxModelScore = modelIdToModelScore.entrySet().stream()
|
|
||||||
.max(Comparator.comparingDouble(o -> o.getValue())).orElse(null);
|
|
||||||
log.info("maxModelCount:{},modelIdToModelCount:{}", maxModelScore, modelIdToModelScore);
|
|
||||||
if (Objects.nonNull(maxModelScore)) {
|
|
||||||
return maxModelScore.getKey();
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return null;
|
|
||||||
}
|
|
||||||
|
|
||||||
public static Map<String, ModelMatchResult> getModelTypeMap(SchemaModelClusterMapInfo schemaMap) {
|
|
||||||
Map<String, ModelMatchResult> modelCount = new HashMap<>();
|
|
||||||
for (Map.Entry<String, List<SchemaElementMatch>> entry : schemaMap.getModelElementMatches().entrySet()) {
|
|
||||||
List<SchemaElementMatch> schemaElementMatches = schemaMap.getMatchedElements(entry.getKey());
|
|
||||||
if (schemaElementMatches != null && schemaElementMatches.size() > 0) {
|
|
||||||
if (!modelCount.containsKey(entry.getKey())) {
|
|
||||||
modelCount.put(entry.getKey(), new ModelMatchResult());
|
|
||||||
}
|
|
||||||
ModelMatchResult modelMatchResult = modelCount.get(entry.getKey());
|
|
||||||
Set<SchemaElementType> schemaElementTypes = new HashSet<>();
|
|
||||||
schemaElementMatches.stream()
|
|
||||||
.forEach(schemaElementMatch -> schemaElementTypes.add(
|
|
||||||
schemaElementMatch.getElement().getType()));
|
|
||||||
SchemaElementMatch schemaElementMatchMax = schemaElementMatches.stream()
|
|
||||||
.sorted((o1, o2) ->
|
|
||||||
((int) ((o2.getSimilarity() - o1.getSimilarity()) * 100))
|
|
||||||
).findFirst().orElse(null);
|
|
||||||
if (schemaElementMatchMax != null) {
|
|
||||||
modelMatchResult.setMaxSimilarity(schemaElementMatchMax.getSimilarity());
|
|
||||||
}
|
|
||||||
modelMatchResult.setCount(schemaElementTypes.size());
|
|
||||||
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return modelCount;
|
|
||||||
}
|
|
||||||
|
|
||||||
public String resolve(QueryContext queryContext, ChatContext chatCtx, Set<Long> restrictiveModels) {
|
|
||||||
SchemaModelClusterMapInfo mapInfo = queryContext.getModelClusterMapInfo();
|
|
||||||
Set<String> matchedModelClusters = mapInfo.getElementMatchesByModelIds(restrictiveModels).keySet();
|
|
||||||
Long modelId = queryContext.getModelId();
|
|
||||||
if (Objects.nonNull(modelId) && modelId > 0) {
|
|
||||||
if (CollectionUtils.isEmpty(restrictiveModels) || restrictiveModels.contains(modelId)) {
|
|
||||||
return getModelClusterByModelId(modelId, matchedModelClusters);
|
|
||||||
}
|
|
||||||
return null;
|
|
||||||
}
|
|
||||||
|
|
||||||
Map<String, SemanticQuery> modelQueryModes = new HashMap<>();
|
|
||||||
for (String matchedModel : matchedModelClusters) {
|
|
||||||
modelQueryModes.put(matchedModel, null);
|
|
||||||
}
|
|
||||||
if (modelQueryModes.size() == 1) {
|
|
||||||
return modelQueryModes.keySet().stream().findFirst().get();
|
|
||||||
}
|
|
||||||
return selectModelBySchemaElementMatchScore(modelQueryModes, mapInfo);
|
|
||||||
}
|
|
||||||
|
|
||||||
private String getModelClusterByModelId(Long modelId, Set<String> modelClusterKeySet) {
|
|
||||||
for (String modelClusterKey : modelClusterKeySet) {
|
|
||||||
if (ModelCluster.build(modelClusterKey).getModelIds().contains(modelId)) {
|
|
||||||
return modelClusterKey;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return null;
|
|
||||||
}
|
|
||||||
|
|
||||||
}
|
|
||||||
@@ -0,0 +1,138 @@
|
|||||||
|
package com.tencent.supersonic.chat.core.parser.sql.llm;
|
||||||
|
|
||||||
|
import com.tencent.supersonic.chat.api.pojo.SchemaElementMatch;
|
||||||
|
import com.tencent.supersonic.headless.api.pojo.SchemaElementType;
|
||||||
|
import com.tencent.supersonic.chat.api.pojo.SchemaMapInfo;
|
||||||
|
import com.tencent.supersonic.chat.core.pojo.QueryContext;
|
||||||
|
import com.tencent.supersonic.chat.core.query.SemanticQuery;
|
||||||
|
import lombok.extern.slf4j.Slf4j;
|
||||||
|
import org.apache.commons.collections.CollectionUtils;
|
||||||
|
|
||||||
|
import java.util.ArrayList;
|
||||||
|
import java.util.Comparator;
|
||||||
|
import java.util.HashMap;
|
||||||
|
import java.util.HashSet;
|
||||||
|
import java.util.List;
|
||||||
|
import java.util.Map;
|
||||||
|
import java.util.Map.Entry;
|
||||||
|
import java.util.Objects;
|
||||||
|
import java.util.Set;
|
||||||
|
import java.util.stream.Collectors;
|
||||||
|
|
||||||
|
@Slf4j
|
||||||
|
public class HeuristicViewResolver implements ViewResolver {
|
||||||
|
|
||||||
|
protected static Long selectViewBySchemaElementMatchScore(Map<Long, SemanticQuery> viewQueryModes,
|
||||||
|
SchemaMapInfo schemaMap) {
|
||||||
|
//view count priority
|
||||||
|
Long viewIdByViewCount = getViewIdByMatchViewScore(schemaMap);
|
||||||
|
if (Objects.nonNull(viewIdByViewCount)) {
|
||||||
|
log.info("selectView by view count:{}", viewIdByViewCount);
|
||||||
|
return viewIdByViewCount;
|
||||||
|
}
|
||||||
|
|
||||||
|
Map<Long, ViewMatchResult> viewTypeMap = getViewTypeMap(schemaMap);
|
||||||
|
if (viewTypeMap.size() == 1) {
|
||||||
|
Long viewSelect = new ArrayList<>(viewTypeMap.entrySet()).get(0).getKey();
|
||||||
|
if (viewQueryModes.containsKey(viewSelect)) {
|
||||||
|
log.info("selectView with only one View [{}]", viewSelect);
|
||||||
|
return viewSelect;
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
Map.Entry<Long, ViewMatchResult> maxView = viewTypeMap.entrySet().stream()
|
||||||
|
.filter(entry -> viewQueryModes.containsKey(entry.getKey()))
|
||||||
|
.sorted((o1, o2) -> {
|
||||||
|
int difference = o2.getValue().getCount() - o1.getValue().getCount();
|
||||||
|
if (difference == 0) {
|
||||||
|
return (int) ((o2.getValue().getMaxSimilarity()
|
||||||
|
- o1.getValue().getMaxSimilarity()) * 100);
|
||||||
|
}
|
||||||
|
return difference;
|
||||||
|
}).findFirst().orElse(null);
|
||||||
|
if (maxView != null) {
|
||||||
|
log.info("selectView with multiple Views [{}]", maxView.getKey());
|
||||||
|
return maxView.getKey();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return null;
|
||||||
|
}
|
||||||
|
|
||||||
|
private static Long getViewIdByMatchViewScore(SchemaMapInfo schemaMap) {
|
||||||
|
Map<Long, List<SchemaElementMatch>> viewElementMatches = schemaMap.getViewElementMatches();
|
||||||
|
// calculate view match score, matched element gets 1.0 point, and inherit element gets 0.5 point
|
||||||
|
Map<Long, Double> viewIdToViewScore = new HashMap<>();
|
||||||
|
if (Objects.nonNull(viewElementMatches)) {
|
||||||
|
for (Entry<Long, List<SchemaElementMatch>> viewElementMatch : viewElementMatches.entrySet()) {
|
||||||
|
Long viewId = viewElementMatch.getKey();
|
||||||
|
List<Double> viewMatchesScore = viewElementMatch.getValue().stream()
|
||||||
|
.filter(elementMatch -> elementMatch.getSimilarity() >= 1)
|
||||||
|
.filter(elementMatch -> SchemaElementType.VIEW.equals(elementMatch.getElement().getType()))
|
||||||
|
.map(elementMatch -> elementMatch.isInherited() ? 0.5 : 1.0).collect(Collectors.toList());
|
||||||
|
|
||||||
|
if (!CollectionUtils.isEmpty(viewMatchesScore)) {
|
||||||
|
// get sum of view match score
|
||||||
|
double score = viewMatchesScore.stream().mapToDouble(Double::doubleValue).sum();
|
||||||
|
viewIdToViewScore.put(viewId, score);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
Entry<Long, Double> maxViewScore = viewIdToViewScore.entrySet().stream()
|
||||||
|
.max(Comparator.comparingDouble(Entry::getValue)).orElse(null);
|
||||||
|
log.info("maxViewCount:{},viewIdToViewCount:{}", maxViewScore, viewIdToViewScore);
|
||||||
|
if (Objects.nonNull(maxViewScore)) {
|
||||||
|
return maxViewScore.getKey();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return null;
|
||||||
|
}
|
||||||
|
|
||||||
|
public static Map<Long, ViewMatchResult> getViewTypeMap(SchemaMapInfo schemaMap) {
|
||||||
|
Map<Long, ViewMatchResult> viewCount = new HashMap<>();
|
||||||
|
for (Map.Entry<Long, List<SchemaElementMatch>> entry : schemaMap.getViewElementMatches().entrySet()) {
|
||||||
|
List<SchemaElementMatch> schemaElementMatches = schemaMap.getMatchedElements(entry.getKey());
|
||||||
|
if (schemaElementMatches != null && schemaElementMatches.size() > 0) {
|
||||||
|
if (!viewCount.containsKey(entry.getKey())) {
|
||||||
|
viewCount.put(entry.getKey(), new ViewMatchResult());
|
||||||
|
}
|
||||||
|
ViewMatchResult viewMatchResult = viewCount.get(entry.getKey());
|
||||||
|
Set<SchemaElementType> schemaElementTypes = new HashSet<>();
|
||||||
|
schemaElementMatches.stream()
|
||||||
|
.forEach(schemaElementMatch -> schemaElementTypes.add(
|
||||||
|
schemaElementMatch.getElement().getType()));
|
||||||
|
SchemaElementMatch schemaElementMatchMax = schemaElementMatches.stream()
|
||||||
|
.sorted((o1, o2) ->
|
||||||
|
((int) ((o2.getSimilarity() - o1.getSimilarity()) * 100))
|
||||||
|
).findFirst().orElse(null);
|
||||||
|
if (schemaElementMatchMax != null) {
|
||||||
|
viewMatchResult.setMaxSimilarity(schemaElementMatchMax.getSimilarity());
|
||||||
|
}
|
||||||
|
viewMatchResult.setCount(schemaElementTypes.size());
|
||||||
|
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return viewCount;
|
||||||
|
}
|
||||||
|
|
||||||
|
public Long resolve(QueryContext queryContext, Set<Long> agentViewIds) {
|
||||||
|
SchemaMapInfo mapInfo = queryContext.getMapInfo();
|
||||||
|
Set<Long> matchedViews = mapInfo.getMatchedViewInfos();
|
||||||
|
Long viewId = queryContext.getViewId();
|
||||||
|
if (Objects.nonNull(viewId) && viewId > 0) {
|
||||||
|
if (CollectionUtils.isEmpty(agentViewIds) || agentViewIds.contains(viewId)) {
|
||||||
|
return viewId;
|
||||||
|
}
|
||||||
|
return null;
|
||||||
|
}
|
||||||
|
if (CollectionUtils.isNotEmpty(agentViewIds)) {
|
||||||
|
matchedViews.retainAll(agentViewIds);
|
||||||
|
}
|
||||||
|
Map<Long, SemanticQuery> viewQueryModes = new HashMap<>();
|
||||||
|
for (Long viewIds : matchedViews) {
|
||||||
|
viewQueryModes.put(viewIds, null);
|
||||||
|
}
|
||||||
|
if (viewQueryModes.size() == 1) {
|
||||||
|
return viewQueryModes.keySet().stream().findFirst().get();
|
||||||
|
}
|
||||||
|
return selectViewBySchemaElementMatchScore(viewQueryModes, mapInfo);
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
||||||
@@ -1,31 +1,35 @@
|
|||||||
package com.tencent.supersonic.chat.core.parser.sql.llm;
|
package com.tencent.supersonic.chat.core.parser.sql.llm;
|
||||||
|
|
||||||
import com.tencent.supersonic.chat.api.pojo.SchemaElement;
|
import com.google.common.collect.Lists;
|
||||||
|
import com.tencent.supersonic.chat.core.utils.S2SqlDateHelper;
|
||||||
|
import com.tencent.supersonic.headless.api.pojo.SchemaElement;
|
||||||
import com.tencent.supersonic.chat.api.pojo.SchemaElementMatch;
|
import com.tencent.supersonic.chat.api.pojo.SchemaElementMatch;
|
||||||
import com.tencent.supersonic.chat.api.pojo.SchemaElementType;
|
import com.tencent.supersonic.headless.api.pojo.SchemaElementType;
|
||||||
import com.tencent.supersonic.chat.api.pojo.SemanticSchema;
|
import com.tencent.supersonic.chat.api.pojo.SemanticSchema;
|
||||||
import com.tencent.supersonic.chat.core.agent.Agent;
|
import com.tencent.supersonic.chat.core.agent.Agent;
|
||||||
import com.tencent.supersonic.chat.core.agent.AgentToolType;
|
import com.tencent.supersonic.chat.core.agent.AgentToolType;
|
||||||
import com.tencent.supersonic.chat.core.agent.NL2SQLTool;
|
import com.tencent.supersonic.chat.core.agent.NL2SQLTool;
|
||||||
import com.tencent.supersonic.chat.core.config.LLMParserConfig;
|
import com.tencent.supersonic.chat.core.config.LLMParserConfig;
|
||||||
import com.tencent.supersonic.chat.core.config.OptimizationConfig;
|
import com.tencent.supersonic.chat.core.config.OptimizationConfig;
|
||||||
import com.tencent.supersonic.chat.core.knowledge.semantic.SemanticInterpreter;
|
import com.tencent.supersonic.chat.core.query.semantic.SemanticInterpreter;
|
||||||
import com.tencent.supersonic.chat.core.parser.SatisfactionChecker;
|
import com.tencent.supersonic.chat.core.parser.SatisfactionChecker;
|
||||||
import com.tencent.supersonic.chat.core.pojo.ChatContext;
|
|
||||||
import com.tencent.supersonic.chat.core.pojo.QueryContext;
|
import com.tencent.supersonic.chat.core.pojo.QueryContext;
|
||||||
import com.tencent.supersonic.chat.core.query.QueryManager;
|
|
||||||
import com.tencent.supersonic.chat.core.query.SemanticQuery;
|
|
||||||
import com.tencent.supersonic.chat.core.query.llm.s2sql.LLMReq;
|
import com.tencent.supersonic.chat.core.query.llm.s2sql.LLMReq;
|
||||||
import com.tencent.supersonic.chat.core.query.llm.s2sql.LLMReq.ElementValue;
|
import com.tencent.supersonic.chat.core.query.llm.s2sql.LLMReq.ElementValue;
|
||||||
import com.tencent.supersonic.chat.core.query.llm.s2sql.LLMResp;
|
import com.tencent.supersonic.chat.core.query.llm.s2sql.LLMResp;
|
||||||
import com.tencent.supersonic.chat.core.query.llm.s2sql.LLMSqlQuery;
|
|
||||||
import com.tencent.supersonic.chat.core.utils.ComponentFactory;
|
import com.tencent.supersonic.chat.core.utils.ComponentFactory;
|
||||||
import com.tencent.supersonic.common.pojo.ModelCluster;
|
|
||||||
import com.tencent.supersonic.common.pojo.enums.DataFormatTypeEnum;
|
import com.tencent.supersonic.common.pojo.enums.DataFormatTypeEnum;
|
||||||
import com.tencent.supersonic.common.pojo.enums.TimeDimensionEnum;
|
import com.tencent.supersonic.common.pojo.enums.TimeDimensionEnum;
|
||||||
import com.tencent.supersonic.common.util.DateUtils;
|
import com.tencent.supersonic.common.util.DateUtils;
|
||||||
import com.tencent.supersonic.headless.api.pojo.SchemaItem;
|
import com.tencent.supersonic.headless.api.pojo.SchemaItem;
|
||||||
import com.tencent.supersonic.headless.api.response.ModelSchemaResp;
|
import com.tencent.supersonic.headless.api.pojo.response.ViewSchemaResp;
|
||||||
|
import lombok.extern.slf4j.Slf4j;
|
||||||
|
import org.apache.commons.lang3.StringUtils;
|
||||||
|
import org.apache.commons.lang3.tuple.Pair;
|
||||||
|
import org.springframework.beans.factory.annotation.Autowired;
|
||||||
|
import org.springframework.stereotype.Service;
|
||||||
|
import org.springframework.util.CollectionUtils;
|
||||||
|
|
||||||
import java.util.ArrayList;
|
import java.util.ArrayList;
|
||||||
import java.util.Comparator;
|
import java.util.Comparator;
|
||||||
import java.util.HashSet;
|
import java.util.HashSet;
|
||||||
@@ -35,12 +39,6 @@ import java.util.Objects;
|
|||||||
import java.util.Optional;
|
import java.util.Optional;
|
||||||
import java.util.Set;
|
import java.util.Set;
|
||||||
import java.util.stream.Collectors;
|
import java.util.stream.Collectors;
|
||||||
import lombok.extern.slf4j.Slf4j;
|
|
||||||
import org.apache.commons.lang3.StringUtils;
|
|
||||||
import org.apache.commons.lang3.tuple.Pair;
|
|
||||||
import org.springframework.beans.factory.annotation.Autowired;
|
|
||||||
import org.springframework.stereotype.Service;
|
|
||||||
import org.springframework.util.CollectionUtils;
|
|
||||||
|
|
||||||
@Slf4j
|
@Slf4j
|
||||||
@Service
|
@Service
|
||||||
@@ -63,79 +61,54 @@ public class LLMRequestService {
|
|||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
|
||||||
public ModelCluster getModelCluster(QueryContext queryCtx, ChatContext chatCtx) {
|
public Long getViewId(QueryContext queryCtx) {
|
||||||
Agent agent = queryCtx.getAgent();
|
Agent agent = queryCtx.getAgent();
|
||||||
Set<Long> distinctModelIds = new HashSet<>();
|
Set<Long> agentViewIds = new HashSet<>();
|
||||||
if (Objects.nonNull(agent)) {
|
if (Objects.nonNull(agent)) {
|
||||||
distinctModelIds = agent.getModelIds(AgentToolType.NL2SQL_LLM);
|
agentViewIds = agent.getViewIds(AgentToolType.NL2SQL_LLM);
|
||||||
}
|
}
|
||||||
if (llmParserConfig.getAllModel()) {
|
if (Agent.containsAllModel(agentViewIds)) {
|
||||||
ModelCluster modelCluster = ModelCluster.build(distinctModelIds);
|
agentViewIds = new HashSet<>();
|
||||||
if (!CollectionUtils.isEmpty(queryCtx.getCandidateQueries())) {
|
|
||||||
queryCtx.getCandidateQueries().stream().forEach(o -> {
|
|
||||||
if (LLMSqlQuery.QUERY_MODE.equals(o.getParseInfo().getQueryMode())) {
|
|
||||||
o.getParseInfo().setModel(modelCluster);
|
|
||||||
}
|
}
|
||||||
});
|
ViewResolver viewResolver = ComponentFactory.getModelResolver();
|
||||||
}
|
return viewResolver.resolve(queryCtx, agentViewIds);
|
||||||
SemanticQuery semanticQuery = QueryManager.createQuery(LLMSqlQuery.QUERY_MODE);
|
|
||||||
semanticQuery.getParseInfo().setModel(modelCluster);
|
|
||||||
List<SchemaElementMatch> schemaElementMatches = new ArrayList<>();
|
|
||||||
distinctModelIds.stream().forEach(o -> {
|
|
||||||
if (!CollectionUtils.isEmpty(queryCtx.getMapInfo().getMatchedElements(o))) {
|
|
||||||
schemaElementMatches.addAll(queryCtx.getMapInfo().getMatchedElements(o));
|
|
||||||
}
|
|
||||||
});
|
|
||||||
queryCtx.getModelClusterMapInfo().setMatchedElements(modelCluster.getKey(), schemaElementMatches);
|
|
||||||
return modelCluster;
|
|
||||||
}
|
|
||||||
if (Agent.containsAllModel(distinctModelIds)) {
|
|
||||||
distinctModelIds = new HashSet<>();
|
|
||||||
}
|
|
||||||
ModelResolver modelResolver = ComponentFactory.getModelResolver();
|
|
||||||
String modelCluster = modelResolver.resolve(queryCtx, chatCtx, distinctModelIds);
|
|
||||||
log.info("resolve modelId:{},llmParser Models:{}", modelCluster, distinctModelIds);
|
|
||||||
return ModelCluster.build(modelCluster);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
public NL2SQLTool getParserTool(QueryContext queryCtx, Set<Long> modelIdSet) {
|
public NL2SQLTool getParserTool(QueryContext queryCtx, Long viewId) {
|
||||||
Agent agent = queryCtx.getAgent();
|
Agent agent = queryCtx.getAgent();
|
||||||
|
if (Objects.isNull(agent)) {
|
||||||
|
return null;
|
||||||
|
}
|
||||||
List<NL2SQLTool> commonAgentTools = agent.getParserTools(AgentToolType.NL2SQL_LLM);
|
List<NL2SQLTool> commonAgentTools = agent.getParserTools(AgentToolType.NL2SQL_LLM);
|
||||||
Optional<NL2SQLTool> llmParserTool = commonAgentTools.stream()
|
Optional<NL2SQLTool> llmParserTool = commonAgentTools.stream()
|
||||||
.filter(tool -> {
|
.filter(tool -> {
|
||||||
List<Long> modelIds = tool.getModelIds();
|
List<Long> viewIds = tool.getViewIds();
|
||||||
if (Agent.containsAllModel(new HashSet<>(modelIds))) {
|
if (Agent.containsAllModel(new HashSet<>(viewIds))) {
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
for (Long modelId : modelIdSet) {
|
return viewIds.contains(viewId);
|
||||||
if (modelIds.contains(modelId)) {
|
|
||||||
return true;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return false;
|
|
||||||
})
|
})
|
||||||
.findFirst();
|
.findFirst();
|
||||||
return llmParserTool.orElse(null);
|
return llmParserTool.orElse(null);
|
||||||
}
|
}
|
||||||
|
|
||||||
public LLMReq getLlmReq(QueryContext queryCtx, SemanticSchema semanticSchema,
|
public LLMReq getLlmReq(QueryContext queryCtx, Long viewId,
|
||||||
ModelCluster modelCluster, List<ElementValue> linkingValues) {
|
SemanticSchema semanticSchema, List<ElementValue> linkingValues) {
|
||||||
Map<Long, String> modelIdToName = semanticSchema.getModelIdToName();
|
Map<Long, String> viewIdToName = semanticSchema.getViewIdToName();
|
||||||
String queryText = queryCtx.getQueryText();
|
String queryText = queryCtx.getQueryText();
|
||||||
|
|
||||||
LLMReq llmReq = new LLMReq();
|
LLMReq llmReq = new LLMReq();
|
||||||
llmReq.setQueryText(queryText);
|
llmReq.setQueryText(queryText);
|
||||||
Long firstModelId = modelCluster.getFirstModel();
|
|
||||||
LLMReq.FilterCondition filterCondition = new LLMReq.FilterCondition();
|
LLMReq.FilterCondition filterCondition = new LLMReq.FilterCondition();
|
||||||
llmReq.setFilterCondition(filterCondition);
|
llmReq.setFilterCondition(filterCondition);
|
||||||
|
|
||||||
LLMReq.LLMSchema llmSchema = new LLMReq.LLMSchema();
|
LLMReq.LLMSchema llmSchema = new LLMReq.LLMSchema();
|
||||||
llmSchema.setModelName(modelIdToName.get(firstModelId));
|
llmSchema.setViewName(viewIdToName.get(viewId));
|
||||||
llmSchema.setDomainName(modelIdToName.get(firstModelId));
|
llmSchema.setDomainName(viewIdToName.get(viewId));
|
||||||
|
|
||||||
List<String> fieldNameList = getFieldNameList(queryCtx, modelCluster, llmParserConfig);
|
List<String> fieldNameList = getFieldNameList(queryCtx, viewId, llmParserConfig);
|
||||||
|
|
||||||
String priorExts = getPriorExts(modelCluster.getModelIds(), fieldNameList);
|
String priorExts = getPriorExts(viewId, fieldNameList);
|
||||||
llmReq.setPriorExts(priorExts);
|
llmReq.setPriorExts(priorExts);
|
||||||
|
|
||||||
fieldNameList.add(TimeDimensionEnum.DAY.getChName());
|
fieldNameList.add(TimeDimensionEnum.DAY.getChName());
|
||||||
@@ -148,7 +121,7 @@ public class LLMRequestService {
|
|||||||
}
|
}
|
||||||
llmReq.setLinking(linking);
|
llmReq.setLinking(linking);
|
||||||
|
|
||||||
String currentDate = S2SqlDateHelper.getReferenceDate(queryCtx, firstModelId);
|
String currentDate = S2SqlDateHelper.getReferenceDate(queryCtx, viewId);
|
||||||
if (StringUtils.isEmpty(currentDate)) {
|
if (StringUtils.isEmpty(currentDate)) {
|
||||||
currentDate = DateUtils.getBeforeDate(0);
|
currentDate = DateUtils.getBeforeDate(0);
|
||||||
}
|
}
|
||||||
@@ -157,29 +130,28 @@ public class LLMRequestService {
|
|||||||
return llmReq;
|
return llmReq;
|
||||||
}
|
}
|
||||||
|
|
||||||
public LLMResp requestLLM(LLMReq llmReq, String modelClusterKey) {
|
public LLMResp requestLLM(LLMReq llmReq, Long viewId) {
|
||||||
return ComponentFactory.getLLMProxy().query2sql(llmReq, modelClusterKey);
|
return ComponentFactory.getLLMProxy().query2sql(llmReq, viewId);
|
||||||
}
|
}
|
||||||
|
|
||||||
protected List<String> getFieldNameList(QueryContext queryCtx, ModelCluster modelCluster,
|
protected List<String> getFieldNameList(QueryContext queryCtx, Long viewId,
|
||||||
LLMParserConfig llmParserConfig) {
|
LLMParserConfig llmParserConfig) {
|
||||||
|
|
||||||
Set<String> results = getTopNFieldNames(queryCtx, modelCluster, llmParserConfig);
|
Set<String> results = getTopNFieldNames(queryCtx, viewId, llmParserConfig);
|
||||||
|
|
||||||
Set<String> fieldNameList = getMatchedFieldNames(queryCtx, modelCluster);
|
Set<String> fieldNameList = getMatchedFieldNames(queryCtx, viewId);
|
||||||
|
|
||||||
results.addAll(fieldNameList);
|
results.addAll(fieldNameList);
|
||||||
return new ArrayList<>(results);
|
return new ArrayList<>(results);
|
||||||
}
|
}
|
||||||
|
|
||||||
private String getPriorExts(Set<Long> modelIds, List<String> fieldNameList) {
|
private String getPriorExts(Long viewId, List<String> fieldNameList) {
|
||||||
StringBuilder extraInfoSb = new StringBuilder();
|
StringBuilder extraInfoSb = new StringBuilder();
|
||||||
List<ModelSchemaResp> modelSchemaResps = semanticInterpreter.fetchModelSchema(
|
List<ViewSchemaResp> viewSchemaResps = semanticInterpreter.fetchViewSchema(
|
||||||
new ArrayList<>(modelIds), true);
|
Lists.newArrayList(viewId), true);
|
||||||
if (!CollectionUtils.isEmpty(modelSchemaResps)) {
|
if (!CollectionUtils.isEmpty(viewSchemaResps)) {
|
||||||
|
ViewSchemaResp viewSchemaResp = viewSchemaResps.get(0);
|
||||||
ModelSchemaResp modelSchemaResp = modelSchemaResps.get(0);
|
Map<String, String> fieldNameToDataFormatType = viewSchemaResp.getMetrics()
|
||||||
Map<String, String> fieldNameToDataFormatType = modelSchemaResp.getMetrics()
|
|
||||||
.stream().filter(metricSchemaResp -> Objects.nonNull(metricSchemaResp.getDataFormatType()))
|
.stream().filter(metricSchemaResp -> Objects.nonNull(metricSchemaResp.getDataFormatType()))
|
||||||
.flatMap(metricSchemaResp -> {
|
.flatMap(metricSchemaResp -> {
|
||||||
Set<Pair<String, String>> result = new HashSet<>();
|
Set<Pair<String, String>> result = new HashSet<>();
|
||||||
@@ -207,11 +179,9 @@ public class LLMRequestService {
|
|||||||
return extraInfoSb.toString();
|
return extraInfoSb.toString();
|
||||||
}
|
}
|
||||||
|
|
||||||
protected List<ElementValue> getValueList(QueryContext queryCtx, ModelCluster modelCluster) {
|
protected List<ElementValue> getValueList(QueryContext queryCtx, Long viewId) {
|
||||||
Map<Long, String> itemIdToName = getItemIdToName(queryCtx, modelCluster);
|
Map<Long, String> itemIdToName = getItemIdToName(queryCtx, viewId);
|
||||||
|
List<SchemaElementMatch> matchedElements = queryCtx.getMapInfo().getMatchedElements(viewId);
|
||||||
List<SchemaElementMatch> matchedElements = queryCtx.getModelClusterMapInfo()
|
|
||||||
.getMatchedElements(modelCluster.getKey());
|
|
||||||
if (CollectionUtils.isEmpty(matchedElements)) {
|
if (CollectionUtils.isEmpty(matchedElements)) {
|
||||||
return new ArrayList<>();
|
return new ArrayList<>();
|
||||||
}
|
}
|
||||||
@@ -231,22 +201,21 @@ public class LLMRequestService {
|
|||||||
return new ArrayList<>(valueMatches);
|
return new ArrayList<>(valueMatches);
|
||||||
}
|
}
|
||||||
|
|
||||||
protected Map<Long, String> getItemIdToName(QueryContext queryCtx, ModelCluster modelCluster) {
|
protected Map<Long, String> getItemIdToName(QueryContext queryCtx, Long viewId) {
|
||||||
SemanticSchema semanticSchema = queryCtx.getSemanticSchema();
|
SemanticSchema semanticSchema = queryCtx.getSemanticSchema();
|
||||||
return semanticSchema.getDimensions(modelCluster.getModelIds()).stream()
|
return semanticSchema.getDimensions(viewId).stream()
|
||||||
.collect(Collectors.toMap(SchemaElement::getId, SchemaElement::getName, (value1, value2) -> value2));
|
.collect(Collectors.toMap(SchemaElement::getId, SchemaElement::getName, (value1, value2) -> value2));
|
||||||
}
|
}
|
||||||
|
|
||||||
private Set<String> getTopNFieldNames(QueryContext queryCtx, ModelCluster modelCluster,
|
private Set<String> getTopNFieldNames(QueryContext queryCtx, Long viewId, LLMParserConfig llmParserConfig) {
|
||||||
LLMParserConfig llmParserConfig) {
|
|
||||||
SemanticSchema semanticSchema = queryCtx.getSemanticSchema();
|
SemanticSchema semanticSchema = queryCtx.getSemanticSchema();
|
||||||
Set<String> results = semanticSchema.getDimensions(modelCluster.getModelIds()).stream()
|
Set<String> results = semanticSchema.getDimensions(viewId).stream()
|
||||||
.sorted(Comparator.comparing(SchemaElement::getUseCnt).reversed())
|
.sorted(Comparator.comparing(SchemaElement::getUseCnt).reversed())
|
||||||
.limit(llmParserConfig.getDimensionTopN())
|
.limit(llmParserConfig.getDimensionTopN())
|
||||||
.map(entry -> entry.getName())
|
.map(entry -> entry.getName())
|
||||||
.collect(Collectors.toSet());
|
.collect(Collectors.toSet());
|
||||||
|
|
||||||
Set<String> metrics = semanticSchema.getMetrics(modelCluster.getModelIds()).stream()
|
Set<String> metrics = semanticSchema.getMetrics(viewId).stream()
|
||||||
.sorted(Comparator.comparing(SchemaElement::getUseCnt).reversed())
|
.sorted(Comparator.comparing(SchemaElement::getUseCnt).reversed())
|
||||||
.limit(llmParserConfig.getMetricTopN())
|
.limit(llmParserConfig.getMetricTopN())
|
||||||
.map(entry -> entry.getName())
|
.map(entry -> entry.getName())
|
||||||
@@ -256,10 +225,9 @@ public class LLMRequestService {
|
|||||||
return results;
|
return results;
|
||||||
}
|
}
|
||||||
|
|
||||||
protected Set<String> getMatchedFieldNames(QueryContext queryCtx, ModelCluster modelCluster) {
|
protected Set<String> getMatchedFieldNames(QueryContext queryCtx, Long viewId) {
|
||||||
Map<Long, String> itemIdToName = getItemIdToName(queryCtx, modelCluster);
|
Map<Long, String> itemIdToName = getItemIdToName(queryCtx, viewId);
|
||||||
List<SchemaElementMatch> matchedElements = queryCtx.getModelClusterMapInfo()
|
List<SchemaElementMatch> matchedElements = queryCtx.getMapInfo().getMatchedElements(viewId);
|
||||||
.getMatchedElements(modelCluster.getKey());
|
|
||||||
if (CollectionUtils.isEmpty(matchedElements)) {
|
if (CollectionUtils.isEmpty(matchedElements)) {
|
||||||
return new HashSet<>();
|
return new HashSet<>();
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -9,7 +9,7 @@ import com.tencent.supersonic.chat.core.query.llm.s2sql.LLMResp;
|
|||||||
import com.tencent.supersonic.chat.core.query.llm.s2sql.LLMSqlQuery;
|
import com.tencent.supersonic.chat.core.query.llm.s2sql.LLMSqlQuery;
|
||||||
import com.tencent.supersonic.chat.core.query.llm.s2sql.LLMSqlResp;
|
import com.tencent.supersonic.chat.core.query.llm.s2sql.LLMSqlResp;
|
||||||
import com.tencent.supersonic.common.pojo.Constants;
|
import com.tencent.supersonic.common.pojo.Constants;
|
||||||
import com.tencent.supersonic.common.util.jsqlparser.SqlParserEqualHelper;
|
import com.tencent.supersonic.common.util.jsqlparser.SqlEqualHelper;
|
||||||
import lombok.extern.slf4j.Slf4j;
|
import lombok.extern.slf4j.Slf4j;
|
||||||
import org.apache.commons.collections4.MapUtils;
|
import org.apache.commons.collections4.MapUtils;
|
||||||
import org.springframework.stereotype.Service;
|
import org.springframework.stereotype.Service;
|
||||||
@@ -28,10 +28,9 @@ public class LLMResponseService {
|
|||||||
}
|
}
|
||||||
LLMSemanticQuery semanticQuery = QueryManager.createLLMQuery(LLMSqlQuery.QUERY_MODE);
|
LLMSemanticQuery semanticQuery = QueryManager.createLLMQuery(LLMSqlQuery.QUERY_MODE);
|
||||||
SemanticParseInfo parseInfo = semanticQuery.getParseInfo();
|
SemanticParseInfo parseInfo = semanticQuery.getParseInfo();
|
||||||
parseInfo.setModel(parseResult.getModelCluster());
|
parseInfo.setView(queryCtx.getSemanticSchema().getView(parseResult.getViewId()));
|
||||||
NL2SQLTool commonAgentTool = parseResult.getCommonAgentTool();
|
NL2SQLTool commonAgentTool = parseResult.getCommonAgentTool();
|
||||||
parseInfo.getElementMatches().addAll(queryCtx.getModelClusterMapInfo()
|
parseInfo.getElementMatches().addAll(queryCtx.getMapInfo().getMatchedElements(parseInfo.getViewId()));
|
||||||
.getMatchedElements(parseInfo.getModelClusterKey()));
|
|
||||||
|
|
||||||
Map<String, Object> properties = new HashMap<>();
|
Map<String, Object> properties = new HashMap<>();
|
||||||
properties.put(Constants.CONTEXT, parseResult);
|
properties.put(Constants.CONTEXT, parseResult);
|
||||||
@@ -42,7 +41,6 @@ public class LLMResponseService {
|
|||||||
parseInfo.setScore(queryCtx.getQueryText().length() * (1 + weight));
|
parseInfo.setScore(queryCtx.getQueryText().length() * (1 + weight));
|
||||||
parseInfo.setQueryMode(semanticQuery.getQueryMode());
|
parseInfo.setQueryMode(semanticQuery.getQueryMode());
|
||||||
parseInfo.getSqlInfo().setS2SQL(s2SQL);
|
parseInfo.getSqlInfo().setS2SQL(s2SQL);
|
||||||
parseInfo.setModel(parseResult.getModelCluster());
|
|
||||||
queryCtx.getCandidateQueries().add(semanticQuery);
|
queryCtx.getCandidateQueries().add(semanticQuery);
|
||||||
return parseInfo;
|
return parseInfo;
|
||||||
}
|
}
|
||||||
@@ -54,7 +52,7 @@ public class LLMResponseService {
|
|||||||
Map<String, LLMSqlResp> result = new HashMap<>();
|
Map<String, LLMSqlResp> result = new HashMap<>();
|
||||||
for (Map.Entry<String, LLMSqlResp> entry : llmResp.getSqlRespMap().entrySet()) {
|
for (Map.Entry<String, LLMSqlResp> entry : llmResp.getSqlRespMap().entrySet()) {
|
||||||
String key = entry.getKey();
|
String key = entry.getKey();
|
||||||
if (result.keySet().stream().anyMatch(existKey -> SqlParserEqualHelper.equals(existKey, key))) {
|
if (result.keySet().stream().anyMatch(existKey -> SqlEqualHelper.equals(existKey, key))) {
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
result.put(key, entry.getValue());
|
result.put(key, entry.getValue());
|
||||||
|
|||||||
@@ -9,14 +9,13 @@ import com.tencent.supersonic.chat.core.query.llm.s2sql.LLMReq;
|
|||||||
import com.tencent.supersonic.chat.core.query.llm.s2sql.LLMReq.ElementValue;
|
import com.tencent.supersonic.chat.core.query.llm.s2sql.LLMReq.ElementValue;
|
||||||
import com.tencent.supersonic.chat.core.query.llm.s2sql.LLMResp;
|
import com.tencent.supersonic.chat.core.query.llm.s2sql.LLMResp;
|
||||||
import com.tencent.supersonic.chat.core.query.llm.s2sql.LLMSqlResp;
|
import com.tencent.supersonic.chat.core.query.llm.s2sql.LLMSqlResp;
|
||||||
import com.tencent.supersonic.common.pojo.ModelCluster;
|
|
||||||
import com.tencent.supersonic.common.util.ContextUtils;
|
import com.tencent.supersonic.common.util.ContextUtils;
|
||||||
|
import lombok.extern.slf4j.Slf4j;
|
||||||
|
import org.apache.commons.collections4.MapUtils;
|
||||||
|
|
||||||
import java.util.List;
|
import java.util.List;
|
||||||
import java.util.Map;
|
import java.util.Map;
|
||||||
import java.util.Objects;
|
import java.util.Objects;
|
||||||
import lombok.extern.slf4j.Slf4j;
|
|
||||||
import org.apache.commons.collections4.MapUtils;
|
|
||||||
import org.apache.commons.lang3.StringUtils;
|
|
||||||
|
|
||||||
@Slf4j
|
@Slf4j
|
||||||
public class LLMSqlParser implements SemanticParser {
|
public class LLMSqlParser implements SemanticParser {
|
||||||
@@ -30,31 +29,30 @@ public class LLMSqlParser implements SemanticParser {
|
|||||||
}
|
}
|
||||||
try {
|
try {
|
||||||
//2.get modelId from queryCtx and chatCtx.
|
//2.get modelId from queryCtx and chatCtx.
|
||||||
ModelCluster modelCluster = requestService.getModelCluster(queryCtx, chatCtx);
|
Long viewId = requestService.getViewId(queryCtx);
|
||||||
if (StringUtils.isBlank(modelCluster.getKey())) {
|
if (viewId == null) {
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
//3.get agent tool and determine whether to skip this parser.
|
//3.get agent tool and determine whether to skip this parser.
|
||||||
NL2SQLTool commonAgentTool = requestService.getParserTool(queryCtx, modelCluster.getModelIds());
|
NL2SQLTool commonAgentTool = requestService.getParserTool(queryCtx, viewId);
|
||||||
if (Objects.isNull(commonAgentTool)) {
|
if (Objects.isNull(commonAgentTool)) {
|
||||||
log.info("no tool in this agent, skip {}", LLMSqlParser.class);
|
log.info("no tool in this agent, skip {}", LLMSqlParser.class);
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
//4.construct a request, call the API for the large model, and retrieve the results.
|
//4.construct a request, call the API for the large model, and retrieve the results.
|
||||||
List<ElementValue> linkingValues = requestService.getValueList(queryCtx, modelCluster);
|
List<ElementValue> linkingValues = requestService.getValueList(queryCtx, viewId);
|
||||||
SemanticSchema semanticSchema = queryCtx.getSemanticSchema();
|
SemanticSchema semanticSchema = queryCtx.getSemanticSchema();
|
||||||
LLMReq llmReq = requestService.getLlmReq(queryCtx, semanticSchema, modelCluster, linkingValues);
|
LLMReq llmReq = requestService.getLlmReq(queryCtx, viewId, semanticSchema, linkingValues);
|
||||||
LLMResp llmResp = requestService.requestLLM(llmReq, modelCluster.getKey());
|
LLMResp llmResp = requestService.requestLLM(llmReq, viewId);
|
||||||
|
|
||||||
if (Objects.isNull(llmResp)) {
|
if (Objects.isNull(llmResp)) {
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
//5. deduplicate the SQL result list and build parserInfo
|
//5. deduplicate the SQL result list and build parserInfo
|
||||||
modelCluster.buildName(semanticSchema.getModelIdToName());
|
|
||||||
LLMResponseService responseService = ContextUtils.getBean(LLMResponseService.class);
|
LLMResponseService responseService = ContextUtils.getBean(LLMResponseService.class);
|
||||||
Map<String, LLMSqlResp> deduplicationSqlResp = responseService.getDeduplicationSqlResp(llmResp);
|
Map<String, LLMSqlResp> deduplicationSqlResp = responseService.getDeduplicationSqlResp(llmResp);
|
||||||
ParseResult parseResult = ParseResult.builder()
|
ParseResult parseResult = ParseResult.builder()
|
||||||
.modelCluster(modelCluster)
|
.viewId(viewId)
|
||||||
.commonAgentTool(commonAgentTool)
|
.commonAgentTool(commonAgentTool)
|
||||||
.llmReq(llmReq)
|
.llmReq(llmReq)
|
||||||
.llmResp(llmResp)
|
.llmResp(llmResp)
|
||||||
|
|||||||
@@ -1,12 +0,0 @@
|
|||||||
package com.tencent.supersonic.chat.core.parser.sql.llm;
|
|
||||||
|
|
||||||
|
|
||||||
import com.tencent.supersonic.chat.core.pojo.ChatContext;
|
|
||||||
import com.tencent.supersonic.chat.core.pojo.QueryContext;
|
|
||||||
import java.util.Set;
|
|
||||||
|
|
||||||
public interface ModelResolver {
|
|
||||||
|
|
||||||
String resolve(QueryContext queryContext, ChatContext chatCtx, Set<Long> restrictiveModels);
|
|
||||||
|
|
||||||
}
|
|
||||||
@@ -11,11 +11,6 @@ import dev.langchain4j.model.chat.ChatLanguageModel;
|
|||||||
import dev.langchain4j.model.input.Prompt;
|
import dev.langchain4j.model.input.Prompt;
|
||||||
import dev.langchain4j.model.input.PromptTemplate;
|
import dev.langchain4j.model.input.PromptTemplate;
|
||||||
import dev.langchain4j.model.output.Response;
|
import dev.langchain4j.model.output.Response;
|
||||||
import java.util.HashMap;
|
|
||||||
import java.util.List;
|
|
||||||
import java.util.Map;
|
|
||||||
import java.util.concurrent.CopyOnWriteArrayList;
|
|
||||||
import java.util.stream.Collectors;
|
|
||||||
import lombok.extern.slf4j.Slf4j;
|
import lombok.extern.slf4j.Slf4j;
|
||||||
import org.apache.commons.lang3.tuple.Pair;
|
import org.apache.commons.lang3.tuple.Pair;
|
||||||
import org.slf4j.Logger;
|
import org.slf4j.Logger;
|
||||||
@@ -24,6 +19,12 @@ import org.springframework.beans.factory.InitializingBean;
|
|||||||
import org.springframework.beans.factory.annotation.Autowired;
|
import org.springframework.beans.factory.annotation.Autowired;
|
||||||
import org.springframework.stereotype.Service;
|
import org.springframework.stereotype.Service;
|
||||||
|
|
||||||
|
import java.util.HashMap;
|
||||||
|
import java.util.List;
|
||||||
|
import java.util.Map;
|
||||||
|
import java.util.concurrent.CopyOnWriteArrayList;
|
||||||
|
import java.util.stream.Collectors;
|
||||||
|
|
||||||
@Service
|
@Service
|
||||||
@Slf4j
|
@Slf4j
|
||||||
public class OnePassSCSqlGeneration implements SqlGeneration, InitializingBean {
|
public class OnePassSCSqlGeneration implements SqlGeneration, InitializingBean {
|
||||||
@@ -33,7 +34,7 @@ public class OnePassSCSqlGeneration implements SqlGeneration, InitializingBean {
|
|||||||
private ChatLanguageModel chatLanguageModel;
|
private ChatLanguageModel chatLanguageModel;
|
||||||
|
|
||||||
@Autowired
|
@Autowired
|
||||||
private SqlExampleLoader sqlExampleLoader;
|
private SqlExamplarLoader sqlExamplarLoader;
|
||||||
|
|
||||||
@Autowired
|
@Autowired
|
||||||
private OptimizationConfig optimizationConfig;
|
private OptimizationConfig optimizationConfig;
|
||||||
@@ -42,12 +43,12 @@ public class OnePassSCSqlGeneration implements SqlGeneration, InitializingBean {
|
|||||||
private SqlPromptGenerator sqlPromptGenerator;
|
private SqlPromptGenerator sqlPromptGenerator;
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public LLMResp generation(LLMReq llmReq, String modelClusterKey) {
|
public LLMResp generation(LLMReq llmReq, Long viewId) {
|
||||||
//1.retriever sqlExamples and generate exampleListPool
|
//1.retriever sqlExamples and generate exampleListPool
|
||||||
keyPipelineLog.info("modelClusterKey:{},llmReq:{}", modelClusterKey, llmReq);
|
keyPipelineLog.info("viewId:{},llmReq:{}", viewId, llmReq);
|
||||||
|
|
||||||
List<Map<String, String>> sqlExamples = sqlExampleLoader.retrieverSqlExamples(llmReq.getQueryText(),
|
List<Map<String, String>> sqlExamples = sqlExamplarLoader.retrieverSqlExamples(llmReq.getQueryText(),
|
||||||
optimizationConfig.getText2sqlCollectionName(), optimizationConfig.getText2sqlExampleNum());
|
optimizationConfig.getText2sqlExampleNum());
|
||||||
|
|
||||||
List<List<Map<String, String>>> exampleListPool = sqlPromptGenerator.getExampleCombos(sqlExamples,
|
List<List<Map<String, String>>> exampleListPool = sqlPromptGenerator.getExampleCombos(sqlExamples,
|
||||||
optimizationConfig.getText2sqlFewShotsNum(), optimizationConfig.getText2sqlSelfConsistencyNum());
|
optimizationConfig.getText2sqlFewShotsNum(), optimizationConfig.getText2sqlSelfConsistencyNum());
|
||||||
|
|||||||
@@ -2,19 +2,16 @@ package com.tencent.supersonic.chat.core.parser.sql.llm;
|
|||||||
|
|
||||||
|
|
||||||
import com.tencent.supersonic.chat.core.config.OptimizationConfig;
|
import com.tencent.supersonic.chat.core.config.OptimizationConfig;
|
||||||
import com.tencent.supersonic.chat.core.query.llm.s2sql.LLMSqlResp;
|
|
||||||
import com.tencent.supersonic.chat.core.query.llm.s2sql.LLMReq;
|
import com.tencent.supersonic.chat.core.query.llm.s2sql.LLMReq;
|
||||||
import com.tencent.supersonic.chat.core.query.llm.s2sql.LLMReq.SqlGenerationMode;
|
import com.tencent.supersonic.chat.core.query.llm.s2sql.LLMReq.SqlGenerationMode;
|
||||||
import com.tencent.supersonic.chat.core.query.llm.s2sql.LLMResp;
|
import com.tencent.supersonic.chat.core.query.llm.s2sql.LLMResp;
|
||||||
|
import com.tencent.supersonic.chat.core.query.llm.s2sql.LLMSqlResp;
|
||||||
import com.tencent.supersonic.common.util.JsonUtil;
|
import com.tencent.supersonic.common.util.JsonUtil;
|
||||||
import dev.langchain4j.data.message.AiMessage;
|
import dev.langchain4j.data.message.AiMessage;
|
||||||
import dev.langchain4j.model.chat.ChatLanguageModel;
|
import dev.langchain4j.model.chat.ChatLanguageModel;
|
||||||
import dev.langchain4j.model.input.Prompt;
|
import dev.langchain4j.model.input.Prompt;
|
||||||
import dev.langchain4j.model.input.PromptTemplate;
|
import dev.langchain4j.model.input.PromptTemplate;
|
||||||
import dev.langchain4j.model.output.Response;
|
import dev.langchain4j.model.output.Response;
|
||||||
import java.util.HashMap;
|
|
||||||
import java.util.List;
|
|
||||||
import java.util.Map;
|
|
||||||
import lombok.extern.slf4j.Slf4j;
|
import lombok.extern.slf4j.Slf4j;
|
||||||
import org.slf4j.Logger;
|
import org.slf4j.Logger;
|
||||||
import org.slf4j.LoggerFactory;
|
import org.slf4j.LoggerFactory;
|
||||||
@@ -22,6 +19,10 @@ import org.springframework.beans.factory.InitializingBean;
|
|||||||
import org.springframework.beans.factory.annotation.Autowired;
|
import org.springframework.beans.factory.annotation.Autowired;
|
||||||
import org.springframework.stereotype.Service;
|
import org.springframework.stereotype.Service;
|
||||||
|
|
||||||
|
import java.util.HashMap;
|
||||||
|
import java.util.List;
|
||||||
|
import java.util.Map;
|
||||||
|
|
||||||
@Service
|
@Service
|
||||||
@Slf4j
|
@Slf4j
|
||||||
public class OnePassSqlGeneration implements SqlGeneration, InitializingBean {
|
public class OnePassSqlGeneration implements SqlGeneration, InitializingBean {
|
||||||
@@ -31,7 +32,7 @@ public class OnePassSqlGeneration implements SqlGeneration, InitializingBean {
|
|||||||
private ChatLanguageModel chatLanguageModel;
|
private ChatLanguageModel chatLanguageModel;
|
||||||
|
|
||||||
@Autowired
|
@Autowired
|
||||||
private SqlExampleLoader sqlExampleLoader;
|
private SqlExamplarLoader sqlExampleLoader;
|
||||||
|
|
||||||
@Autowired
|
@Autowired
|
||||||
private OptimizationConfig optimizationConfig;
|
private OptimizationConfig optimizationConfig;
|
||||||
@@ -40,11 +41,11 @@ public class OnePassSqlGeneration implements SqlGeneration, InitializingBean {
|
|||||||
private SqlPromptGenerator sqlPromptGenerator;
|
private SqlPromptGenerator sqlPromptGenerator;
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public LLMResp generation(LLMReq llmReq, String modelClusterKey) {
|
public LLMResp generation(LLMReq llmReq, Long viewId) {
|
||||||
//1.retriever sqlExamples
|
//1.retriever sqlExamples
|
||||||
keyPipelineLog.info("modelClusterKey:{},llmReq:{}", modelClusterKey, llmReq);
|
keyPipelineLog.info("viewId:{},llmReq:{}", viewId, llmReq);
|
||||||
List<Map<String, String>> sqlExamples = sqlExampleLoader.retrieverSqlExamples(llmReq.getQueryText(),
|
List<Map<String, String>> sqlExamples = sqlExampleLoader.retrieverSqlExamples(llmReq.getQueryText(),
|
||||||
optimizationConfig.getText2sqlCollectionName(), optimizationConfig.getText2sqlExampleNum());
|
optimizationConfig.getText2sqlExampleNum());
|
||||||
|
|
||||||
//2.generator linking and sql prompt by sqlExamples,and generate response.
|
//2.generator linking and sql prompt by sqlExamples,and generate response.
|
||||||
String promptStr = sqlPromptGenerator.generatorLinkingAndSqlPrompt(llmReq, sqlExamples);
|
String promptStr = sqlPromptGenerator.generatorLinkingAndSqlPrompt(llmReq, sqlExamples);
|
||||||
|
|||||||
@@ -5,7 +5,6 @@ import com.tencent.supersonic.chat.api.pojo.request.QueryReq;
|
|||||||
import com.tencent.supersonic.chat.core.query.llm.s2sql.LLMReq;
|
import com.tencent.supersonic.chat.core.query.llm.s2sql.LLMReq;
|
||||||
import com.tencent.supersonic.chat.core.query.llm.s2sql.LLMReq.ElementValue;
|
import com.tencent.supersonic.chat.core.query.llm.s2sql.LLMReq.ElementValue;
|
||||||
import com.tencent.supersonic.chat.core.query.llm.s2sql.LLMResp;
|
import com.tencent.supersonic.chat.core.query.llm.s2sql.LLMResp;
|
||||||
import com.tencent.supersonic.common.pojo.ModelCluster;
|
|
||||||
import lombok.AllArgsConstructor;
|
import lombok.AllArgsConstructor;
|
||||||
import lombok.Builder;
|
import lombok.Builder;
|
||||||
import lombok.Data;
|
import lombok.Data;
|
||||||
@@ -19,7 +18,7 @@ import java.util.List;
|
|||||||
@NoArgsConstructor
|
@NoArgsConstructor
|
||||||
public class ParseResult {
|
public class ParseResult {
|
||||||
|
|
||||||
private ModelCluster modelCluster;
|
private Long viewId;
|
||||||
|
|
||||||
private LLMReq llmReq;
|
private LLMReq llmReq;
|
||||||
|
|
||||||
|
|||||||
@@ -1,49 +0,0 @@
|
|||||||
package com.tencent.supersonic.chat.core.parser.sql.llm;
|
|
||||||
|
|
||||||
import com.tencent.supersonic.chat.api.pojo.request.ChatConfigFilter;
|
|
||||||
import com.tencent.supersonic.chat.api.pojo.response.ChatConfigRichResp;
|
|
||||||
import com.tencent.supersonic.chat.api.pojo.response.ChatDefaultRichConfigResp;
|
|
||||||
import com.tencent.supersonic.chat.core.pojo.QueryContext;
|
|
||||||
import com.tencent.supersonic.common.util.DatePeriodEnum;
|
|
||||||
import com.tencent.supersonic.common.util.DateUtils;
|
|
||||||
import java.util.Objects;
|
|
||||||
|
|
||||||
public class S2SqlDateHelper {
|
|
||||||
|
|
||||||
public static String getReferenceDate(QueryContext queryContext, Long modelId) {
|
|
||||||
String defaultDate = DateUtils.getBeforeDate(0);
|
|
||||||
if (Objects.isNull(modelId)) {
|
|
||||||
return defaultDate;
|
|
||||||
}
|
|
||||||
ChatConfigFilter filter = new ChatConfigFilter();
|
|
||||||
filter.setModelId(modelId);
|
|
||||||
ChatConfigRichResp chatConfigRichResp = queryContext.getModelIdToChatRichConfig().get(modelId);
|
|
||||||
|
|
||||||
if (Objects.isNull(chatConfigRichResp)) {
|
|
||||||
return defaultDate;
|
|
||||||
}
|
|
||||||
if (Objects.isNull(chatConfigRichResp.getChatDetailRichConfig()) || Objects.isNull(
|
|
||||||
chatConfigRichResp.getChatDetailRichConfig().getChatDefaultConfig())) {
|
|
||||||
return defaultDate;
|
|
||||||
}
|
|
||||||
|
|
||||||
ChatDefaultRichConfigResp chatDefaultConfig = chatConfigRichResp.getChatDetailRichConfig()
|
|
||||||
.getChatDefaultConfig();
|
|
||||||
Integer unit = chatDefaultConfig.getUnit();
|
|
||||||
String period = chatDefaultConfig.getPeriod();
|
|
||||||
if (Objects.nonNull(unit)) {
|
|
||||||
// If the unit is set to less than 0, then do not add relative date.
|
|
||||||
if (unit < 0) {
|
|
||||||
return null;
|
|
||||||
}
|
|
||||||
DatePeriodEnum datePeriodEnum = DatePeriodEnum.get(period);
|
|
||||||
if (Objects.isNull(datePeriodEnum)) {
|
|
||||||
return DateUtils.getBeforeDate(unit);
|
|
||||||
} else {
|
|
||||||
return DateUtils.getBeforeDate(unit, datePeriodEnum);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return defaultDate;
|
|
||||||
}
|
|
||||||
|
|
||||||
}
|
|
||||||
@@ -2,6 +2,7 @@ package com.tencent.supersonic.chat.core.parser.sql.llm;
|
|||||||
|
|
||||||
|
|
||||||
import com.fasterxml.jackson.core.type.TypeReference;
|
import com.fasterxml.jackson.core.type.TypeReference;
|
||||||
|
import com.tencent.supersonic.common.config.EmbeddingConfig;
|
||||||
import com.tencent.supersonic.common.util.ComponentFactory;
|
import com.tencent.supersonic.common.util.ComponentFactory;
|
||||||
import com.tencent.supersonic.common.util.JsonUtil;
|
import com.tencent.supersonic.common.util.JsonUtil;
|
||||||
import com.tencent.supersonic.common.util.embedding.EmbeddingQuery;
|
import com.tencent.supersonic.common.util.embedding.EmbeddingQuery;
|
||||||
@@ -19,12 +20,13 @@ import java.util.Objects;
|
|||||||
import java.util.stream.Collectors;
|
import java.util.stream.Collectors;
|
||||||
import lombok.extern.slf4j.Slf4j;
|
import lombok.extern.slf4j.Slf4j;
|
||||||
import org.apache.commons.collections4.CollectionUtils;
|
import org.apache.commons.collections4.CollectionUtils;
|
||||||
|
import org.springframework.beans.factory.annotation.Autowired;
|
||||||
import org.springframework.core.io.ClassPathResource;
|
import org.springframework.core.io.ClassPathResource;
|
||||||
import org.springframework.stereotype.Component;
|
import org.springframework.stereotype.Component;
|
||||||
|
|
||||||
@Slf4j
|
@Slf4j
|
||||||
@Component
|
@Component
|
||||||
public class SqlExampleLoader {
|
public class SqlExamplarLoader {
|
||||||
|
|
||||||
private static final String EXAMPLE_JSON_FILE = "s2ql_examplar.json";
|
private static final String EXAMPLE_JSON_FILE = "s2ql_examplar.json";
|
||||||
|
|
||||||
@@ -32,6 +34,9 @@ public class SqlExampleLoader {
|
|||||||
private TypeReference<List<SqlExample>> valueTypeRef = new TypeReference<List<SqlExample>>() {
|
private TypeReference<List<SqlExample>> valueTypeRef = new TypeReference<List<SqlExample>>() {
|
||||||
};
|
};
|
||||||
|
|
||||||
|
@Autowired
|
||||||
|
private EmbeddingConfig embeddingConfig;
|
||||||
|
|
||||||
public List<SqlExample> getSqlExamples() throws IOException {
|
public List<SqlExample> getSqlExamples() throws IOException {
|
||||||
ClassPathResource resource = new ClassPathResource(EXAMPLE_JSON_FILE);
|
ClassPathResource resource = new ClassPathResource(EXAMPLE_JSON_FILE);
|
||||||
InputStream inputStream = resource.getInputStream();
|
InputStream inputStream = resource.getInputStream();
|
||||||
@@ -53,8 +58,8 @@ public class SqlExampleLoader {
|
|||||||
s2EmbeddingStore.addQuery(collectionName, queries);
|
s2EmbeddingStore.addQuery(collectionName, queries);
|
||||||
}
|
}
|
||||||
|
|
||||||
public List<Map<String, String>> retrieverSqlExamples(String queryText, String collectionName, int maxResults) {
|
public List<Map<String, String>> retrieverSqlExamples(String queryText, int maxResults) {
|
||||||
|
String collectionName = embeddingConfig.getText2sqlCollectionName();
|
||||||
RetrieveQuery retrieveQuery = RetrieveQuery.builder().queryTextsList(Collections.singletonList(queryText))
|
RetrieveQuery retrieveQuery = RetrieveQuery.builder().queryTextsList(Collections.singletonList(queryText))
|
||||||
.queryEmbeddings(null).build();
|
.queryEmbeddings(null).build();
|
||||||
|
|
||||||
@@ -12,9 +12,9 @@ public interface SqlGeneration {
|
|||||||
/***
|
/***
|
||||||
* generate llmResp (sql, schemaLink, prompt, etc.) through LLMReq.
|
* generate llmResp (sql, schemaLink, prompt, etc.) through LLMReq.
|
||||||
* @param llmReq
|
* @param llmReq
|
||||||
* @param modelClusterKey
|
* @param viewId
|
||||||
* @return
|
* @return
|
||||||
*/
|
*/
|
||||||
LLMResp generation(LLMReq llmReq, String modelClusterKey);
|
LLMResp generation(LLMReq llmReq, Long viewId);
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -2,14 +2,15 @@ package com.tencent.supersonic.chat.core.parser.sql.llm;
|
|||||||
|
|
||||||
import com.tencent.supersonic.chat.core.query.llm.s2sql.LLMReq;
|
import com.tencent.supersonic.chat.core.query.llm.s2sql.LLMReq;
|
||||||
import com.tencent.supersonic.chat.core.query.llm.s2sql.LLMReq.ElementValue;
|
import com.tencent.supersonic.chat.core.query.llm.s2sql.LLMReq.ElementValue;
|
||||||
|
import lombok.extern.slf4j.Slf4j;
|
||||||
|
import org.apache.commons.lang3.tuple.Pair;
|
||||||
|
import org.springframework.stereotype.Component;
|
||||||
|
|
||||||
import java.util.ArrayList;
|
import java.util.ArrayList;
|
||||||
import java.util.Arrays;
|
import java.util.Arrays;
|
||||||
import java.util.Collections;
|
import java.util.Collections;
|
||||||
import java.util.List;
|
import java.util.List;
|
||||||
import java.util.Map;
|
import java.util.Map;
|
||||||
import lombok.extern.slf4j.Slf4j;
|
|
||||||
import org.apache.commons.lang3.tuple.Pair;
|
|
||||||
import org.springframework.stereotype.Component;
|
|
||||||
|
|
||||||
@Component
|
@Component
|
||||||
@Slf4j
|
@Slf4j
|
||||||
@@ -95,7 +96,7 @@ public class SqlPromptGenerator {
|
|||||||
}
|
}
|
||||||
|
|
||||||
public Pair<String, String> transformQuestionPrompt(LLMReq llmReq) {
|
public Pair<String, String> transformQuestionPrompt(LLMReq llmReq) {
|
||||||
String modelName = llmReq.getSchema().getModelName();
|
String modelName = llmReq.getSchema().getViewName();
|
||||||
List<String> fieldNameList = llmReq.getSchema().getFieldNameList();
|
List<String> fieldNameList = llmReq.getSchema().getFieldNameList();
|
||||||
List<ElementValue> linking = llmReq.getLinking();
|
List<ElementValue> linking = llmReq.getLinking();
|
||||||
String currentDate = llmReq.getCurrentDate();
|
String currentDate = llmReq.getCurrentDate();
|
||||||
|
|||||||
@@ -11,10 +11,6 @@ import dev.langchain4j.model.chat.ChatLanguageModel;
|
|||||||
import dev.langchain4j.model.input.Prompt;
|
import dev.langchain4j.model.input.Prompt;
|
||||||
import dev.langchain4j.model.input.PromptTemplate;
|
import dev.langchain4j.model.input.PromptTemplate;
|
||||||
import dev.langchain4j.model.output.Response;
|
import dev.langchain4j.model.output.Response;
|
||||||
import java.util.HashMap;
|
|
||||||
import java.util.List;
|
|
||||||
import java.util.Map;
|
|
||||||
import java.util.concurrent.CopyOnWriteArrayList;
|
|
||||||
import org.apache.commons.lang3.tuple.Pair;
|
import org.apache.commons.lang3.tuple.Pair;
|
||||||
import org.slf4j.Logger;
|
import org.slf4j.Logger;
|
||||||
import org.slf4j.LoggerFactory;
|
import org.slf4j.LoggerFactory;
|
||||||
@@ -22,6 +18,11 @@ import org.springframework.beans.factory.InitializingBean;
|
|||||||
import org.springframework.beans.factory.annotation.Autowired;
|
import org.springframework.beans.factory.annotation.Autowired;
|
||||||
import org.springframework.stereotype.Service;
|
import org.springframework.stereotype.Service;
|
||||||
|
|
||||||
|
import java.util.HashMap;
|
||||||
|
import java.util.List;
|
||||||
|
import java.util.Map;
|
||||||
|
import java.util.concurrent.CopyOnWriteArrayList;
|
||||||
|
|
||||||
@Service
|
@Service
|
||||||
public class TwoPassSCSqlGeneration implements SqlGeneration, InitializingBean {
|
public class TwoPassSCSqlGeneration implements SqlGeneration, InitializingBean {
|
||||||
|
|
||||||
@@ -30,7 +31,7 @@ public class TwoPassSCSqlGeneration implements SqlGeneration, InitializingBean {
|
|||||||
private ChatLanguageModel chatLanguageModel;
|
private ChatLanguageModel chatLanguageModel;
|
||||||
|
|
||||||
@Autowired
|
@Autowired
|
||||||
private SqlExampleLoader sqlExampleLoader;
|
private SqlExamplarLoader sqlExamplarLoader;
|
||||||
|
|
||||||
@Autowired
|
@Autowired
|
||||||
private OptimizationConfig optimizationConfig;
|
private OptimizationConfig optimizationConfig;
|
||||||
@@ -39,11 +40,11 @@ public class TwoPassSCSqlGeneration implements SqlGeneration, InitializingBean {
|
|||||||
private SqlPromptGenerator sqlPromptGenerator;
|
private SqlPromptGenerator sqlPromptGenerator;
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public LLMResp generation(LLMReq llmReq, String modelClusterKey) {
|
public LLMResp generation(LLMReq llmReq, Long viewId) {
|
||||||
//1.retriever sqlExamples and generate exampleListPool
|
//1.retriever sqlExamples and generate exampleListPool
|
||||||
keyPipelineLog.info("modelClusterKey:{},llmReq:{}", modelClusterKey, llmReq);
|
keyPipelineLog.info("viewId:{},llmReq:{}", viewId, llmReq);
|
||||||
List<Map<String, String>> sqlExamples = sqlExampleLoader.retrieverSqlExamples(llmReq.getQueryText(),
|
List<Map<String, String>> sqlExamples = sqlExamplarLoader.retrieverSqlExamples(llmReq.getQueryText(),
|
||||||
optimizationConfig.getText2sqlCollectionName(), optimizationConfig.getText2sqlExampleNum());
|
optimizationConfig.getText2sqlExampleNum());
|
||||||
|
|
||||||
List<List<Map<String, String>>> exampleListPool = sqlPromptGenerator.getExampleCombos(sqlExamples,
|
List<List<Map<String, String>>> exampleListPool = sqlPromptGenerator.getExampleCombos(sqlExamples,
|
||||||
optimizationConfig.getText2sqlFewShotsNum(), optimizationConfig.getText2sqlSelfConsistencyNum());
|
optimizationConfig.getText2sqlFewShotsNum(), optimizationConfig.getText2sqlSelfConsistencyNum());
|
||||||
|
|||||||
@@ -11,9 +11,6 @@ import dev.langchain4j.model.chat.ChatLanguageModel;
|
|||||||
import dev.langchain4j.model.input.Prompt;
|
import dev.langchain4j.model.input.Prompt;
|
||||||
import dev.langchain4j.model.input.PromptTemplate;
|
import dev.langchain4j.model.input.PromptTemplate;
|
||||||
import dev.langchain4j.model.output.Response;
|
import dev.langchain4j.model.output.Response;
|
||||||
import java.util.HashMap;
|
|
||||||
import java.util.List;
|
|
||||||
import java.util.Map;
|
|
||||||
import lombok.extern.slf4j.Slf4j;
|
import lombok.extern.slf4j.Slf4j;
|
||||||
import org.slf4j.Logger;
|
import org.slf4j.Logger;
|
||||||
import org.slf4j.LoggerFactory;
|
import org.slf4j.LoggerFactory;
|
||||||
@@ -21,6 +18,10 @@ import org.springframework.beans.factory.InitializingBean;
|
|||||||
import org.springframework.beans.factory.annotation.Autowired;
|
import org.springframework.beans.factory.annotation.Autowired;
|
||||||
import org.springframework.stereotype.Service;
|
import org.springframework.stereotype.Service;
|
||||||
|
|
||||||
|
import java.util.HashMap;
|
||||||
|
import java.util.List;
|
||||||
|
import java.util.Map;
|
||||||
|
|
||||||
@Service
|
@Service
|
||||||
@Slf4j
|
@Slf4j
|
||||||
public class TwoPassSqlGeneration implements SqlGeneration, InitializingBean {
|
public class TwoPassSqlGeneration implements SqlGeneration, InitializingBean {
|
||||||
@@ -30,7 +31,7 @@ public class TwoPassSqlGeneration implements SqlGeneration, InitializingBean {
|
|||||||
private ChatLanguageModel chatLanguageModel;
|
private ChatLanguageModel chatLanguageModel;
|
||||||
|
|
||||||
@Autowired
|
@Autowired
|
||||||
private SqlExampleLoader sqlExampleLoader;
|
private SqlExamplarLoader sqlExamplarLoader;
|
||||||
|
|
||||||
@Autowired
|
@Autowired
|
||||||
private OptimizationConfig optimizationConfig;
|
private OptimizationConfig optimizationConfig;
|
||||||
@@ -39,10 +40,10 @@ public class TwoPassSqlGeneration implements SqlGeneration, InitializingBean {
|
|||||||
private SqlPromptGenerator sqlPromptGenerator;
|
private SqlPromptGenerator sqlPromptGenerator;
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public LLMResp generation(LLMReq llmReq, String modelClusterKey) {
|
public LLMResp generation(LLMReq llmReq, Long viewId) {
|
||||||
keyPipelineLog.info("modelClusterKey:{},llmReq:{}", modelClusterKey, llmReq);
|
keyPipelineLog.info("viewId:{},llmReq:{}", viewId, llmReq);
|
||||||
List<Map<String, String>> sqlExamples = sqlExampleLoader.retrieverSqlExamples(llmReq.getQueryText(),
|
List<Map<String, String>> sqlExamples = sqlExamplarLoader.retrieverSqlExamples(llmReq.getQueryText(),
|
||||||
optimizationConfig.getText2sqlCollectionName(), optimizationConfig.getText2sqlExampleNum());
|
optimizationConfig.getText2sqlExampleNum());
|
||||||
|
|
||||||
String linkingPromptStr = sqlPromptGenerator.generateLinkingPrompt(llmReq, sqlExamples);
|
String linkingPromptStr = sqlPromptGenerator.generateLinkingPrompt(llmReq, sqlExamples);
|
||||||
|
|
||||||
|
|||||||
@@ -3,7 +3,7 @@ package com.tencent.supersonic.chat.core.parser.sql.llm;
|
|||||||
import lombok.Data;
|
import lombok.Data;
|
||||||
|
|
||||||
@Data
|
@Data
|
||||||
public class ModelMatchResult {
|
public class ViewMatchResult {
|
||||||
private Integer count = 0;
|
private Integer count = 0;
|
||||||
private double maxSimilarity;
|
private double maxSimilarity;
|
||||||
}
|
}
|
||||||
@@ -0,0 +1,12 @@
|
|||||||
|
package com.tencent.supersonic.chat.core.parser.sql.llm;
|
||||||
|
|
||||||
|
|
||||||
|
import com.tencent.supersonic.chat.core.pojo.QueryContext;
|
||||||
|
|
||||||
|
import java.util.Set;
|
||||||
|
|
||||||
|
public interface ViewResolver {
|
||||||
|
|
||||||
|
Long resolve(QueryContext queryContext, Set<Long> restrictiveModels);
|
||||||
|
|
||||||
|
}
|
||||||
@@ -11,32 +11,34 @@ import com.tencent.supersonic.chat.core.pojo.QueryContext;
|
|||||||
import com.tencent.supersonic.chat.core.query.QueryManager;
|
import com.tencent.supersonic.chat.core.query.QueryManager;
|
||||||
import com.tencent.supersonic.chat.core.query.SemanticQuery;
|
import com.tencent.supersonic.chat.core.query.SemanticQuery;
|
||||||
import com.tencent.supersonic.common.pojo.enums.QueryType;
|
import com.tencent.supersonic.common.pojo.enums.QueryType;
|
||||||
import java.util.HashSet;
|
|
||||||
import java.util.List;
|
|
||||||
import java.util.stream.Collectors;
|
|
||||||
import lombok.extern.slf4j.Slf4j;
|
import lombok.extern.slf4j.Slf4j;
|
||||||
import org.apache.commons.collections.CollectionUtils;
|
import org.apache.commons.collections.CollectionUtils;
|
||||||
|
|
||||||
|
import java.util.List;
|
||||||
|
import java.util.stream.Collectors;
|
||||||
|
|
||||||
@Slf4j
|
@Slf4j
|
||||||
public class AgentCheckParser implements SemanticParser {
|
public class AgentCheckParser implements SemanticParser {
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public void parse(QueryContext queryContext, ChatContext chatContext) {
|
public void parse(QueryContext queryContext, ChatContext chatContext) {
|
||||||
List<SemanticQuery> queries = queryContext.getCandidateQueries();
|
List<SemanticQuery> queries = queryContext.getCandidateQueries();
|
||||||
agentCanSupport(queryContext, queries);
|
log.info("query size before agent filter:{}", queryContext.getCandidateQueries().size());
|
||||||
|
filterQueries(queryContext, queries);
|
||||||
|
log.info("query size after agent filter: {}", queryContext.getCandidateQueries().size());
|
||||||
}
|
}
|
||||||
|
|
||||||
private void agentCanSupport(QueryContext queryContext, List<SemanticQuery> queries) {
|
private void filterQueries(QueryContext queryContext, List<SemanticQuery> queries) {
|
||||||
Agent agent = queryContext.getAgent();
|
Agent agent = queryContext.getAgent();
|
||||||
if (agent == null) {
|
if (agent == null) {
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
List<RuleParserTool> queryTools = getRuleTools(agent);
|
List<RuleParserTool> queryTools = getRuleTools(agent);
|
||||||
if (CollectionUtils.isEmpty(queryTools)) {
|
if (CollectionUtils.isEmpty(queryTools)) {
|
||||||
queries.clear();
|
queryContext.setCandidateQueries(Lists.newArrayList());
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
log.info("queries resolved:{} {}", agent.getName(),
|
log.info("agent name :{}, queries resolved: {}", agent.getName(),
|
||||||
queries.stream().map(SemanticQuery::getQueryMode).collect(Collectors.toList()));
|
queries.stream().map(SemanticQuery::getQueryMode).collect(Collectors.toList()));
|
||||||
queries.removeIf(query -> {
|
queries.removeIf(query -> {
|
||||||
for (RuleParserTool tool : queryTools) {
|
for (RuleParserTool tool : queryTools) {
|
||||||
@@ -46,26 +48,28 @@ public class AgentCheckParser implements SemanticParser {
|
|||||||
}
|
}
|
||||||
if (CollectionUtils.isNotEmpty(tool.getQueryTypes())) {
|
if (CollectionUtils.isNotEmpty(tool.getQueryTypes())) {
|
||||||
if (QueryManager.isTagQuery(query.getQueryMode())) {
|
if (QueryManager.isTagQuery(query.getQueryMode())) {
|
||||||
return !tool.getQueryTypes().contains(QueryType.TAG.name());
|
if (!tool.getQueryTypes().contains(QueryType.TAG.name())) {
|
||||||
|
return true;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
if (QueryManager.isMetricQuery(query.getQueryMode())) {
|
if (QueryManager.isMetricQuery(query.getQueryMode())) {
|
||||||
return !tool.getQueryTypes().contains(QueryType.METRIC.name());
|
if (!tool.getQueryTypes().contains(QueryType.METRIC.name())) {
|
||||||
|
return true;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
if (CollectionUtils.isEmpty(tool.getModelIds())) {
|
}
|
||||||
|
if (CollectionUtils.isEmpty(tool.getViewIds())) {
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
if (tool.isContainsAllModel()) {
|
if (tool.isContainsAllModel()) {
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
if (new HashSet<>(tool.getModelIds())
|
return !tool.getViewIds().contains(query.getParseInfo().getViewId());
|
||||||
.containsAll(query.getParseInfo().getModel().getModelIds())) {
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
return true;
|
return true;
|
||||||
});
|
});
|
||||||
log.info("rule queries witch can be supported by agent :{} {}", agent.getName(),
|
queryContext.setCandidateQueries(queries);
|
||||||
|
log.info("agent name :{}, rule queries witch can be supported by agent :{}", agent.getName(),
|
||||||
queries.stream().map(SemanticQuery::getQueryMode).collect(Collectors.toList()));
|
queries.stream().map(SemanticQuery::getQueryMode).collect(Collectors.toList()));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -1,8 +1,7 @@
|
|||||||
package com.tencent.supersonic.chat.core.parser.sql.rule;
|
package com.tencent.supersonic.chat.core.parser.sql.rule;
|
||||||
|
|
||||||
import com.tencent.supersonic.chat.api.pojo.SchemaElementMatch;
|
import com.tencent.supersonic.chat.api.pojo.SchemaElementMatch;
|
||||||
import com.tencent.supersonic.chat.api.pojo.SchemaElementType;
|
import com.tencent.supersonic.headless.api.pojo.SchemaElementType;
|
||||||
import com.tencent.supersonic.chat.api.pojo.SemanticSchema;
|
|
||||||
import com.tencent.supersonic.chat.core.parser.SemanticParser;
|
import com.tencent.supersonic.chat.core.parser.SemanticParser;
|
||||||
import com.tencent.supersonic.chat.core.pojo.ChatContext;
|
import com.tencent.supersonic.chat.core.pojo.ChatContext;
|
||||||
import com.tencent.supersonic.chat.core.pojo.QueryContext;
|
import com.tencent.supersonic.chat.core.pojo.QueryContext;
|
||||||
@@ -12,8 +11,8 @@ import com.tencent.supersonic.chat.core.query.rule.RuleSemanticQuery;
|
|||||||
import com.tencent.supersonic.chat.core.query.rule.metric.MetricModelQuery;
|
import com.tencent.supersonic.chat.core.query.rule.metric.MetricModelQuery;
|
||||||
import com.tencent.supersonic.chat.core.query.rule.metric.MetricSemanticQuery;
|
import com.tencent.supersonic.chat.core.query.rule.metric.MetricSemanticQuery;
|
||||||
import com.tencent.supersonic.chat.core.query.rule.metric.MetricTagQuery;
|
import com.tencent.supersonic.chat.core.query.rule.metric.MetricTagQuery;
|
||||||
import com.tencent.supersonic.chat.core.utils.ModelClusterBuilder;
|
import lombok.extern.slf4j.Slf4j;
|
||||||
import com.tencent.supersonic.common.pojo.ModelCluster;
|
|
||||||
import java.util.AbstractMap;
|
import java.util.AbstractMap;
|
||||||
import java.util.ArrayList;
|
import java.util.ArrayList;
|
||||||
import java.util.Arrays;
|
import java.util.Arrays;
|
||||||
@@ -23,8 +22,6 @@ import java.util.Objects;
|
|||||||
import java.util.Set;
|
import java.util.Set;
|
||||||
import java.util.stream.Collectors;
|
import java.util.stream.Collectors;
|
||||||
import java.util.stream.Stream;
|
import java.util.stream.Stream;
|
||||||
import lombok.extern.slf4j.Slf4j;
|
|
||||||
import org.apache.commons.lang3.StringUtils;
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* ContextInheritParser tries to inherit certain schema elements from context
|
* ContextInheritParser tries to inherit certain schema elements from context
|
||||||
@@ -42,7 +39,7 @@ public class ContextInheritParser implements SemanticParser {
|
|||||||
SchemaElementType.VALUE, Arrays.asList(SchemaElementType.VALUE, SchemaElementType.DIMENSION)),
|
SchemaElementType.VALUE, Arrays.asList(SchemaElementType.VALUE, SchemaElementType.DIMENSION)),
|
||||||
new AbstractMap.SimpleEntry<>(SchemaElementType.ENTITY, Arrays.asList(SchemaElementType.ENTITY)),
|
new AbstractMap.SimpleEntry<>(SchemaElementType.ENTITY, Arrays.asList(SchemaElementType.ENTITY)),
|
||||||
new AbstractMap.SimpleEntry<>(SchemaElementType.TAG, Arrays.asList(SchemaElementType.TAG)),
|
new AbstractMap.SimpleEntry<>(SchemaElementType.TAG, Arrays.asList(SchemaElementType.TAG)),
|
||||||
new AbstractMap.SimpleEntry<>(SchemaElementType.MODEL, Arrays.asList(SchemaElementType.MODEL)),
|
new AbstractMap.SimpleEntry<>(SchemaElementType.VIEW, Arrays.asList(SchemaElementType.VIEW)),
|
||||||
new AbstractMap.SimpleEntry<>(SchemaElementType.ID, Arrays.asList(SchemaElementType.ID))
|
new AbstractMap.SimpleEntry<>(SchemaElementType.ID, Arrays.asList(SchemaElementType.ID))
|
||||||
).collect(Collectors.toMap(Map.Entry::getKey, Map.Entry::getValue));
|
).collect(Collectors.toMap(Map.Entry::getKey, Map.Entry::getValue));
|
||||||
|
|
||||||
@@ -51,12 +48,13 @@ public class ContextInheritParser implements SemanticParser {
|
|||||||
if (!shouldInherit(queryContext)) {
|
if (!shouldInherit(queryContext)) {
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
ModelCluster modelCluster = getMatchedModelCluster(queryContext, chatContext);
|
Long viewId = getMatchedView(queryContext, chatContext);
|
||||||
if (modelCluster == null) {
|
if (viewId == null) {
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
List<SchemaElementMatch> elementMatches = queryContext.getModelClusterMapInfo()
|
|
||||||
.getMatchedElements(modelCluster.getKey());
|
List<SchemaElementMatch> elementMatches = queryContext.getMapInfo().getMatchedElements(viewId);
|
||||||
|
|
||||||
List<SchemaElementMatch> matchesToInherit = new ArrayList<>();
|
List<SchemaElementMatch> matchesToInherit = new ArrayList<>();
|
||||||
for (SchemaElementMatch match : chatContext.getParseInfo().getElementMatches()) {
|
for (SchemaElementMatch match : chatContext.getParseInfo().getElementMatches()) {
|
||||||
SchemaElementType matchType = match.getElement().getType();
|
SchemaElementType matchType = match.getElement().getType();
|
||||||
@@ -72,17 +70,17 @@ public class ContextInheritParser implements SemanticParser {
|
|||||||
List<RuleSemanticQuery> queries = RuleSemanticQuery.resolve(elementMatches, queryContext);
|
List<RuleSemanticQuery> queries = RuleSemanticQuery.resolve(elementMatches, queryContext);
|
||||||
for (RuleSemanticQuery query : queries) {
|
for (RuleSemanticQuery query : queries) {
|
||||||
query.fillParseInfo(queryContext, chatContext);
|
query.fillParseInfo(queryContext, chatContext);
|
||||||
if (existSameQuery(query.getParseInfo().getModelClusterKey(), query.getQueryMode(), queryContext)) {
|
if (existSameQuery(query.getParseInfo().getViewId(), query.getQueryMode(), queryContext)) {
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
queryContext.getCandidateQueries().add(query);
|
queryContext.getCandidateQueries().add(query);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
private boolean existSameQuery(String modelClusterKey, String queryMode, QueryContext queryContext) {
|
private boolean existSameQuery(Long viewId, String queryMode, QueryContext queryContext) {
|
||||||
for (SemanticQuery semanticQuery : queryContext.getCandidateQueries()) {
|
for (SemanticQuery semanticQuery : queryContext.getCandidateQueries()) {
|
||||||
if (semanticQuery.getQueryMode().equals(queryMode)
|
if (semanticQuery.getQueryMode().equals(queryMode)
|
||||||
&& semanticQuery.getParseInfo().getModelClusterKey().equals(modelClusterKey)) {
|
&& semanticQuery.getParseInfo().getViewId().equals(viewId)) {
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -111,25 +109,16 @@ public class ContextInheritParser implements SemanticParser {
|
|||||||
return metricModelQueries.size() == queryContext.getCandidateQueries().size();
|
return metricModelQueries.size() == queryContext.getCandidateQueries().size();
|
||||||
}
|
}
|
||||||
|
|
||||||
protected ModelCluster getMatchedModelCluster(QueryContext queryContext, ChatContext chatContext) {
|
protected Long getMatchedView(QueryContext queryContext, ChatContext chatContext) {
|
||||||
String contextModelClusterKey = chatContext.getParseInfo().getModelClusterKey();
|
Long viewId = chatContext.getParseInfo().getViewId();
|
||||||
if (StringUtils.isBlank(contextModelClusterKey)) {
|
if (viewId == null) {
|
||||||
return null;
|
return null;
|
||||||
}
|
}
|
||||||
SemanticSchema semanticSchema = queryContext.getSemanticSchema();
|
Set<Long> queryViews = queryContext.getMapInfo().getMatchedViewInfos();
|
||||||
List<ModelCluster> allModelClusters = ModelClusterBuilder.buildModelClusters(semanticSchema);
|
if (queryViews.contains(viewId)) {
|
||||||
Set<String> queryModelClusters = queryContext.getModelClusterMapInfo().getMatchedModelClusters();
|
return viewId;
|
||||||
ModelCluster contextModelCluster = ModelCluster.build(contextModelClusterKey);
|
|
||||||
for (String cluster : queryModelClusters) {
|
|
||||||
ModelCluster queryModelCluster = ModelCluster.build(cluster);
|
|
||||||
for (ModelCluster modelCluster : allModelClusters) {
|
|
||||||
if (modelCluster.getModelIds().containsAll(contextModelCluster.getModelIds())
|
|
||||||
&& modelCluster.getModelIds().containsAll(queryModelCluster.getModelIds())) {
|
|
||||||
return queryModelCluster;
|
|
||||||
}
|
}
|
||||||
}
|
return viewId;
|
||||||
}
|
|
||||||
return null;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,10 +1,10 @@
|
|||||||
package com.tencent.supersonic.chat.core.parser.sql.rule;
|
package com.tencent.supersonic.chat.core.parser.sql.rule;
|
||||||
|
|
||||||
|
import com.tencent.supersonic.chat.api.pojo.SchemaMapInfo;
|
||||||
import com.tencent.supersonic.chat.core.parser.SemanticParser;
|
import com.tencent.supersonic.chat.core.parser.SemanticParser;
|
||||||
import com.tencent.supersonic.chat.core.pojo.ChatContext;
|
import com.tencent.supersonic.chat.core.pojo.ChatContext;
|
||||||
import com.tencent.supersonic.chat.core.pojo.QueryContext;
|
import com.tencent.supersonic.chat.core.pojo.QueryContext;
|
||||||
import com.tencent.supersonic.chat.api.pojo.SchemaElementMatch;
|
import com.tencent.supersonic.chat.api.pojo.SchemaElementMatch;
|
||||||
import com.tencent.supersonic.chat.api.pojo.SchemaModelClusterMapInfo;
|
|
||||||
import com.tencent.supersonic.chat.core.query.rule.RuleSemanticQuery;
|
import com.tencent.supersonic.chat.core.query.rule.RuleSemanticQuery;
|
||||||
import lombok.extern.slf4j.Slf4j;
|
import lombok.extern.slf4j.Slf4j;
|
||||||
|
|
||||||
@@ -27,10 +27,10 @@ public class RuleSqlParser implements SemanticParser {
|
|||||||
|
|
||||||
@Override
|
@Override
|
||||||
public void parse(QueryContext queryContext, ChatContext chatContext) {
|
public void parse(QueryContext queryContext, ChatContext chatContext) {
|
||||||
SchemaModelClusterMapInfo modelClusterMapInfo = queryContext.getModelClusterMapInfo();
|
SchemaMapInfo mapInfo = queryContext.getMapInfo();
|
||||||
// iterate all schemaElementMatches to resolve query mode
|
// iterate all schemaElementMatches to resolve query mode
|
||||||
for (String modelClusterKey : modelClusterMapInfo.getMatchedModelClusters()) {
|
for (Long viewId : mapInfo.getMatchedViewInfos()) {
|
||||||
List<SchemaElementMatch> elementMatches = modelClusterMapInfo.getMatchedElements(modelClusterKey);
|
List<SchemaElementMatch> elementMatches = mapInfo.getMatchedElements(viewId);
|
||||||
List<RuleSemanticQuery> queries = RuleSemanticQuery.resolve(elementMatches, queryContext);
|
List<RuleSemanticQuery> queries = RuleSemanticQuery.resolve(elementMatches, queryContext);
|
||||||
for (RuleSemanticQuery query : queries) {
|
for (RuleSemanticQuery query : queries) {
|
||||||
query.fillParseInfo(queryContext, chatContext);
|
query.fillParseInfo(queryContext, chatContext);
|
||||||
|
|||||||
@@ -20,7 +20,7 @@ public class Plugin extends RecordInfo {
|
|||||||
*/
|
*/
|
||||||
private String type;
|
private String type;
|
||||||
|
|
||||||
private List<Long> modelList = Lists.newArrayList();
|
private List<Long> viewList = Lists.newArrayList();
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* description, for parsing
|
* description, for parsing
|
||||||
@@ -52,7 +52,7 @@ public class Plugin extends RecordInfo {
|
|||||||
}
|
}
|
||||||
|
|
||||||
public boolean isContainsAllModel() {
|
public boolean isContainsAllModel() {
|
||||||
return CollectionUtils.isNotEmpty(modelList) && modelList.contains(-1L);
|
return CollectionUtils.isNotEmpty(viewList) && viewList.contains(-1L);
|
||||||
}
|
}
|
||||||
|
|
||||||
public Long getDefaultMode() {
|
public Long getDefaultMode() {
|
||||||
|
|||||||
@@ -3,9 +3,9 @@ package com.tencent.supersonic.chat.core.plugin;
|
|||||||
import com.alibaba.fastjson.JSONObject;
|
import com.alibaba.fastjson.JSONObject;
|
||||||
import com.google.common.collect.Lists;
|
import com.google.common.collect.Lists;
|
||||||
import com.google.common.collect.Sets;
|
import com.google.common.collect.Sets;
|
||||||
import com.tencent.supersonic.chat.api.pojo.SchemaElement;
|
import com.tencent.supersonic.headless.api.pojo.SchemaElement;
|
||||||
import com.tencent.supersonic.chat.api.pojo.SchemaElementMatch;
|
import com.tencent.supersonic.chat.api.pojo.SchemaElementMatch;
|
||||||
import com.tencent.supersonic.chat.api.pojo.SchemaElementType;
|
import com.tencent.supersonic.headless.api.pojo.SchemaElementType;
|
||||||
import com.tencent.supersonic.chat.api.pojo.SchemaMapInfo;
|
import com.tencent.supersonic.chat.api.pojo.SchemaMapInfo;
|
||||||
import com.tencent.supersonic.chat.core.agent.Agent;
|
import com.tencent.supersonic.chat.core.agent.Agent;
|
||||||
import com.tencent.supersonic.chat.core.agent.AgentToolType;
|
import com.tencent.supersonic.chat.core.agent.AgentToolType;
|
||||||
@@ -23,6 +23,12 @@ import com.tencent.supersonic.common.util.embedding.Retrieval;
|
|||||||
import com.tencent.supersonic.common.util.embedding.RetrieveQuery;
|
import com.tencent.supersonic.common.util.embedding.RetrieveQuery;
|
||||||
import com.tencent.supersonic.common.util.embedding.RetrieveQueryResult;
|
import com.tencent.supersonic.common.util.embedding.RetrieveQueryResult;
|
||||||
import com.tencent.supersonic.common.util.embedding.S2EmbeddingStore;
|
import com.tencent.supersonic.common.util.embedding.S2EmbeddingStore;
|
||||||
|
import lombok.extern.slf4j.Slf4j;
|
||||||
|
import org.apache.commons.collections.CollectionUtils;
|
||||||
|
import org.apache.commons.lang3.tuple.Pair;
|
||||||
|
import org.springframework.context.event.EventListener;
|
||||||
|
import org.springframework.stereotype.Component;
|
||||||
|
|
||||||
import java.util.ArrayList;
|
import java.util.ArrayList;
|
||||||
import java.util.Collection;
|
import java.util.Collection;
|
||||||
import java.util.Collections;
|
import java.util.Collections;
|
||||||
@@ -32,11 +38,6 @@ import java.util.Map;
|
|||||||
import java.util.Objects;
|
import java.util.Objects;
|
||||||
import java.util.Set;
|
import java.util.Set;
|
||||||
import java.util.stream.Collectors;
|
import java.util.stream.Collectors;
|
||||||
import lombok.extern.slf4j.Slf4j;
|
|
||||||
import org.apache.commons.collections.CollectionUtils;
|
|
||||||
import org.apache.commons.lang3.tuple.Pair;
|
|
||||||
import org.springframework.context.event.EventListener;
|
|
||||||
import org.springframework.stereotype.Component;
|
|
||||||
|
|
||||||
@Slf4j
|
@Slf4j
|
||||||
@Component
|
@Component
|
||||||
@@ -265,14 +266,14 @@ public class PluginManager {
|
|||||||
}
|
}
|
||||||
|
|
||||||
private static Set<Long> getPluginMatchedModel(Plugin plugin, QueryContext queryContext) {
|
private static Set<Long> getPluginMatchedModel(Plugin plugin, QueryContext queryContext) {
|
||||||
Set<Long> matchedModel = queryContext.getMapInfo().getMatchedModels();
|
Set<Long> matchedViews = queryContext.getMapInfo().getMatchedViewInfos();
|
||||||
if (plugin.isContainsAllModel()) {
|
if (plugin.isContainsAllModel()) {
|
||||||
return Sets.newHashSet(plugin.getDefaultMode());
|
return Sets.newHashSet(plugin.getDefaultMode());
|
||||||
}
|
}
|
||||||
List<Long> modelIds = plugin.getModelList();
|
List<Long> modelIds = plugin.getViewList();
|
||||||
Set<Long> pluginMatchedModel = Sets.newHashSet();
|
Set<Long> pluginMatchedModel = Sets.newHashSet();
|
||||||
for (Long modelId : modelIds) {
|
for (Long modelId : modelIds) {
|
||||||
if (matchedModel.contains(modelId)) {
|
if (matchedViews.contains(modelId)) {
|
||||||
pluginMatchedModel.add(modelId);
|
pluginMatchedModel.add(modelId);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -4,6 +4,7 @@ import lombok.AllArgsConstructor;
|
|||||||
import lombok.Builder;
|
import lombok.Builder;
|
||||||
import lombok.Data;
|
import lombok.Data;
|
||||||
import lombok.NoArgsConstructor;
|
import lombok.NoArgsConstructor;
|
||||||
|
|
||||||
import java.util.Set;
|
import java.util.Set;
|
||||||
|
|
||||||
@Data
|
@Data
|
||||||
@@ -14,7 +15,7 @@ public class PluginRecallResult {
|
|||||||
|
|
||||||
private Plugin plugin;
|
private Plugin plugin;
|
||||||
|
|
||||||
private Set<Long> modelIds;
|
private Set<Long> viewIds;
|
||||||
|
|
||||||
private double score;
|
private double score;
|
||||||
|
|
||||||
|
|||||||
@@ -1,8 +1,8 @@
|
|||||||
package com.tencent.supersonic.chat.core.pojo;
|
package com.tencent.supersonic.chat.core.pojo;
|
||||||
|
|
||||||
|
import com.fasterxml.jackson.annotation.JsonIgnore;
|
||||||
import com.tencent.supersonic.auth.api.authentication.pojo.User;
|
import com.tencent.supersonic.auth.api.authentication.pojo.User;
|
||||||
import com.tencent.supersonic.chat.api.pojo.SchemaMapInfo;
|
import com.tencent.supersonic.chat.api.pojo.SchemaMapInfo;
|
||||||
import com.tencent.supersonic.chat.api.pojo.SchemaModelClusterMapInfo;
|
|
||||||
import com.tencent.supersonic.chat.api.pojo.SemanticSchema;
|
import com.tencent.supersonic.chat.api.pojo.SemanticSchema;
|
||||||
import com.tencent.supersonic.chat.api.pojo.request.QueryFilters;
|
import com.tencent.supersonic.chat.api.pojo.request.QueryFilters;
|
||||||
import com.tencent.supersonic.chat.api.pojo.response.ChatConfigRichResp;
|
import com.tencent.supersonic.chat.api.pojo.response.ChatConfigRichResp;
|
||||||
@@ -11,15 +11,16 @@ import com.tencent.supersonic.chat.core.config.OptimizationConfig;
|
|||||||
import com.tencent.supersonic.chat.core.plugin.Plugin;
|
import com.tencent.supersonic.chat.core.plugin.Plugin;
|
||||||
import com.tencent.supersonic.chat.core.query.SemanticQuery;
|
import com.tencent.supersonic.chat.core.query.SemanticQuery;
|
||||||
import com.tencent.supersonic.common.util.ContextUtils;
|
import com.tencent.supersonic.common.util.ContextUtils;
|
||||||
|
import lombok.AllArgsConstructor;
|
||||||
|
import lombok.Builder;
|
||||||
|
import lombok.Data;
|
||||||
|
import lombok.NoArgsConstructor;
|
||||||
|
|
||||||
import java.util.ArrayList;
|
import java.util.ArrayList;
|
||||||
import java.util.Comparator;
|
import java.util.Comparator;
|
||||||
import java.util.List;
|
import java.util.List;
|
||||||
import java.util.Map;
|
import java.util.Map;
|
||||||
import java.util.stream.Collectors;
|
import java.util.stream.Collectors;
|
||||||
import lombok.AllArgsConstructor;
|
|
||||||
import lombok.Builder;
|
|
||||||
import lombok.Data;
|
|
||||||
import lombok.NoArgsConstructor;
|
|
||||||
|
|
||||||
@Data
|
@Data
|
||||||
@Builder
|
@Builder
|
||||||
@@ -29,18 +30,22 @@ public class QueryContext {
|
|||||||
|
|
||||||
private String queryText;
|
private String queryText;
|
||||||
private Integer chatId;
|
private Integer chatId;
|
||||||
private Long modelId;
|
private Long viewId;
|
||||||
private User user;
|
private User user;
|
||||||
private boolean saveAnswer = true;
|
private boolean saveAnswer = true;
|
||||||
private Integer agentId;
|
private Integer agentId;
|
||||||
private QueryFilters queryFilters;
|
private QueryFilters queryFilters;
|
||||||
private List<SemanticQuery> candidateQueries = new ArrayList<>();
|
private List<SemanticQuery> candidateQueries = new ArrayList<>();
|
||||||
private SchemaMapInfo mapInfo = new SchemaMapInfo();
|
private SchemaMapInfo mapInfo = new SchemaMapInfo();
|
||||||
private SchemaModelClusterMapInfo modelClusterMapInfo = new SchemaModelClusterMapInfo();
|
@JsonIgnore
|
||||||
private SemanticSchema semanticSchema;
|
private SemanticSchema semanticSchema;
|
||||||
|
@JsonIgnore
|
||||||
private Agent agent;
|
private Agent agent;
|
||||||
|
@JsonIgnore
|
||||||
private Map<Long, ChatConfigRichResp> modelIdToChatRichConfig;
|
private Map<Long, ChatConfigRichResp> modelIdToChatRichConfig;
|
||||||
|
@JsonIgnore
|
||||||
private Map<String, Plugin> nameToPlugin;
|
private Map<String, Plugin> nameToPlugin;
|
||||||
|
@JsonIgnore
|
||||||
private List<Plugin> pluginList;
|
private List<Plugin> pluginList;
|
||||||
|
|
||||||
public List<SemanticQuery> getCandidateQueries() {
|
public List<SemanticQuery> getCandidateQueries() {
|
||||||
|
|||||||
@@ -6,7 +6,7 @@ import com.tencent.supersonic.chat.api.pojo.SemanticParseInfo;
|
|||||||
import com.tencent.supersonic.chat.api.pojo.SemanticSchema;
|
import com.tencent.supersonic.chat.api.pojo.SemanticSchema;
|
||||||
import com.tencent.supersonic.chat.api.pojo.response.SqlInfo;
|
import com.tencent.supersonic.chat.api.pojo.response.SqlInfo;
|
||||||
import com.tencent.supersonic.chat.core.config.OptimizationConfig;
|
import com.tencent.supersonic.chat.core.config.OptimizationConfig;
|
||||||
import com.tencent.supersonic.chat.core.knowledge.semantic.SemanticInterpreter;
|
import com.tencent.supersonic.chat.core.query.semantic.SemanticInterpreter;
|
||||||
import com.tencent.supersonic.chat.core.utils.ComponentFactory;
|
import com.tencent.supersonic.chat.core.utils.ComponentFactory;
|
||||||
import com.tencent.supersonic.chat.core.utils.QueryReqBuilder;
|
import com.tencent.supersonic.chat.core.utils.QueryReqBuilder;
|
||||||
import com.tencent.supersonic.common.pojo.Aggregator;
|
import com.tencent.supersonic.common.pojo.Aggregator;
|
||||||
@@ -14,20 +14,21 @@ import com.tencent.supersonic.common.pojo.Filter;
|
|||||||
import com.tencent.supersonic.common.pojo.Order;
|
import com.tencent.supersonic.common.pojo.Order;
|
||||||
import com.tencent.supersonic.common.pojo.enums.TimeDimensionEnum;
|
import com.tencent.supersonic.common.pojo.enums.TimeDimensionEnum;
|
||||||
import com.tencent.supersonic.common.util.ContextUtils;
|
import com.tencent.supersonic.common.util.ContextUtils;
|
||||||
import com.tencent.supersonic.headless.api.enums.QueryType;
|
import com.tencent.supersonic.headless.api.pojo.enums.QueryType;
|
||||||
import com.tencent.supersonic.headless.api.request.ExplainSqlReq;
|
import com.tencent.supersonic.headless.api.pojo.request.ExplainSqlReq;
|
||||||
import com.tencent.supersonic.headless.api.request.QuerySqlReq;
|
import com.tencent.supersonic.headless.api.pojo.request.QuerySqlReq;
|
||||||
import com.tencent.supersonic.headless.api.request.QueryStructReq;
|
import com.tencent.supersonic.headless.api.pojo.request.QueryStructReq;
|
||||||
import com.tencent.supersonic.headless.api.response.ExplainResp;
|
import com.tencent.supersonic.headless.api.pojo.response.ExplainResp;
|
||||||
|
import lombok.ToString;
|
||||||
|
import lombok.extern.slf4j.Slf4j;
|
||||||
|
import org.apache.commons.collections4.CollectionUtils;
|
||||||
|
import org.apache.commons.lang3.StringUtils;
|
||||||
|
|
||||||
import java.io.Serializable;
|
import java.io.Serializable;
|
||||||
import java.util.List;
|
import java.util.List;
|
||||||
import java.util.Map;
|
import java.util.Map;
|
||||||
import java.util.Objects;
|
import java.util.Objects;
|
||||||
import java.util.stream.Collectors;
|
import java.util.stream.Collectors;
|
||||||
import lombok.ToString;
|
|
||||||
import lombok.extern.slf4j.Slf4j;
|
|
||||||
import org.apache.commons.collections4.CollectionUtils;
|
|
||||||
import org.apache.commons.lang3.StringUtils;
|
|
||||||
|
|
||||||
@Slf4j
|
@Slf4j
|
||||||
@ToString
|
@ToString
|
||||||
@@ -48,7 +49,7 @@ public abstract class BaseSemanticQuery implements SemanticQuery, Serializable {
|
|||||||
explainSqlReq = ExplainSqlReq.builder()
|
explainSqlReq = ExplainSqlReq.builder()
|
||||||
.queryTypeEnum(QueryType.SQL)
|
.queryTypeEnum(QueryType.SQL)
|
||||||
.queryReq(QueryReqBuilder.buildS2SQLReq(
|
.queryReq(QueryReqBuilder.buildS2SQLReq(
|
||||||
sqlInfo.getCorrectS2SQL(), parseInfo.getModel().getModelIds()
|
sqlInfo.getCorrectS2SQL(), parseInfo.getViewId()
|
||||||
))
|
))
|
||||||
.build();
|
.build();
|
||||||
} else {
|
} else {
|
||||||
@@ -83,7 +84,7 @@ public abstract class BaseSemanticQuery implements SemanticQuery, Serializable {
|
|||||||
}
|
}
|
||||||
|
|
||||||
protected void convertBizNameToName(SemanticSchema semanticSchema, QueryStructReq queryStructReq) {
|
protected void convertBizNameToName(SemanticSchema semanticSchema, QueryStructReq queryStructReq) {
|
||||||
Map<String, String> bizNameToName = semanticSchema.getBizNameToName(queryStructReq.getModelIdSet());
|
Map<String, String> bizNameToName = semanticSchema.getBizNameToName(queryStructReq.getViewId());
|
||||||
bizNameToName.putAll(TimeDimensionEnum.getNameToNameMap());
|
bizNameToName.putAll(TimeDimensionEnum.getNameToNameMap());
|
||||||
|
|
||||||
List<Order> orders = queryStructReq.getOrders();
|
List<Order> orders = queryStructReq.getOrders();
|
||||||
@@ -100,18 +101,17 @@ public abstract class BaseSemanticQuery implements SemanticQuery, Serializable {
|
|||||||
}
|
}
|
||||||
List<String> groups = queryStructReq.getGroups();
|
List<String> groups = queryStructReq.getGroups();
|
||||||
if (CollectionUtils.isNotEmpty(groups)) {
|
if (CollectionUtils.isNotEmpty(groups)) {
|
||||||
groups = groups.stream().map(group -> bizNameToName.get(group)).collect(Collectors.toList());
|
groups = groups.stream().map(bizNameToName::get).collect(Collectors.toList());
|
||||||
queryStructReq.setGroups(groups);
|
queryStructReq.setGroups(groups);
|
||||||
}
|
}
|
||||||
List<Filter> dimensionFilters = queryStructReq.getDimensionFilters();
|
List<Filter> dimensionFilters = queryStructReq.getDimensionFilters();
|
||||||
if (CollectionUtils.isNotEmpty(dimensionFilters)) {
|
if (CollectionUtils.isNotEmpty(dimensionFilters)) {
|
||||||
dimensionFilters.stream().forEach(filter -> filter.setName(bizNameToName.get(filter.getBizName())));
|
dimensionFilters.forEach(filter -> filter.setName(bizNameToName.get(filter.getBizName())));
|
||||||
}
|
}
|
||||||
List<Filter> metricFilters = queryStructReq.getMetricFilters();
|
List<Filter> metricFilters = queryStructReq.getMetricFilters();
|
||||||
if (CollectionUtils.isNotEmpty(dimensionFilters)) {
|
if (CollectionUtils.isNotEmpty(dimensionFilters)) {
|
||||||
metricFilters.stream().forEach(filter -> filter.setName(bizNameToName.get(filter.getBizName())));
|
metricFilters.forEach(filter -> filter.setName(bizNameToName.get(filter.getBizName())));
|
||||||
}
|
}
|
||||||
queryStructReq.setModelName(parseInfo.getModelName());
|
|
||||||
}
|
}
|
||||||
|
|
||||||
protected void initS2SqlByStruct(SemanticSchema semanticSchema) {
|
protected void initS2SqlByStruct(SemanticSchema semanticSchema) {
|
||||||
@@ -121,7 +121,7 @@ public abstract class BaseSemanticQuery implements SemanticQuery, Serializable {
|
|||||||
}
|
}
|
||||||
QueryStructReq queryStructReq = convertQueryStruct();
|
QueryStructReq queryStructReq = convertQueryStruct();
|
||||||
convertBizNameToName(semanticSchema, queryStructReq);
|
convertBizNameToName(semanticSchema, queryStructReq);
|
||||||
QuerySqlReq querySQLReq = queryStructReq.convert(queryStructReq);
|
QuerySqlReq querySQLReq = queryStructReq.convert();
|
||||||
parseInfo.getSqlInfo().setS2SQL(querySQLReq.getSql());
|
parseInfo.getSqlInfo().setS2SQL(querySQLReq.getSql());
|
||||||
parseInfo.getSqlInfo().setCorrectS2SQL(querySQLReq.getSql());
|
parseInfo.getSqlInfo().setCorrectS2SQL(querySQLReq.getSql());
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -2,14 +2,14 @@ package com.tencent.supersonic.chat.core.query.llm.analytics;
|
|||||||
|
|
||||||
import com.google.common.collect.Lists;
|
import com.google.common.collect.Lists;
|
||||||
import com.tencent.supersonic.auth.api.authentication.pojo.User;
|
import com.tencent.supersonic.auth.api.authentication.pojo.User;
|
||||||
import com.tencent.supersonic.chat.api.pojo.SchemaElement;
|
import com.tencent.supersonic.headless.api.pojo.SchemaElement;
|
||||||
import com.tencent.supersonic.chat.api.pojo.SchemaElementMatch;
|
import com.tencent.supersonic.chat.api.pojo.SchemaElementMatch;
|
||||||
import com.tencent.supersonic.chat.api.pojo.SchemaElementType;
|
import com.tencent.supersonic.headless.api.pojo.SchemaElementType;
|
||||||
import com.tencent.supersonic.chat.api.pojo.SemanticSchema;
|
import com.tencent.supersonic.chat.api.pojo.SemanticSchema;
|
||||||
import com.tencent.supersonic.chat.api.pojo.response.QueryResult;
|
import com.tencent.supersonic.chat.api.pojo.response.QueryResult;
|
||||||
import com.tencent.supersonic.chat.api.pojo.response.QueryState;
|
import com.tencent.supersonic.chat.api.pojo.response.QueryState;
|
||||||
import com.tencent.supersonic.chat.core.config.OptimizationConfig;
|
import com.tencent.supersonic.chat.core.config.OptimizationConfig;
|
||||||
import com.tencent.supersonic.chat.core.knowledge.semantic.SemanticInterpreter;
|
import com.tencent.supersonic.chat.core.query.semantic.SemanticInterpreter;
|
||||||
import com.tencent.supersonic.chat.core.query.QueryManager;
|
import com.tencent.supersonic.chat.core.query.QueryManager;
|
||||||
import com.tencent.supersonic.chat.core.query.llm.LLMSemanticQuery;
|
import com.tencent.supersonic.chat.core.query.llm.LLMSemanticQuery;
|
||||||
import com.tencent.supersonic.chat.core.utils.ComponentFactory;
|
import com.tencent.supersonic.chat.core.utils.ComponentFactory;
|
||||||
@@ -19,8 +19,8 @@ import com.tencent.supersonic.common.pojo.QueryColumn;
|
|||||||
import com.tencent.supersonic.common.pojo.enums.AggOperatorEnum;
|
import com.tencent.supersonic.common.pojo.enums.AggOperatorEnum;
|
||||||
import com.tencent.supersonic.common.pojo.enums.QueryType;
|
import com.tencent.supersonic.common.pojo.enums.QueryType;
|
||||||
import com.tencent.supersonic.common.util.ContextUtils;
|
import com.tencent.supersonic.common.util.ContextUtils;
|
||||||
import com.tencent.supersonic.headless.api.request.QueryStructReq;
|
import com.tencent.supersonic.headless.api.pojo.request.QueryStructReq;
|
||||||
import com.tencent.supersonic.headless.api.response.SemanticQueryResp;
|
import com.tencent.supersonic.headless.api.pojo.response.SemanticQueryResp;
|
||||||
import java.util.HashMap;
|
import java.util.HashMap;
|
||||||
import java.util.List;
|
import java.util.List;
|
||||||
import java.util.Map;
|
import java.util.Map;
|
||||||
|
|||||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user