mirror of
https://github.com/tencentmusic/supersonic.git
synced 2025-12-10 11:07:06 +00:00
Compare commits
78 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
532a00518c | ||
|
|
895f38b6f7 | ||
|
|
eba3a8ad34 | ||
|
|
6813582ea0 | ||
|
|
6b6e54e95f | ||
|
|
26260c79f1 | ||
|
|
0a42932db2 | ||
|
|
3b4678d682 | ||
|
|
c0458ccf0e | ||
|
|
c7c70208ff | ||
|
|
3a38200448 | ||
|
|
b72e280990 | ||
|
|
0541614dad | ||
|
|
0beb3cefd3 | ||
|
|
afdf18398c | ||
|
|
bdf7df933b | ||
|
|
a909493414 | ||
|
|
aa86fc9275 | ||
|
|
e36060eae4 | ||
|
|
3893e897cb | ||
|
|
617cd87a48 | ||
|
|
042a610231 | ||
|
|
b555beae21 | ||
|
|
ba01cdb9bc | ||
|
|
e610dd8246 | ||
|
|
01bc4dcacf | ||
|
|
c224b81160 | ||
|
|
bfd0e040da | ||
|
|
f50a3157d5 | ||
|
|
61316e939c | ||
|
|
fab1bac50c | ||
|
|
d8043c356f | ||
|
|
e95a528219 | ||
|
|
16643e8d75 | ||
|
|
417a43dee8 | ||
|
|
4a22fdf452 | ||
|
|
16afbc6945 | ||
|
|
fc5ff01eca | ||
|
|
d10801ef38 | ||
|
|
33240cc382 | ||
|
|
3317f1b7ec | ||
|
|
b85778babd | ||
|
|
699a33b1c1 | ||
|
|
fdb69547e6 | ||
|
|
39158d6877 | ||
|
|
329ad327b0 | ||
|
|
9600456bae | ||
|
|
74d0ec2b23 | ||
|
|
8a342eb32a | ||
|
|
e801c448be | ||
|
|
da5e7b9b75 | ||
|
|
75853a8e9e | ||
|
|
2546d1c0e1 | ||
|
|
0c4c6d83ef | ||
|
|
4d4922d269 | ||
|
|
1004f71ba4 | ||
|
|
c13a0e672c | ||
|
|
491c76368c | ||
|
|
2c1c443b3e | ||
|
|
f29b1854ba | ||
|
|
7f15bacca4 | ||
|
|
df975b231d | ||
|
|
24b442baef | ||
|
|
31f8c1df35 | ||
|
|
26aefceb04 | ||
|
|
954c67c947 | ||
|
|
fdfad515dd | ||
|
|
c398ac1a84 | ||
|
|
aae3d6b297 | ||
|
|
923c65b2f9 | ||
|
|
22775343f4 | ||
|
|
d9533c53ea | ||
|
|
841db25198 | ||
|
|
922201c181 | ||
|
|
48fb01f6bc | ||
|
|
9d6f96e6d4 | ||
|
|
42a6f61456 | ||
|
|
163e782f51 |
35
.github/workflows/mac-ci.yml
vendored
Normal file
35
.github/workflows/mac-ci.yml
vendored
Normal file
@@ -0,0 +1,35 @@
|
||||
name: supersonic mac CI
|
||||
|
||||
on:
|
||||
push:
|
||||
branches:
|
||||
- master
|
||||
pull_request:
|
||||
branches:
|
||||
- master
|
||||
|
||||
jobs:
|
||||
build:
|
||||
runs-on: macos-latest # Specify a macOS runner
|
||||
|
||||
steps:
|
||||
- uses: actions/checkout@v2
|
||||
|
||||
- name: Set up JDK 8
|
||||
uses: actions/setup-java@v2
|
||||
with:
|
||||
java-version: '8'
|
||||
distribution: 'adopt'
|
||||
|
||||
- name: Cache Maven packages
|
||||
uses: actions/cache@v2
|
||||
with:
|
||||
path: ~/Library/Caches/Maven # macOS Maven cache path
|
||||
key: ${{ runner.os }}-m2-${{ hashFiles('**/pom.xml') }}
|
||||
restore-keys: ${{ runner.os }}-m2
|
||||
|
||||
- name: Build with Maven
|
||||
run: mvn -B package --file pom.xml
|
||||
|
||||
- name: Test with Maven
|
||||
run: mvn test
|
||||
@@ -1,4 +1,4 @@
|
||||
name: supersonic CI
|
||||
name: supersonic ubuntu CI
|
||||
|
||||
on:
|
||||
push:
|
||||
35
.github/workflows/windows-ci.yml
vendored
Normal file
35
.github/workflows/windows-ci.yml
vendored
Normal file
@@ -0,0 +1,35 @@
|
||||
name: supersonic windows CI
|
||||
|
||||
on:
|
||||
push:
|
||||
branches:
|
||||
- master
|
||||
pull_request:
|
||||
branches:
|
||||
- master
|
||||
|
||||
jobs:
|
||||
build:
|
||||
runs-on: windows-latest # Specify a Windows runner
|
||||
|
||||
steps:
|
||||
- uses: actions/checkout@v2
|
||||
|
||||
- name: Set up JDK 8
|
||||
uses: actions/setup-java@v2
|
||||
with:
|
||||
java-version: '8'
|
||||
distribution: 'adopt'
|
||||
|
||||
- name: Cache Maven packages
|
||||
uses: actions/cache@v2
|
||||
with:
|
||||
path: ~\.m2 # Windows uses a backslash for paths
|
||||
key: ${{ runner.os }}-m2-${{ hashFiles('**/pom.xml') }}
|
||||
restore-keys: ${{ runner.os }}-m2
|
||||
|
||||
- name: Build with Maven
|
||||
run: mvn -B package --file pom.xml
|
||||
|
||||
- name: Test with Maven
|
||||
run: mvn test
|
||||
@@ -4,6 +4,14 @@
|
||||
- "Breaking Changes" describes any changes that may break existing functionality or cause
|
||||
compatibility issues with previous versions.
|
||||
|
||||
## SuperSonic [0.8.6] - 2024-02-23
|
||||
|
||||
### Added
|
||||
- support view abstraction to Headless.
|
||||
- add the Metric API to Headless and optimizing the Headless API.
|
||||
- add integration tests to Headless.
|
||||
- add TimeCorrector to Chat.
|
||||
|
||||
## SuperSonic [0.8.4] - 2024-01-19
|
||||
|
||||
### Added
|
||||
|
||||
@@ -4,7 +4,7 @@
|
||||
|
||||
# SuperSonic (超音数)
|
||||
|
||||
**SuperSonic is the next-generation LLM-powered data analytics platform that integrates ChatBI and HeadlessBI**. SuperSonic provides a chat interface that empowers users to query data using natural language and visualize the results with suitable charts. To enable such experience, the only thing necessary is to build logical semantic models (definition of entities/metrics/dimensions/tags, along with their meaning, context and relationships) on top of physical data models, and **no data modification or copying** is required. Meanwhile, SuperSonic is designed to be **highly extensible**, allowing custom functionalities to be added and configured with Java SPI.
|
||||
**SuperSonic is the next-generation LLM-powered data analytics platform that integrates ChatBI and HeadlessBI**. SuperSonic provides a chat interface that empowers users to query data using natural language and visualize the results with suitable charts. To enable such experience, the only thing necessary is to build logical semantic models (definition of entities/metrics/dimensions/tags, along with their meaning, context and relationships) with semantic layer, and **no data modification or copying** is required. Meanwhile, SuperSonic is designed to be **highly extensible**, allowing custom functionalities to be added and configured with Java SPI.
|
||||
|
||||
<img src="./docs/images/supersonic_demo.gif" height="100%" width="100%" align="center"/>
|
||||
|
||||
@@ -13,7 +13,8 @@
|
||||
The emergence of Large Language Model (LLM) like ChatGPT is reshaping the way information is retrieved. In the field of data analytics, both academia and industry are primarily focused on leveraging LLM to convert natural language into SQL (so called Text2SQL or NL2SQL). While some approaches exhibit promising results, their **reliability** and **efficiency** are insufficient for real-world applications.
|
||||
|
||||
From our perspective, the key to filling the real-world gap lies in three aspects:
|
||||
1. Integrate ChatBI with HeadlessBI encapsulating underlying data context (joins, keys, formulas, etc) to **reduce complexity**.
|
||||
1. Integrate ChatBI with HeadlessBI encapsulating underlying data context (joins, keys, formulas, etc) to **reduce complexity**.
|
||||
<img src="./docs/images/supersonic_ideas.png" height="65%" width="65%" align="center"/>
|
||||
2. Augment the LLM with schema mappers(as a kind of preprocessor) and semantic correctors(as a kind of postprocessor) to **mitigate hallucination**.
|
||||
3. Utilize rule-based schema parsers when necessary to **improve efficiency**(in terms of latency and cost).
|
||||
|
||||
|
||||
@@ -10,6 +10,8 @@
|
||||
|
||||
在我们看来,为了在实际场景发挥价值,有三个关键点:
|
||||
1. 融合HeadlessBI,通过统一语义层封装底层数据细节(关联、键值、公式等),降低SQL生成的**复杂度**。
|
||||
|
||||
<img src="./docs/images/supersonic_ideas.png" height="65%" width="65%" align="center"/>
|
||||
2. 通过一前一后的模式映射器和语义修正器,来缓解LLM常见的**幻觉**现象。
|
||||
3. 设计启发式的规则,在一些特定场景提升语义解析的**效率**。
|
||||
|
||||
|
||||
@@ -24,6 +24,10 @@ public class User {
|
||||
return new User(id, name, displayName, email, isAdmin);
|
||||
}
|
||||
|
||||
public static User get(Long id, String name) {
|
||||
return new User(id, name, name, name, 0);
|
||||
}
|
||||
|
||||
public static User getFakeUser() {
|
||||
return new User(1L, "admin", "admin", "admin@email", 1);
|
||||
}
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
package com.tencent.supersonic.chat.api.pojo;
|
||||
|
||||
import com.tencent.supersonic.headless.api.pojo.SchemaElement;
|
||||
import lombok.AllArgsConstructor;
|
||||
import lombok.Builder;
|
||||
import lombok.Data;
|
||||
|
||||
@@ -1,5 +1,7 @@
|
||||
package com.tencent.supersonic.chat.api.pojo;
|
||||
|
||||
import com.google.common.collect.Lists;
|
||||
|
||||
import java.util.HashMap;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
@@ -7,25 +9,25 @@ import java.util.Set;
|
||||
|
||||
public class SchemaMapInfo {
|
||||
|
||||
private Map<Long, List<SchemaElementMatch>> modelElementMatches = new HashMap<>();
|
||||
private Map<Long, List<SchemaElementMatch>> viewElementMatches = new HashMap<>();
|
||||
|
||||
public Set<Long> getMatchedModels() {
|
||||
return modelElementMatches.keySet();
|
||||
public Set<Long> getMatchedViewInfos() {
|
||||
return viewElementMatches.keySet();
|
||||
}
|
||||
|
||||
public List<SchemaElementMatch> getMatchedElements(Long model) {
|
||||
return modelElementMatches.get(model);
|
||||
public List<SchemaElementMatch> getMatchedElements(Long view) {
|
||||
return viewElementMatches.getOrDefault(view, Lists.newArrayList());
|
||||
}
|
||||
|
||||
public Map<Long, List<SchemaElementMatch>> getModelElementMatches() {
|
||||
return modelElementMatches;
|
||||
public Map<Long, List<SchemaElementMatch>> getViewElementMatches() {
|
||||
return viewElementMatches;
|
||||
}
|
||||
|
||||
public void setModelElementMatches(Map<Long, List<SchemaElementMatch>> modelElementMatches) {
|
||||
this.modelElementMatches = modelElementMatches;
|
||||
public void setViewElementMatches(Map<Long, List<SchemaElementMatch>> viewElementMatches) {
|
||||
this.viewElementMatches = viewElementMatches;
|
||||
}
|
||||
|
||||
public void setMatchedElements(Long model, List<SchemaElementMatch> elementMatches) {
|
||||
modelElementMatches.put(model, elementMatches);
|
||||
public void setMatchedElements(Long view, List<SchemaElementMatch> elementMatches) {
|
||||
viewElementMatches.put(view, elementMatches);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,61 +0,0 @@
|
||||
package com.tencent.supersonic.chat.api.pojo;
|
||||
|
||||
import com.clickhouse.client.internal.apache.commons.compress.utils.Lists;
|
||||
import com.tencent.supersonic.common.pojo.ModelCluster;
|
||||
import lombok.Data;
|
||||
import org.springframework.util.CollectionUtils;
|
||||
|
||||
import java.util.HashMap;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
import java.util.Set;
|
||||
|
||||
@Data
|
||||
public class SchemaModelClusterMapInfo {
|
||||
|
||||
private Map<String, List<SchemaElementMatch>> modelElementMatches = new HashMap<>();
|
||||
|
||||
public Set<String> getMatchedModelClusters() {
|
||||
return modelElementMatches.keySet();
|
||||
}
|
||||
|
||||
public List<SchemaElementMatch> getMatchedElements(Long modelId) {
|
||||
for (String key : modelElementMatches.keySet()) {
|
||||
if (ModelCluster.getModelIdFromKey(key).contains(modelId)) {
|
||||
return modelElementMatches.get(key);
|
||||
}
|
||||
}
|
||||
return Lists.newArrayList();
|
||||
}
|
||||
|
||||
public List<SchemaElementMatch> getMatchedElements(String modelCluster) {
|
||||
return modelElementMatches.get(modelCluster);
|
||||
}
|
||||
|
||||
public Map<String, List<SchemaElementMatch>> getModelElementMatches() {
|
||||
return modelElementMatches;
|
||||
}
|
||||
|
||||
public Map<String, List<SchemaElementMatch>> getElementMatchesByModelIds(Set<Long> modelIds) {
|
||||
if (CollectionUtils.isEmpty(modelIds)) {
|
||||
return modelElementMatches;
|
||||
}
|
||||
Map<String, List<SchemaElementMatch>> modelElementMatchesFiltered = new HashMap<>();
|
||||
for (String key : modelElementMatches.keySet()) {
|
||||
for (Long modelId : modelIds) {
|
||||
if (ModelCluster.getModelIdFromKey(key).contains(modelId)) {
|
||||
modelElementMatchesFiltered.put(key, modelElementMatches.get(key));
|
||||
}
|
||||
}
|
||||
}
|
||||
return modelElementMatchesFiltered;
|
||||
}
|
||||
|
||||
public void setModelElementMatches(Map<String, List<SchemaElementMatch>> modelElementMatches) {
|
||||
this.modelElementMatches = modelElementMatches;
|
||||
}
|
||||
|
||||
public void setMatchedElements(String modelCluster, List<SchemaElementMatch> elementMatches) {
|
||||
modelElementMatches.put(modelCluster, elementMatches);
|
||||
}
|
||||
}
|
||||
@@ -5,11 +5,11 @@ import com.tencent.supersonic.chat.api.pojo.response.SqlInfo;
|
||||
import com.tencent.supersonic.chat.api.pojo.request.QueryFilter;
|
||||
import com.tencent.supersonic.chat.api.pojo.response.EntityInfo;
|
||||
import com.tencent.supersonic.common.pojo.DateConf;
|
||||
import com.tencent.supersonic.common.pojo.ModelCluster;
|
||||
import com.tencent.supersonic.common.pojo.Order;
|
||||
import com.tencent.supersonic.common.pojo.enums.QueryType;
|
||||
import com.tencent.supersonic.common.pojo.enums.AggregateTypeEnum;
|
||||
import com.tencent.supersonic.common.pojo.enums.FilterType;
|
||||
import com.tencent.supersonic.headless.api.pojo.SchemaElement;
|
||||
import lombok.Data;
|
||||
|
||||
import java.util.ArrayList;
|
||||
@@ -26,7 +26,7 @@ public class SemanticParseInfo {
|
||||
|
||||
private Integer id;
|
||||
private String queryMode;
|
||||
private ModelCluster model = new ModelCluster();
|
||||
private SchemaElement view;
|
||||
private Set<SchemaElement> metrics = new TreeSet<>(new SchemaNameLengthComparator());
|
||||
private Set<SchemaElement> dimensions = new LinkedHashSet();
|
||||
private SchemaElement entity;
|
||||
@@ -44,20 +44,6 @@ public class SemanticParseInfo {
|
||||
private SqlInfo sqlInfo = new SqlInfo();
|
||||
private QueryType queryType = QueryType.ID;
|
||||
|
||||
public String getModelClusterKey() {
|
||||
if (model == null) {
|
||||
return "";
|
||||
}
|
||||
return model.getKey();
|
||||
}
|
||||
|
||||
public String getModelName() {
|
||||
if (model == null) {
|
||||
return "";
|
||||
}
|
||||
return model.getName();
|
||||
}
|
||||
|
||||
private static class SchemaNameLengthComparator implements Comparator<SchemaElement> {
|
||||
|
||||
@Override
|
||||
@@ -86,27 +72,15 @@ public class SemanticParseInfo {
|
||||
return metrics;
|
||||
}
|
||||
|
||||
private Map<Long, Integer> getModelElementCountMap() {
|
||||
Map<Long, Integer> elementCountMap = new HashMap<>();
|
||||
elementMatches.stream().filter(element -> element.getElement().getModel() != null)
|
||||
.forEach(element -> {
|
||||
int count = elementCountMap.getOrDefault(element.getElement().getModel(), 0);
|
||||
elementCountMap.put(element.getElement().getModel(), count + 1);
|
||||
});
|
||||
return elementCountMap;
|
||||
public Long getViewId() {
|
||||
if (view == null) {
|
||||
return null;
|
||||
}
|
||||
return view.getView();
|
||||
}
|
||||
|
||||
public Long getModelId() {
|
||||
Map<Long, Integer> elementCountMap = getModelElementCountMap();
|
||||
Long modelId = -1L;
|
||||
int maxCnt = 0;
|
||||
for (Long model : elementCountMap.keySet()) {
|
||||
if (elementCountMap.get(model) > maxCnt) {
|
||||
maxCnt = elementCountMap.get(model);
|
||||
modelId = model;
|
||||
}
|
||||
}
|
||||
return modelId;
|
||||
public SchemaElement getModel() {
|
||||
return view;
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
@@ -1,26 +1,26 @@
|
||||
package com.tencent.supersonic.chat.api.pojo;
|
||||
|
||||
import com.tencent.supersonic.headless.api.pojo.SchemaElement;
|
||||
import com.tencent.supersonic.headless.api.pojo.SchemaElementType;
|
||||
import java.io.Serializable;
|
||||
import java.util.ArrayList;
|
||||
import java.util.HashMap;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
import java.util.Objects;
|
||||
import java.util.Optional;
|
||||
import java.util.Set;
|
||||
import java.util.stream.Collectors;
|
||||
import org.springframework.util.CollectionUtils;
|
||||
|
||||
public class SemanticSchema implements Serializable {
|
||||
|
||||
private List<ModelSchema> modelSchemaList;
|
||||
private List<ViewSchema> viewSchemaList;
|
||||
|
||||
public SemanticSchema(List<ModelSchema> modelSchemaList) {
|
||||
this.modelSchemaList = modelSchemaList;
|
||||
public SemanticSchema(List<ViewSchema> viewSchemaList) {
|
||||
this.viewSchemaList = viewSchemaList;
|
||||
}
|
||||
|
||||
public void add(ModelSchema schema) {
|
||||
modelSchemaList.add(schema);
|
||||
public void add(ViewSchema schema) {
|
||||
viewSchemaList.add(schema);
|
||||
}
|
||||
|
||||
public SchemaElement getElement(SchemaElementType elementType, long elementID) {
|
||||
@@ -30,8 +30,8 @@ public class SemanticSchema implements Serializable {
|
||||
case ENTITY:
|
||||
element = getElementsById(elementID, getEntities());
|
||||
break;
|
||||
case MODEL:
|
||||
element = getElementsById(elementID, getModels());
|
||||
case VIEW:
|
||||
element = getElementsById(elementID, getViews());
|
||||
break;
|
||||
case METRIC:
|
||||
element = getElementsById(elementID, getMetrics());
|
||||
@@ -52,58 +52,29 @@ public class SemanticSchema implements Serializable {
|
||||
}
|
||||
}
|
||||
|
||||
public SchemaElement getElementByName(SchemaElementType elementType, String name) {
|
||||
Optional<SchemaElement> element = Optional.empty();
|
||||
|
||||
switch (elementType) {
|
||||
case ENTITY:
|
||||
element = getElementsByNameOrAlias(name, getEntities());
|
||||
break;
|
||||
case MODEL:
|
||||
element = getElementsByNameOrAlias(name, getModels());
|
||||
break;
|
||||
case METRIC:
|
||||
element = getElementsByNameOrAlias(name, getMetrics());
|
||||
break;
|
||||
case DIMENSION:
|
||||
element = getElementsByNameOrAlias(name, getDimensions());
|
||||
break;
|
||||
case VALUE:
|
||||
element = getElementsByNameOrAlias(name, getDimensionValues());
|
||||
break;
|
||||
default:
|
||||
}
|
||||
|
||||
if (element.isPresent()) {
|
||||
return element.get();
|
||||
} else {
|
||||
return null;
|
||||
}
|
||||
}
|
||||
|
||||
public Map<Long, String> getModelIdToName() {
|
||||
return modelSchemaList.stream()
|
||||
.collect(Collectors.toMap(a -> a.getModel().getId(), a -> a.getModel().getName(), (k1, k2) -> k1));
|
||||
public Map<Long, String> getViewIdToName() {
|
||||
return viewSchemaList.stream()
|
||||
.collect(Collectors.toMap(a -> a.getView().getId(), a -> a.getView().getName(), (k1, k2) -> k1));
|
||||
}
|
||||
|
||||
public List<SchemaElement> getDimensionValues() {
|
||||
List<SchemaElement> dimensionValues = new ArrayList<>();
|
||||
modelSchemaList.stream().forEach(d -> dimensionValues.addAll(d.getDimensionValues()));
|
||||
viewSchemaList.stream().forEach(d -> dimensionValues.addAll(d.getDimensionValues()));
|
||||
return dimensionValues;
|
||||
}
|
||||
|
||||
public List<SchemaElement> getDimensions() {
|
||||
List<SchemaElement> dimensions = new ArrayList<>();
|
||||
modelSchemaList.stream().forEach(d -> dimensions.addAll(d.getDimensions()));
|
||||
viewSchemaList.stream().forEach(d -> dimensions.addAll(d.getDimensions()));
|
||||
return dimensions;
|
||||
}
|
||||
|
||||
public List<SchemaElement> getDimensions(Set<Long> modelIds) {
|
||||
public List<SchemaElement> getDimensions(Long viewId) {
|
||||
List<SchemaElement> dimensions = getDimensions();
|
||||
return getElementsByModelId(modelIds, dimensions);
|
||||
return getElementsByViewId(viewId, dimensions);
|
||||
}
|
||||
|
||||
public SchemaElement getDimensions(Long id) {
|
||||
public SchemaElement getDimension(Long id) {
|
||||
List<SchemaElement> dimensions = getDimensions();
|
||||
Optional<SchemaElement> dimension = getElementsById(id, dimensions);
|
||||
return dimension.orElse(null);
|
||||
@@ -111,43 +82,43 @@ public class SemanticSchema implements Serializable {
|
||||
|
||||
public List<SchemaElement> getTags() {
|
||||
List<SchemaElement> tags = new ArrayList<>();
|
||||
modelSchemaList.stream().forEach(d -> tags.addAll(d.getTags()));
|
||||
viewSchemaList.stream().forEach(d -> tags.addAll(d.getTags()));
|
||||
return tags;
|
||||
}
|
||||
|
||||
public List<SchemaElement> getTags(Set<Long> modelIds) {
|
||||
public List<SchemaElement> getTags(Long viewId) {
|
||||
List<SchemaElement> tags = new ArrayList<>();
|
||||
modelSchemaList.stream().filter(schemaElement ->
|
||||
modelIds.contains(schemaElement.getModel().getModel()))
|
||||
viewSchemaList.stream().filter(schemaElement ->
|
||||
viewId.equals(schemaElement.getView().getView()))
|
||||
.forEach(d -> tags.addAll(d.getTags()));
|
||||
return tags;
|
||||
}
|
||||
|
||||
public List<SchemaElement> getMetrics() {
|
||||
List<SchemaElement> metrics = new ArrayList<>();
|
||||
modelSchemaList.stream().forEach(d -> metrics.addAll(d.getMetrics()));
|
||||
viewSchemaList.stream().forEach(d -> metrics.addAll(d.getMetrics()));
|
||||
return metrics;
|
||||
}
|
||||
|
||||
public List<SchemaElement> getMetrics(Set<Long> modelIds) {
|
||||
public List<SchemaElement> getMetrics(Long viewId) {
|
||||
List<SchemaElement> metrics = getMetrics();
|
||||
return getElementsByModelId(modelIds, metrics);
|
||||
return getElementsByViewId(viewId, metrics);
|
||||
}
|
||||
|
||||
public List<SchemaElement> getEntities() {
|
||||
List<SchemaElement> entities = new ArrayList<>();
|
||||
modelSchemaList.stream().forEach(d -> entities.add(d.getEntity()));
|
||||
viewSchemaList.stream().forEach(d -> entities.add(d.getEntity()));
|
||||
return entities;
|
||||
}
|
||||
|
||||
public List<SchemaElement> getEntities(Set<Long> modelIds) {
|
||||
public List<SchemaElement> getEntities(Long viewId) {
|
||||
List<SchemaElement> entities = getEntities();
|
||||
return getElementsByModelId(modelIds, entities);
|
||||
return getElementsByViewId(viewId, entities);
|
||||
}
|
||||
|
||||
private List<SchemaElement> getElementsByModelId(Set<Long> modelIds, List<SchemaElement> elements) {
|
||||
private List<SchemaElement> getElementsByViewId(Long viewId, List<SchemaElement> elements) {
|
||||
return elements.stream()
|
||||
.filter(schemaElement -> modelIds.contains(schemaElement.getModel()))
|
||||
.filter(schemaElement -> viewId.equals(schemaElement.getView()))
|
||||
.collect(Collectors.toList());
|
||||
}
|
||||
|
||||
@@ -157,33 +128,30 @@ public class SemanticSchema implements Serializable {
|
||||
.findFirst();
|
||||
}
|
||||
|
||||
private Optional<SchemaElement> getElementsByNameOrAlias(String name, List<SchemaElement> elements) {
|
||||
return elements.stream()
|
||||
.filter(schemaElement ->
|
||||
name.equals(schemaElement.getName()) || (Objects.nonNull(schemaElement.getAlias())
|
||||
&& schemaElement.getAlias().contains(name))
|
||||
).findFirst();
|
||||
public SchemaElement getView(Long viewId) {
|
||||
List<SchemaElement> views = getViews();
|
||||
return getElementsById(viewId, views).orElse(null);
|
||||
}
|
||||
|
||||
public List<SchemaElement> getModels() {
|
||||
List<SchemaElement> models = new ArrayList<>();
|
||||
modelSchemaList.stream().forEach(d -> models.add(d.getModel()));
|
||||
return models;
|
||||
public List<SchemaElement> getViews() {
|
||||
List<SchemaElement> views = new ArrayList<>();
|
||||
viewSchemaList.stream().forEach(d -> views.add(d.getView()));
|
||||
return views;
|
||||
}
|
||||
|
||||
public Map<String, String> getBizNameToName(Set<Long> modelIds) {
|
||||
public Map<String, String> getBizNameToName(Long viewId) {
|
||||
List<SchemaElement> allElements = new ArrayList<>();
|
||||
allElements.addAll(getDimensions(modelIds));
|
||||
allElements.addAll(getMetrics(modelIds));
|
||||
allElements.addAll(getDimensions(viewId));
|
||||
allElements.addAll(getMetrics(viewId));
|
||||
return allElements.stream()
|
||||
.collect(Collectors.toMap(SchemaElement::getBizName, SchemaElement::getName, (k1, k2) -> k1));
|
||||
}
|
||||
|
||||
public Map<Long, ModelSchema> getModelSchemaMap() {
|
||||
if (CollectionUtils.isEmpty(modelSchemaList)) {
|
||||
public Map<Long, ViewSchema> getViewSchemaMap() {
|
||||
if (CollectionUtils.isEmpty(viewSchemaList)) {
|
||||
return new HashMap<>();
|
||||
}
|
||||
return modelSchemaList.stream().collect(Collectors.toMap(modelSchema
|
||||
-> modelSchema.getModel().getModel(), modelSchema -> modelSchema));
|
||||
return viewSchemaList.stream().collect(Collectors.toMap(viewSchema
|
||||
-> viewSchema.getView().getView(), viewSchema -> viewSchema));
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,26 +1,26 @@
|
||||
package com.tencent.supersonic.chat.api.pojo;
|
||||
|
||||
import com.google.common.collect.Sets;
|
||||
import com.tencent.supersonic.common.pojo.ModelRela;
|
||||
import com.tencent.supersonic.headless.api.pojo.QueryConfig;
|
||||
import com.tencent.supersonic.headless.api.pojo.SchemaElement;
|
||||
import com.tencent.supersonic.headless.api.pojo.SchemaElementType;
|
||||
import com.tencent.supersonic.headless.api.pojo.TagTypeDefaultConfig;
|
||||
import com.tencent.supersonic.headless.api.pojo.TimeDefaultConfig;
|
||||
import lombok.Data;
|
||||
import org.springframework.util.CollectionUtils;
|
||||
|
||||
import java.util.ArrayList;
|
||||
import java.util.HashSet;
|
||||
import java.util.List;
|
||||
import java.util.Optional;
|
||||
import java.util.Set;
|
||||
|
||||
@Data
|
||||
public class ModelSchema {
|
||||
public class ViewSchema {
|
||||
|
||||
private SchemaElement model;
|
||||
private SchemaElement view;
|
||||
private Set<SchemaElement> metrics = new HashSet<>();
|
||||
private Set<SchemaElement> dimensions = new HashSet<>();
|
||||
private Set<SchemaElement> dimensionValues = new HashSet<>();
|
||||
private Set<SchemaElement> tags = new HashSet<>();
|
||||
private SchemaElement entity = new SchemaElement();
|
||||
private List<ModelRela> modelRelas = new ArrayList<>();
|
||||
private QueryConfig queryConfig;
|
||||
|
||||
public SchemaElement getElement(SchemaElementType elementType, long elementID) {
|
||||
Optional<SchemaElement> element = Optional.empty();
|
||||
@@ -29,8 +29,8 @@ public class ModelSchema {
|
||||
case ENTITY:
|
||||
element = Optional.ofNullable(entity);
|
||||
break;
|
||||
case MODEL:
|
||||
element = Optional.of(model);
|
||||
case VIEW:
|
||||
element = Optional.of(view);
|
||||
break;
|
||||
case METRIC:
|
||||
element = metrics.stream().filter(e -> e.getId() == elementID).findFirst();
|
||||
@@ -61,8 +61,8 @@ public class ModelSchema {
|
||||
case ENTITY:
|
||||
element = Optional.ofNullable(entity);
|
||||
break;
|
||||
case MODEL:
|
||||
element = Optional.of(model);
|
||||
case VIEW:
|
||||
element = Optional.of(view);
|
||||
break;
|
||||
case METRIC:
|
||||
element = metrics.stream().filter(e -> name.equals(e.getName())).findFirst();
|
||||
@@ -83,16 +83,31 @@ public class ModelSchema {
|
||||
}
|
||||
}
|
||||
|
||||
public Set<Long> getModelClusterSet() {
|
||||
if (CollectionUtils.isEmpty(modelRelas)) {
|
||||
return Sets.newHashSet();
|
||||
public TimeDefaultConfig getTagTypeTimeDefaultConfig() {
|
||||
if (queryConfig == null) {
|
||||
return null;
|
||||
}
|
||||
Set<Long> modelClusterSet = new HashSet<>();
|
||||
modelRelas.forEach(modelRela -> {
|
||||
modelClusterSet.add(modelRela.getToModelId());
|
||||
modelClusterSet.add(modelRela.getFromModelId());
|
||||
});
|
||||
return modelClusterSet;
|
||||
if (queryConfig.getTagTypeDefaultConfig() == null) {
|
||||
return null;
|
||||
}
|
||||
return queryConfig.getTagTypeDefaultConfig().getTimeDefaultConfig();
|
||||
}
|
||||
|
||||
public TimeDefaultConfig getMetricTypeTimeDefaultConfig() {
|
||||
if (queryConfig == null) {
|
||||
return null;
|
||||
}
|
||||
if (queryConfig.getMetricTypeDefaultConfig() == null) {
|
||||
return null;
|
||||
}
|
||||
return queryConfig.getMetricTypeDefaultConfig().getTimeDefaultConfig();
|
||||
}
|
||||
|
||||
public TagTypeDefaultConfig getTagTypeDefaultConfig() {
|
||||
if (queryConfig == null) {
|
||||
return null;
|
||||
}
|
||||
return queryConfig.getTagTypeDefaultConfig();
|
||||
}
|
||||
|
||||
}
|
||||
@@ -1,7 +1,6 @@
|
||||
package com.tencent.supersonic.chat.api.pojo.request;
|
||||
|
||||
|
||||
import com.tencent.supersonic.common.pojo.Constants;
|
||||
import lombok.Data;
|
||||
|
||||
import java.util.ArrayList;
|
||||
@@ -13,26 +12,5 @@ public class ChatDefaultConfigReq {
|
||||
private List<Long> dimensionIds = new ArrayList<>();
|
||||
private List<Long> metricIds = new ArrayList<>();
|
||||
|
||||
/**
|
||||
* default time span unit
|
||||
*/
|
||||
private Integer unit = 1;
|
||||
|
||||
/**
|
||||
* default time type: day
|
||||
* DAY, WEEK, MONTH, YEAR
|
||||
*/
|
||||
private String period = Constants.DAY;
|
||||
|
||||
private TimeMode timeMode = TimeMode.LAST;
|
||||
|
||||
public enum TimeMode {
|
||||
/**
|
||||
* date mode
|
||||
* LAST - a certain time
|
||||
* RECENT - a period time
|
||||
*/
|
||||
LAST, RECENT
|
||||
}
|
||||
|
||||
}
|
||||
@@ -1,23 +0,0 @@
|
||||
package com.tencent.supersonic.chat.api.pojo.request;
|
||||
|
||||
import lombok.Data;
|
||||
import lombok.NoArgsConstructor;
|
||||
import lombok.ToString;
|
||||
|
||||
import javax.validation.constraints.NotNull;
|
||||
import java.util.List;
|
||||
|
||||
import static java.time.LocalDate.now;
|
||||
|
||||
@ToString
|
||||
@Data
|
||||
@NoArgsConstructor
|
||||
public class DictLatestTaskReq {
|
||||
|
||||
@NotNull
|
||||
private Long modelId;
|
||||
|
||||
private List<Long> dimIds;
|
||||
|
||||
private String createdAt = now().plusDays(-4).toString();
|
||||
}
|
||||
@@ -1,20 +0,0 @@
|
||||
package com.tencent.supersonic.chat.api.pojo.request;
|
||||
|
||||
import com.tencent.supersonic.common.pojo.enums.TaskStatusEnum;
|
||||
import lombok.Data;
|
||||
import lombok.ToString;
|
||||
|
||||
@ToString
|
||||
@Data
|
||||
public class DictTaskFilterReq {
|
||||
|
||||
private Long id;
|
||||
|
||||
private String name;
|
||||
|
||||
private String createdBy;
|
||||
|
||||
private String createdAt;
|
||||
|
||||
private TaskStatusEnum status;
|
||||
}
|
||||
@@ -13,7 +13,7 @@ public class PluginQueryReq {
|
||||
|
||||
private String type;
|
||||
|
||||
private String model;
|
||||
private String view;
|
||||
|
||||
private String pattern;
|
||||
|
||||
|
||||
@@ -2,7 +2,7 @@ package com.tencent.supersonic.chat.api.pojo.request;
|
||||
|
||||
|
||||
import com.tencent.supersonic.auth.api.authentication.pojo.User;
|
||||
import com.tencent.supersonic.chat.api.pojo.SchemaElement;
|
||||
import com.tencent.supersonic.headless.api.pojo.SchemaElement;
|
||||
import com.tencent.supersonic.common.pojo.DateConf;
|
||||
import java.util.HashSet;
|
||||
import java.util.Set;
|
||||
|
||||
@@ -7,7 +7,7 @@ import lombok.Data;
|
||||
public class QueryReq {
|
||||
private String queryText;
|
||||
private Integer chatId;
|
||||
private Long modelId;
|
||||
private Long viewId;
|
||||
private User user;
|
||||
private QueryFilters queryFilters;
|
||||
private boolean saveAnswer = true;
|
||||
|
||||
@@ -18,7 +18,7 @@ public class SimilarQueryReq {
|
||||
|
||||
private String queryText;
|
||||
|
||||
private String modelId;
|
||||
private Long viewId;
|
||||
|
||||
private Integer agentId;
|
||||
|
||||
|
||||
@@ -1,9 +1,9 @@
|
||||
package com.tencent.supersonic.chat.api.pojo.response;
|
||||
|
||||
|
||||
import com.tencent.supersonic.chat.api.pojo.SchemaElement;
|
||||
import com.tencent.supersonic.chat.api.pojo.request.ChatDefaultConfigReq.TimeMode;
|
||||
import com.tencent.supersonic.headless.api.pojo.SchemaElement;
|
||||
import com.tencent.supersonic.common.pojo.Constants;
|
||||
import com.tencent.supersonic.common.pojo.enums.TimeMode;
|
||||
import lombok.Data;
|
||||
|
||||
import java.util.List;
|
||||
@@ -21,7 +21,7 @@ public class ChatDefaultRichConfigResp {
|
||||
private Integer unit = 1;
|
||||
|
||||
/**
|
||||
* default time type: day
|
||||
* default time type:
|
||||
* DAY, WEEK, MONTH, YEAR
|
||||
*/
|
||||
private String period = Constants.DAY;
|
||||
|
||||
@@ -1,8 +1,12 @@
|
||||
package com.tencent.supersonic.chat.api.pojo.response;
|
||||
|
||||
import lombok.AllArgsConstructor;
|
||||
import lombok.Data;
|
||||
import lombok.NoArgsConstructor;
|
||||
|
||||
@Data
|
||||
@AllArgsConstructor
|
||||
@NoArgsConstructor
|
||||
public class DataInfo {
|
||||
|
||||
private Integer itemId;
|
||||
|
||||
@@ -1,14 +1,16 @@
|
||||
package com.tencent.supersonic.chat.api.pojo.response;
|
||||
|
||||
import lombok.Data;
|
||||
|
||||
import java.util.ArrayList;
|
||||
import java.util.List;
|
||||
import lombok.Data;
|
||||
|
||||
@Data
|
||||
public class EntityInfo {
|
||||
|
||||
private ModelInfo modelInfo = new ModelInfo();
|
||||
private ViewInfo viewInfo = new ViewInfo();
|
||||
private List<DataInfo> dimensions = new ArrayList<>();
|
||||
private List<DataInfo> metrics = new ArrayList<>();
|
||||
private String entityId;
|
||||
|
||||
}
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
package com.tencent.supersonic.chat.api.pojo.response;
|
||||
|
||||
import com.tencent.supersonic.chat.api.pojo.SchemaElement;
|
||||
import com.tencent.supersonic.headless.api.pojo.SchemaElement;
|
||||
|
||||
import java.util.List;
|
||||
import lombok.Data;
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
package com.tencent.supersonic.chat.api.pojo.response;
|
||||
|
||||
import com.tencent.supersonic.chat.api.pojo.SchemaElement;
|
||||
import com.tencent.supersonic.headless.api.pojo.SchemaElement;
|
||||
import com.tencent.supersonic.chat.api.pojo.SemanticParseInfo;
|
||||
import com.tencent.supersonic.common.pojo.QueryAuthorization;
|
||||
import com.tencent.supersonic.common.pojo.QueryColumn;
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
package com.tencent.supersonic.chat.api.pojo.response;
|
||||
|
||||
import com.tencent.supersonic.chat.api.pojo.SchemaElement;
|
||||
import com.tencent.supersonic.headless.api.pojo.SchemaElement;
|
||||
import lombok.Data;
|
||||
|
||||
import java.util.List;
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
package com.tencent.supersonic.chat.api.pojo.response;
|
||||
|
||||
import com.tencent.supersonic.chat.api.pojo.SchemaElementType;
|
||||
import com.tencent.supersonic.headless.api.pojo.SchemaElementType;
|
||||
import java.util.Objects;
|
||||
import lombok.Builder;
|
||||
import lombok.Data;
|
||||
|
||||
@@ -6,7 +6,7 @@ import java.io.Serializable;
|
||||
import java.util.List;
|
||||
|
||||
@Data
|
||||
public class ModelInfo extends DataInfo implements Serializable {
|
||||
public class ViewInfo extends DataInfo implements Serializable {
|
||||
|
||||
private List<String> words;
|
||||
private String primaryKey;
|
||||
@@ -21,70 +21,6 @@
|
||||
<groupId>org.springframework</groupId>
|
||||
<artifactId>spring-context</artifactId>
|
||||
</dependency>
|
||||
|
||||
<dependency>
|
||||
<groupId>com.hankcs</groupId>
|
||||
<artifactId>hanlp</artifactId>
|
||||
<version>${hanlp.version}</version>
|
||||
</dependency>
|
||||
|
||||
<dependency>
|
||||
<groupId>org.apache.hadoop</groupId>
|
||||
<artifactId>hadoop-client</artifactId>
|
||||
<version>${hadoop.version}</version>
|
||||
<exclusions>
|
||||
<exclusion>
|
||||
<groupId>org.apache.zookeeper</groupId>
|
||||
<artifactId>zookeeper</artifactId>
|
||||
</exclusion>
|
||||
<exclusion>
|
||||
<groupId>javax.servlet</groupId>
|
||||
<artifactId>servlet-api</artifactId>
|
||||
</exclusion>
|
||||
</exclusions>
|
||||
</dependency>
|
||||
<dependency>
|
||||
<groupId>org.apache.hadoop</groupId>
|
||||
<artifactId>hadoop-hdfs</artifactId>
|
||||
<version>${hadoop.version}</version>
|
||||
<exclusions>
|
||||
<exclusion>
|
||||
<groupId>org.apache.zookeeper</groupId>
|
||||
<artifactId>zookeeper</artifactId>
|
||||
</exclusion>
|
||||
<exclusion>
|
||||
<groupId>javax.servlet</groupId>
|
||||
<artifactId>servlet-api</artifactId>
|
||||
</exclusion>
|
||||
</exclusions>
|
||||
</dependency>
|
||||
<dependency>
|
||||
<groupId>org.apache.hadoop</groupId>
|
||||
<artifactId>hadoop-common</artifactId>
|
||||
<version>${hadoop.version}</version>
|
||||
<exclusions>
|
||||
<exclusion>
|
||||
<groupId>org.slf4j</groupId>
|
||||
<artifactId>slf4j-log4j12</artifactId>
|
||||
</exclusion>
|
||||
<exclusion>
|
||||
<groupId>log4j</groupId>
|
||||
<artifactId>log4j</artifactId>
|
||||
</exclusion>
|
||||
<exclusion>
|
||||
<groupId>org.apache.zookeeper</groupId>
|
||||
<artifactId>zookeeper</artifactId>
|
||||
</exclusion>
|
||||
<exclusion>
|
||||
<groupId>org.apache.curator</groupId>
|
||||
<artifactId>*</artifactId>
|
||||
</exclusion>
|
||||
<exclusion>
|
||||
<groupId>javax.servlet</groupId>
|
||||
<artifactId>servlet-api</artifactId>
|
||||
</exclusion>
|
||||
</exclusions>
|
||||
</dependency>
|
||||
<dependency>
|
||||
<groupId>org.testng</groupId>
|
||||
<artifactId>testng</artifactId>
|
||||
|
||||
@@ -4,6 +4,9 @@ package com.tencent.supersonic.chat.core.agent;
|
||||
import com.alibaba.fastjson.JSONObject;
|
||||
import com.google.common.collect.Lists;
|
||||
import com.tencent.supersonic.common.pojo.RecordInfo;
|
||||
import lombok.Data;
|
||||
import org.springframework.util.CollectionUtils;
|
||||
|
||||
import java.util.Collection;
|
||||
import java.util.HashSet;
|
||||
import java.util.List;
|
||||
@@ -11,8 +14,6 @@ import java.util.Map;
|
||||
import java.util.Objects;
|
||||
import java.util.Set;
|
||||
import java.util.stream.Collectors;
|
||||
import lombok.Data;
|
||||
import org.springframework.util.CollectionUtils;
|
||||
|
||||
@Data
|
||||
public class Agent extends RecordInfo {
|
||||
@@ -51,8 +52,8 @@ public class Agent extends RecordInfo {
|
||||
return enableSearch != null && enableSearch == 1;
|
||||
}
|
||||
|
||||
public static boolean containsAllModel(Set<Long> detectModelIds) {
|
||||
return !CollectionUtils.isEmpty(detectModelIds) && detectModelIds.contains(-1L);
|
||||
public static boolean containsAllModel(Set<Long> detectViewIds) {
|
||||
return !CollectionUtils.isEmpty(detectViewIds) && detectViewIds.contains(-1L);
|
||||
}
|
||||
|
||||
public List<NL2SQLTool> getParserTools(AgentToolType agentToolType) {
|
||||
@@ -64,12 +65,12 @@ public class Agent extends RecordInfo {
|
||||
.collect(Collectors.toList());
|
||||
}
|
||||
|
||||
public Set<Long> getModelIds(AgentToolType agentToolType) {
|
||||
public Set<Long> getViewIds(AgentToolType agentToolType) {
|
||||
List<NL2SQLTool> commonAgentTools = getParserTools(agentToolType);
|
||||
if (CollectionUtils.isEmpty(commonAgentTools)) {
|
||||
return new HashSet<>();
|
||||
}
|
||||
return commonAgentTools.stream().map(NL2SQLTool::getModelIds)
|
||||
return commonAgentTools.stream().map(NL2SQLTool::getViewIds)
|
||||
.filter(modelIds -> !CollectionUtils.isEmpty(modelIds))
|
||||
.flatMap(Collection::stream)
|
||||
.collect(Collectors.toSet());
|
||||
|
||||
@@ -1,8 +1,25 @@
|
||||
package com.tencent.supersonic.chat.core.agent;
|
||||
|
||||
import java.util.HashMap;
|
||||
import java.util.Map;
|
||||
|
||||
public enum AgentToolType {
|
||||
NL2SQL_RULE,
|
||||
NL2SQL_LLM,
|
||||
PLUGIN,
|
||||
ANALYTICS
|
||||
NL2SQL_RULE("基于规则Text-to-SQL"),
|
||||
NL2SQL_LLM("基于大模型Text-to-SQL"),
|
||||
PLUGIN("第三方插件");
|
||||
|
||||
private String title;
|
||||
|
||||
AgentToolType(String title) {
|
||||
this.title = title;
|
||||
}
|
||||
|
||||
public static Map<AgentToolType, String> getToolTypes() {
|
||||
Map<AgentToolType, String> map = new HashMap<>();
|
||||
map.put(NL2SQL_RULE, NL2SQL_RULE.title);
|
||||
map.put(NL2SQL_LLM, NL2SQL_LLM.title);
|
||||
map.put(PLUGIN, PLUGIN.title);
|
||||
return map;
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
@@ -1,11 +0,0 @@
|
||||
package com.tencent.supersonic.chat.core.agent;
|
||||
|
||||
import lombok.Data;
|
||||
|
||||
|
||||
@Data
|
||||
public class DataAnalyticsTool extends AgentTool {
|
||||
|
||||
private Long modelId;
|
||||
|
||||
}
|
||||
@@ -1,16 +1,17 @@
|
||||
package com.tencent.supersonic.chat.core.agent;
|
||||
|
||||
|
||||
import java.util.List;
|
||||
import lombok.AllArgsConstructor;
|
||||
import lombok.Data;
|
||||
import lombok.NoArgsConstructor;
|
||||
|
||||
import java.util.List;
|
||||
|
||||
@Data
|
||||
@NoArgsConstructor
|
||||
@AllArgsConstructor
|
||||
public class NL2SQLTool extends AgentTool {
|
||||
|
||||
protected List<Long> modelIds;
|
||||
protected List<Long> viewIds;
|
||||
|
||||
}
|
||||
@@ -15,7 +15,7 @@ public class RuleParserTool extends NL2SQLTool {
|
||||
private List<String> queryTypes;
|
||||
|
||||
public boolean isContainsAllModel() {
|
||||
return CollectionUtils.isNotEmpty(modelIds) && modelIds.contains(-1L);
|
||||
return CollectionUtils.isNotEmpty(viewIds) && viewIds.contains(-1L);
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
@@ -1,7 +1,9 @@
|
||||
package com.tencent.supersonic.chat.core.config;
|
||||
|
||||
import com.tencent.supersonic.chat.core.utils.HanlpHelper;
|
||||
import com.tencent.supersonic.headless.core.knowledge.helper.HanlpHelper;
|
||||
|
||||
import java.io.FileNotFoundException;
|
||||
|
||||
import lombok.Data;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.springframework.beans.factory.annotation.Value;
|
||||
@@ -11,7 +13,7 @@ import org.springframework.context.annotation.Configuration;
|
||||
@Data
|
||||
@Configuration
|
||||
@Slf4j
|
||||
public class LocalFileConfig {
|
||||
public class ChatLocalFileConfig {
|
||||
|
||||
|
||||
@Value("${dict.directory.latest:/data/dictionary/custom}")
|
||||
@@ -1,8 +1,8 @@
|
||||
package com.tencent.supersonic.chat.core.config;
|
||||
|
||||
|
||||
import com.tencent.supersonic.headless.api.response.DimSchemaResp;
|
||||
import com.tencent.supersonic.headless.api.response.MetricSchemaResp;
|
||||
import com.tencent.supersonic.headless.api.pojo.response.DimSchemaResp;
|
||||
import com.tencent.supersonic.headless.api.pojo.response.MetricSchemaResp;
|
||||
import java.util.List;
|
||||
import lombok.Data;
|
||||
|
||||
|
||||
@@ -52,7 +52,7 @@ public class OptimizationConfig {
|
||||
@Value("${embedding.mapper.round.number:10}")
|
||||
private int embeddingMapperRoundNumber;
|
||||
|
||||
@Value("${embedding.mapper.distance.threshold:0.58}")
|
||||
@Value("${embedding.mapper.distance.threshold:0.01}")
|
||||
private Double embeddingMapperDistanceThreshold;
|
||||
|
||||
@Value("${s2SQL.linking.value.switch:true}")
|
||||
@@ -73,9 +73,6 @@ public class OptimizationConfig {
|
||||
@Value("${text2sql.self.consistency.num:5}")
|
||||
private int text2sqlSelfConsistencyNum;
|
||||
|
||||
@Value("${text2sql.collection.name:text2dsl_agent_collection}")
|
||||
private String text2sqlCollectionName;
|
||||
|
||||
@Value("${parse.show.count:3}")
|
||||
private Integer parseShowCount;
|
||||
|
||||
|
||||
@@ -1,14 +1,21 @@
|
||||
package com.tencent.supersonic.chat.core.corrector;
|
||||
|
||||
import com.tencent.supersonic.chat.core.pojo.QueryContext;
|
||||
import com.tencent.supersonic.chat.api.pojo.SchemaElement;
|
||||
import com.tencent.supersonic.common.util.ContextUtils;
|
||||
import com.tencent.supersonic.headless.api.pojo.SchemaElement;
|
||||
import com.tencent.supersonic.chat.api.pojo.SemanticParseInfo;
|
||||
import com.tencent.supersonic.chat.api.pojo.SemanticSchema;
|
||||
import com.tencent.supersonic.chat.core.pojo.QueryContext;
|
||||
import com.tencent.supersonic.common.pojo.enums.AggregateTypeEnum;
|
||||
import com.tencent.supersonic.common.pojo.enums.TimeDimensionEnum;
|
||||
import com.tencent.supersonic.common.util.jsqlparser.SqlParserAddHelper;
|
||||
import com.tencent.supersonic.common.util.jsqlparser.SqlParserSelectFunctionHelper;
|
||||
import com.tencent.supersonic.common.util.jsqlparser.SqlParserSelectHelper;
|
||||
import com.tencent.supersonic.common.util.jsqlparser.SqlAddHelper;
|
||||
import com.tencent.supersonic.common.util.jsqlparser.SqlSelectFunctionHelper;
|
||||
import com.tencent.supersonic.common.util.jsqlparser.SqlSelectHelper;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.apache.commons.lang3.StringUtils;
|
||||
import org.apache.commons.lang3.tuple.Pair;
|
||||
import org.springframework.core.env.Environment;
|
||||
import org.springframework.util.CollectionUtils;
|
||||
|
||||
import java.util.ArrayList;
|
||||
import java.util.HashSet;
|
||||
import java.util.List;
|
||||
@@ -16,10 +23,6 @@ import java.util.Map;
|
||||
import java.util.Objects;
|
||||
import java.util.Set;
|
||||
import java.util.stream.Collectors;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.apache.commons.lang3.StringUtils;
|
||||
import org.apache.commons.lang3.tuple.Pair;
|
||||
import org.springframework.util.CollectionUtils;
|
||||
|
||||
/**
|
||||
* basic semantic correction functionality, offering common methods and an
|
||||
@@ -42,7 +45,7 @@ public abstract class BaseSemanticCorrector implements SemanticCorrector {
|
||||
|
||||
public abstract void doCorrect(QueryContext queryContext, SemanticParseInfo semanticParseInfo);
|
||||
|
||||
protected Map<String, String> getFieldNameMap(QueryContext queryContext, Set<Long> modelIds) {
|
||||
protected Map<String, String> getFieldNameMap(QueryContext queryContext, Long viewId) {
|
||||
|
||||
SemanticSchema semanticSchema = queryContext.getSemanticSchema();
|
||||
|
||||
@@ -52,7 +55,7 @@ public abstract class BaseSemanticCorrector implements SemanticCorrector {
|
||||
|
||||
// support fieldName and field alias
|
||||
Map<String, String> result = dbAllFields.stream()
|
||||
.filter(entry -> modelIds.contains(entry.getModel()))
|
||||
.filter(entry -> viewId.equals(entry.getView()))
|
||||
.flatMap(schemaElement -> {
|
||||
Set<String> elements = new HashSet<>();
|
||||
elements.add(schemaElement.getName());
|
||||
@@ -74,14 +77,20 @@ public abstract class BaseSemanticCorrector implements SemanticCorrector {
|
||||
}
|
||||
|
||||
protected void addFieldsToSelect(SemanticParseInfo semanticParseInfo, String correctS2SQL) {
|
||||
Set<String> selectFields = new HashSet<>(SqlParserSelectHelper.getSelectFields(correctS2SQL));
|
||||
Set<String> needAddFields = new HashSet<>(SqlParserSelectHelper.getGroupByFields(correctS2SQL));
|
||||
needAddFields.addAll(SqlParserSelectHelper.getOrderByFields(correctS2SQL));
|
||||
Set<String> selectFields = new HashSet<>(SqlSelectHelper.getSelectFields(correctS2SQL));
|
||||
Set<String> needAddFields = new HashSet<>(SqlSelectHelper.getGroupByFields(correctS2SQL));
|
||||
|
||||
//decide whether add order by expression field to select
|
||||
Environment environment = ContextUtils.getBean(Environment.class);
|
||||
String correctorAdditionalInfo = environment.getProperty("corrector.additional.information");
|
||||
if (StringUtils.isNotBlank(correctorAdditionalInfo) && Boolean.parseBoolean(correctorAdditionalInfo)) {
|
||||
needAddFields.addAll(SqlSelectHelper.getOrderByFields(correctS2SQL));
|
||||
}
|
||||
|
||||
// If there is no aggregate function in the S2SQL statement and
|
||||
// there is a data field in 'WHERE' statement, add the field to the 'SELECT' statement.
|
||||
if (!SqlParserSelectFunctionHelper.hasAggregateFunction(correctS2SQL)) {
|
||||
List<String> whereFields = SqlParserSelectHelper.getWhereFields(correctS2SQL);
|
||||
if (!SqlSelectFunctionHelper.hasAggregateFunction(correctS2SQL)) {
|
||||
List<String> whereFields = SqlSelectHelper.getWhereFields(correctS2SQL);
|
||||
List<String> timeChNameList = TimeDimensionEnum.getChNameList();
|
||||
Set<String> timeFields = whereFields.stream().filter(field -> timeChNameList.contains(field))
|
||||
.collect(Collectors.toSet());
|
||||
@@ -93,16 +102,15 @@ public abstract class BaseSemanticCorrector implements SemanticCorrector {
|
||||
}
|
||||
|
||||
needAddFields.removeAll(selectFields);
|
||||
String replaceFields = SqlParserAddHelper.addFieldsToSelect(correctS2SQL, new ArrayList<>(needAddFields));
|
||||
String replaceFields = SqlAddHelper.addFieldsToSelect(correctS2SQL, new ArrayList<>(needAddFields));
|
||||
semanticParseInfo.getSqlInfo().setCorrectS2SQL(replaceFields);
|
||||
}
|
||||
|
||||
protected void addAggregateToMetric(QueryContext queryContext, SemanticParseInfo semanticParseInfo) {
|
||||
//add aggregate to all metric
|
||||
String correctS2SQL = semanticParseInfo.getSqlInfo().getCorrectS2SQL();
|
||||
Set<Long> modelIds = semanticParseInfo.getModel().getModelIds();
|
||||
|
||||
List<SchemaElement> metrics = getMetricElements(queryContext, modelIds);
|
||||
Long viewId = semanticParseInfo.getView().getView();
|
||||
List<SchemaElement> metrics = getMetricElements(queryContext, viewId);
|
||||
|
||||
Map<String, String> metricToAggregate = metrics.stream()
|
||||
.map(schemaElement -> {
|
||||
@@ -123,13 +131,28 @@ public abstract class BaseSemanticCorrector implements SemanticCorrector {
|
||||
if (CollectionUtils.isEmpty(metricToAggregate)) {
|
||||
return;
|
||||
}
|
||||
String aggregateSql = SqlParserAddHelper.addAggregateToField(correctS2SQL, metricToAggregate);
|
||||
String aggregateSql = SqlAddHelper.addAggregateToField(correctS2SQL, metricToAggregate);
|
||||
semanticParseInfo.getSqlInfo().setCorrectS2SQL(aggregateSql);
|
||||
}
|
||||
|
||||
protected List<SchemaElement> getMetricElements(QueryContext queryContext, Set<Long> modelIds) {
|
||||
protected List<SchemaElement> getMetricElements(QueryContext queryContext, Long viewId) {
|
||||
SemanticSchema semanticSchema = queryContext.getSemanticSchema();
|
||||
return semanticSchema.getMetrics(modelIds);
|
||||
return semanticSchema.getMetrics(viewId);
|
||||
}
|
||||
|
||||
protected Set<String> getDimensions(Long viewId, SemanticSchema semanticSchema) {
|
||||
Set<String> dimensions = semanticSchema.getDimensions(viewId).stream()
|
||||
.flatMap(
|
||||
schemaElement -> {
|
||||
Set<String> elements = new HashSet<>();
|
||||
elements.add(schemaElement.getName());
|
||||
if (!CollectionUtils.isEmpty(schemaElement.getAlias())) {
|
||||
elements.addAll(schemaElement.getAlias());
|
||||
}
|
||||
return elements.stream();
|
||||
}
|
||||
).collect(Collectors.toSet());
|
||||
dimensions.add(TimeDimensionEnum.DAY.getChName());
|
||||
return dimensions;
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,21 +0,0 @@
|
||||
package com.tencent.supersonic.chat.core.corrector;
|
||||
|
||||
import com.tencent.supersonic.chat.core.pojo.QueryContext;
|
||||
import com.tencent.supersonic.chat.api.pojo.SemanticParseInfo;
|
||||
import com.tencent.supersonic.common.util.jsqlparser.SqlParserReplaceHelper;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
/**
|
||||
* Perform SQL corrections on the "From" section in S2SQL.
|
||||
*/
|
||||
@Slf4j
|
||||
public class FromCorrector extends BaseSemanticCorrector {
|
||||
|
||||
@Override
|
||||
public void doCorrect(QueryContext queryContext, SemanticParseInfo semanticParseInfo) {
|
||||
String modelName = semanticParseInfo.getModel().getName();
|
||||
String correctSql = SqlParserReplaceHelper
|
||||
.replaceTable(semanticParseInfo.getSqlInfo().getCorrectS2SQL(), modelName);
|
||||
semanticParseInfo.getSqlInfo().setCorrectS2SQL(correctSql);
|
||||
}
|
||||
|
||||
}
|
||||
@@ -0,0 +1,31 @@
|
||||
package com.tencent.supersonic.chat.core.corrector;
|
||||
|
||||
import com.tencent.supersonic.chat.api.pojo.SemanticParseInfo;
|
||||
import com.tencent.supersonic.chat.core.pojo.QueryContext;
|
||||
import java.util.ArrayList;
|
||||
import java.util.List;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
|
||||
/**
|
||||
* Correcting SQL syntax, primarily including fixes to select, where, groupBy, and Having clauses
|
||||
*/
|
||||
@Slf4j
|
||||
public class GrammarCorrector extends BaseSemanticCorrector {
|
||||
|
||||
private List<BaseSemanticCorrector> correctors;
|
||||
|
||||
public GrammarCorrector() {
|
||||
correctors = new ArrayList<>();
|
||||
correctors.add(new SelectCorrector());
|
||||
correctors.add(new WhereCorrector());
|
||||
correctors.add(new GroupByCorrector());
|
||||
correctors.add(new HavingCorrector());
|
||||
}
|
||||
|
||||
@Override
|
||||
public void doCorrect(QueryContext queryContext, SemanticParseInfo semanticParseInfo) {
|
||||
for (BaseSemanticCorrector corrector : correctors) {
|
||||
corrector.correct(queryContext, semanticParseInfo);
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -1,14 +1,21 @@
|
||||
package com.tencent.supersonic.chat.core.corrector;
|
||||
|
||||
import com.tencent.supersonic.chat.core.pojo.QueryContext;
|
||||
import com.tencent.supersonic.chat.api.pojo.SemanticParseInfo;
|
||||
import com.tencent.supersonic.chat.api.pojo.SemanticSchema;
|
||||
import com.tencent.supersonic.chat.api.pojo.response.SqlInfo;
|
||||
import com.tencent.supersonic.chat.core.pojo.QueryContext;
|
||||
import com.tencent.supersonic.common.pojo.enums.TimeDimensionEnum;
|
||||
import com.tencent.supersonic.common.util.jsqlparser.SqlParserAddHelper;
|
||||
import com.tencent.supersonic.common.util.jsqlparser.SqlParserSelectHelper;
|
||||
import java.util.HashSet;
|
||||
import com.tencent.supersonic.common.util.ContextUtils;
|
||||
import com.tencent.supersonic.common.util.jsqlparser.SqlAddHelper;
|
||||
import com.tencent.supersonic.common.util.jsqlparser.SqlSelectHelper;
|
||||
import com.tencent.supersonic.headless.api.pojo.Dim;
|
||||
import com.tencent.supersonic.headless.api.pojo.response.ModelResp;
|
||||
import com.tencent.supersonic.headless.api.pojo.response.ViewResp;
|
||||
import com.tencent.supersonic.headless.server.pojo.MetaFilter;
|
||||
import com.tencent.supersonic.headless.server.service.ModelService;
|
||||
import com.tencent.supersonic.headless.server.service.ViewService;
|
||||
import java.util.List;
|
||||
import java.util.Objects;
|
||||
import java.util.Set;
|
||||
import java.util.stream.Collectors;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
@@ -22,47 +29,67 @@ public class GroupByCorrector extends BaseSemanticCorrector {
|
||||
|
||||
@Override
|
||||
public void doCorrect(QueryContext queryContext, SemanticParseInfo semanticParseInfo) {
|
||||
|
||||
Boolean needAddGroupBy = needAddGroupBy(queryContext, semanticParseInfo);
|
||||
if (!needAddGroupBy) {
|
||||
return;
|
||||
}
|
||||
addGroupByFields(queryContext, semanticParseInfo);
|
||||
}
|
||||
|
||||
private Boolean needAddGroupBy(QueryContext queryContext, SemanticParseInfo semanticParseInfo) {
|
||||
Long viewId = semanticParseInfo.getViewId();
|
||||
ViewService viewService = ContextUtils.getBean(ViewService.class);
|
||||
ModelService modelService = ContextUtils.getBean(ModelService.class);
|
||||
ViewResp viewResp = viewService.getView(viewId);
|
||||
List<Long> modelIds = viewResp.getViewDetail().getViewModelConfigs().stream().map(config -> config.getId())
|
||||
.collect(Collectors.toList());
|
||||
MetaFilter metaFilter = new MetaFilter();
|
||||
metaFilter.setIds(modelIds);
|
||||
List<ModelResp> modelRespList = modelService.getModelList(metaFilter);
|
||||
for (ModelResp modelResp : modelRespList) {
|
||||
List<Dim> dimList = modelResp.getModelDetail().getDimensions();
|
||||
for (Dim dim : dimList) {
|
||||
if (Objects.nonNull(dim.getTypeParams()) && dim.getTypeParams().getTimeGranularity().equals("none")) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
}
|
||||
//add dimension group by
|
||||
SqlInfo sqlInfo = semanticParseInfo.getSqlInfo();
|
||||
String correctS2SQL = sqlInfo.getCorrectS2SQL();
|
||||
SemanticSchema semanticSchema = queryContext.getSemanticSchema();
|
||||
// check has distinct
|
||||
if (SqlSelectHelper.hasDistinct(correctS2SQL)) {
|
||||
log.info("not add group by ,exist distinct in correctS2SQL:{}", correctS2SQL);
|
||||
return false;
|
||||
}
|
||||
//add alias field name
|
||||
Set<String> dimensions = getDimensions(viewId, semanticSchema);
|
||||
List<String> selectFields = SqlSelectHelper.getSelectFields(correctS2SQL);
|
||||
if (CollectionUtils.isEmpty(selectFields) || CollectionUtils.isEmpty(dimensions)) {
|
||||
return false;
|
||||
}
|
||||
// if only date in select not add group by.
|
||||
if (selectFields.size() == 1 && selectFields.contains(TimeDimensionEnum.DAY.getChName())) {
|
||||
return false;
|
||||
}
|
||||
if (SqlSelectHelper.hasGroupBy(correctS2SQL)) {
|
||||
log.info("not add group by ,exist group by in correctS2SQL:{}", correctS2SQL);
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
private void addGroupByFields(QueryContext queryContext, SemanticParseInfo semanticParseInfo) {
|
||||
Set<Long> modelIds = semanticParseInfo.getModel().getModelIds();
|
||||
|
||||
Long viewId = semanticParseInfo.getViewId();
|
||||
//add dimension group by
|
||||
SqlInfo sqlInfo = semanticParseInfo.getSqlInfo();
|
||||
String correctS2SQL = sqlInfo.getCorrectS2SQL();
|
||||
SemanticSchema semanticSchema = queryContext.getSemanticSchema();
|
||||
//add alias field name
|
||||
Set<String> dimensions = semanticSchema.getDimensions(modelIds).stream()
|
||||
.flatMap(
|
||||
schemaElement -> {
|
||||
Set<String> elements = new HashSet<>();
|
||||
elements.add(schemaElement.getName());
|
||||
if (!CollectionUtils.isEmpty(schemaElement.getAlias())) {
|
||||
elements.addAll(schemaElement.getAlias());
|
||||
}
|
||||
return elements.stream();
|
||||
}
|
||||
).collect(Collectors.toSet());
|
||||
dimensions.add(TimeDimensionEnum.DAY.getChName());
|
||||
|
||||
List<String> selectFields = SqlParserSelectHelper.getSelectFields(correctS2SQL);
|
||||
|
||||
if (CollectionUtils.isEmpty(selectFields) || CollectionUtils.isEmpty(dimensions)) {
|
||||
return;
|
||||
}
|
||||
// if only date in select not add group by.
|
||||
if (selectFields.size() == 1 && selectFields.contains(TimeDimensionEnum.DAY.getChName())) {
|
||||
return;
|
||||
}
|
||||
if (SqlParserSelectHelper.hasGroupBy(correctS2SQL)) {
|
||||
log.info("not add group by ,exist group by in correctS2SQL:{}", correctS2SQL);
|
||||
return;
|
||||
}
|
||||
|
||||
List<String> aggregateFields = SqlParserSelectHelper.getAggregateFields(correctS2SQL);
|
||||
Set<String> dimensions = getDimensions(viewId, semanticSchema);
|
||||
List<String> selectFields = SqlSelectHelper.getSelectFields(correctS2SQL);
|
||||
List<String> aggregateFields = SqlSelectHelper.getAggregateFields(correctS2SQL);
|
||||
Set<String> groupByFields = selectFields.stream()
|
||||
.filter(field -> dimensions.contains(field))
|
||||
.filter(field -> {
|
||||
@@ -72,13 +99,12 @@ public class GroupByCorrector extends BaseSemanticCorrector {
|
||||
return true;
|
||||
})
|
||||
.collect(Collectors.toSet());
|
||||
semanticParseInfo.getSqlInfo().setCorrectS2SQL(SqlParserAddHelper.addGroupBy(correctS2SQL, groupByFields));
|
||||
|
||||
semanticParseInfo.getSqlInfo().setCorrectS2SQL(SqlAddHelper.addGroupBy(correctS2SQL, groupByFields));
|
||||
addAggregate(queryContext, semanticParseInfo);
|
||||
}
|
||||
|
||||
private void addAggregate(QueryContext queryContext, SemanticParseInfo semanticParseInfo) {
|
||||
List<String> sqlGroupByFields = SqlParserSelectHelper.getGroupByFields(
|
||||
List<String> sqlGroupByFields = SqlSelectHelper.getGroupByFields(
|
||||
semanticParseInfo.getSqlInfo().getCorrectS2SQL());
|
||||
if (CollectionUtils.isEmpty(sqlGroupByFields)) {
|
||||
return;
|
||||
|
||||
@@ -1,17 +1,21 @@
|
||||
package com.tencent.supersonic.chat.core.corrector;
|
||||
|
||||
import com.tencent.supersonic.chat.core.pojo.QueryContext;
|
||||
import com.tencent.supersonic.chat.api.pojo.SemanticParseInfo;
|
||||
import com.tencent.supersonic.chat.api.pojo.SemanticSchema;
|
||||
import com.tencent.supersonic.common.util.jsqlparser.SqlParserAddHelper;
|
||||
import com.tencent.supersonic.common.util.jsqlparser.SqlParserSelectFunctionHelper;
|
||||
import com.tencent.supersonic.common.util.jsqlparser.SqlParserSelectHelper;
|
||||
import com.tencent.supersonic.chat.core.pojo.QueryContext;
|
||||
import com.tencent.supersonic.common.util.ContextUtils;
|
||||
import com.tencent.supersonic.common.util.jsqlparser.SqlAddHelper;
|
||||
import com.tencent.supersonic.common.util.jsqlparser.SqlSelectFunctionHelper;
|
||||
import com.tencent.supersonic.common.util.jsqlparser.SqlSelectHelper;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import net.sf.jsqlparser.expression.Expression;
|
||||
import org.apache.commons.lang3.StringUtils;
|
||||
import org.springframework.core.env.Environment;
|
||||
import org.springframework.util.CollectionUtils;
|
||||
|
||||
import java.util.List;
|
||||
import java.util.Set;
|
||||
import java.util.stream.Collectors;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import net.sf.jsqlparser.expression.Expression;
|
||||
import org.springframework.util.CollectionUtils;
|
||||
|
||||
/**
|
||||
* Perform SQL corrections on the "Having" section in S2SQL.
|
||||
@@ -25,34 +29,38 @@ public class HavingCorrector extends BaseSemanticCorrector {
|
||||
//add aggregate to all metric
|
||||
addHaving(queryContext, semanticParseInfo);
|
||||
|
||||
//add having expression filed to select
|
||||
addHavingToSelect(semanticParseInfo);
|
||||
//decide whether add having expression field to select
|
||||
Environment environment = ContextUtils.getBean(Environment.class);
|
||||
String correctorAdditionalInfo = environment.getProperty("corrector.additional.information");
|
||||
if (StringUtils.isNotBlank(correctorAdditionalInfo) && Boolean.parseBoolean(correctorAdditionalInfo)) {
|
||||
addHavingToSelect(semanticParseInfo);
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
private void addHaving(QueryContext queryContext, SemanticParseInfo semanticParseInfo) {
|
||||
Set<Long> modelIds = semanticParseInfo.getModel().getModelIds();
|
||||
Long viewId = semanticParseInfo.getView().getView();
|
||||
|
||||
SemanticSchema semanticSchema = queryContext.getSemanticSchema();
|
||||
|
||||
Set<String> metrics = semanticSchema.getMetrics(modelIds).stream()
|
||||
Set<String> metrics = semanticSchema.getMetrics(viewId).stream()
|
||||
.map(schemaElement -> schemaElement.getName()).collect(Collectors.toSet());
|
||||
|
||||
if (CollectionUtils.isEmpty(metrics)) {
|
||||
return;
|
||||
}
|
||||
String havingSql = SqlParserAddHelper.addHaving(semanticParseInfo.getSqlInfo().getCorrectS2SQL(), metrics);
|
||||
String havingSql = SqlAddHelper.addHaving(semanticParseInfo.getSqlInfo().getCorrectS2SQL(), metrics);
|
||||
semanticParseInfo.getSqlInfo().setCorrectS2SQL(havingSql);
|
||||
}
|
||||
|
||||
private void addHavingToSelect(SemanticParseInfo semanticParseInfo) {
|
||||
String correctS2SQL = semanticParseInfo.getSqlInfo().getCorrectS2SQL();
|
||||
if (!SqlParserSelectFunctionHelper.hasAggregateFunction(correctS2SQL)) {
|
||||
if (!SqlSelectFunctionHelper.hasAggregateFunction(correctS2SQL)) {
|
||||
return;
|
||||
}
|
||||
List<Expression> havingExpressionList = SqlParserSelectHelper.getHavingExpression(correctS2SQL);
|
||||
List<Expression> havingExpressionList = SqlSelectHelper.getHavingExpression(correctS2SQL);
|
||||
if (!CollectionUtils.isEmpty(havingExpressionList)) {
|
||||
String replaceSql = SqlParserAddHelper.addFunctionToSelect(correctS2SQL, havingExpressionList);
|
||||
String replaceSql = SqlAddHelper.addFunctionToSelect(correctS2SQL, havingExpressionList);
|
||||
semanticParseInfo.getSqlInfo().setCorrectS2SQL(replaceSql);
|
||||
}
|
||||
return;
|
||||
|
||||
@@ -1,24 +1,33 @@
|
||||
package com.tencent.supersonic.chat.core.corrector;
|
||||
|
||||
import com.tencent.supersonic.chat.core.pojo.QueryContext;
|
||||
import com.tencent.supersonic.chat.api.pojo.SemanticParseInfo;
|
||||
import com.tencent.supersonic.chat.api.pojo.SemanticSchema;
|
||||
import com.tencent.supersonic.chat.api.pojo.response.SqlInfo;
|
||||
import com.tencent.supersonic.chat.core.parser.sql.llm.ParseResult;
|
||||
import com.tencent.supersonic.chat.core.pojo.QueryContext;
|
||||
import com.tencent.supersonic.chat.core.query.llm.s2sql.LLMReq.ElementValue;
|
||||
import com.tencent.supersonic.common.pojo.Constants;
|
||||
import com.tencent.supersonic.common.pojo.enums.FilterOperatorEnum;
|
||||
import com.tencent.supersonic.common.pojo.enums.TimeDimensionEnum;
|
||||
import com.tencent.supersonic.common.util.DateUtils;
|
||||
import com.tencent.supersonic.common.util.JsonUtil;
|
||||
import com.tencent.supersonic.common.util.jsqlparser.AggregateEnum;
|
||||
import com.tencent.supersonic.common.util.jsqlparser.SqlParserReplaceHelper;
|
||||
import com.tencent.supersonic.common.util.jsqlparser.FieldExpression;
|
||||
import com.tencent.supersonic.common.util.jsqlparser.SqlRemoveHelper;
|
||||
import com.tencent.supersonic.common.util.jsqlparser.SqlReplaceHelper;
|
||||
import com.tencent.supersonic.common.util.jsqlparser.SqlSelectHelper;
|
||||
import java.util.ArrayList;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
import java.util.Objects;
|
||||
import java.util.Set;
|
||||
import java.util.stream.Collectors;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.apache.commons.lang3.StringUtils;
|
||||
import org.springframework.util.CollectionUtils;
|
||||
|
||||
/**
|
||||
* Perform schema corrections on the Schema information in S2QL.
|
||||
* Perform schema corrections on the Schema information in S2SQL.
|
||||
*/
|
||||
@Slf4j
|
||||
public class SchemaCorrector extends BaseSemanticCorrector {
|
||||
@@ -26,6 +35,8 @@ public class SchemaCorrector extends BaseSemanticCorrector {
|
||||
@Override
|
||||
public void doCorrect(QueryContext queryContext, SemanticParseInfo semanticParseInfo) {
|
||||
|
||||
removeFilterIfNotInLinkingValue(queryContext, semanticParseInfo);
|
||||
|
||||
correctAggFunction(semanticParseInfo);
|
||||
|
||||
replaceAlias(semanticParseInfo);
|
||||
@@ -40,20 +51,20 @@ public class SchemaCorrector extends BaseSemanticCorrector {
|
||||
private void correctAggFunction(SemanticParseInfo semanticParseInfo) {
|
||||
Map<String, String> aggregateEnum = AggregateEnum.getAggregateEnum();
|
||||
SqlInfo sqlInfo = semanticParseInfo.getSqlInfo();
|
||||
String sql = SqlParserReplaceHelper.replaceFunction(sqlInfo.getCorrectS2SQL(), aggregateEnum);
|
||||
String sql = SqlReplaceHelper.replaceFunction(sqlInfo.getCorrectS2SQL(), aggregateEnum);
|
||||
sqlInfo.setCorrectS2SQL(sql);
|
||||
}
|
||||
|
||||
private void replaceAlias(SemanticParseInfo semanticParseInfo) {
|
||||
SqlInfo sqlInfo = semanticParseInfo.getSqlInfo();
|
||||
String replaceAlias = SqlParserReplaceHelper.replaceAlias(sqlInfo.getCorrectS2SQL());
|
||||
String replaceAlias = SqlReplaceHelper.replaceAlias(sqlInfo.getCorrectS2SQL());
|
||||
sqlInfo.setCorrectS2SQL(replaceAlias);
|
||||
}
|
||||
|
||||
private void correctFieldName(QueryContext queryContext, SemanticParseInfo semanticParseInfo) {
|
||||
Map<String, String> fieldNameMap = getFieldNameMap(queryContext, semanticParseInfo.getModel().getModelIds());
|
||||
Map<String, String> fieldNameMap = getFieldNameMap(queryContext, semanticParseInfo.getViewId());
|
||||
SqlInfo sqlInfo = semanticParseInfo.getSqlInfo();
|
||||
String sql = SqlParserReplaceHelper.replaceFields(sqlInfo.getCorrectS2SQL(), fieldNameMap);
|
||||
String sql = SqlReplaceHelper.replaceFields(sqlInfo.getCorrectS2SQL(), fieldNameMap);
|
||||
sqlInfo.setCorrectS2SQL(sql);
|
||||
}
|
||||
|
||||
@@ -69,7 +80,7 @@ public class SchemaCorrector extends BaseSemanticCorrector {
|
||||
|
||||
SqlInfo sqlInfo = semanticParseInfo.getSqlInfo();
|
||||
|
||||
String sql = SqlParserReplaceHelper.replaceFieldNameByValue(sqlInfo.getCorrectS2SQL(), fieldValueToFieldNames);
|
||||
String sql = SqlReplaceHelper.replaceFieldNameByValue(sqlInfo.getCorrectS2SQL(), fieldValueToFieldNames);
|
||||
sqlInfo.setCorrectS2SQL(sql);
|
||||
}
|
||||
|
||||
@@ -101,7 +112,38 @@ public class SchemaCorrector extends BaseSemanticCorrector {
|
||||
)));
|
||||
|
||||
SqlInfo sqlInfo = semanticParseInfo.getSqlInfo();
|
||||
String sql = SqlParserReplaceHelper.replaceValue(sqlInfo.getCorrectS2SQL(), filedNameToValueMap, false);
|
||||
String sql = SqlReplaceHelper.replaceValue(sqlInfo.getCorrectS2SQL(), filedNameToValueMap, false);
|
||||
sqlInfo.setCorrectS2SQL(sql);
|
||||
}
|
||||
|
||||
public void removeFilterIfNotInLinkingValue(QueryContext queryContext, SemanticParseInfo semanticParseInfo) {
|
||||
SqlInfo sqlInfo = semanticParseInfo.getSqlInfo();
|
||||
String correctS2SQL = sqlInfo.getCorrectS2SQL();
|
||||
List<FieldExpression> whereExpressionList = SqlSelectHelper.getWhereExpressions(correctS2SQL);
|
||||
if (CollectionUtils.isEmpty(whereExpressionList)) {
|
||||
return;
|
||||
}
|
||||
List<ElementValue> linkingValues = getLinkingValues(semanticParseInfo);
|
||||
SemanticSchema semanticSchema = queryContext.getSemanticSchema();
|
||||
Set<String> dimensions = getDimensions(semanticParseInfo.getViewId(), semanticSchema);
|
||||
|
||||
if (CollectionUtils.isEmpty(linkingValues)) {
|
||||
linkingValues = new ArrayList<>();
|
||||
}
|
||||
Set<String> linkingFieldNames = linkingValues.stream().map(linking -> linking.getFieldName())
|
||||
.collect(Collectors.toSet());
|
||||
|
||||
Set<String> removeFieldNames = whereExpressionList.stream()
|
||||
.filter(fieldExpression -> StringUtils.isBlank(fieldExpression.getFunction()))
|
||||
.filter(fieldExpression -> !TimeDimensionEnum.containsTimeDimension(fieldExpression.getFieldName()))
|
||||
.filter(fieldExpression -> FilterOperatorEnum.EQUALS.getValue().equals(fieldExpression.getOperator()))
|
||||
.filter(fieldExpression -> dimensions.contains(fieldExpression.getFieldName()))
|
||||
.filter(fieldExpression -> !DateUtils.isAnyDateString(fieldExpression.getFieldValue().toString()))
|
||||
.filter(fieldExpression -> !linkingFieldNames.contains(fieldExpression.getFieldName()))
|
||||
.map(fieldExpression -> fieldExpression.getFieldName()).collect(Collectors.toSet());
|
||||
|
||||
String sql = SqlRemoveHelper.removeWhereCondition(correctS2SQL, removeFieldNames);
|
||||
sqlInfo.setCorrectS2SQL(sql);
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
@@ -2,7 +2,7 @@ package com.tencent.supersonic.chat.core.corrector;
|
||||
|
||||
import com.tencent.supersonic.chat.core.pojo.QueryContext;
|
||||
import com.tencent.supersonic.chat.api.pojo.SemanticParseInfo;
|
||||
import com.tencent.supersonic.common.util.jsqlparser.SqlParserSelectHelper;
|
||||
import com.tencent.supersonic.common.util.jsqlparser.SqlSelectHelper;
|
||||
import java.util.List;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.springframework.util.CollectionUtils;
|
||||
@@ -16,8 +16,8 @@ public class SelectCorrector extends BaseSemanticCorrector {
|
||||
@Override
|
||||
public void doCorrect(QueryContext queryContext, SemanticParseInfo semanticParseInfo) {
|
||||
String correctS2SQL = semanticParseInfo.getSqlInfo().getCorrectS2SQL();
|
||||
List<String> aggregateFields = SqlParserSelectHelper.getAggregateFields(correctS2SQL);
|
||||
List<String> selectFields = SqlParserSelectHelper.getSelectFields(correctS2SQL);
|
||||
List<String> aggregateFields = SqlSelectHelper.getAggregateFields(correctS2SQL);
|
||||
List<String> selectFields = SqlSelectHelper.getSelectFields(correctS2SQL);
|
||||
// If the number of aggregated fields is equal to the number of queried fields, do not add fields to select.
|
||||
if (!CollectionUtils.isEmpty(aggregateFields)
|
||||
&& !CollectionUtils.isEmpty(selectFields)
|
||||
|
||||
@@ -0,0 +1,57 @@
|
||||
package com.tencent.supersonic.chat.core.corrector;
|
||||
|
||||
import com.tencent.supersonic.chat.api.pojo.SemanticParseInfo;
|
||||
import com.tencent.supersonic.chat.core.pojo.QueryContext;
|
||||
import com.tencent.supersonic.common.util.jsqlparser.DateVisitor.DateBoundInfo;
|
||||
import com.tencent.supersonic.common.util.jsqlparser.SqlAddHelper;
|
||||
import com.tencent.supersonic.common.util.jsqlparser.SqlDateSelectHelper;
|
||||
import com.tencent.supersonic.common.util.jsqlparser.SqlReplaceHelper;
|
||||
import java.util.Objects;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import net.sf.jsqlparser.JSQLParserException;
|
||||
import net.sf.jsqlparser.parser.CCJSqlParserUtil;
|
||||
import org.apache.commons.lang3.StringUtils;
|
||||
|
||||
/**
|
||||
* Perform SQL corrections on the time in S2SQL.
|
||||
*/
|
||||
@Slf4j
|
||||
public class TimeCorrector extends BaseSemanticCorrector {
|
||||
|
||||
@Override
|
||||
public void doCorrect(QueryContext queryContext, SemanticParseInfo semanticParseInfo) {
|
||||
|
||||
parserDateDiffFunction(semanticParseInfo);
|
||||
|
||||
addLowerBoundDate(semanticParseInfo);
|
||||
|
||||
}
|
||||
|
||||
private void addLowerBoundDate(SemanticParseInfo semanticParseInfo) {
|
||||
String correctS2SQL = semanticParseInfo.getSqlInfo().getCorrectS2SQL();
|
||||
DateBoundInfo dateBoundInfo = SqlDateSelectHelper.getDateBoundInfo(correctS2SQL);
|
||||
if (Objects.isNull(dateBoundInfo)) {
|
||||
return;
|
||||
}
|
||||
if (StringUtils.isBlank(dateBoundInfo.getLowerBound())
|
||||
&& StringUtils.isNotBlank(dateBoundInfo.getUpperBound())
|
||||
&& StringUtils.isNotBlank(dateBoundInfo.getUpperDate())) {
|
||||
String upperDate = dateBoundInfo.getUpperDate();
|
||||
try {
|
||||
correctS2SQL = SqlAddHelper.addParenthesisToWhere(correctS2SQL);
|
||||
String condExpr = dateBoundInfo.getColumName() + " >= '" + upperDate + "'";
|
||||
correctS2SQL = SqlAddHelper.addWhere(correctS2SQL, CCJSqlParserUtil.parseCondExpression(condExpr));
|
||||
} catch (JSQLParserException e) {
|
||||
log.error("parseCondExpression", e);
|
||||
}
|
||||
semanticParseInfo.getSqlInfo().setCorrectS2SQL(correctS2SQL);
|
||||
}
|
||||
}
|
||||
|
||||
private void parserDateDiffFunction(SemanticParseInfo semanticParseInfo) {
|
||||
String correctS2SQL = semanticParseInfo.getSqlInfo().getCorrectS2SQL();
|
||||
correctS2SQL = SqlReplaceHelper.replaceFunction(correctS2SQL);
|
||||
semanticParseInfo.getSqlInfo().setCorrectS2SQL(correctS2SQL);
|
||||
}
|
||||
|
||||
}
|
||||
@@ -1,32 +1,33 @@
|
||||
package com.tencent.supersonic.chat.core.corrector;
|
||||
|
||||
import com.tencent.supersonic.chat.core.pojo.QueryContext;
|
||||
import com.tencent.supersonic.chat.api.pojo.SchemaElement;
|
||||
import com.tencent.supersonic.chat.api.pojo.SchemaValueMap;
|
||||
import com.tencent.supersonic.headless.api.pojo.SchemaElement;
|
||||
import com.tencent.supersonic.headless.api.pojo.SchemaValueMap;
|
||||
import com.tencent.supersonic.chat.api.pojo.SemanticParseInfo;
|
||||
import com.tencent.supersonic.chat.api.pojo.SemanticSchema;
|
||||
import com.tencent.supersonic.chat.api.pojo.request.QueryFilters;
|
||||
import com.tencent.supersonic.chat.core.parser.sql.llm.S2SqlDateHelper;
|
||||
import com.tencent.supersonic.chat.core.utils.S2SqlDateHelper;
|
||||
import com.tencent.supersonic.chat.core.pojo.QueryContext;
|
||||
import com.tencent.supersonic.common.pojo.Constants;
|
||||
import com.tencent.supersonic.common.pojo.enums.TimeDimensionEnum;
|
||||
import com.tencent.supersonic.common.util.StringUtil;
|
||||
import com.tencent.supersonic.common.util.jsqlparser.SqlParserAddHelper;
|
||||
import com.tencent.supersonic.common.util.jsqlparser.SqlParserReplaceHelper;
|
||||
import com.tencent.supersonic.common.util.jsqlparser.SqlParserSelectHelper;
|
||||
import java.util.HashMap;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
import java.util.Objects;
|
||||
import java.util.Set;
|
||||
import java.util.stream.Collectors;
|
||||
import com.tencent.supersonic.common.util.jsqlparser.SqlAddHelper;
|
||||
import com.tencent.supersonic.common.util.jsqlparser.SqlReplaceHelper;
|
||||
import com.tencent.supersonic.common.util.jsqlparser.SqlSelectHelper;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import net.sf.jsqlparser.JSQLParserException;
|
||||
import net.sf.jsqlparser.expression.Expression;
|
||||
import net.sf.jsqlparser.parser.CCJSqlParserUtil;
|
||||
import org.apache.commons.lang3.StringUtils;
|
||||
import org.apache.commons.lang3.tuple.Pair;
|
||||
import org.apache.logging.log4j.util.Strings;
|
||||
import org.springframework.util.CollectionUtils;
|
||||
|
||||
import java.util.HashMap;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
import java.util.Objects;
|
||||
import java.util.stream.Collectors;
|
||||
|
||||
/**
|
||||
* Perform SQL corrections on the "Where" section in S2SQL.
|
||||
*/
|
||||
@@ -38,8 +39,6 @@ public class WhereCorrector extends BaseSemanticCorrector {
|
||||
|
||||
addDateIfNotExist(queryContext, semanticParseInfo);
|
||||
|
||||
parserDateDiffFunction(semanticParseInfo);
|
||||
|
||||
addQueryFilter(queryContext, semanticParseInfo);
|
||||
|
||||
updateFieldValueByTechName(queryContext, semanticParseInfo);
|
||||
@@ -58,26 +57,29 @@ public class WhereCorrector extends BaseSemanticCorrector {
|
||||
} catch (JSQLParserException e) {
|
||||
log.error("parseCondExpression", e);
|
||||
}
|
||||
correctS2SQL = SqlParserAddHelper.addWhere(correctS2SQL, expression);
|
||||
correctS2SQL = SqlAddHelper.addWhere(correctS2SQL, expression);
|
||||
semanticParseInfo.getSqlInfo().setCorrectS2SQL(correctS2SQL);
|
||||
}
|
||||
}
|
||||
|
||||
private void parserDateDiffFunction(SemanticParseInfo semanticParseInfo) {
|
||||
String correctS2SQL = semanticParseInfo.getSqlInfo().getCorrectS2SQL();
|
||||
correctS2SQL = SqlParserReplaceHelper.replaceFunction(correctS2SQL);
|
||||
semanticParseInfo.getSqlInfo().setCorrectS2SQL(correctS2SQL);
|
||||
}
|
||||
|
||||
private void addDateIfNotExist(QueryContext queryContext, SemanticParseInfo semanticParseInfo) {
|
||||
String correctS2SQL = semanticParseInfo.getSqlInfo().getCorrectS2SQL();
|
||||
List<String> whereFields = SqlParserSelectHelper.getWhereFields(correctS2SQL);
|
||||
List<String> whereFields = SqlSelectHelper.getWhereFields(correctS2SQL);
|
||||
if (CollectionUtils.isEmpty(whereFields) || !TimeDimensionEnum.containsZhTimeDimension(whereFields)) {
|
||||
String currentDate = S2SqlDateHelper.getReferenceDate(queryContext, semanticParseInfo.getModelId());
|
||||
if (StringUtils.isNotBlank(currentDate)) {
|
||||
correctS2SQL = SqlParserAddHelper.addParenthesisToWhere(correctS2SQL);
|
||||
correctS2SQL = SqlParserAddHelper.addWhere(
|
||||
correctS2SQL, TimeDimensionEnum.DAY.getChName(), currentDate);
|
||||
Pair<String, String> startEndDate = S2SqlDateHelper.getStartEndDate(queryContext,
|
||||
semanticParseInfo.getViewId(), semanticParseInfo.getQueryType());
|
||||
if (StringUtils.isNotBlank(startEndDate.getLeft())
|
||||
&& StringUtils.isNotBlank(startEndDate.getRight())) {
|
||||
correctS2SQL = SqlAddHelper.addParenthesisToWhere(correctS2SQL);
|
||||
String dateChName = TimeDimensionEnum.DAY.getChName();
|
||||
String condExpr = String.format(" ( %s >= '%s' and %s <= '%s' )", dateChName,
|
||||
startEndDate.getLeft(), dateChName, startEndDate.getRight());
|
||||
try {
|
||||
Expression expression = CCJSqlParserUtil.parseCondExpression(condExpr);
|
||||
correctS2SQL = SqlAddHelper.addWhere(correctS2SQL, expression);
|
||||
} catch (JSQLParserException e) {
|
||||
log.error("parseCondExpression:{}", e);
|
||||
}
|
||||
}
|
||||
}
|
||||
semanticParseInfo.getSqlInfo().setCorrectS2SQL(correctS2SQL);
|
||||
@@ -99,15 +101,15 @@ public class WhereCorrector extends BaseSemanticCorrector {
|
||||
|
||||
private void updateFieldValueByTechName(QueryContext queryContext, SemanticParseInfo semanticParseInfo) {
|
||||
SemanticSchema semanticSchema = queryContext.getSemanticSchema();
|
||||
Set<Long> modelIds = semanticParseInfo.getModel().getModelIds();
|
||||
List<SchemaElement> dimensions = semanticSchema.getDimensions(modelIds);
|
||||
Long viewId = semanticParseInfo.getViewId();
|
||||
List<SchemaElement> dimensions = semanticSchema.getDimensions(viewId);
|
||||
|
||||
if (CollectionUtils.isEmpty(dimensions)) {
|
||||
return;
|
||||
}
|
||||
|
||||
Map<String, Map<String, String>> aliasAndBizNameToTechName = getAliasAndBizNameToTechName(dimensions);
|
||||
String correctS2SQL = SqlParserReplaceHelper.replaceValue(semanticParseInfo.getSqlInfo().getCorrectS2SQL(),
|
||||
String correctS2SQL = SqlReplaceHelper.replaceValue(semanticParseInfo.getSqlInfo().getCorrectS2SQL(),
|
||||
aliasAndBizNameToTechName);
|
||||
semanticParseInfo.getSqlInfo().setCorrectS2SQL(correctS2SQL);
|
||||
}
|
||||
|
||||
@@ -1,13 +0,0 @@
|
||||
package com.tencent.supersonic.chat.core.knowledge;
|
||||
|
||||
import java.util.List;
|
||||
import lombok.Data;
|
||||
|
||||
|
||||
@Data
|
||||
public class DictConfig {
|
||||
|
||||
private Long modelId;
|
||||
|
||||
private List<DimValueInfo> dimValueInfoList;
|
||||
}
|
||||
@@ -1,18 +0,0 @@
|
||||
package com.tencent.supersonic.chat.core.knowledge;
|
||||
|
||||
|
||||
import java.util.HashMap;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
|
||||
import lombok.Data;
|
||||
|
||||
@Data
|
||||
public class DimValue2DictCommand {
|
||||
|
||||
private DictUpdateMode updateMode;
|
||||
|
||||
private List<Long> modelIds;
|
||||
|
||||
private Map<Long, List<Long>> modelAndDimPair = new HashMap<>();
|
||||
}
|
||||
@@ -1,31 +0,0 @@
|
||||
package com.tencent.supersonic.chat.core.knowledge;
|
||||
|
||||
|
||||
import com.tencent.supersonic.common.pojo.enums.TaskStatusEnum;
|
||||
|
||||
import java.util.Date;
|
||||
import java.util.Set;
|
||||
|
||||
import lombok.Data;
|
||||
|
||||
@Data
|
||||
public class DimValueDictInfo {
|
||||
|
||||
private Long id;
|
||||
|
||||
private String name;
|
||||
|
||||
private String description;
|
||||
|
||||
private String command;
|
||||
|
||||
private TaskStatusEnum status;
|
||||
|
||||
private String createdBy;
|
||||
|
||||
private Date createdAt;
|
||||
|
||||
private Long elapsedMs;
|
||||
|
||||
private Set<Long> dimIds;
|
||||
}
|
||||
@@ -1,26 +0,0 @@
|
||||
package com.tencent.supersonic.chat.core.knowledge;
|
||||
|
||||
|
||||
import com.tencent.supersonic.common.pojo.enums.TypeEnums;
|
||||
import java.util.List;
|
||||
import javax.validation.constraints.NotNull;
|
||||
|
||||
public class DimValueInfo {
|
||||
|
||||
/**
|
||||
* metricId、DimensionId、domainId
|
||||
*/
|
||||
private Long itemId;
|
||||
|
||||
/**
|
||||
* type: IntentionTypeEnum
|
||||
* temporarily only supports dimension-related information
|
||||
*/
|
||||
@NotNull
|
||||
private TypeEnums type = TypeEnums.DIMENSION;
|
||||
|
||||
private List<String> blackList;
|
||||
private List<String> whiteList;
|
||||
private List<String> ruleList;
|
||||
private Boolean isDictInfo;
|
||||
}
|
||||
@@ -1,21 +0,0 @@
|
||||
package com.tencent.supersonic.chat.core.knowledge;
|
||||
|
||||
import java.io.Serializable;
|
||||
import lombok.Builder;
|
||||
import lombok.Data;
|
||||
import lombok.ToString;
|
||||
|
||||
@Data
|
||||
@ToString
|
||||
@Builder
|
||||
public class ModelInfoStat implements Serializable {
|
||||
|
||||
private long modelCount;
|
||||
|
||||
private long metricModelCount;
|
||||
|
||||
private long dimensionModelCount;
|
||||
|
||||
private long dimensionValueModelCount;
|
||||
|
||||
}
|
||||
@@ -1,67 +0,0 @@
|
||||
package com.tencent.supersonic.chat.core.knowledge.semantic;
|
||||
|
||||
import com.google.common.cache.Cache;
|
||||
import com.google.common.cache.CacheBuilder;
|
||||
import com.tencent.supersonic.chat.api.pojo.ModelSchema;
|
||||
import com.tencent.supersonic.headless.api.response.ModelSchemaResp;
|
||||
import java.util.ArrayList;
|
||||
import java.util.List;
|
||||
import java.util.Optional;
|
||||
import java.util.concurrent.TimeUnit;
|
||||
import lombok.SneakyThrows;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.springframework.util.CollectionUtils;
|
||||
|
||||
@Slf4j
|
||||
public abstract class BaseSemanticInterpreter implements SemanticInterpreter {
|
||||
|
||||
protected final Cache<String, List<ModelSchemaResp>> modelSchemaCache =
|
||||
CacheBuilder.newBuilder().expireAfterWrite(10, TimeUnit.SECONDS).build();
|
||||
|
||||
@SneakyThrows
|
||||
public List<ModelSchemaResp> fetchModelSchema(List<Long> ids, Boolean cacheEnable) {
|
||||
if (cacheEnable) {
|
||||
return modelSchemaCache.get(String.valueOf(ids), () -> {
|
||||
List<ModelSchemaResp> data = doFetchModelSchema(ids);
|
||||
modelSchemaCache.put(String.valueOf(ids), data);
|
||||
return data;
|
||||
});
|
||||
}
|
||||
List<ModelSchemaResp> data = doFetchModelSchema(ids);
|
||||
return data;
|
||||
}
|
||||
|
||||
@Override
|
||||
public ModelSchema getModelSchema(Long model, Boolean cacheEnable) {
|
||||
List<Long> ids = new ArrayList<>();
|
||||
ids.add(model);
|
||||
List<ModelSchemaResp> modelSchemaResps = fetchModelSchema(ids, cacheEnable);
|
||||
if (!CollectionUtils.isEmpty(modelSchemaResps)) {
|
||||
Optional<ModelSchemaResp> modelSchemaResp = modelSchemaResps.stream()
|
||||
.filter(d -> d.getId().equals(model)).findFirst();
|
||||
if (modelSchemaResp.isPresent()) {
|
||||
ModelSchemaResp modelSchema = modelSchemaResp.get();
|
||||
return ModelSchemaBuilder.build(modelSchema);
|
||||
}
|
||||
}
|
||||
return null;
|
||||
}
|
||||
|
||||
@Override
|
||||
public List<ModelSchema> getModelSchema() {
|
||||
return getModelSchema(new ArrayList<>());
|
||||
}
|
||||
|
||||
@Override
|
||||
public List<ModelSchema> getModelSchema(List<Long> ids) {
|
||||
List<ModelSchema> domainSchemaList = new ArrayList<>();
|
||||
|
||||
for (ModelSchemaResp resp : fetchModelSchema(ids, true)) {
|
||||
domainSchemaList.add(ModelSchemaBuilder.build(resp));
|
||||
}
|
||||
|
||||
return domainSchemaList;
|
||||
}
|
||||
|
||||
protected abstract List<ModelSchemaResp> doFetchModelSchema(List<Long> ids);
|
||||
}
|
||||
@@ -1,63 +0,0 @@
|
||||
package com.tencent.supersonic.chat.core.knowledge.semantic;
|
||||
|
||||
import com.github.pagehelper.PageInfo;
|
||||
import com.tencent.supersonic.auth.api.authentication.pojo.User;
|
||||
import com.tencent.supersonic.chat.api.pojo.ModelSchema;
|
||||
import com.tencent.supersonic.common.pojo.enums.AuthType;
|
||||
import com.tencent.supersonic.headless.api.request.PageDimensionReq;
|
||||
import com.tencent.supersonic.headless.api.request.PageMetricReq;
|
||||
import com.tencent.supersonic.headless.api.response.DomainResp;
|
||||
import com.tencent.supersonic.headless.api.response.DimensionResp;
|
||||
import com.tencent.supersonic.headless.api.response.ExplainResp;
|
||||
import com.tencent.supersonic.headless.api.response.MetricResp;
|
||||
import com.tencent.supersonic.headless.api.response.ModelResp;
|
||||
import com.tencent.supersonic.headless.api.response.ModelSchemaResp;
|
||||
import com.tencent.supersonic.headless.api.response.SemanticQueryResp;
|
||||
import com.tencent.supersonic.headless.api.request.ExplainSqlReq;
|
||||
import com.tencent.supersonic.headless.api.request.QueryDimValueReq;
|
||||
import com.tencent.supersonic.headless.api.request.QuerySqlReq;
|
||||
import com.tencent.supersonic.headless.api.request.QueryMultiStructReq;
|
||||
import com.tencent.supersonic.headless.api.request.QueryStructReq;
|
||||
|
||||
import java.util.List;
|
||||
|
||||
/**
|
||||
* A semantic layer provides a simplified and consistent view of data from multiple sources.
|
||||
* It abstracts away the complexity of the underlying data sources and provides a unified view
|
||||
* of the data that is easier to understand and use.
|
||||
* <p>
|
||||
* The interface defines methods for getting metadata as well as querying data in the semantic layer.
|
||||
* Implementations of this interface should provide concrete implementations that interact with the
|
||||
* underlying data sources and return results in a consistent format. Or it can be implemented
|
||||
* as proxy to a remote semantic service.
|
||||
* </p>
|
||||
*/
|
||||
public interface SemanticInterpreter {
|
||||
|
||||
SemanticQueryResp queryByStruct(QueryStructReq queryStructReq, User user);
|
||||
|
||||
SemanticQueryResp queryByMultiStruct(QueryMultiStructReq queryMultiStructReq, User user);
|
||||
|
||||
SemanticQueryResp queryByS2SQL(QuerySqlReq querySQLReq, User user);
|
||||
|
||||
SemanticQueryResp queryDimValue(QueryDimValueReq queryDimValueReq, User user);
|
||||
|
||||
List<ModelSchema> getModelSchema();
|
||||
|
||||
List<ModelSchema> getModelSchema(List<Long> ids);
|
||||
|
||||
ModelSchema getModelSchema(Long model, Boolean cacheEnable);
|
||||
|
||||
PageInfo<DimensionResp> getDimensionPage(PageDimensionReq pageDimensionReq);
|
||||
|
||||
PageInfo<MetricResp> getMetricPage(PageMetricReq pageDimensionReq, User user);
|
||||
|
||||
List<DomainResp> getDomainList(User user);
|
||||
|
||||
List<ModelResp> getModelList(AuthType authType, Long domainId, User user);
|
||||
|
||||
<T> ExplainResp explain(ExplainSqlReq<T> explainSqlReq, User user) throws Exception;
|
||||
|
||||
List<ModelSchemaResp> fetchModelSchema(List<Long> ids, Boolean cacheEnable);
|
||||
|
||||
}
|
||||
@@ -1,21 +1,22 @@
|
||||
package com.tencent.supersonic.chat.core.mapper;
|
||||
|
||||
import com.tencent.supersonic.chat.api.pojo.ModelSchema;
|
||||
import com.tencent.supersonic.chat.core.pojo.QueryContext;
|
||||
import com.tencent.supersonic.chat.api.pojo.SchemaElement;
|
||||
import com.tencent.supersonic.headless.api.pojo.SchemaElement;
|
||||
import com.tencent.supersonic.chat.api.pojo.SchemaElementMatch;
|
||||
import com.tencent.supersonic.chat.api.pojo.SchemaElementType;
|
||||
import com.tencent.supersonic.headless.api.pojo.SchemaElementType;
|
||||
import com.tencent.supersonic.chat.api.pojo.SchemaMapInfo;
|
||||
import com.tencent.supersonic.chat.api.pojo.SemanticSchema;
|
||||
import com.tencent.supersonic.chat.api.pojo.ViewSchema;
|
||||
import com.tencent.supersonic.chat.core.pojo.QueryContext;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.apache.commons.lang3.StringUtils;
|
||||
import org.springframework.beans.BeanUtils;
|
||||
|
||||
import java.util.ArrayList;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
import java.util.Objects;
|
||||
import java.util.concurrent.atomic.AtomicBoolean;
|
||||
import java.util.stream.Collectors;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.apache.commons.lang3.StringUtils;
|
||||
import org.springframework.beans.BeanUtils;
|
||||
|
||||
@Slf4j
|
||||
public abstract class BaseMapper implements SchemaMapper {
|
||||
@@ -25,7 +26,7 @@ public abstract class BaseMapper implements SchemaMapper {
|
||||
|
||||
String simpleName = this.getClass().getSimpleName();
|
||||
long startTime = System.currentTimeMillis();
|
||||
log.info("before {},mapInfo:{}", simpleName, queryContext.getMapInfo().getModelElementMatches());
|
||||
log.debug("before {},mapInfo:{}", simpleName, queryContext.getMapInfo().getViewElementMatches());
|
||||
|
||||
try {
|
||||
doMap(queryContext);
|
||||
@@ -34,13 +35,13 @@ public abstract class BaseMapper implements SchemaMapper {
|
||||
}
|
||||
|
||||
long cost = System.currentTimeMillis() - startTime;
|
||||
log.info("after {},cost:{},mapInfo:{}", simpleName, cost, queryContext.getMapInfo().getModelElementMatches());
|
||||
log.debug("after {},cost:{},mapInfo:{}", simpleName, cost, queryContext.getMapInfo().getViewElementMatches());
|
||||
}
|
||||
|
||||
public abstract void doMap(QueryContext queryContext);
|
||||
|
||||
public void addToSchemaMap(SchemaMapInfo schemaMap, Long modelId, SchemaElementMatch newElementMatch) {
|
||||
Map<Long, List<SchemaElementMatch>> modelElementMatches = schemaMap.getModelElementMatches();
|
||||
Map<Long, List<SchemaElementMatch>> modelElementMatches = schemaMap.getViewElementMatches();
|
||||
List<SchemaElementMatch> schemaElementMatches = modelElementMatches.putIfAbsent(modelId, new ArrayList<>());
|
||||
if (schemaElementMatches == null) {
|
||||
schemaElementMatches = modelElementMatches.get(modelId);
|
||||
@@ -66,14 +67,14 @@ public abstract class BaseMapper implements SchemaMapper {
|
||||
}
|
||||
}
|
||||
|
||||
public SchemaElement getSchemaElement(Long modelId, SchemaElementType elementType, Long elementID,
|
||||
public SchemaElement getSchemaElement(Long viewId, SchemaElementType elementType, Long elementID,
|
||||
SemanticSchema semanticSchema) {
|
||||
SchemaElement element = new SchemaElement();
|
||||
ModelSchema modelSchema = semanticSchema.getModelSchemaMap().get(modelId);
|
||||
if (Objects.isNull(modelSchema)) {
|
||||
ViewSchema viewSchema = semanticSchema.getViewSchemaMap().get(viewId);
|
||||
if (Objects.isNull(viewSchema)) {
|
||||
return null;
|
||||
}
|
||||
SchemaElement elementDb = modelSchema.getElement(elementType, elementID);
|
||||
SchemaElement elementDb = viewSchema.getElement(elementType, elementID);
|
||||
if (Objects.isNull(elementDb)) {
|
||||
log.info("element is null, elementType:{},elementID:{}", elementType, elementID);
|
||||
return null;
|
||||
|
||||
@@ -1,8 +1,8 @@
|
||||
package com.tencent.supersonic.chat.core.mapper;
|
||||
|
||||
import com.hankcs.hanlp.seg.common.Term;
|
||||
import com.tencent.supersonic.chat.core.pojo.QueryContext;
|
||||
import com.tencent.supersonic.chat.core.utils.NatureHelper;
|
||||
import com.tencent.supersonic.headless.api.pojo.response.S2Term;
|
||||
import com.tencent.supersonic.headless.core.knowledge.helper.NatureHelper;
|
||||
import java.util.ArrayList;
|
||||
import java.util.Comparator;
|
||||
import java.util.HashMap;
|
||||
@@ -27,22 +27,23 @@ public abstract class BaseMatchStrategy<T> implements MatchStrategy<T> {
|
||||
private MapperHelper mapperHelper;
|
||||
|
||||
@Override
|
||||
public Map<MatchText, List<T>> match(QueryContext queryContext, List<Term> terms, Set<Long> detectModelIds) {
|
||||
public Map<MatchText, List<T>> match(QueryContext queryContext, List<S2Term> terms,
|
||||
Set<Long> detectViewIds) {
|
||||
String text = queryContext.getQueryText();
|
||||
if (Objects.isNull(terms) || StringUtils.isEmpty(text)) {
|
||||
return null;
|
||||
}
|
||||
|
||||
log.debug("terms:{},,detectModelIds:{}", terms, detectModelIds);
|
||||
log.debug("terms:{},,detectViewIds:{}", terms, detectViewIds);
|
||||
|
||||
List<T> detects = detect(queryContext, terms, detectModelIds);
|
||||
List<T> detects = detect(queryContext, terms, detectViewIds);
|
||||
Map<MatchText, List<T>> result = new HashMap<>();
|
||||
|
||||
result.put(MatchText.builder().regText(text).detectSegment(text).build(), detects);
|
||||
return result;
|
||||
}
|
||||
|
||||
public List<T> detect(QueryContext queryContext, List<Term> terms, Set<Long> detectModelIds) {
|
||||
public List<T> detect(QueryContext queryContext, List<S2Term> terms, Set<Long> detectViewIds) {
|
||||
Map<Integer, Integer> regOffsetToLength = getRegOffsetToLength(terms);
|
||||
String text = queryContext.getQueryText();
|
||||
Set<T> results = new HashSet<>();
|
||||
@@ -55,25 +56,26 @@ public abstract class BaseMatchStrategy<T> implements MatchStrategy<T> {
|
||||
int offset = mapperHelper.getStepOffset(terms, startIndex);
|
||||
index = mapperHelper.getStepIndex(regOffsetToLength, index);
|
||||
if (index <= text.length()) {
|
||||
String detectSegment = text.substring(startIndex, index);
|
||||
String detectSegment = text.substring(startIndex, index).trim();
|
||||
detectSegments.add(detectSegment);
|
||||
detectByStep(queryContext, results, detectModelIds, startIndex, index, offset);
|
||||
detectByStep(queryContext, results, detectViewIds, detectSegment, offset);
|
||||
}
|
||||
}
|
||||
startIndex = mapperHelper.getStepIndex(regOffsetToLength, startIndex);
|
||||
}
|
||||
detectByBatch(queryContext, results, detectModelIds, detectSegments);
|
||||
detectByBatch(queryContext, results, detectViewIds, detectSegments);
|
||||
return new ArrayList<>(results);
|
||||
}
|
||||
|
||||
protected void detectByBatch(QueryContext queryContext, Set<T> results, Set<Long> detectModelIds,
|
||||
protected void detectByBatch(QueryContext queryContext, Set<T> results, Set<Long> detectViewIds,
|
||||
Set<String> detectSegments) {
|
||||
return;
|
||||
}
|
||||
|
||||
public Map<Integer, Integer> getRegOffsetToLength(List<Term> terms) {
|
||||
return terms.stream().sorted(Comparator.comparing(Term::length))
|
||||
.collect(Collectors.toMap(Term::getOffset, term -> term.word.length(), (value1, value2) -> value2));
|
||||
public Map<Integer, Integer> getRegOffsetToLength(List<S2Term> terms) {
|
||||
return terms.stream().sorted(Comparator.comparing(S2Term::length))
|
||||
.collect(Collectors.toMap(S2Term::getOffset, term -> term.word.length(),
|
||||
(value1, value2) -> value2));
|
||||
}
|
||||
|
||||
public void selectResultInOneRound(Set<T> existResults, List<T> oneRoundResults) {
|
||||
@@ -101,10 +103,10 @@ public abstract class BaseMatchStrategy<T> implements MatchStrategy<T> {
|
||||
}
|
||||
}
|
||||
|
||||
public List<T> getMatches(QueryContext queryContext, List<Term> terms) {
|
||||
Set<Long> detectModelIds = mapperHelper.getModelIds(queryContext.getModelId(), queryContext.getAgent());
|
||||
terms = filterByModelIds(terms, detectModelIds);
|
||||
Map<MatchText, List<T>> matchResult = match(queryContext, terms, detectModelIds);
|
||||
public List<T> getMatches(QueryContext queryContext, List<S2Term> terms) {
|
||||
Set<Long> viewIds = mapperHelper.getViewIds(queryContext.getViewId(), queryContext.getAgent());
|
||||
terms = filterByViewId(terms, viewIds);
|
||||
Map<MatchText, List<T>> matchResult = match(queryContext, terms, viewIds);
|
||||
List<T> matches = new ArrayList<>();
|
||||
if (Objects.isNull(matchResult)) {
|
||||
return matches;
|
||||
@@ -119,27 +121,27 @@ public abstract class BaseMatchStrategy<T> implements MatchStrategy<T> {
|
||||
return matches;
|
||||
}
|
||||
|
||||
public List<Term> filterByModelIds(List<Term> terms, Set<Long> detectModelIds) {
|
||||
public List<S2Term> filterByViewId(List<S2Term> terms, Set<Long> viewIds) {
|
||||
logTerms(terms);
|
||||
if (CollectionUtils.isNotEmpty(detectModelIds)) {
|
||||
if (CollectionUtils.isNotEmpty(viewIds)) {
|
||||
terms = terms.stream().filter(term -> {
|
||||
Long modelId = NatureHelper.getModelId(term.getNature().toString());
|
||||
if (Objects.nonNull(modelId)) {
|
||||
return detectModelIds.contains(modelId);
|
||||
Long viewId = NatureHelper.getViewId(term.getNature().toString());
|
||||
if (Objects.nonNull(viewId)) {
|
||||
return viewIds.contains(viewId);
|
||||
}
|
||||
return false;
|
||||
}).collect(Collectors.toList());
|
||||
log.info("terms filter by modelIds:{}", detectModelIds);
|
||||
log.info("terms filter by viewId:{}", viewIds);
|
||||
logTerms(terms);
|
||||
}
|
||||
return terms;
|
||||
}
|
||||
|
||||
public void logTerms(List<Term> terms) {
|
||||
public void logTerms(List<S2Term> terms) {
|
||||
if (CollectionUtils.isEmpty(terms)) {
|
||||
return;
|
||||
}
|
||||
for (Term term : terms) {
|
||||
for (S2Term term : terms) {
|
||||
log.debug("word:{},nature:{},frequency:{}", term.word, term.nature.toString(), term.getFrequency());
|
||||
}
|
||||
}
|
||||
@@ -148,7 +150,7 @@ public abstract class BaseMatchStrategy<T> implements MatchStrategy<T> {
|
||||
|
||||
public abstract String getMapKey(T a);
|
||||
|
||||
public abstract void detectByStep(QueryContext queryContext, Set<T> results,
|
||||
Set<Long> detectModelIds, Integer startIndex, Integer index, int offset);
|
||||
public abstract void detectByStep(QueryContext queryContext, Set<T> existResults, Set<Long> detectViewIds,
|
||||
String detectSegment, int offset);
|
||||
|
||||
}
|
||||
|
||||
@@ -1,12 +1,18 @@
|
||||
package com.tencent.supersonic.chat.core.mapper;
|
||||
|
||||
import com.hankcs.hanlp.seg.common.Term;
|
||||
import com.tencent.supersonic.chat.core.pojo.QueryContext;
|
||||
import com.tencent.supersonic.chat.api.pojo.SchemaElement;
|
||||
import com.tencent.supersonic.headless.api.pojo.SchemaElement;
|
||||
import com.tencent.supersonic.chat.api.pojo.SchemaElementMatch;
|
||||
import com.tencent.supersonic.chat.core.config.OptimizationConfig;
|
||||
import com.tencent.supersonic.chat.core.knowledge.DatabaseMapResult;
|
||||
import com.tencent.supersonic.headless.api.pojo.response.S2Term;
|
||||
import com.tencent.supersonic.headless.core.knowledge.DatabaseMapResult;
|
||||
import com.tencent.supersonic.chat.core.pojo.QueryContext;
|
||||
import com.tencent.supersonic.common.pojo.Constants;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.apache.commons.lang3.StringUtils;
|
||||
import org.springframework.beans.factory.annotation.Autowired;
|
||||
import org.springframework.stereotype.Service;
|
||||
import org.springframework.util.CollectionUtils;
|
||||
|
||||
import java.util.ArrayList;
|
||||
import java.util.HashSet;
|
||||
import java.util.List;
|
||||
@@ -14,11 +20,6 @@ import java.util.Map;
|
||||
import java.util.Map.Entry;
|
||||
import java.util.Set;
|
||||
import java.util.stream.Collectors;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.apache.commons.lang3.StringUtils;
|
||||
import org.springframework.beans.factory.annotation.Autowired;
|
||||
import org.springframework.stereotype.Service;
|
||||
import org.springframework.util.CollectionUtils;
|
||||
|
||||
/**
|
||||
* DatabaseMatchStrategy uses SQL LIKE operator to match schema elements.
|
||||
@@ -35,10 +36,10 @@ public class DatabaseMatchStrategy extends BaseMatchStrategy<DatabaseMapResult>
|
||||
private List<SchemaElement> allElements;
|
||||
|
||||
@Override
|
||||
public Map<MatchText, List<DatabaseMapResult>> match(QueryContext queryContext, List<Term> terms,
|
||||
Set<Long> detectModelIds) {
|
||||
public Map<MatchText, List<DatabaseMapResult>> match(QueryContext queryContext, List<S2Term> terms,
|
||||
Set<Long> detectViewIds) {
|
||||
this.allElements = getSchemaElements(queryContext);
|
||||
return super.match(queryContext, terms, detectModelIds);
|
||||
return super.match(queryContext, terms, detectViewIds);
|
||||
}
|
||||
|
||||
@Override
|
||||
@@ -53,16 +54,13 @@ public class DatabaseMatchStrategy extends BaseMatchStrategy<DatabaseMapResult>
|
||||
+ Constants.UNDERLINE + a.getSchemaElement().getName();
|
||||
}
|
||||
|
||||
public void detectByStep(QueryContext queryContext, Set<DatabaseMapResult> existResults, Set<Long> detectModelIds,
|
||||
Integer startIndex, Integer index, int offset) {
|
||||
String detectSegment = queryContext.getQueryText().substring(startIndex, index);
|
||||
public void detectByStep(QueryContext queryContext, Set<DatabaseMapResult> existResults, Set<Long> detectViewIds,
|
||||
String detectSegment, int offset) {
|
||||
if (StringUtils.isBlank(detectSegment)) {
|
||||
return;
|
||||
}
|
||||
Set<Long> modelIds = mapperHelper.getModelIds(queryContext.getModelId(), queryContext.getAgent());
|
||||
|
||||
Double metricDimensionThresholdConfig = getThreshold(queryContext);
|
||||
|
||||
Map<String, Set<SchemaElement>> nameToItems = getNameToItems(allElements);
|
||||
|
||||
for (Entry<String, Set<SchemaElement>> entry : nameToItems.entrySet()) {
|
||||
@@ -72,9 +70,9 @@ public class DatabaseMatchStrategy extends BaseMatchStrategy<DatabaseMapResult>
|
||||
continue;
|
||||
}
|
||||
Set<SchemaElement> schemaElements = entry.getValue();
|
||||
if (!CollectionUtils.isEmpty(modelIds)) {
|
||||
if (!CollectionUtils.isEmpty(detectViewIds)) {
|
||||
schemaElements = schemaElements.stream()
|
||||
.filter(schemaElement -> modelIds.contains(schemaElement.getModel()))
|
||||
.filter(schemaElement -> detectViewIds.contains(schemaElement.getView()))
|
||||
.collect(Collectors.toSet());
|
||||
}
|
||||
for (SchemaElement schemaElement : schemaElements) {
|
||||
@@ -98,7 +96,7 @@ public class DatabaseMatchStrategy extends BaseMatchStrategy<DatabaseMapResult>
|
||||
Double metricDimensionThresholdConfig = optimizationConfig.getMetricDimensionThresholdConfig();
|
||||
Double metricDimensionMinThresholdConfig = optimizationConfig.getMetricDimensionMinThresholdConfig();
|
||||
|
||||
Map<Long, List<SchemaElementMatch>> modelElementMatches = queryContext.getMapInfo().getModelElementMatches();
|
||||
Map<Long, List<SchemaElementMatch>> modelElementMatches = queryContext.getMapInfo().getViewElementMatches();
|
||||
|
||||
boolean existElement = modelElementMatches.entrySet().stream().anyMatch(entry -> entry.getValue().size() >= 1);
|
||||
|
||||
|
||||
@@ -1,18 +1,19 @@
|
||||
package com.tencent.supersonic.chat.core.mapper;
|
||||
|
||||
import com.alibaba.fastjson.JSONObject;
|
||||
import com.hankcs.hanlp.seg.common.Term;
|
||||
import com.tencent.supersonic.chat.core.pojo.QueryContext;
|
||||
import com.tencent.supersonic.chat.api.pojo.SchemaElement;
|
||||
import com.tencent.supersonic.chat.api.pojo.SchemaElementMatch;
|
||||
import com.tencent.supersonic.chat.core.knowledge.EmbeddingResult;
|
||||
import com.tencent.supersonic.chat.core.knowledge.builder.BaseWordBuilder;
|
||||
import com.tencent.supersonic.chat.core.utils.HanlpHelper;
|
||||
import com.tencent.supersonic.chat.core.pojo.QueryContext;
|
||||
import com.tencent.supersonic.common.util.ContextUtils;
|
||||
import com.tencent.supersonic.common.util.embedding.Retrieval;
|
||||
import java.util.List;
|
||||
import com.tencent.supersonic.headless.api.pojo.SchemaElement;
|
||||
import com.tencent.supersonic.headless.api.pojo.SchemaElementType;
|
||||
import com.tencent.supersonic.headless.api.pojo.response.S2Term;
|
||||
import com.tencent.supersonic.headless.core.knowledge.EmbeddingResult;
|
||||
import com.tencent.supersonic.headless.core.knowledge.builder.BaseWordBuilder;
|
||||
import com.tencent.supersonic.headless.core.knowledge.helper.HanlpHelper;
|
||||
import com.tencent.supersonic.headless.server.service.KnowledgeService;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.apache.commons.lang3.StringUtils;
|
||||
import java.util.List;
|
||||
import java.util.Objects;
|
||||
|
||||
/***
|
||||
* A mapper that recognizes schema elements with vector embedding.
|
||||
@@ -24,7 +25,8 @@ public class EmbeddingMapper extends BaseMapper {
|
||||
public void doMap(QueryContext queryContext) {
|
||||
//1. query from embedding by queryText
|
||||
String queryText = queryContext.getQueryText();
|
||||
List<Term> terms = HanlpHelper.getTerms(queryText);
|
||||
KnowledgeService knowledgeService = ContextUtils.getBean(KnowledgeService.class);
|
||||
List<S2Term> terms = knowledgeService.getTerms(queryText);
|
||||
|
||||
EmbeddingMatchStrategy matchStrategy = ContextUtils.getBean(EmbeddingMatchStrategy.class);
|
||||
List<EmbeddingResult> matchResults = matchStrategy.getMatches(queryContext, terms);
|
||||
@@ -34,16 +36,12 @@ public class EmbeddingMapper extends BaseMapper {
|
||||
//2. build SchemaElementMatch by info
|
||||
for (EmbeddingResult matchResult : matchResults) {
|
||||
Long elementId = Retrieval.getLongId(matchResult.getId());
|
||||
|
||||
SchemaElement schemaElement = JSONObject.parseObject(JSONObject.toJSONString(matchResult.getMetadata()),
|
||||
SchemaElement.class);
|
||||
|
||||
String modelIdStr = matchResult.getMetadata().get("modelId");
|
||||
if (StringUtils.isBlank(modelIdStr)) {
|
||||
Long viewId = Retrieval.getLongId(matchResult.getMetadata().get("viewId"));
|
||||
if (Objects.isNull(viewId)) {
|
||||
continue;
|
||||
}
|
||||
long modelId = Long.parseLong(modelIdStr);
|
||||
schemaElement = getSchemaElement(modelId, schemaElement.getType(), elementId,
|
||||
SchemaElementType elementType = SchemaElementType.valueOf(matchResult.getMetadata().get("type"));
|
||||
SchemaElement schemaElement = getSchemaElement(viewId, elementType, elementId,
|
||||
queryContext.getSemanticSchema());
|
||||
if (schemaElement == null) {
|
||||
continue;
|
||||
@@ -56,7 +54,7 @@ public class EmbeddingMapper extends BaseMapper {
|
||||
.detectWord(matchResult.getDetectWord())
|
||||
.build();
|
||||
//3. add to mapInfo
|
||||
addToSchemaMap(queryContext.getMapInfo(), modelId, schemaElementMatch);
|
||||
addToSchemaMap(queryContext.getMapInfo(), viewId, schemaElementMatch);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -2,17 +2,15 @@ package com.tencent.supersonic.chat.core.mapper;
|
||||
|
||||
import com.google.common.collect.Lists;
|
||||
import com.tencent.supersonic.chat.core.config.OptimizationConfig;
|
||||
import com.tencent.supersonic.chat.core.knowledge.EmbeddingResult;
|
||||
import com.tencent.supersonic.chat.core.pojo.QueryContext;
|
||||
import com.tencent.supersonic.common.config.EmbeddingConfig;
|
||||
import com.tencent.supersonic.common.pojo.Constants;
|
||||
import com.tencent.supersonic.common.util.ComponentFactory;
|
||||
import com.tencent.supersonic.common.util.embedding.Retrieval;
|
||||
import com.tencent.supersonic.common.util.embedding.RetrieveQuery;
|
||||
import com.tencent.supersonic.common.util.embedding.RetrieveQueryResult;
|
||||
import com.tencent.supersonic.common.util.embedding.S2EmbeddingStore;
|
||||
import com.tencent.supersonic.headless.core.knowledge.EmbeddingResult;
|
||||
import com.tencent.supersonic.headless.server.service.MetaEmbeddingService;
|
||||
import java.util.ArrayList;
|
||||
import java.util.Comparator;
|
||||
import java.util.HashMap;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
import java.util.Set;
|
||||
@@ -36,9 +34,7 @@ public class EmbeddingMatchStrategy extends BaseMatchStrategy<EmbeddingResult> {
|
||||
private OptimizationConfig optimizationConfig;
|
||||
|
||||
@Autowired
|
||||
private EmbeddingConfig embeddingConfig;
|
||||
|
||||
private S2EmbeddingStore s2EmbeddingStore = ComponentFactory.getS2EmbeddingStore();
|
||||
private MetaEmbeddingService metaEmbeddingService;
|
||||
|
||||
@Override
|
||||
public boolean needDelete(EmbeddingResult oneRoundResult, EmbeddingResult existResult) {
|
||||
@@ -52,7 +48,13 @@ public class EmbeddingMatchStrategy extends BaseMatchStrategy<EmbeddingResult> {
|
||||
}
|
||||
|
||||
@Override
|
||||
protected void detectByBatch(QueryContext queryContext, Set<EmbeddingResult> results, Set<Long> detectModelIds,
|
||||
public void detectByStep(QueryContext queryContext, Set<EmbeddingResult> existResults, Set<Long> detectViewIds,
|
||||
String detectSegment, int offset) {
|
||||
|
||||
}
|
||||
|
||||
@Override
|
||||
protected void detectByBatch(QueryContext queryContext, Set<EmbeddingResult> results, Set<Long> detectViewIds,
|
||||
Set<String> detectSegments) {
|
||||
|
||||
List<String> queryTextsList = detectSegments.stream()
|
||||
@@ -66,49 +68,29 @@ public class EmbeddingMatchStrategy extends BaseMatchStrategy<EmbeddingResult> {
|
||||
optimizationConfig.getEmbeddingMapperBatch());
|
||||
|
||||
for (List<String> queryTextsSub : queryTextsSubList) {
|
||||
detectByQueryTextsSub(results, detectModelIds, queryTextsSub);
|
||||
detectByQueryTextsSub(results, detectViewIds, queryTextsSub);
|
||||
}
|
||||
}
|
||||
|
||||
private void detectByQueryTextsSub(Set<EmbeddingResult> results, Set<Long> detectModelIds,
|
||||
private void detectByQueryTextsSub(Set<EmbeddingResult> results, Set<Long> detectViewIds,
|
||||
List<String> queryTextsSub) {
|
||||
int embeddingNumber = optimizationConfig.getEmbeddingMapperNumber();
|
||||
Double distance = optimizationConfig.getEmbeddingMapperDistanceThreshold();
|
||||
Map<String, String> filterCondition = null;
|
||||
// step1. build query params
|
||||
// if only one modelId, add to filterCondition
|
||||
if (CollectionUtils.isNotEmpty(detectModelIds) && detectModelIds.size() == 1) {
|
||||
filterCondition = new HashMap<>();
|
||||
filterCondition.put("modelId", detectModelIds.stream().findFirst().get().toString());
|
||||
}
|
||||
|
||||
RetrieveQuery retrieveQuery = RetrieveQuery.builder()
|
||||
.queryTextsList(queryTextsSub)
|
||||
.filterCondition(filterCondition)
|
||||
.queryEmbeddings(null)
|
||||
.build();
|
||||
RetrieveQuery retrieveQuery = RetrieveQuery.builder().queryTextsList(queryTextsSub).build();
|
||||
// step2. retrieveQuery by detectSegment
|
||||
List<RetrieveQueryResult> retrieveQueryResults = s2EmbeddingStore.retrieveQuery(
|
||||
embeddingConfig.getMetaCollectionName(), retrieveQuery, embeddingNumber);
|
||||
List<RetrieveQueryResult> retrieveQueryResults = metaEmbeddingService.retrieveQuery(
|
||||
new ArrayList<>(detectViewIds), retrieveQuery, embeddingNumber);
|
||||
|
||||
if (CollectionUtils.isEmpty(retrieveQueryResults)) {
|
||||
return;
|
||||
}
|
||||
// step3. build EmbeddingResults. filter by modelId
|
||||
// step3. build EmbeddingResults
|
||||
List<EmbeddingResult> collect = retrieveQueryResults.stream()
|
||||
.map(retrieveQueryResult -> {
|
||||
List<Retrieval> retrievals = retrieveQueryResult.getRetrieval();
|
||||
if (CollectionUtils.isNotEmpty(retrievals)) {
|
||||
retrievals.removeIf(retrieval -> retrieval.getDistance() > distance.doubleValue());
|
||||
if (CollectionUtils.isNotEmpty(detectModelIds)) {
|
||||
retrievals.removeIf(retrieval -> {
|
||||
String modelIdStr = retrieval.getMetadata().get("modelId").toString();
|
||||
if (StringUtils.isBlank(modelIdStr)) {
|
||||
return true;
|
||||
}
|
||||
return detectModelIds.contains(Long.parseLong(modelIdStr));
|
||||
});
|
||||
}
|
||||
}
|
||||
return retrieveQueryResult;
|
||||
})
|
||||
@@ -119,6 +101,9 @@ public class EmbeddingMatchStrategy extends BaseMatchStrategy<EmbeddingResult> {
|
||||
BeanUtils.copyProperties(retrieval, embeddingResult);
|
||||
embeddingResult.setDetectWord(retrieveQueryResult.getQuery());
|
||||
embeddingResult.setName(retrieval.getQuery());
|
||||
Map<String, String> convertedMap = retrieval.getMetadata().entrySet().stream()
|
||||
.collect(Collectors.toMap(Map.Entry::getKey, entry -> entry.getValue().toString()));
|
||||
embeddingResult.setMetadata(convertedMap);
|
||||
return embeddingResult;
|
||||
}))
|
||||
.collect(Collectors.toList());
|
||||
@@ -132,9 +117,4 @@ public class EmbeddingMatchStrategy extends BaseMatchStrategy<EmbeddingResult> {
|
||||
selectResultInOneRound(results, oneRoundResults);
|
||||
}
|
||||
|
||||
@Override
|
||||
public void detectByStep(QueryContext queryContext, Set<EmbeddingResult> existResults, Set<Long> detectModelIds,
|
||||
Integer startIndex, Integer index, int offset) {
|
||||
return;
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,19 +1,20 @@
|
||||
package com.tencent.supersonic.chat.core.mapper;
|
||||
|
||||
|
||||
import com.tencent.supersonic.chat.api.pojo.ModelSchema;
|
||||
import com.tencent.supersonic.chat.core.pojo.QueryContext;
|
||||
import com.tencent.supersonic.chat.api.pojo.SchemaElement;
|
||||
import com.tencent.supersonic.headless.api.pojo.SchemaElement;
|
||||
import com.tencent.supersonic.chat.api.pojo.SchemaElementMatch;
|
||||
import com.tencent.supersonic.chat.api.pojo.SchemaElementType;
|
||||
import com.tencent.supersonic.headless.api.pojo.SchemaElementType;
|
||||
import com.tencent.supersonic.chat.api.pojo.SchemaMapInfo;
|
||||
import com.tencent.supersonic.chat.api.pojo.SemanticSchema;
|
||||
import java.util.List;
|
||||
import java.util.stream.Collectors;
|
||||
import com.tencent.supersonic.chat.api.pojo.ViewSchema;
|
||||
import com.tencent.supersonic.chat.core.pojo.QueryContext;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.springframework.beans.BeanUtils;
|
||||
import org.springframework.util.CollectionUtils;
|
||||
|
||||
import java.util.List;
|
||||
import java.util.stream.Collectors;
|
||||
|
||||
/**
|
||||
* A mapper capable of converting the VALUE of entity dimension values into ID types.
|
||||
*/
|
||||
@@ -23,12 +24,12 @@ public class EntityMapper extends BaseMapper {
|
||||
@Override
|
||||
public void doMap(QueryContext queryContext) {
|
||||
SchemaMapInfo schemaMapInfo = queryContext.getMapInfo();
|
||||
for (Long modelId : schemaMapInfo.getMatchedModels()) {
|
||||
List<SchemaElementMatch> schemaElementMatchList = schemaMapInfo.getMatchedElements(modelId);
|
||||
for (Long viewId : schemaMapInfo.getMatchedViewInfos()) {
|
||||
List<SchemaElementMatch> schemaElementMatchList = schemaMapInfo.getMatchedElements(viewId);
|
||||
if (CollectionUtils.isEmpty(schemaElementMatchList)) {
|
||||
continue;
|
||||
}
|
||||
SchemaElement entity = getEntity(modelId, queryContext);
|
||||
SchemaElement entity = getEntity(viewId, queryContext);
|
||||
if (entity == null || entity.getId() == null) {
|
||||
continue;
|
||||
}
|
||||
@@ -64,9 +65,9 @@ public class EntityMapper extends BaseMapper {
|
||||
return false;
|
||||
}
|
||||
|
||||
private SchemaElement getEntity(Long modelId, QueryContext queryContext) {
|
||||
private SchemaElement getEntity(Long viewId, QueryContext queryContext) {
|
||||
SemanticSchema semanticSchema = queryContext.getSemanticSchema();
|
||||
ModelSchema modelSchema = semanticSchema.getModelSchemaMap().get(modelId);
|
||||
ViewSchema modelSchema = semanticSchema.getViewSchemaMap().get(viewId);
|
||||
if (modelSchema != null && modelSchema.getEntity() != null) {
|
||||
return modelSchema.getEntity();
|
||||
}
|
||||
|
||||
@@ -1,11 +1,11 @@
|
||||
package com.tencent.supersonic.chat.core.mapper;
|
||||
|
||||
import com.hankcs.hanlp.seg.common.Term;
|
||||
import com.tencent.supersonic.chat.core.config.OptimizationConfig;
|
||||
import com.tencent.supersonic.chat.core.knowledge.HanlpMapResult;
|
||||
import com.tencent.supersonic.chat.core.knowledge.SearchService;
|
||||
import com.tencent.supersonic.chat.core.pojo.QueryContext;
|
||||
import com.tencent.supersonic.common.pojo.Constants;
|
||||
import com.tencent.supersonic.headless.api.pojo.response.S2Term;
|
||||
import com.tencent.supersonic.headless.core.knowledge.HanlpMapResult;
|
||||
import com.tencent.supersonic.headless.server.service.KnowledgeService;
|
||||
import java.util.HashMap;
|
||||
import java.util.LinkedHashSet;
|
||||
import java.util.List;
|
||||
@@ -34,17 +34,20 @@ public class HanlpDictMatchStrategy extends BaseMatchStrategy<HanlpMapResult> {
|
||||
@Autowired
|
||||
private OptimizationConfig optimizationConfig;
|
||||
|
||||
@Autowired
|
||||
private KnowledgeService knowledgeService;
|
||||
|
||||
@Override
|
||||
public Map<MatchText, List<HanlpMapResult>> match(QueryContext queryContext, List<Term> terms,
|
||||
Set<Long> detectModelIds) {
|
||||
public Map<MatchText, List<HanlpMapResult>> match(QueryContext queryContext, List<S2Term> terms,
|
||||
Set<Long> detectViewIds) {
|
||||
String text = queryContext.getQueryText();
|
||||
if (Objects.isNull(terms) || StringUtils.isEmpty(text)) {
|
||||
return null;
|
||||
}
|
||||
|
||||
log.debug("retryCount:{},terms:{},,detectModelIds:{}", terms, detectModelIds);
|
||||
log.debug("retryCount:{},terms:{},,detectModelIds:{}", terms, detectViewIds);
|
||||
|
||||
List<HanlpMapResult> detects = detect(queryContext, terms, detectModelIds);
|
||||
List<HanlpMapResult> detects = detect(queryContext, terms, detectViewIds);
|
||||
Map<MatchText, List<HanlpMapResult>> result = new HashMap<>();
|
||||
|
||||
result.put(MatchText.builder().regText(text).detectSegment(text).build(), detects);
|
||||
@@ -57,20 +60,15 @@ public class HanlpDictMatchStrategy extends BaseMatchStrategy<HanlpMapResult> {
|
||||
&& existResult.getDetectWord().length() < oneRoundResult.getDetectWord().length();
|
||||
}
|
||||
|
||||
public void detectByStep(QueryContext queryContext, Set<HanlpMapResult> existResults, Set<Long> detectModelIds,
|
||||
Integer startIndex, Integer index, int offset) {
|
||||
String text = queryContext.getQueryText();
|
||||
Integer agentId = queryContext.getAgentId();
|
||||
String detectSegment = text.substring(startIndex, index);
|
||||
|
||||
public void detectByStep(QueryContext queryContext, Set<HanlpMapResult> existResults, Set<Long> detectViewIds,
|
||||
String detectSegment, int offset) {
|
||||
// step1. pre search
|
||||
Integer oneDetectionMaxSize = optimizationConfig.getOneDetectionMaxSize();
|
||||
LinkedHashSet<HanlpMapResult> hanlpMapResults = SearchService.prefixSearch(detectSegment, oneDetectionMaxSize,
|
||||
agentId, detectModelIds).stream().collect(Collectors.toCollection(LinkedHashSet::new));
|
||||
LinkedHashSet<HanlpMapResult> hanlpMapResults = knowledgeService.prefixSearch(detectSegment,
|
||||
oneDetectionMaxSize, detectViewIds).stream().collect(Collectors.toCollection(LinkedHashSet::new));
|
||||
// step2. suffix search
|
||||
LinkedHashSet<HanlpMapResult> suffixHanlpMapResults = SearchService.suffixSearch(detectSegment,
|
||||
oneDetectionMaxSize, agentId, detectModelIds).stream()
|
||||
.collect(Collectors.toCollection(LinkedHashSet::new));
|
||||
LinkedHashSet<HanlpMapResult> suffixHanlpMapResults = knowledgeService.suffixSearch(detectSegment,
|
||||
oneDetectionMaxSize, detectViewIds).stream().collect(Collectors.toCollection(LinkedHashSet::new));
|
||||
|
||||
hanlpMapResults.addAll(suffixHanlpMapResults);
|
||||
|
||||
|
||||
@@ -1,24 +1,26 @@
|
||||
package com.tencent.supersonic.chat.core.mapper;
|
||||
|
||||
import com.hankcs.hanlp.seg.common.Term;
|
||||
import com.tencent.supersonic.chat.core.pojo.QueryContext;
|
||||
import com.tencent.supersonic.chat.api.pojo.SchemaElement;
|
||||
import com.tencent.supersonic.headless.api.pojo.SchemaElement;
|
||||
import com.tencent.supersonic.chat.api.pojo.SchemaElementMatch;
|
||||
import com.tencent.supersonic.chat.api.pojo.SchemaElementType;
|
||||
import com.tencent.supersonic.headless.api.pojo.SchemaElementType;
|
||||
import com.tencent.supersonic.chat.api.pojo.SchemaMapInfo;
|
||||
import com.tencent.supersonic.chat.core.knowledge.DatabaseMapResult;
|
||||
import com.tencent.supersonic.chat.core.knowledge.HanlpMapResult;
|
||||
import com.tencent.supersonic.chat.core.utils.HanlpHelper;
|
||||
import com.tencent.supersonic.chat.core.utils.NatureHelper;
|
||||
import com.tencent.supersonic.headless.api.pojo.response.S2Term;
|
||||
import com.tencent.supersonic.headless.core.knowledge.DatabaseMapResult;
|
||||
import com.tencent.supersonic.headless.core.knowledge.HanlpMapResult;
|
||||
import com.tencent.supersonic.chat.core.pojo.QueryContext;
|
||||
import com.tencent.supersonic.headless.server.service.KnowledgeService;
|
||||
import com.tencent.supersonic.headless.core.knowledge.helper.HanlpHelper;
|
||||
import com.tencent.supersonic.headless.core.knowledge.helper.NatureHelper;
|
||||
import com.tencent.supersonic.common.util.ContextUtils;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.springframework.util.CollectionUtils;
|
||||
|
||||
import java.util.HashSet;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
import java.util.Objects;
|
||||
import java.util.Set;
|
||||
import java.util.stream.Collectors;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.springframework.util.CollectionUtils;
|
||||
|
||||
/***
|
||||
* A mapper that recognizes schema elements with keyword.
|
||||
@@ -31,7 +33,8 @@ public class KeywordMapper extends BaseMapper {
|
||||
public void doMap(QueryContext queryContext) {
|
||||
String queryText = queryContext.getQueryText();
|
||||
//1.hanlpDict Match
|
||||
List<Term> terms = HanlpHelper.getTerms(queryText);
|
||||
KnowledgeService knowledgeService = ContextUtils.getBean(KnowledgeService.class);
|
||||
List<S2Term> terms = knowledgeService.getTerms(queryText);
|
||||
HanlpDictMatchStrategy hanlpMatchStrategy = ContextUtils.getBean(HanlpDictMatchStrategy.class);
|
||||
|
||||
List<HanlpMapResult> hanlpMapResults = hanlpMatchStrategy.getMatches(queryContext, terms);
|
||||
@@ -45,7 +48,7 @@ public class KeywordMapper extends BaseMapper {
|
||||
}
|
||||
|
||||
private void convertHanlpMapResultToMapInfo(List<HanlpMapResult> mapResults, QueryContext queryContext,
|
||||
List<Term> terms) {
|
||||
List<S2Term> terms) {
|
||||
if (CollectionUtils.isEmpty(mapResults)) {
|
||||
return;
|
||||
}
|
||||
@@ -56,8 +59,8 @@ public class KeywordMapper extends BaseMapper {
|
||||
|
||||
for (HanlpMapResult hanlpMapResult : mapResults) {
|
||||
for (String nature : hanlpMapResult.getNatures()) {
|
||||
Long modelId = NatureHelper.getModelId(nature);
|
||||
if (Objects.isNull(modelId)) {
|
||||
Long viewId = NatureHelper.getViewId(nature);
|
||||
if (Objects.isNull(viewId)) {
|
||||
continue;
|
||||
}
|
||||
SchemaElementType elementType = NatureHelper.convertToElementType(nature);
|
||||
@@ -65,8 +68,8 @@ public class KeywordMapper extends BaseMapper {
|
||||
continue;
|
||||
}
|
||||
Long elementID = NatureHelper.getElementID(nature);
|
||||
SchemaElement element = getSchemaElement(modelId, elementType, elementID,
|
||||
queryContext.getSemanticSchema());
|
||||
SchemaElement element = getSchemaElement(viewId, elementType,
|
||||
elementID, queryContext.getSemanticSchema());
|
||||
if (element == null) {
|
||||
continue;
|
||||
}
|
||||
@@ -82,7 +85,7 @@ public class KeywordMapper extends BaseMapper {
|
||||
.detectWord(hanlpMapResult.getDetectWord())
|
||||
.build();
|
||||
|
||||
addToSchemaMap(queryContext.getMapInfo(), modelId, schemaElementMatch);
|
||||
addToSchemaMap(queryContext.getMapInfo(), viewId, schemaElementMatch);
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -103,12 +106,12 @@ public class KeywordMapper extends BaseMapper {
|
||||
.similarity(mapperHelper.getSimilarity(match.getDetectWord(), schemaElement.getName()))
|
||||
.build();
|
||||
log.info("add to schema, elementMatch {}", schemaElementMatch);
|
||||
addToSchemaMap(queryContext.getMapInfo(), schemaElement.getModel(), schemaElementMatch);
|
||||
addToSchemaMap(queryContext.getMapInfo(), schemaElement.getView(), schemaElementMatch);
|
||||
}
|
||||
}
|
||||
|
||||
private Set<Long> getRegElementSet(SchemaMapInfo schemaMap, SchemaElement schemaElement) {
|
||||
List<SchemaElementMatch> elements = schemaMap.getMatchedElements(schemaElement.getModel());
|
||||
List<SchemaElementMatch> elements = schemaMap.getMatchedElements(schemaElement.getView());
|
||||
if (CollectionUtils.isEmpty(elements)) {
|
||||
return new HashSet<>();
|
||||
}
|
||||
|
||||
@@ -1,10 +1,15 @@
|
||||
package com.tencent.supersonic.chat.core.mapper;
|
||||
|
||||
import com.hankcs.hanlp.algorithm.EditDistance;
|
||||
import com.hankcs.hanlp.seg.common.Term;
|
||||
import com.tencent.supersonic.chat.core.agent.Agent;
|
||||
import com.tencent.supersonic.chat.core.config.OptimizationConfig;
|
||||
import com.tencent.supersonic.chat.core.utils.NatureHelper;
|
||||
import com.tencent.supersonic.headless.api.pojo.response.S2Term;
|
||||
import com.tencent.supersonic.headless.core.knowledge.helper.NatureHelper;
|
||||
import lombok.Data;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.springframework.beans.factory.annotation.Autowired;
|
||||
import org.springframework.stereotype.Service;
|
||||
|
||||
import java.util.Comparator;
|
||||
import java.util.HashSet;
|
||||
import java.util.List;
|
||||
@@ -12,10 +17,6 @@ import java.util.Map;
|
||||
import java.util.Objects;
|
||||
import java.util.Set;
|
||||
import java.util.stream.Collectors;
|
||||
import lombok.Data;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.springframework.beans.factory.annotation.Autowired;
|
||||
import org.springframework.stereotype.Service;
|
||||
|
||||
@Data
|
||||
@Service
|
||||
@@ -35,8 +36,8 @@ public class MapperHelper {
|
||||
return index;
|
||||
}
|
||||
|
||||
public Integer getStepOffset(List<Term> termList, Integer index) {
|
||||
List<Integer> offsetList = termList.stream().sorted(Comparator.comparing(Term::getOffset))
|
||||
public Integer getStepOffset(List<S2Term> termList, Integer index) {
|
||||
List<Integer> offsetList = termList.stream().sorted(Comparator.comparing(S2Term::getOffset))
|
||||
.map(term -> term.getOffset()).collect(Collectors.toList());
|
||||
|
||||
for (int j = 0; j < termList.size() - 1; j++) {
|
||||
@@ -61,7 +62,7 @@ public class MapperHelper {
|
||||
*/
|
||||
public boolean existDimensionValues(List<String> natures) {
|
||||
for (String nature : natures) {
|
||||
if (NatureHelper.isDimensionValueModelId(nature)) {
|
||||
if (NatureHelper.isDimensionValueViewId(nature)) {
|
||||
return true;
|
||||
}
|
||||
}
|
||||
@@ -81,33 +82,33 @@ public class MapperHelper {
|
||||
detectSegment.length());
|
||||
}
|
||||
|
||||
public Set<Long> getModelIds(Long modelId, Agent agent) {
|
||||
public Set<Long> getViewIds(Long viewId, Agent agent) {
|
||||
|
||||
Set<Long> detectModelIds = new HashSet<>();
|
||||
Set<Long> detectViewIds = new HashSet<>();
|
||||
if (Objects.nonNull(agent)) {
|
||||
detectModelIds = agent.getModelIds(null);
|
||||
detectViewIds = agent.getViewIds(null);
|
||||
}
|
||||
//contains all
|
||||
if (Agent.containsAllModel(detectModelIds)) {
|
||||
if (Objects.nonNull(modelId) && modelId > 0) {
|
||||
if (Agent.containsAllModel(detectViewIds)) {
|
||||
if (Objects.nonNull(viewId) && viewId > 0) {
|
||||
Set<Long> result = new HashSet<>();
|
||||
result.add(modelId);
|
||||
result.add(viewId);
|
||||
return result;
|
||||
}
|
||||
return new HashSet<>();
|
||||
}
|
||||
|
||||
if (Objects.nonNull(detectModelIds)) {
|
||||
detectModelIds = detectModelIds.stream().filter(entry -> entry > 0).collect(Collectors.toSet());
|
||||
if (Objects.nonNull(detectViewIds)) {
|
||||
detectViewIds = detectViewIds.stream().filter(entry -> entry > 0).collect(Collectors.toSet());
|
||||
}
|
||||
|
||||
if (Objects.nonNull(modelId) && modelId > 0 && Objects.nonNull(detectModelIds)) {
|
||||
if (detectModelIds.contains(modelId)) {
|
||||
if (Objects.nonNull(viewId) && viewId > 0 && Objects.nonNull(detectViewIds)) {
|
||||
if (detectViewIds.contains(viewId)) {
|
||||
Set<Long> result = new HashSet<>();
|
||||
result.add(modelId);
|
||||
result.add(viewId);
|
||||
return result;
|
||||
}
|
||||
}
|
||||
return detectModelIds;
|
||||
return detectViewIds;
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,7 +1,8 @@
|
||||
package com.tencent.supersonic.chat.core.mapper;
|
||||
|
||||
import com.hankcs.hanlp.seg.common.Term;
|
||||
import com.tencent.supersonic.chat.core.pojo.QueryContext;
|
||||
import com.tencent.supersonic.headless.api.pojo.response.S2Term;
|
||||
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
import java.util.Set;
|
||||
@@ -12,6 +13,6 @@ import java.util.Set;
|
||||
*/
|
||||
public interface MatchStrategy<T> {
|
||||
|
||||
Map<MatchText, List<T>> match(QueryContext queryContext, List<Term> terms, Set<Long> detectModelId);
|
||||
Map<MatchText, List<T>> match(QueryContext queryContext, List<S2Term> terms, Set<Long> detectViewIds);
|
||||
|
||||
}
|
||||
@@ -1,52 +0,0 @@
|
||||
package com.tencent.supersonic.chat.core.mapper;
|
||||
|
||||
import com.tencent.supersonic.chat.api.pojo.SchemaElementMatch;
|
||||
import com.tencent.supersonic.chat.api.pojo.SchemaMapInfo;
|
||||
import com.tencent.supersonic.chat.api.pojo.SchemaModelClusterMapInfo;
|
||||
import com.tencent.supersonic.chat.api.pojo.SemanticSchema;
|
||||
import com.tencent.supersonic.chat.core.pojo.QueryContext;
|
||||
import com.tencent.supersonic.chat.core.utils.ModelClusterBuilder;
|
||||
import com.tencent.supersonic.common.pojo.ModelCluster;
|
||||
import java.util.ArrayList;
|
||||
import java.util.HashMap;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
import java.util.Set;
|
||||
import java.util.stream.Collectors;
|
||||
|
||||
/***
|
||||
* ModelClusterMapper build a cluster from
|
||||
* connectable data models based on model-rela configuration
|
||||
* and generate SchemaModelClusterMapInfo
|
||||
*/
|
||||
public class ModelClusterMapper implements SchemaMapper {
|
||||
|
||||
@Override
|
||||
public void map(QueryContext queryContext) {
|
||||
SemanticSchema semanticSchema = queryContext.getSemanticSchema();
|
||||
SchemaMapInfo schemaMapInfo = queryContext.getMapInfo();
|
||||
List<ModelCluster> modelClusters = buildModelClusterMatched(schemaMapInfo, semanticSchema);
|
||||
Map<String, List<SchemaElementMatch>> modelClusterElementMatches = new HashMap<>();
|
||||
for (ModelCluster modelCluster : modelClusters) {
|
||||
for (Long modelId : schemaMapInfo.getMatchedModels()) {
|
||||
if (modelCluster.getModelIds().contains(modelId)) {
|
||||
modelClusterElementMatches.computeIfAbsent(modelCluster.getKey(), k -> new ArrayList<>())
|
||||
.addAll(schemaMapInfo.getMatchedElements(modelId));
|
||||
}
|
||||
}
|
||||
}
|
||||
SchemaModelClusterMapInfo modelClusterMapInfo = new SchemaModelClusterMapInfo();
|
||||
modelClusterMapInfo.setModelElementMatches(modelClusterElementMatches);
|
||||
queryContext.setModelClusterMapInfo(modelClusterMapInfo);
|
||||
}
|
||||
|
||||
private List<ModelCluster> buildModelClusterMatched(SchemaMapInfo schemaMapInfo,
|
||||
SemanticSchema semanticSchema) {
|
||||
Set<Long> matchedModels = schemaMapInfo.getMatchedModels();
|
||||
List<ModelCluster> modelClusters = ModelClusterBuilder.buildModelClusters(semanticSchema);
|
||||
return modelClusters.stream().map(ModelCluster::getModelIds).peek(modelCluster -> {
|
||||
modelCluster.removeIf(model -> !matchedModels.contains(model));
|
||||
}).filter(modelCluster -> modelCluster.size() > 0).map(ModelCluster::build).collect(Collectors.toList());
|
||||
}
|
||||
|
||||
}
|
||||
@@ -1,6 +1,6 @@
|
||||
package com.tencent.supersonic.chat.core.mapper;
|
||||
|
||||
import com.tencent.supersonic.chat.api.pojo.SchemaElementType;
|
||||
import com.tencent.supersonic.headless.api.pojo.SchemaElementType;
|
||||
import java.io.Serializable;
|
||||
import lombok.Data;
|
||||
import lombok.ToString;
|
||||
|
||||
@@ -1,20 +1,21 @@
|
||||
package com.tencent.supersonic.chat.core.mapper;
|
||||
|
||||
import com.google.common.collect.Lists;
|
||||
import com.tencent.supersonic.chat.api.pojo.SchemaElement;
|
||||
import com.tencent.supersonic.headless.api.pojo.SchemaElement;
|
||||
import com.tencent.supersonic.chat.api.pojo.SchemaElementMatch;
|
||||
import com.tencent.supersonic.chat.api.pojo.SchemaElementType;
|
||||
import com.tencent.supersonic.headless.api.pojo.SchemaElementType;
|
||||
import com.tencent.supersonic.chat.api.pojo.SchemaMapInfo;
|
||||
import com.tencent.supersonic.chat.api.pojo.request.QueryFilter;
|
||||
import com.tencent.supersonic.chat.api.pojo.request.QueryFilters;
|
||||
import com.tencent.supersonic.chat.core.knowledge.builder.BaseWordBuilder;
|
||||
import com.tencent.supersonic.headless.core.knowledge.builder.BaseWordBuilder;
|
||||
import com.tencent.supersonic.chat.core.pojo.QueryContext;
|
||||
import com.tencent.supersonic.common.pojo.Constants;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.springframework.util.CollectionUtils;
|
||||
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
import java.util.stream.Collectors;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.springframework.util.CollectionUtils;
|
||||
|
||||
@Slf4j
|
||||
public class QueryFilterMapper implements SchemaMapper {
|
||||
@@ -23,22 +24,22 @@ public class QueryFilterMapper implements SchemaMapper {
|
||||
|
||||
@Override
|
||||
public void map(QueryContext queryContext) {
|
||||
Long modelId = queryContext.getModelId();
|
||||
if (modelId == null || modelId <= 0) {
|
||||
Long viewId = queryContext.getViewId();
|
||||
if (viewId == null || viewId <= 0) {
|
||||
return;
|
||||
}
|
||||
SchemaMapInfo schemaMapInfo = queryContext.getMapInfo();
|
||||
clearOtherSchemaElementMatch(modelId, schemaMapInfo);
|
||||
List<SchemaElementMatch> schemaElementMatches = schemaMapInfo.getMatchedElements(modelId);
|
||||
clearOtherSchemaElementMatch(viewId, schemaMapInfo);
|
||||
List<SchemaElementMatch> schemaElementMatches = schemaMapInfo.getMatchedElements(viewId);
|
||||
if (schemaElementMatches == null) {
|
||||
schemaElementMatches = Lists.newArrayList();
|
||||
schemaMapInfo.setMatchedElements(modelId, schemaElementMatches);
|
||||
schemaMapInfo.setMatchedElements(viewId, schemaElementMatches);
|
||||
}
|
||||
addValueSchemaElementMatch(queryContext, schemaElementMatches);
|
||||
}
|
||||
|
||||
private void clearOtherSchemaElementMatch(Long modelId, SchemaMapInfo schemaMapInfo) {
|
||||
for (Map.Entry<Long, List<SchemaElementMatch>> entry : schemaMapInfo.getModelElementMatches().entrySet()) {
|
||||
for (Map.Entry<Long, List<SchemaElementMatch>> entry : schemaMapInfo.getViewElementMatches().entrySet()) {
|
||||
if (!entry.getKey().equals(modelId)) {
|
||||
entry.getValue().clear();
|
||||
}
|
||||
@@ -60,7 +61,7 @@ public class QueryFilterMapper implements SchemaMapper {
|
||||
.name(String.valueOf(filter.getValue()))
|
||||
.type(SchemaElementType.VALUE)
|
||||
.bizName(filter.getBizName())
|
||||
.model(queryContext.getModelId())
|
||||
.view(queryContext.getViewId())
|
||||
.build();
|
||||
SchemaElementMatch schemaElementMatch = SchemaElementMatch.builder()
|
||||
.element(element)
|
||||
|
||||
@@ -1,11 +1,12 @@
|
||||
package com.tencent.supersonic.chat.core.mapper;
|
||||
|
||||
import com.google.common.collect.Lists;
|
||||
import com.hankcs.hanlp.seg.common.Term;
|
||||
import com.tencent.supersonic.chat.core.knowledge.HanlpMapResult;
|
||||
import com.tencent.supersonic.chat.core.knowledge.SearchService;
|
||||
import com.tencent.supersonic.headless.api.pojo.response.S2Term;
|
||||
import com.tencent.supersonic.headless.core.knowledge.HanlpMapResult;
|
||||
import com.tencent.supersonic.headless.core.knowledge.SearchService;
|
||||
import com.tencent.supersonic.chat.core.pojo.QueryContext;
|
||||
import com.tencent.supersonic.common.pojo.enums.DictWordType;
|
||||
import com.tencent.supersonic.headless.server.service.KnowledgeService;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
import java.util.Objects;
|
||||
@@ -14,6 +15,7 @@ import java.util.concurrent.ConcurrentHashMap;
|
||||
import java.util.stream.Collectors;
|
||||
import org.apache.commons.collections.CollectionUtils;
|
||||
import org.apache.commons.lang3.StringUtils;
|
||||
import org.springframework.beans.factory.annotation.Autowired;
|
||||
import org.springframework.stereotype.Service;
|
||||
|
||||
/**
|
||||
@@ -25,9 +27,12 @@ public class SearchMatchStrategy extends BaseMatchStrategy<HanlpMapResult> {
|
||||
|
||||
private static final int SEARCH_SIZE = 3;
|
||||
|
||||
@Autowired
|
||||
private KnowledgeService knowledgeService;
|
||||
|
||||
@Override
|
||||
public Map<MatchText, List<HanlpMapResult>> match(QueryContext queryContext, List<Term> originals,
|
||||
Set<Long> detectModelIds) {
|
||||
public Map<MatchText, List<HanlpMapResult>> match(QueryContext queryContext, List<S2Term> originals,
|
||||
Set<Long> detectViewIds) {
|
||||
String text = queryContext.getQueryText();
|
||||
Map<Integer, Integer> regOffsetToLength = getRegOffsetToLength(originals);
|
||||
|
||||
@@ -51,10 +56,10 @@ public class SearchMatchStrategy extends BaseMatchStrategy<HanlpMapResult> {
|
||||
String detectSegment = text.substring(detectIndex);
|
||||
|
||||
if (StringUtils.isNotEmpty(detectSegment)) {
|
||||
List<HanlpMapResult> hanlpMapResults = SearchService.prefixSearch(detectSegment,
|
||||
SearchService.SEARCH_SIZE, queryContext.getAgentId(), detectModelIds);
|
||||
List<HanlpMapResult> suffixHanlpMapResults = SearchService.suffixSearch(
|
||||
detectSegment, SEARCH_SIZE, queryContext.getAgentId(), detectModelIds);
|
||||
List<HanlpMapResult> hanlpMapResults = knowledgeService.prefixSearch(detectSegment,
|
||||
SearchService.SEARCH_SIZE, detectViewIds);
|
||||
List<HanlpMapResult> suffixHanlpMapResults = knowledgeService.suffixSearch(
|
||||
detectSegment, SEARCH_SIZE, detectViewIds);
|
||||
hanlpMapResults.addAll(suffixHanlpMapResults);
|
||||
// remove entity name where search
|
||||
hanlpMapResults = hanlpMapResults.stream().filter(entry -> {
|
||||
@@ -88,9 +93,9 @@ public class SearchMatchStrategy extends BaseMatchStrategy<HanlpMapResult> {
|
||||
}
|
||||
|
||||
@Override
|
||||
public void detectByStep(QueryContext queryContext, Set<HanlpMapResult> results, Set<Long> detectModelIds,
|
||||
Integer startIndex,
|
||||
Integer i, int offset) {
|
||||
public void detectByStep(QueryContext queryContext, Set<HanlpMapResult> existResults, Set<Long> detectViewIds,
|
||||
String detectSegment, int offset) {
|
||||
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
@@ -1,23 +1,24 @@
|
||||
package com.tencent.supersonic.chat.core.parser;
|
||||
|
||||
import com.tencent.supersonic.chat.core.pojo.QueryContext;
|
||||
import com.tencent.supersonic.chat.core.parser.plugin.function.FunctionPromptGenerator;
|
||||
import com.tencent.supersonic.chat.core.parser.plugin.function.FunctionReq;
|
||||
import com.tencent.supersonic.chat.core.parser.plugin.function.FunctionResp;
|
||||
import com.tencent.supersonic.chat.core.parser.sql.llm.OutputFormat;
|
||||
import com.tencent.supersonic.chat.core.parser.sql.llm.SqlGeneration;
|
||||
import com.tencent.supersonic.chat.core.parser.sql.llm.SqlGenerationFactory;
|
||||
import com.tencent.supersonic.chat.core.parser.plugin.function.FunctionReq;
|
||||
import com.tencent.supersonic.chat.core.parser.plugin.function.FunctionResp;
|
||||
import com.tencent.supersonic.chat.core.pojo.QueryContext;
|
||||
import com.tencent.supersonic.chat.core.query.llm.s2sql.LLMReq;
|
||||
import com.tencent.supersonic.chat.core.query.llm.s2sql.LLMReq.SqlGenerationMode;
|
||||
import com.tencent.supersonic.chat.core.query.llm.s2sql.LLMResp;
|
||||
import com.tencent.supersonic.common.util.ContextUtils;
|
||||
import dev.langchain4j.model.chat.ChatLanguageModel;
|
||||
import java.util.Objects;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.slf4j.Logger;
|
||||
import org.slf4j.LoggerFactory;
|
||||
import org.springframework.stereotype.Component;
|
||||
|
||||
import java.util.Objects;
|
||||
|
||||
/**
|
||||
* LLMProxy based on langchain4j Java version.
|
||||
*/
|
||||
@@ -37,12 +38,12 @@ public class JavaLLMProxy implements LLMProxy {
|
||||
return false;
|
||||
}
|
||||
|
||||
public LLMResp query2sql(LLMReq llmReq, String modelClusterKey) {
|
||||
public LLMResp query2sql(LLMReq llmReq, Long viewId) {
|
||||
|
||||
SqlGeneration sqlGeneration = SqlGenerationFactory.get(
|
||||
SqlGenerationMode.getMode(llmReq.getSqlGenerationMode()));
|
||||
String modelName = llmReq.getSchema().getModelName();
|
||||
LLMResp result = sqlGeneration.generation(llmReq, modelClusterKey);
|
||||
String modelName = llmReq.getSchema().getViewName();
|
||||
LLMResp result = sqlGeneration.generation(llmReq, viewId);
|
||||
result.setQuery(llmReq.getQueryText());
|
||||
result.setModelName(modelName);
|
||||
return result;
|
||||
|
||||
@@ -15,7 +15,7 @@ public interface LLMProxy {
|
||||
|
||||
boolean isSkip(QueryContext queryContext);
|
||||
|
||||
LLMResp query2sql(LLMReq llmReq, String modelClusterKey);
|
||||
LLMResp query2sql(LLMReq llmReq, Long viewId);
|
||||
|
||||
FunctionResp requestFunction(FunctionReq functionReq);
|
||||
|
||||
|
||||
@@ -1,19 +1,16 @@
|
||||
package com.tencent.supersonic.chat.core.parser;
|
||||
|
||||
import com.alibaba.fastjson.JSON;
|
||||
import com.tencent.supersonic.chat.core.pojo.QueryContext;
|
||||
import com.tencent.supersonic.chat.core.config.LLMParserConfig;
|
||||
import com.tencent.supersonic.chat.core.parser.sql.llm.OutputFormat;
|
||||
import com.tencent.supersonic.chat.core.parser.plugin.function.FunctionCallConfig;
|
||||
import com.tencent.supersonic.chat.core.parser.plugin.function.FunctionReq;
|
||||
import com.tencent.supersonic.chat.core.parser.plugin.function.FunctionResp;
|
||||
import com.tencent.supersonic.chat.core.parser.sql.llm.OutputFormat;
|
||||
import com.tencent.supersonic.chat.core.pojo.QueryContext;
|
||||
import com.tencent.supersonic.chat.core.query.llm.s2sql.LLMReq;
|
||||
import com.tencent.supersonic.chat.core.query.llm.s2sql.LLMResp;
|
||||
import com.tencent.supersonic.common.util.ContextUtils;
|
||||
import com.tencent.supersonic.common.util.JsonUtil;
|
||||
import java.net.URI;
|
||||
import java.net.URL;
|
||||
import java.util.ArrayList;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.apache.commons.collections4.MapUtils;
|
||||
import org.apache.commons.lang3.StringUtils;
|
||||
@@ -28,6 +25,10 @@ import org.springframework.stereotype.Component;
|
||||
import org.springframework.web.client.RestTemplate;
|
||||
import org.springframework.web.util.UriComponentsBuilder;
|
||||
|
||||
import java.net.URI;
|
||||
import java.net.URL;
|
||||
import java.util.ArrayList;
|
||||
|
||||
/**
|
||||
* PythonLLMProxy sends requests to LangChain-based python service.
|
||||
*/
|
||||
@@ -47,10 +48,10 @@ public class PythonLLMProxy implements LLMProxy {
|
||||
return false;
|
||||
}
|
||||
|
||||
public LLMResp query2sql(LLMReq llmReq, String modelClusterKey) {
|
||||
public LLMResp query2sql(LLMReq llmReq, Long viewId) {
|
||||
long startTime = System.currentTimeMillis();
|
||||
log.info("requestLLM request, modelId:{},llmReq:{}", modelClusterKey, llmReq);
|
||||
keyPipelineLog.info("modelClusterKey:{},llmReq:{}", modelClusterKey, llmReq);
|
||||
log.info("requestLLM request, viewId:{},llmReq:{}", viewId, llmReq);
|
||||
keyPipelineLog.info("viewId:{},llmReq:{}", viewId, llmReq);
|
||||
try {
|
||||
LLMParserConfig llmParserConfig = ContextUtils.getBean(LLMParserConfig.class);
|
||||
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
package com.tencent.supersonic.chat.core.parser;
|
||||
|
||||
import com.tencent.supersonic.auth.api.authentication.pojo.User;
|
||||
import com.tencent.supersonic.chat.api.pojo.SchemaElement;
|
||||
import com.tencent.supersonic.headless.api.pojo.SchemaElement;
|
||||
import com.tencent.supersonic.chat.api.pojo.SemanticParseInfo;
|
||||
import com.tencent.supersonic.chat.api.pojo.SemanticSchema;
|
||||
import com.tencent.supersonic.chat.api.pojo.response.SqlInfo;
|
||||
@@ -12,14 +12,15 @@ import com.tencent.supersonic.chat.core.query.llm.s2sql.LLMSqlQuery;
|
||||
import com.tencent.supersonic.chat.core.query.rule.RuleSemanticQuery;
|
||||
import com.tencent.supersonic.common.pojo.enums.QueryType;
|
||||
import com.tencent.supersonic.common.pojo.enums.TimeDimensionEnum;
|
||||
import com.tencent.supersonic.common.util.jsqlparser.SqlParserSelectHelper;
|
||||
import com.tencent.supersonic.common.util.jsqlparser.SqlSelectHelper;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.apache.commons.collections4.CollectionUtils;
|
||||
import org.apache.commons.lang3.StringUtils;
|
||||
|
||||
import java.util.List;
|
||||
import java.util.Objects;
|
||||
import java.util.Set;
|
||||
import java.util.stream.Collectors;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.apache.commons.collections4.CollectionUtils;
|
||||
import org.apache.commons.lang3.StringUtils;
|
||||
|
||||
/**
|
||||
* QueryTypeParser resolves query type as either METRIC or TAG, or ID.
|
||||
@@ -49,21 +50,21 @@ public class QueryTypeParser implements SemanticParser {
|
||||
return QueryType.ID;
|
||||
}
|
||||
//1. entity queryType
|
||||
Set<Long> modelIds = parseInfo.getModel().getModelIds();
|
||||
Long viewId = parseInfo.getViewId();
|
||||
SemanticSchema semanticSchema = queryContext.getSemanticSchema();
|
||||
if (semanticQuery instanceof RuleSemanticQuery || semanticQuery instanceof LLMSqlQuery) {
|
||||
//If all the fields in the SELECT statement are of tag type.
|
||||
List<String> whereFields = SqlParserSelectHelper.getWhereFields(sqlInfo.getS2SQL())
|
||||
List<String> whereFields = SqlSelectHelper.getWhereFields(sqlInfo.getS2SQL())
|
||||
.stream().filter(field -> !TimeDimensionEnum.containsTimeDimension(field))
|
||||
.collect(Collectors.toList());
|
||||
|
||||
if (CollectionUtils.isNotEmpty(whereFields)) {
|
||||
Set<String> ids = semanticSchema.getEntities(modelIds).stream().map(SchemaElement::getName)
|
||||
Set<String> ids = semanticSchema.getEntities(viewId).stream().map(SchemaElement::getName)
|
||||
.collect(Collectors.toSet());
|
||||
if (CollectionUtils.isNotEmpty(ids) && ids.stream().anyMatch(whereFields::contains)) {
|
||||
return QueryType.ID;
|
||||
}
|
||||
Set<String> tags = semanticSchema.getTags(modelIds).stream().map(SchemaElement::getName)
|
||||
Set<String> tags = semanticSchema.getTags(viewId).stream().map(SchemaElement::getName)
|
||||
.collect(Collectors.toSet());
|
||||
if (CollectionUtils.isNotEmpty(tags) && tags.containsAll(whereFields)) {
|
||||
return QueryType.TAG;
|
||||
@@ -71,8 +72,8 @@ public class QueryTypeParser implements SemanticParser {
|
||||
}
|
||||
}
|
||||
//2. metric queryType
|
||||
List<String> selectFields = SqlParserSelectHelper.getSelectFields(sqlInfo.getS2SQL());
|
||||
List<SchemaElement> metrics = semanticSchema.getMetrics(modelIds);
|
||||
List<String> selectFields = SqlSelectHelper.getSelectFields(sqlInfo.getS2SQL());
|
||||
List<SchemaElement> metrics = semanticSchema.getMetrics(viewId);
|
||||
if (CollectionUtils.isNotEmpty(metrics)) {
|
||||
Set<String> metricNameSet = metrics.stream().map(SchemaElement::getName).collect(Collectors.toSet());
|
||||
boolean containMetric = selectFields.stream().anyMatch(metricNameSet::contains);
|
||||
|
||||
@@ -3,7 +3,7 @@ package com.tencent.supersonic.chat.core.parser.plugin;
|
||||
import com.google.common.collect.Lists;
|
||||
import com.google.common.collect.Sets;
|
||||
import com.tencent.supersonic.chat.api.pojo.SchemaElementMatch;
|
||||
import com.tencent.supersonic.chat.api.pojo.SchemaElementType;
|
||||
import com.tencent.supersonic.headless.api.pojo.SchemaElementType;
|
||||
import com.tencent.supersonic.chat.api.pojo.SemanticParseInfo;
|
||||
import com.tencent.supersonic.chat.api.pojo.request.QueryFilter;
|
||||
import com.tencent.supersonic.chat.api.pojo.request.QueryFilters;
|
||||
@@ -18,7 +18,6 @@ import com.tencent.supersonic.chat.core.query.QueryManager;
|
||||
import com.tencent.supersonic.chat.core.query.SemanticQuery;
|
||||
import com.tencent.supersonic.chat.core.query.plugin.PluginSemanticQuery;
|
||||
import com.tencent.supersonic.common.pojo.Constants;
|
||||
import com.tencent.supersonic.common.pojo.ModelCluster;
|
||||
import com.tencent.supersonic.common.pojo.enums.FilterOperatorEnum;
|
||||
import org.springframework.util.CollectionUtils;
|
||||
import java.util.HashMap;
|
||||
@@ -56,13 +55,13 @@ public abstract class PluginParser implements SemanticParser {
|
||||
|
||||
public void buildQuery(QueryContext queryContext, PluginRecallResult pluginRecallResult) {
|
||||
Plugin plugin = pluginRecallResult.getPlugin();
|
||||
Set<Long> modelIds = pluginRecallResult.getModelIds();
|
||||
Set<Long> viewIds = pluginRecallResult.getViewIds();
|
||||
if (plugin.isContainsAllModel()) {
|
||||
modelIds = Sets.newHashSet(-1L);
|
||||
viewIds = Sets.newHashSet(-1L);
|
||||
}
|
||||
for (Long modelId : modelIds) {
|
||||
for (Long viewId : viewIds) {
|
||||
PluginSemanticQuery pluginQuery = QueryManager.createPluginQuery(plugin.getType());
|
||||
SemanticParseInfo semanticParseInfo = buildSemanticParseInfo(modelId, plugin,
|
||||
SemanticParseInfo semanticParseInfo = buildSemanticParseInfo(viewId, plugin,
|
||||
queryContext, pluginRecallResult.getDistance());
|
||||
semanticParseInfo.setQueryMode(pluginQuery.getQueryMode());
|
||||
semanticParseInfo.setScore(pluginRecallResult.getScore());
|
||||
@@ -75,20 +74,19 @@ public abstract class PluginParser implements SemanticParser {
|
||||
return PluginManager.getPluginAgentCanSupport(queryContext);
|
||||
}
|
||||
|
||||
protected SemanticParseInfo buildSemanticParseInfo(Long modelId, Plugin plugin,
|
||||
protected SemanticParseInfo buildSemanticParseInfo(Long viewId, Plugin plugin,
|
||||
QueryContext queryContext, double distance) {
|
||||
List<SchemaElementMatch> schemaElementMatches =
|
||||
queryContext.getModelClusterMapInfo().getMatchedElements(modelId);
|
||||
List<SchemaElementMatch> schemaElementMatches = queryContext.getMapInfo().getMatchedElements(viewId);
|
||||
QueryFilters queryFilters = queryContext.getQueryFilters();
|
||||
if (modelId == null && !CollectionUtils.isEmpty(plugin.getModelList())) {
|
||||
modelId = plugin.getModelList().get(0);
|
||||
if (viewId == null && !CollectionUtils.isEmpty(plugin.getViewList())) {
|
||||
viewId = plugin.getViewList().get(0);
|
||||
}
|
||||
if (schemaElementMatches == null) {
|
||||
schemaElementMatches = Lists.newArrayList();
|
||||
}
|
||||
SemanticParseInfo semanticParseInfo = new SemanticParseInfo();
|
||||
semanticParseInfo.setElementMatches(schemaElementMatches);
|
||||
semanticParseInfo.setModel(ModelCluster.build(Sets.newHashSet(modelId)));
|
||||
semanticParseInfo.setView(queryContext.getSemanticSchema().getView(viewId));
|
||||
Map<String, Object> properties = new HashMap<>();
|
||||
PluginParseResult pluginParseResult = new PluginParseResult();
|
||||
pluginParseResult.setPlugin(plugin);
|
||||
|
||||
@@ -57,15 +57,15 @@ public class EmbeddingRecallParser extends PluginParser {
|
||||
Pair<Boolean, Set<Long>> pair = PluginManager.resolve(plugin, queryContext);
|
||||
log.info("embedding plugin resolve: {}", pair);
|
||||
if (pair.getLeft()) {
|
||||
Set<Long> modelList = pair.getRight();
|
||||
if (CollectionUtils.isEmpty(modelList)) {
|
||||
Set<Long> viewList = pair.getRight();
|
||||
if (CollectionUtils.isEmpty(viewList)) {
|
||||
continue;
|
||||
}
|
||||
plugin.setParseMode(ParseMode.EMBEDDING_RECALL);
|
||||
double distance = embeddingRetrieval.getDistance();
|
||||
double score = queryContext.getQueryText().length() * (1 - distance);
|
||||
return PluginRecallResult.builder()
|
||||
.plugin(plugin).modelIds(modelList).score(score).distance(distance).build();
|
||||
.plugin(plugin).viewIds(viewList).score(score).distance(distance).build();
|
||||
}
|
||||
}
|
||||
return null;
|
||||
|
||||
@@ -12,15 +12,16 @@ import com.tencent.supersonic.chat.core.query.llm.s2sql.LLMSqlQuery;
|
||||
import com.tencent.supersonic.chat.core.utils.ComponentFactory;
|
||||
import com.tencent.supersonic.common.util.ContextUtils;
|
||||
import com.tencent.supersonic.common.util.JsonUtil;
|
||||
import java.util.List;
|
||||
import java.util.Objects;
|
||||
import java.util.Set;
|
||||
import java.util.stream.Collectors;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.apache.commons.lang3.StringUtils;
|
||||
import org.apache.commons.lang3.tuple.Pair;
|
||||
import org.springframework.util.CollectionUtils;
|
||||
|
||||
import java.util.List;
|
||||
import java.util.Objects;
|
||||
import java.util.Set;
|
||||
import java.util.stream.Collectors;
|
||||
|
||||
/**
|
||||
* FunctionCallParser is an implementation of a recall plugin based on FunctionCall
|
||||
*/
|
||||
@@ -56,19 +57,19 @@ public class FunctionCallParser extends PluginParser {
|
||||
plugin.setParseMode(ParseMode.FUNCTION_CALL);
|
||||
Pair<Boolean, Set<Long>> pluginResolveResult = PluginManager.resolve(plugin, queryContext);
|
||||
if (pluginResolveResult.getLeft()) {
|
||||
Set<Long> modelList = pluginResolveResult.getRight();
|
||||
if (CollectionUtils.isEmpty(modelList)) {
|
||||
Set<Long> viewList = pluginResolveResult.getRight();
|
||||
if (CollectionUtils.isEmpty(viewList)) {
|
||||
return null;
|
||||
}
|
||||
double score = queryContext.getQueryText().length();
|
||||
return PluginRecallResult.builder().plugin(plugin).modelIds(modelList).score(score).build();
|
||||
return PluginRecallResult.builder().plugin(plugin).viewIds(viewList).score(score).build();
|
||||
}
|
||||
return null;
|
||||
}
|
||||
|
||||
public FunctionResp functionCall(QueryContext queryContext) {
|
||||
List<PluginParseConfig> pluginToFunctionCall =
|
||||
getPluginToFunctionCall(queryContext.getModelId(), queryContext);
|
||||
getPluginToFunctionCall(queryContext.getViewId(), queryContext);
|
||||
if (CollectionUtils.isEmpty(pluginToFunctionCall)) {
|
||||
log.info("function call parser, plugin is empty, skip");
|
||||
return null;
|
||||
|
||||
@@ -1,146 +0,0 @@
|
||||
package com.tencent.supersonic.chat.core.parser.sql.llm;
|
||||
|
||||
import com.tencent.supersonic.chat.core.query.SemanticQuery;
|
||||
import com.tencent.supersonic.chat.core.pojo.ChatContext;
|
||||
import com.tencent.supersonic.chat.core.pojo.QueryContext;
|
||||
import com.tencent.supersonic.chat.api.pojo.SchemaElementMatch;
|
||||
import com.tencent.supersonic.chat.api.pojo.SchemaElementType;
|
||||
import com.tencent.supersonic.chat.api.pojo.SchemaModelClusterMapInfo;
|
||||
import com.tencent.supersonic.common.pojo.ModelCluster;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.apache.commons.collections.CollectionUtils;
|
||||
|
||||
import java.util.Comparator;
|
||||
import java.util.HashMap;
|
||||
import java.util.HashSet;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
import java.util.Map.Entry;
|
||||
import java.util.Objects;
|
||||
import java.util.Set;
|
||||
import java.util.stream.Collectors;
|
||||
|
||||
@Slf4j
|
||||
public class HeuristicModelResolver implements ModelResolver {
|
||||
|
||||
protected static String selectModelBySchemaElementMatchScore(Map<String, SemanticQuery> modelQueryModes,
|
||||
SchemaModelClusterMapInfo schemaMap) {
|
||||
//model count priority
|
||||
String modelIdByModelCount = getModelIdByMatchModelScore(schemaMap);
|
||||
if (Objects.nonNull(modelIdByModelCount)) {
|
||||
log.info("selectModel by model count:{}", modelIdByModelCount);
|
||||
return modelIdByModelCount;
|
||||
}
|
||||
|
||||
Map<String, ModelMatchResult> modelTypeMap = getModelTypeMap(schemaMap);
|
||||
if (modelTypeMap.size() == 1) {
|
||||
String modelSelect = modelTypeMap.entrySet().stream().collect(Collectors.toList()).get(0).getKey();
|
||||
if (modelQueryModes.containsKey(modelSelect)) {
|
||||
log.info("selectModel with only one Model [{}]", modelSelect);
|
||||
return modelSelect;
|
||||
}
|
||||
} else {
|
||||
Map.Entry<String, ModelMatchResult> maxModel = modelTypeMap.entrySet().stream()
|
||||
.filter(entry -> modelQueryModes.containsKey(entry.getKey()))
|
||||
.sorted((o1, o2) -> {
|
||||
int difference = o2.getValue().getCount() - o1.getValue().getCount();
|
||||
if (difference == 0) {
|
||||
return (int) ((o2.getValue().getMaxSimilarity()
|
||||
- o1.getValue().getMaxSimilarity()) * 100);
|
||||
}
|
||||
return difference;
|
||||
}).findFirst().orElse(null);
|
||||
if (maxModel != null) {
|
||||
log.info("selectModel with multiple Models [{}]", maxModel.getKey());
|
||||
return maxModel.getKey();
|
||||
}
|
||||
}
|
||||
return null;
|
||||
}
|
||||
|
||||
private static String getModelIdByMatchModelScore(SchemaModelClusterMapInfo schemaMap) {
|
||||
Map<String, List<SchemaElementMatch>> modelElementMatches = schemaMap.getModelElementMatches();
|
||||
// calculate model match score, matched element gets 1.0 point, and inherit element gets 0.5 point
|
||||
Map<String, Double> modelIdToModelScore = new HashMap<>();
|
||||
if (Objects.nonNull(modelElementMatches)) {
|
||||
for (Entry<String, List<SchemaElementMatch>> modelElementMatch : modelElementMatches.entrySet()) {
|
||||
String modelId = modelElementMatch.getKey();
|
||||
List<Double> modelMatchesScore = modelElementMatch.getValue().stream()
|
||||
.filter(elementMatch -> elementMatch.getSimilarity() >= 1)
|
||||
.filter(elementMatch -> SchemaElementType.MODEL.equals(elementMatch.getElement().getType()))
|
||||
.map(elementMatch -> elementMatch.isInherited() ? 0.5 : 1.0).collect(Collectors.toList());
|
||||
|
||||
if (!CollectionUtils.isEmpty(modelMatchesScore)) {
|
||||
// get sum of model match score
|
||||
double score = modelMatchesScore.stream().mapToDouble(Double::doubleValue).sum();
|
||||
modelIdToModelScore.put(modelId, score);
|
||||
}
|
||||
}
|
||||
Entry<String, Double> maxModelScore = modelIdToModelScore.entrySet().stream()
|
||||
.max(Comparator.comparingDouble(o -> o.getValue())).orElse(null);
|
||||
log.info("maxModelCount:{},modelIdToModelCount:{}", maxModelScore, modelIdToModelScore);
|
||||
if (Objects.nonNull(maxModelScore)) {
|
||||
return maxModelScore.getKey();
|
||||
}
|
||||
}
|
||||
return null;
|
||||
}
|
||||
|
||||
public static Map<String, ModelMatchResult> getModelTypeMap(SchemaModelClusterMapInfo schemaMap) {
|
||||
Map<String, ModelMatchResult> modelCount = new HashMap<>();
|
||||
for (Map.Entry<String, List<SchemaElementMatch>> entry : schemaMap.getModelElementMatches().entrySet()) {
|
||||
List<SchemaElementMatch> schemaElementMatches = schemaMap.getMatchedElements(entry.getKey());
|
||||
if (schemaElementMatches != null && schemaElementMatches.size() > 0) {
|
||||
if (!modelCount.containsKey(entry.getKey())) {
|
||||
modelCount.put(entry.getKey(), new ModelMatchResult());
|
||||
}
|
||||
ModelMatchResult modelMatchResult = modelCount.get(entry.getKey());
|
||||
Set<SchemaElementType> schemaElementTypes = new HashSet<>();
|
||||
schemaElementMatches.stream()
|
||||
.forEach(schemaElementMatch -> schemaElementTypes.add(
|
||||
schemaElementMatch.getElement().getType()));
|
||||
SchemaElementMatch schemaElementMatchMax = schemaElementMatches.stream()
|
||||
.sorted((o1, o2) ->
|
||||
((int) ((o2.getSimilarity() - o1.getSimilarity()) * 100))
|
||||
).findFirst().orElse(null);
|
||||
if (schemaElementMatchMax != null) {
|
||||
modelMatchResult.setMaxSimilarity(schemaElementMatchMax.getSimilarity());
|
||||
}
|
||||
modelMatchResult.setCount(schemaElementTypes.size());
|
||||
|
||||
}
|
||||
}
|
||||
return modelCount;
|
||||
}
|
||||
|
||||
public String resolve(QueryContext queryContext, ChatContext chatCtx, Set<Long> restrictiveModels) {
|
||||
SchemaModelClusterMapInfo mapInfo = queryContext.getModelClusterMapInfo();
|
||||
Set<String> matchedModelClusters = mapInfo.getElementMatchesByModelIds(restrictiveModels).keySet();
|
||||
Long modelId = queryContext.getModelId();
|
||||
if (Objects.nonNull(modelId) && modelId > 0) {
|
||||
if (CollectionUtils.isEmpty(restrictiveModels) || restrictiveModels.contains(modelId)) {
|
||||
return getModelClusterByModelId(modelId, matchedModelClusters);
|
||||
}
|
||||
return null;
|
||||
}
|
||||
|
||||
Map<String, SemanticQuery> modelQueryModes = new HashMap<>();
|
||||
for (String matchedModel : matchedModelClusters) {
|
||||
modelQueryModes.put(matchedModel, null);
|
||||
}
|
||||
if (modelQueryModes.size() == 1) {
|
||||
return modelQueryModes.keySet().stream().findFirst().get();
|
||||
}
|
||||
return selectModelBySchemaElementMatchScore(modelQueryModes, mapInfo);
|
||||
}
|
||||
|
||||
private String getModelClusterByModelId(Long modelId, Set<String> modelClusterKeySet) {
|
||||
for (String modelClusterKey : modelClusterKeySet) {
|
||||
if (ModelCluster.build(modelClusterKey).getModelIds().contains(modelId)) {
|
||||
return modelClusterKey;
|
||||
}
|
||||
}
|
||||
return null;
|
||||
}
|
||||
|
||||
}
|
||||
@@ -0,0 +1,138 @@
|
||||
package com.tencent.supersonic.chat.core.parser.sql.llm;
|
||||
|
||||
import com.tencent.supersonic.chat.api.pojo.SchemaElementMatch;
|
||||
import com.tencent.supersonic.headless.api.pojo.SchemaElementType;
|
||||
import com.tencent.supersonic.chat.api.pojo.SchemaMapInfo;
|
||||
import com.tencent.supersonic.chat.core.pojo.QueryContext;
|
||||
import com.tencent.supersonic.chat.core.query.SemanticQuery;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.apache.commons.collections.CollectionUtils;
|
||||
|
||||
import java.util.ArrayList;
|
||||
import java.util.Comparator;
|
||||
import java.util.HashMap;
|
||||
import java.util.HashSet;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
import java.util.Map.Entry;
|
||||
import java.util.Objects;
|
||||
import java.util.Set;
|
||||
import java.util.stream.Collectors;
|
||||
|
||||
@Slf4j
|
||||
public class HeuristicViewResolver implements ViewResolver {
|
||||
|
||||
protected static Long selectViewBySchemaElementMatchScore(Map<Long, SemanticQuery> viewQueryModes,
|
||||
SchemaMapInfo schemaMap) {
|
||||
//view count priority
|
||||
Long viewIdByViewCount = getViewIdByMatchViewScore(schemaMap);
|
||||
if (Objects.nonNull(viewIdByViewCount)) {
|
||||
log.info("selectView by view count:{}", viewIdByViewCount);
|
||||
return viewIdByViewCount;
|
||||
}
|
||||
|
||||
Map<Long, ViewMatchResult> viewTypeMap = getViewTypeMap(schemaMap);
|
||||
if (viewTypeMap.size() == 1) {
|
||||
Long viewSelect = new ArrayList<>(viewTypeMap.entrySet()).get(0).getKey();
|
||||
if (viewQueryModes.containsKey(viewSelect)) {
|
||||
log.info("selectView with only one View [{}]", viewSelect);
|
||||
return viewSelect;
|
||||
}
|
||||
} else {
|
||||
Map.Entry<Long, ViewMatchResult> maxView = viewTypeMap.entrySet().stream()
|
||||
.filter(entry -> viewQueryModes.containsKey(entry.getKey()))
|
||||
.sorted((o1, o2) -> {
|
||||
int difference = o2.getValue().getCount() - o1.getValue().getCount();
|
||||
if (difference == 0) {
|
||||
return (int) ((o2.getValue().getMaxSimilarity()
|
||||
- o1.getValue().getMaxSimilarity()) * 100);
|
||||
}
|
||||
return difference;
|
||||
}).findFirst().orElse(null);
|
||||
if (maxView != null) {
|
||||
log.info("selectView with multiple Views [{}]", maxView.getKey());
|
||||
return maxView.getKey();
|
||||
}
|
||||
}
|
||||
return null;
|
||||
}
|
||||
|
||||
private static Long getViewIdByMatchViewScore(SchemaMapInfo schemaMap) {
|
||||
Map<Long, List<SchemaElementMatch>> viewElementMatches = schemaMap.getViewElementMatches();
|
||||
// calculate view match score, matched element gets 1.0 point, and inherit element gets 0.5 point
|
||||
Map<Long, Double> viewIdToViewScore = new HashMap<>();
|
||||
if (Objects.nonNull(viewElementMatches)) {
|
||||
for (Entry<Long, List<SchemaElementMatch>> viewElementMatch : viewElementMatches.entrySet()) {
|
||||
Long viewId = viewElementMatch.getKey();
|
||||
List<Double> viewMatchesScore = viewElementMatch.getValue().stream()
|
||||
.filter(elementMatch -> elementMatch.getSimilarity() >= 1)
|
||||
.filter(elementMatch -> SchemaElementType.VIEW.equals(elementMatch.getElement().getType()))
|
||||
.map(elementMatch -> elementMatch.isInherited() ? 0.5 : 1.0).collect(Collectors.toList());
|
||||
|
||||
if (!CollectionUtils.isEmpty(viewMatchesScore)) {
|
||||
// get sum of view match score
|
||||
double score = viewMatchesScore.stream().mapToDouble(Double::doubleValue).sum();
|
||||
viewIdToViewScore.put(viewId, score);
|
||||
}
|
||||
}
|
||||
Entry<Long, Double> maxViewScore = viewIdToViewScore.entrySet().stream()
|
||||
.max(Comparator.comparingDouble(Entry::getValue)).orElse(null);
|
||||
log.info("maxViewCount:{},viewIdToViewCount:{}", maxViewScore, viewIdToViewScore);
|
||||
if (Objects.nonNull(maxViewScore)) {
|
||||
return maxViewScore.getKey();
|
||||
}
|
||||
}
|
||||
return null;
|
||||
}
|
||||
|
||||
public static Map<Long, ViewMatchResult> getViewTypeMap(SchemaMapInfo schemaMap) {
|
||||
Map<Long, ViewMatchResult> viewCount = new HashMap<>();
|
||||
for (Map.Entry<Long, List<SchemaElementMatch>> entry : schemaMap.getViewElementMatches().entrySet()) {
|
||||
List<SchemaElementMatch> schemaElementMatches = schemaMap.getMatchedElements(entry.getKey());
|
||||
if (schemaElementMatches != null && schemaElementMatches.size() > 0) {
|
||||
if (!viewCount.containsKey(entry.getKey())) {
|
||||
viewCount.put(entry.getKey(), new ViewMatchResult());
|
||||
}
|
||||
ViewMatchResult viewMatchResult = viewCount.get(entry.getKey());
|
||||
Set<SchemaElementType> schemaElementTypes = new HashSet<>();
|
||||
schemaElementMatches.stream()
|
||||
.forEach(schemaElementMatch -> schemaElementTypes.add(
|
||||
schemaElementMatch.getElement().getType()));
|
||||
SchemaElementMatch schemaElementMatchMax = schemaElementMatches.stream()
|
||||
.sorted((o1, o2) ->
|
||||
((int) ((o2.getSimilarity() - o1.getSimilarity()) * 100))
|
||||
).findFirst().orElse(null);
|
||||
if (schemaElementMatchMax != null) {
|
||||
viewMatchResult.setMaxSimilarity(schemaElementMatchMax.getSimilarity());
|
||||
}
|
||||
viewMatchResult.setCount(schemaElementTypes.size());
|
||||
|
||||
}
|
||||
}
|
||||
return viewCount;
|
||||
}
|
||||
|
||||
public Long resolve(QueryContext queryContext, Set<Long> agentViewIds) {
|
||||
SchemaMapInfo mapInfo = queryContext.getMapInfo();
|
||||
Set<Long> matchedViews = mapInfo.getMatchedViewInfos();
|
||||
Long viewId = queryContext.getViewId();
|
||||
if (Objects.nonNull(viewId) && viewId > 0) {
|
||||
if (CollectionUtils.isEmpty(agentViewIds) || agentViewIds.contains(viewId)) {
|
||||
return viewId;
|
||||
}
|
||||
return null;
|
||||
}
|
||||
if (CollectionUtils.isNotEmpty(agentViewIds)) {
|
||||
matchedViews.retainAll(agentViewIds);
|
||||
}
|
||||
Map<Long, SemanticQuery> viewQueryModes = new HashMap<>();
|
||||
for (Long viewIds : matchedViews) {
|
||||
viewQueryModes.put(viewIds, null);
|
||||
}
|
||||
if (viewQueryModes.size() == 1) {
|
||||
return viewQueryModes.keySet().stream().findFirst().get();
|
||||
}
|
||||
return selectViewBySchemaElementMatchScore(viewQueryModes, mapInfo);
|
||||
}
|
||||
|
||||
}
|
||||
@@ -1,31 +1,35 @@
|
||||
package com.tencent.supersonic.chat.core.parser.sql.llm;
|
||||
|
||||
import com.tencent.supersonic.chat.api.pojo.SchemaElement;
|
||||
import com.google.common.collect.Lists;
|
||||
import com.tencent.supersonic.chat.core.utils.S2SqlDateHelper;
|
||||
import com.tencent.supersonic.headless.api.pojo.SchemaElement;
|
||||
import com.tencent.supersonic.chat.api.pojo.SchemaElementMatch;
|
||||
import com.tencent.supersonic.chat.api.pojo.SchemaElementType;
|
||||
import com.tencent.supersonic.headless.api.pojo.SchemaElementType;
|
||||
import com.tencent.supersonic.chat.api.pojo.SemanticSchema;
|
||||
import com.tencent.supersonic.chat.core.agent.Agent;
|
||||
import com.tencent.supersonic.chat.core.agent.AgentToolType;
|
||||
import com.tencent.supersonic.chat.core.agent.NL2SQLTool;
|
||||
import com.tencent.supersonic.chat.core.config.LLMParserConfig;
|
||||
import com.tencent.supersonic.chat.core.config.OptimizationConfig;
|
||||
import com.tencent.supersonic.chat.core.knowledge.semantic.SemanticInterpreter;
|
||||
import com.tencent.supersonic.chat.core.query.semantic.SemanticInterpreter;
|
||||
import com.tencent.supersonic.chat.core.parser.SatisfactionChecker;
|
||||
import com.tencent.supersonic.chat.core.pojo.ChatContext;
|
||||
import com.tencent.supersonic.chat.core.pojo.QueryContext;
|
||||
import com.tencent.supersonic.chat.core.query.QueryManager;
|
||||
import com.tencent.supersonic.chat.core.query.SemanticQuery;
|
||||
import com.tencent.supersonic.chat.core.query.llm.s2sql.LLMReq;
|
||||
import com.tencent.supersonic.chat.core.query.llm.s2sql.LLMReq.ElementValue;
|
||||
import com.tencent.supersonic.chat.core.query.llm.s2sql.LLMResp;
|
||||
import com.tencent.supersonic.chat.core.query.llm.s2sql.LLMSqlQuery;
|
||||
import com.tencent.supersonic.chat.core.utils.ComponentFactory;
|
||||
import com.tencent.supersonic.common.pojo.ModelCluster;
|
||||
import com.tencent.supersonic.common.pojo.enums.DataFormatTypeEnum;
|
||||
import com.tencent.supersonic.common.pojo.enums.TimeDimensionEnum;
|
||||
import com.tencent.supersonic.common.util.DateUtils;
|
||||
import com.tencent.supersonic.headless.api.pojo.SchemaItem;
|
||||
import com.tencent.supersonic.headless.api.response.ModelSchemaResp;
|
||||
import com.tencent.supersonic.headless.api.pojo.response.ViewSchemaResp;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.apache.commons.lang3.StringUtils;
|
||||
import org.apache.commons.lang3.tuple.Pair;
|
||||
import org.springframework.beans.factory.annotation.Autowired;
|
||||
import org.springframework.stereotype.Service;
|
||||
import org.springframework.util.CollectionUtils;
|
||||
|
||||
import java.util.ArrayList;
|
||||
import java.util.Comparator;
|
||||
import java.util.HashSet;
|
||||
@@ -35,12 +39,6 @@ import java.util.Objects;
|
||||
import java.util.Optional;
|
||||
import java.util.Set;
|
||||
import java.util.stream.Collectors;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.apache.commons.lang3.StringUtils;
|
||||
import org.apache.commons.lang3.tuple.Pair;
|
||||
import org.springframework.beans.factory.annotation.Autowired;
|
||||
import org.springframework.stereotype.Service;
|
||||
import org.springframework.util.CollectionUtils;
|
||||
|
||||
@Slf4j
|
||||
@Service
|
||||
@@ -63,79 +61,54 @@ public class LLMRequestService {
|
||||
return false;
|
||||
}
|
||||
|
||||
public ModelCluster getModelCluster(QueryContext queryCtx, ChatContext chatCtx) {
|
||||
public Long getViewId(QueryContext queryCtx) {
|
||||
Agent agent = queryCtx.getAgent();
|
||||
Set<Long> distinctModelIds = new HashSet<>();
|
||||
Set<Long> agentViewIds = new HashSet<>();
|
||||
if (Objects.nonNull(agent)) {
|
||||
distinctModelIds = agent.getModelIds(AgentToolType.NL2SQL_LLM);
|
||||
agentViewIds = agent.getViewIds(AgentToolType.NL2SQL_LLM);
|
||||
}
|
||||
if (llmParserConfig.getAllModel()) {
|
||||
ModelCluster modelCluster = ModelCluster.build(distinctModelIds);
|
||||
if (!CollectionUtils.isEmpty(queryCtx.getCandidateQueries())) {
|
||||
queryCtx.getCandidateQueries().stream().forEach(o -> {
|
||||
if (LLMSqlQuery.QUERY_MODE.equals(o.getParseInfo().getQueryMode())) {
|
||||
o.getParseInfo().setModel(modelCluster);
|
||||
}
|
||||
});
|
||||
}
|
||||
SemanticQuery semanticQuery = QueryManager.createQuery(LLMSqlQuery.QUERY_MODE);
|
||||
semanticQuery.getParseInfo().setModel(modelCluster);
|
||||
List<SchemaElementMatch> schemaElementMatches = new ArrayList<>();
|
||||
distinctModelIds.stream().forEach(o -> {
|
||||
if (!CollectionUtils.isEmpty(queryCtx.getMapInfo().getMatchedElements(o))) {
|
||||
schemaElementMatches.addAll(queryCtx.getMapInfo().getMatchedElements(o));
|
||||
}
|
||||
});
|
||||
queryCtx.getModelClusterMapInfo().setMatchedElements(modelCluster.getKey(), schemaElementMatches);
|
||||
return modelCluster;
|
||||
if (Agent.containsAllModel(agentViewIds)) {
|
||||
agentViewIds = new HashSet<>();
|
||||
}
|
||||
if (Agent.containsAllModel(distinctModelIds)) {
|
||||
distinctModelIds = new HashSet<>();
|
||||
}
|
||||
ModelResolver modelResolver = ComponentFactory.getModelResolver();
|
||||
String modelCluster = modelResolver.resolve(queryCtx, chatCtx, distinctModelIds);
|
||||
log.info("resolve modelId:{},llmParser Models:{}", modelCluster, distinctModelIds);
|
||||
return ModelCluster.build(modelCluster);
|
||||
ViewResolver viewResolver = ComponentFactory.getModelResolver();
|
||||
return viewResolver.resolve(queryCtx, agentViewIds);
|
||||
}
|
||||
|
||||
public NL2SQLTool getParserTool(QueryContext queryCtx, Set<Long> modelIdSet) {
|
||||
public NL2SQLTool getParserTool(QueryContext queryCtx, Long viewId) {
|
||||
Agent agent = queryCtx.getAgent();
|
||||
if (Objects.isNull(agent)) {
|
||||
return null;
|
||||
}
|
||||
List<NL2SQLTool> commonAgentTools = agent.getParserTools(AgentToolType.NL2SQL_LLM);
|
||||
Optional<NL2SQLTool> llmParserTool = commonAgentTools.stream()
|
||||
.filter(tool -> {
|
||||
List<Long> modelIds = tool.getModelIds();
|
||||
if (Agent.containsAllModel(new HashSet<>(modelIds))) {
|
||||
List<Long> viewIds = tool.getViewIds();
|
||||
if (Agent.containsAllModel(new HashSet<>(viewIds))) {
|
||||
return true;
|
||||
}
|
||||
for (Long modelId : modelIdSet) {
|
||||
if (modelIds.contains(modelId)) {
|
||||
return true;
|
||||
}
|
||||
}
|
||||
return false;
|
||||
return viewIds.contains(viewId);
|
||||
})
|
||||
.findFirst();
|
||||
return llmParserTool.orElse(null);
|
||||
}
|
||||
|
||||
public LLMReq getLlmReq(QueryContext queryCtx, SemanticSchema semanticSchema,
|
||||
ModelCluster modelCluster, List<ElementValue> linkingValues) {
|
||||
Map<Long, String> modelIdToName = semanticSchema.getModelIdToName();
|
||||
public LLMReq getLlmReq(QueryContext queryCtx, Long viewId,
|
||||
SemanticSchema semanticSchema, List<ElementValue> linkingValues) {
|
||||
Map<Long, String> viewIdToName = semanticSchema.getViewIdToName();
|
||||
String queryText = queryCtx.getQueryText();
|
||||
|
||||
LLMReq llmReq = new LLMReq();
|
||||
llmReq.setQueryText(queryText);
|
||||
Long firstModelId = modelCluster.getFirstModel();
|
||||
LLMReq.FilterCondition filterCondition = new LLMReq.FilterCondition();
|
||||
llmReq.setFilterCondition(filterCondition);
|
||||
|
||||
LLMReq.LLMSchema llmSchema = new LLMReq.LLMSchema();
|
||||
llmSchema.setModelName(modelIdToName.get(firstModelId));
|
||||
llmSchema.setDomainName(modelIdToName.get(firstModelId));
|
||||
llmSchema.setViewName(viewIdToName.get(viewId));
|
||||
llmSchema.setDomainName(viewIdToName.get(viewId));
|
||||
|
||||
List<String> fieldNameList = getFieldNameList(queryCtx, modelCluster, llmParserConfig);
|
||||
List<String> fieldNameList = getFieldNameList(queryCtx, viewId, llmParserConfig);
|
||||
|
||||
String priorExts = getPriorExts(modelCluster.getModelIds(), fieldNameList);
|
||||
String priorExts = getPriorExts(viewId, fieldNameList);
|
||||
llmReq.setPriorExts(priorExts);
|
||||
|
||||
fieldNameList.add(TimeDimensionEnum.DAY.getChName());
|
||||
@@ -148,7 +121,7 @@ public class LLMRequestService {
|
||||
}
|
||||
llmReq.setLinking(linking);
|
||||
|
||||
String currentDate = S2SqlDateHelper.getReferenceDate(queryCtx, firstModelId);
|
||||
String currentDate = S2SqlDateHelper.getReferenceDate(queryCtx, viewId);
|
||||
if (StringUtils.isEmpty(currentDate)) {
|
||||
currentDate = DateUtils.getBeforeDate(0);
|
||||
}
|
||||
@@ -157,29 +130,28 @@ public class LLMRequestService {
|
||||
return llmReq;
|
||||
}
|
||||
|
||||
public LLMResp requestLLM(LLMReq llmReq, String modelClusterKey) {
|
||||
return ComponentFactory.getLLMProxy().query2sql(llmReq, modelClusterKey);
|
||||
public LLMResp requestLLM(LLMReq llmReq, Long viewId) {
|
||||
return ComponentFactory.getLLMProxy().query2sql(llmReq, viewId);
|
||||
}
|
||||
|
||||
protected List<String> getFieldNameList(QueryContext queryCtx, ModelCluster modelCluster,
|
||||
protected List<String> getFieldNameList(QueryContext queryCtx, Long viewId,
|
||||
LLMParserConfig llmParserConfig) {
|
||||
|
||||
Set<String> results = getTopNFieldNames(queryCtx, modelCluster, llmParserConfig);
|
||||
Set<String> results = getTopNFieldNames(queryCtx, viewId, llmParserConfig);
|
||||
|
||||
Set<String> fieldNameList = getMatchedFieldNames(queryCtx, modelCluster);
|
||||
Set<String> fieldNameList = getMatchedFieldNames(queryCtx, viewId);
|
||||
|
||||
results.addAll(fieldNameList);
|
||||
return new ArrayList<>(results);
|
||||
}
|
||||
|
||||
private String getPriorExts(Set<Long> modelIds, List<String> fieldNameList) {
|
||||
private String getPriorExts(Long viewId, List<String> fieldNameList) {
|
||||
StringBuilder extraInfoSb = new StringBuilder();
|
||||
List<ModelSchemaResp> modelSchemaResps = semanticInterpreter.fetchModelSchema(
|
||||
new ArrayList<>(modelIds), true);
|
||||
if (!CollectionUtils.isEmpty(modelSchemaResps)) {
|
||||
|
||||
ModelSchemaResp modelSchemaResp = modelSchemaResps.get(0);
|
||||
Map<String, String> fieldNameToDataFormatType = modelSchemaResp.getMetrics()
|
||||
List<ViewSchemaResp> viewSchemaResps = semanticInterpreter.fetchViewSchema(
|
||||
Lists.newArrayList(viewId), true);
|
||||
if (!CollectionUtils.isEmpty(viewSchemaResps)) {
|
||||
ViewSchemaResp viewSchemaResp = viewSchemaResps.get(0);
|
||||
Map<String, String> fieldNameToDataFormatType = viewSchemaResp.getMetrics()
|
||||
.stream().filter(metricSchemaResp -> Objects.nonNull(metricSchemaResp.getDataFormatType()))
|
||||
.flatMap(metricSchemaResp -> {
|
||||
Set<Pair<String, String>> result = new HashSet<>();
|
||||
@@ -207,11 +179,9 @@ public class LLMRequestService {
|
||||
return extraInfoSb.toString();
|
||||
}
|
||||
|
||||
protected List<ElementValue> getValueList(QueryContext queryCtx, ModelCluster modelCluster) {
|
||||
Map<Long, String> itemIdToName = getItemIdToName(queryCtx, modelCluster);
|
||||
|
||||
List<SchemaElementMatch> matchedElements = queryCtx.getModelClusterMapInfo()
|
||||
.getMatchedElements(modelCluster.getKey());
|
||||
protected List<ElementValue> getValueList(QueryContext queryCtx, Long viewId) {
|
||||
Map<Long, String> itemIdToName = getItemIdToName(queryCtx, viewId);
|
||||
List<SchemaElementMatch> matchedElements = queryCtx.getMapInfo().getMatchedElements(viewId);
|
||||
if (CollectionUtils.isEmpty(matchedElements)) {
|
||||
return new ArrayList<>();
|
||||
}
|
||||
@@ -231,22 +201,21 @@ public class LLMRequestService {
|
||||
return new ArrayList<>(valueMatches);
|
||||
}
|
||||
|
||||
protected Map<Long, String> getItemIdToName(QueryContext queryCtx, ModelCluster modelCluster) {
|
||||
protected Map<Long, String> getItemIdToName(QueryContext queryCtx, Long viewId) {
|
||||
SemanticSchema semanticSchema = queryCtx.getSemanticSchema();
|
||||
return semanticSchema.getDimensions(modelCluster.getModelIds()).stream()
|
||||
return semanticSchema.getDimensions(viewId).stream()
|
||||
.collect(Collectors.toMap(SchemaElement::getId, SchemaElement::getName, (value1, value2) -> value2));
|
||||
}
|
||||
|
||||
private Set<String> getTopNFieldNames(QueryContext queryCtx, ModelCluster modelCluster,
|
||||
LLMParserConfig llmParserConfig) {
|
||||
private Set<String> getTopNFieldNames(QueryContext queryCtx, Long viewId, LLMParserConfig llmParserConfig) {
|
||||
SemanticSchema semanticSchema = queryCtx.getSemanticSchema();
|
||||
Set<String> results = semanticSchema.getDimensions(modelCluster.getModelIds()).stream()
|
||||
Set<String> results = semanticSchema.getDimensions(viewId).stream()
|
||||
.sorted(Comparator.comparing(SchemaElement::getUseCnt).reversed())
|
||||
.limit(llmParserConfig.getDimensionTopN())
|
||||
.map(entry -> entry.getName())
|
||||
.collect(Collectors.toSet());
|
||||
|
||||
Set<String> metrics = semanticSchema.getMetrics(modelCluster.getModelIds()).stream()
|
||||
Set<String> metrics = semanticSchema.getMetrics(viewId).stream()
|
||||
.sorted(Comparator.comparing(SchemaElement::getUseCnt).reversed())
|
||||
.limit(llmParserConfig.getMetricTopN())
|
||||
.map(entry -> entry.getName())
|
||||
@@ -256,10 +225,9 @@ public class LLMRequestService {
|
||||
return results;
|
||||
}
|
||||
|
||||
protected Set<String> getMatchedFieldNames(QueryContext queryCtx, ModelCluster modelCluster) {
|
||||
Map<Long, String> itemIdToName = getItemIdToName(queryCtx, modelCluster);
|
||||
List<SchemaElementMatch> matchedElements = queryCtx.getModelClusterMapInfo()
|
||||
.getMatchedElements(modelCluster.getKey());
|
||||
protected Set<String> getMatchedFieldNames(QueryContext queryCtx, Long viewId) {
|
||||
Map<Long, String> itemIdToName = getItemIdToName(queryCtx, viewId);
|
||||
List<SchemaElementMatch> matchedElements = queryCtx.getMapInfo().getMatchedElements(viewId);
|
||||
if (CollectionUtils.isEmpty(matchedElements)) {
|
||||
return new HashSet<>();
|
||||
}
|
||||
|
||||
@@ -9,7 +9,7 @@ import com.tencent.supersonic.chat.core.query.llm.s2sql.LLMResp;
|
||||
import com.tencent.supersonic.chat.core.query.llm.s2sql.LLMSqlQuery;
|
||||
import com.tencent.supersonic.chat.core.query.llm.s2sql.LLMSqlResp;
|
||||
import com.tencent.supersonic.common.pojo.Constants;
|
||||
import com.tencent.supersonic.common.util.jsqlparser.SqlParserEqualHelper;
|
||||
import com.tencent.supersonic.common.util.jsqlparser.SqlEqualHelper;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.apache.commons.collections4.MapUtils;
|
||||
import org.springframework.stereotype.Service;
|
||||
@@ -28,10 +28,9 @@ public class LLMResponseService {
|
||||
}
|
||||
LLMSemanticQuery semanticQuery = QueryManager.createLLMQuery(LLMSqlQuery.QUERY_MODE);
|
||||
SemanticParseInfo parseInfo = semanticQuery.getParseInfo();
|
||||
parseInfo.setModel(parseResult.getModelCluster());
|
||||
parseInfo.setView(queryCtx.getSemanticSchema().getView(parseResult.getViewId()));
|
||||
NL2SQLTool commonAgentTool = parseResult.getCommonAgentTool();
|
||||
parseInfo.getElementMatches().addAll(queryCtx.getModelClusterMapInfo()
|
||||
.getMatchedElements(parseInfo.getModelClusterKey()));
|
||||
parseInfo.getElementMatches().addAll(queryCtx.getMapInfo().getMatchedElements(parseInfo.getViewId()));
|
||||
|
||||
Map<String, Object> properties = new HashMap<>();
|
||||
properties.put(Constants.CONTEXT, parseResult);
|
||||
@@ -42,7 +41,6 @@ public class LLMResponseService {
|
||||
parseInfo.setScore(queryCtx.getQueryText().length() * (1 + weight));
|
||||
parseInfo.setQueryMode(semanticQuery.getQueryMode());
|
||||
parseInfo.getSqlInfo().setS2SQL(s2SQL);
|
||||
parseInfo.setModel(parseResult.getModelCluster());
|
||||
queryCtx.getCandidateQueries().add(semanticQuery);
|
||||
return parseInfo;
|
||||
}
|
||||
@@ -54,7 +52,7 @@ public class LLMResponseService {
|
||||
Map<String, LLMSqlResp> result = new HashMap<>();
|
||||
for (Map.Entry<String, LLMSqlResp> entry : llmResp.getSqlRespMap().entrySet()) {
|
||||
String key = entry.getKey();
|
||||
if (result.keySet().stream().anyMatch(existKey -> SqlParserEqualHelper.equals(existKey, key))) {
|
||||
if (result.keySet().stream().anyMatch(existKey -> SqlEqualHelper.equals(existKey, key))) {
|
||||
continue;
|
||||
}
|
||||
result.put(key, entry.getValue());
|
||||
|
||||
@@ -9,14 +9,13 @@ import com.tencent.supersonic.chat.core.query.llm.s2sql.LLMReq;
|
||||
import com.tencent.supersonic.chat.core.query.llm.s2sql.LLMReq.ElementValue;
|
||||
import com.tencent.supersonic.chat.core.query.llm.s2sql.LLMResp;
|
||||
import com.tencent.supersonic.chat.core.query.llm.s2sql.LLMSqlResp;
|
||||
import com.tencent.supersonic.common.pojo.ModelCluster;
|
||||
import com.tencent.supersonic.common.util.ContextUtils;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.apache.commons.collections4.MapUtils;
|
||||
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
import java.util.Objects;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.apache.commons.collections4.MapUtils;
|
||||
import org.apache.commons.lang3.StringUtils;
|
||||
|
||||
@Slf4j
|
||||
public class LLMSqlParser implements SemanticParser {
|
||||
@@ -30,31 +29,30 @@ public class LLMSqlParser implements SemanticParser {
|
||||
}
|
||||
try {
|
||||
//2.get modelId from queryCtx and chatCtx.
|
||||
ModelCluster modelCluster = requestService.getModelCluster(queryCtx, chatCtx);
|
||||
if (StringUtils.isBlank(modelCluster.getKey())) {
|
||||
Long viewId = requestService.getViewId(queryCtx);
|
||||
if (viewId == null) {
|
||||
return;
|
||||
}
|
||||
//3.get agent tool and determine whether to skip this parser.
|
||||
NL2SQLTool commonAgentTool = requestService.getParserTool(queryCtx, modelCluster.getModelIds());
|
||||
NL2SQLTool commonAgentTool = requestService.getParserTool(queryCtx, viewId);
|
||||
if (Objects.isNull(commonAgentTool)) {
|
||||
log.info("no tool in this agent, skip {}", LLMSqlParser.class);
|
||||
return;
|
||||
}
|
||||
//4.construct a request, call the API for the large model, and retrieve the results.
|
||||
List<ElementValue> linkingValues = requestService.getValueList(queryCtx, modelCluster);
|
||||
List<ElementValue> linkingValues = requestService.getValueList(queryCtx, viewId);
|
||||
SemanticSchema semanticSchema = queryCtx.getSemanticSchema();
|
||||
LLMReq llmReq = requestService.getLlmReq(queryCtx, semanticSchema, modelCluster, linkingValues);
|
||||
LLMResp llmResp = requestService.requestLLM(llmReq, modelCluster.getKey());
|
||||
LLMReq llmReq = requestService.getLlmReq(queryCtx, viewId, semanticSchema, linkingValues);
|
||||
LLMResp llmResp = requestService.requestLLM(llmReq, viewId);
|
||||
|
||||
if (Objects.isNull(llmResp)) {
|
||||
return;
|
||||
}
|
||||
//5. deduplicate the SQL result list and build parserInfo
|
||||
modelCluster.buildName(semanticSchema.getModelIdToName());
|
||||
LLMResponseService responseService = ContextUtils.getBean(LLMResponseService.class);
|
||||
Map<String, LLMSqlResp> deduplicationSqlResp = responseService.getDeduplicationSqlResp(llmResp);
|
||||
ParseResult parseResult = ParseResult.builder()
|
||||
.modelCluster(modelCluster)
|
||||
.viewId(viewId)
|
||||
.commonAgentTool(commonAgentTool)
|
||||
.llmReq(llmReq)
|
||||
.llmResp(llmResp)
|
||||
|
||||
@@ -1,12 +0,0 @@
|
||||
package com.tencent.supersonic.chat.core.parser.sql.llm;
|
||||
|
||||
|
||||
import com.tencent.supersonic.chat.core.pojo.ChatContext;
|
||||
import com.tencent.supersonic.chat.core.pojo.QueryContext;
|
||||
import java.util.Set;
|
||||
|
||||
public interface ModelResolver {
|
||||
|
||||
String resolve(QueryContext queryContext, ChatContext chatCtx, Set<Long> restrictiveModels);
|
||||
|
||||
}
|
||||
@@ -11,11 +11,6 @@ import dev.langchain4j.model.chat.ChatLanguageModel;
|
||||
import dev.langchain4j.model.input.Prompt;
|
||||
import dev.langchain4j.model.input.PromptTemplate;
|
||||
import dev.langchain4j.model.output.Response;
|
||||
import java.util.HashMap;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
import java.util.concurrent.CopyOnWriteArrayList;
|
||||
import java.util.stream.Collectors;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.apache.commons.lang3.tuple.Pair;
|
||||
import org.slf4j.Logger;
|
||||
@@ -24,6 +19,12 @@ import org.springframework.beans.factory.InitializingBean;
|
||||
import org.springframework.beans.factory.annotation.Autowired;
|
||||
import org.springframework.stereotype.Service;
|
||||
|
||||
import java.util.HashMap;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
import java.util.concurrent.CopyOnWriteArrayList;
|
||||
import java.util.stream.Collectors;
|
||||
|
||||
@Service
|
||||
@Slf4j
|
||||
public class OnePassSCSqlGeneration implements SqlGeneration, InitializingBean {
|
||||
@@ -33,7 +34,7 @@ public class OnePassSCSqlGeneration implements SqlGeneration, InitializingBean {
|
||||
private ChatLanguageModel chatLanguageModel;
|
||||
|
||||
@Autowired
|
||||
private SqlExampleLoader sqlExampleLoader;
|
||||
private SqlExamplarLoader sqlExamplarLoader;
|
||||
|
||||
@Autowired
|
||||
private OptimizationConfig optimizationConfig;
|
||||
@@ -42,12 +43,12 @@ public class OnePassSCSqlGeneration implements SqlGeneration, InitializingBean {
|
||||
private SqlPromptGenerator sqlPromptGenerator;
|
||||
|
||||
@Override
|
||||
public LLMResp generation(LLMReq llmReq, String modelClusterKey) {
|
||||
public LLMResp generation(LLMReq llmReq, Long viewId) {
|
||||
//1.retriever sqlExamples and generate exampleListPool
|
||||
keyPipelineLog.info("modelClusterKey:{},llmReq:{}", modelClusterKey, llmReq);
|
||||
keyPipelineLog.info("viewId:{},llmReq:{}", viewId, llmReq);
|
||||
|
||||
List<Map<String, String>> sqlExamples = sqlExampleLoader.retrieverSqlExamples(llmReq.getQueryText(),
|
||||
optimizationConfig.getText2sqlCollectionName(), optimizationConfig.getText2sqlExampleNum());
|
||||
List<Map<String, String>> sqlExamples = sqlExamplarLoader.retrieverSqlExamples(llmReq.getQueryText(),
|
||||
optimizationConfig.getText2sqlExampleNum());
|
||||
|
||||
List<List<Map<String, String>>> exampleListPool = sqlPromptGenerator.getExampleCombos(sqlExamples,
|
||||
optimizationConfig.getText2sqlFewShotsNum(), optimizationConfig.getText2sqlSelfConsistencyNum());
|
||||
|
||||
@@ -2,19 +2,16 @@ package com.tencent.supersonic.chat.core.parser.sql.llm;
|
||||
|
||||
|
||||
import com.tencent.supersonic.chat.core.config.OptimizationConfig;
|
||||
import com.tencent.supersonic.chat.core.query.llm.s2sql.LLMSqlResp;
|
||||
import com.tencent.supersonic.chat.core.query.llm.s2sql.LLMReq;
|
||||
import com.tencent.supersonic.chat.core.query.llm.s2sql.LLMReq.SqlGenerationMode;
|
||||
import com.tencent.supersonic.chat.core.query.llm.s2sql.LLMResp;
|
||||
import com.tencent.supersonic.chat.core.query.llm.s2sql.LLMSqlResp;
|
||||
import com.tencent.supersonic.common.util.JsonUtil;
|
||||
import dev.langchain4j.data.message.AiMessage;
|
||||
import dev.langchain4j.model.chat.ChatLanguageModel;
|
||||
import dev.langchain4j.model.input.Prompt;
|
||||
import dev.langchain4j.model.input.PromptTemplate;
|
||||
import dev.langchain4j.model.output.Response;
|
||||
import java.util.HashMap;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.slf4j.Logger;
|
||||
import org.slf4j.LoggerFactory;
|
||||
@@ -22,6 +19,10 @@ import org.springframework.beans.factory.InitializingBean;
|
||||
import org.springframework.beans.factory.annotation.Autowired;
|
||||
import org.springframework.stereotype.Service;
|
||||
|
||||
import java.util.HashMap;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
|
||||
@Service
|
||||
@Slf4j
|
||||
public class OnePassSqlGeneration implements SqlGeneration, InitializingBean {
|
||||
@@ -31,7 +32,7 @@ public class OnePassSqlGeneration implements SqlGeneration, InitializingBean {
|
||||
private ChatLanguageModel chatLanguageModel;
|
||||
|
||||
@Autowired
|
||||
private SqlExampleLoader sqlExampleLoader;
|
||||
private SqlExamplarLoader sqlExampleLoader;
|
||||
|
||||
@Autowired
|
||||
private OptimizationConfig optimizationConfig;
|
||||
@@ -40,11 +41,11 @@ public class OnePassSqlGeneration implements SqlGeneration, InitializingBean {
|
||||
private SqlPromptGenerator sqlPromptGenerator;
|
||||
|
||||
@Override
|
||||
public LLMResp generation(LLMReq llmReq, String modelClusterKey) {
|
||||
public LLMResp generation(LLMReq llmReq, Long viewId) {
|
||||
//1.retriever sqlExamples
|
||||
keyPipelineLog.info("modelClusterKey:{},llmReq:{}", modelClusterKey, llmReq);
|
||||
keyPipelineLog.info("viewId:{},llmReq:{}", viewId, llmReq);
|
||||
List<Map<String, String>> sqlExamples = sqlExampleLoader.retrieverSqlExamples(llmReq.getQueryText(),
|
||||
optimizationConfig.getText2sqlCollectionName(), optimizationConfig.getText2sqlExampleNum());
|
||||
optimizationConfig.getText2sqlExampleNum());
|
||||
|
||||
//2.generator linking and sql prompt by sqlExamples,and generate response.
|
||||
String promptStr = sqlPromptGenerator.generatorLinkingAndSqlPrompt(llmReq, sqlExamples);
|
||||
|
||||
@@ -5,7 +5,6 @@ import com.tencent.supersonic.chat.api.pojo.request.QueryReq;
|
||||
import com.tencent.supersonic.chat.core.query.llm.s2sql.LLMReq;
|
||||
import com.tencent.supersonic.chat.core.query.llm.s2sql.LLMReq.ElementValue;
|
||||
import com.tencent.supersonic.chat.core.query.llm.s2sql.LLMResp;
|
||||
import com.tencent.supersonic.common.pojo.ModelCluster;
|
||||
import lombok.AllArgsConstructor;
|
||||
import lombok.Builder;
|
||||
import lombok.Data;
|
||||
@@ -19,7 +18,7 @@ import java.util.List;
|
||||
@NoArgsConstructor
|
||||
public class ParseResult {
|
||||
|
||||
private ModelCluster modelCluster;
|
||||
private Long viewId;
|
||||
|
||||
private LLMReq llmReq;
|
||||
|
||||
|
||||
@@ -1,49 +0,0 @@
|
||||
package com.tencent.supersonic.chat.core.parser.sql.llm;
|
||||
|
||||
import com.tencent.supersonic.chat.api.pojo.request.ChatConfigFilter;
|
||||
import com.tencent.supersonic.chat.api.pojo.response.ChatConfigRichResp;
|
||||
import com.tencent.supersonic.chat.api.pojo.response.ChatDefaultRichConfigResp;
|
||||
import com.tencent.supersonic.chat.core.pojo.QueryContext;
|
||||
import com.tencent.supersonic.common.util.DatePeriodEnum;
|
||||
import com.tencent.supersonic.common.util.DateUtils;
|
||||
import java.util.Objects;
|
||||
|
||||
public class S2SqlDateHelper {
|
||||
|
||||
public static String getReferenceDate(QueryContext queryContext, Long modelId) {
|
||||
String defaultDate = DateUtils.getBeforeDate(0);
|
||||
if (Objects.isNull(modelId)) {
|
||||
return defaultDate;
|
||||
}
|
||||
ChatConfigFilter filter = new ChatConfigFilter();
|
||||
filter.setModelId(modelId);
|
||||
ChatConfigRichResp chatConfigRichResp = queryContext.getModelIdToChatRichConfig().get(modelId);
|
||||
|
||||
if (Objects.isNull(chatConfigRichResp)) {
|
||||
return defaultDate;
|
||||
}
|
||||
if (Objects.isNull(chatConfigRichResp.getChatDetailRichConfig()) || Objects.isNull(
|
||||
chatConfigRichResp.getChatDetailRichConfig().getChatDefaultConfig())) {
|
||||
return defaultDate;
|
||||
}
|
||||
|
||||
ChatDefaultRichConfigResp chatDefaultConfig = chatConfigRichResp.getChatDetailRichConfig()
|
||||
.getChatDefaultConfig();
|
||||
Integer unit = chatDefaultConfig.getUnit();
|
||||
String period = chatDefaultConfig.getPeriod();
|
||||
if (Objects.nonNull(unit)) {
|
||||
// If the unit is set to less than 0, then do not add relative date.
|
||||
if (unit < 0) {
|
||||
return null;
|
||||
}
|
||||
DatePeriodEnum datePeriodEnum = DatePeriodEnum.get(period);
|
||||
if (Objects.isNull(datePeriodEnum)) {
|
||||
return DateUtils.getBeforeDate(unit);
|
||||
} else {
|
||||
return DateUtils.getBeforeDate(unit, datePeriodEnum);
|
||||
}
|
||||
}
|
||||
return defaultDate;
|
||||
}
|
||||
|
||||
}
|
||||
@@ -2,6 +2,7 @@ package com.tencent.supersonic.chat.core.parser.sql.llm;
|
||||
|
||||
|
||||
import com.fasterxml.jackson.core.type.TypeReference;
|
||||
import com.tencent.supersonic.common.config.EmbeddingConfig;
|
||||
import com.tencent.supersonic.common.util.ComponentFactory;
|
||||
import com.tencent.supersonic.common.util.JsonUtil;
|
||||
import com.tencent.supersonic.common.util.embedding.EmbeddingQuery;
|
||||
@@ -19,12 +20,13 @@ import java.util.Objects;
|
||||
import java.util.stream.Collectors;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.apache.commons.collections4.CollectionUtils;
|
||||
import org.springframework.beans.factory.annotation.Autowired;
|
||||
import org.springframework.core.io.ClassPathResource;
|
||||
import org.springframework.stereotype.Component;
|
||||
|
||||
@Slf4j
|
||||
@Component
|
||||
public class SqlExampleLoader {
|
||||
public class SqlExamplarLoader {
|
||||
|
||||
private static final String EXAMPLE_JSON_FILE = "s2ql_examplar.json";
|
||||
|
||||
@@ -32,6 +34,9 @@ public class SqlExampleLoader {
|
||||
private TypeReference<List<SqlExample>> valueTypeRef = new TypeReference<List<SqlExample>>() {
|
||||
};
|
||||
|
||||
@Autowired
|
||||
private EmbeddingConfig embeddingConfig;
|
||||
|
||||
public List<SqlExample> getSqlExamples() throws IOException {
|
||||
ClassPathResource resource = new ClassPathResource(EXAMPLE_JSON_FILE);
|
||||
InputStream inputStream = resource.getInputStream();
|
||||
@@ -53,8 +58,8 @@ public class SqlExampleLoader {
|
||||
s2EmbeddingStore.addQuery(collectionName, queries);
|
||||
}
|
||||
|
||||
public List<Map<String, String>> retrieverSqlExamples(String queryText, String collectionName, int maxResults) {
|
||||
|
||||
public List<Map<String, String>> retrieverSqlExamples(String queryText, int maxResults) {
|
||||
String collectionName = embeddingConfig.getText2sqlCollectionName();
|
||||
RetrieveQuery retrieveQuery = RetrieveQuery.builder().queryTextsList(Collections.singletonList(queryText))
|
||||
.queryEmbeddings(null).build();
|
||||
|
||||
@@ -12,9 +12,9 @@ public interface SqlGeneration {
|
||||
/***
|
||||
* generate llmResp (sql, schemaLink, prompt, etc.) through LLMReq.
|
||||
* @param llmReq
|
||||
* @param modelClusterKey
|
||||
* @param viewId
|
||||
* @return
|
||||
*/
|
||||
LLMResp generation(LLMReq llmReq, String modelClusterKey);
|
||||
LLMResp generation(LLMReq llmReq, Long viewId);
|
||||
|
||||
}
|
||||
|
||||
@@ -2,14 +2,15 @@ package com.tencent.supersonic.chat.core.parser.sql.llm;
|
||||
|
||||
import com.tencent.supersonic.chat.core.query.llm.s2sql.LLMReq;
|
||||
import com.tencent.supersonic.chat.core.query.llm.s2sql.LLMReq.ElementValue;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.apache.commons.lang3.tuple.Pair;
|
||||
import org.springframework.stereotype.Component;
|
||||
|
||||
import java.util.ArrayList;
|
||||
import java.util.Arrays;
|
||||
import java.util.Collections;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.apache.commons.lang3.tuple.Pair;
|
||||
import org.springframework.stereotype.Component;
|
||||
|
||||
@Component
|
||||
@Slf4j
|
||||
@@ -95,7 +96,7 @@ public class SqlPromptGenerator {
|
||||
}
|
||||
|
||||
public Pair<String, String> transformQuestionPrompt(LLMReq llmReq) {
|
||||
String modelName = llmReq.getSchema().getModelName();
|
||||
String modelName = llmReq.getSchema().getViewName();
|
||||
List<String> fieldNameList = llmReq.getSchema().getFieldNameList();
|
||||
List<ElementValue> linking = llmReq.getLinking();
|
||||
String currentDate = llmReq.getCurrentDate();
|
||||
|
||||
@@ -11,10 +11,6 @@ import dev.langchain4j.model.chat.ChatLanguageModel;
|
||||
import dev.langchain4j.model.input.Prompt;
|
||||
import dev.langchain4j.model.input.PromptTemplate;
|
||||
import dev.langchain4j.model.output.Response;
|
||||
import java.util.HashMap;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
import java.util.concurrent.CopyOnWriteArrayList;
|
||||
import org.apache.commons.lang3.tuple.Pair;
|
||||
import org.slf4j.Logger;
|
||||
import org.slf4j.LoggerFactory;
|
||||
@@ -22,6 +18,11 @@ import org.springframework.beans.factory.InitializingBean;
|
||||
import org.springframework.beans.factory.annotation.Autowired;
|
||||
import org.springframework.stereotype.Service;
|
||||
|
||||
import java.util.HashMap;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
import java.util.concurrent.CopyOnWriteArrayList;
|
||||
|
||||
@Service
|
||||
public class TwoPassSCSqlGeneration implements SqlGeneration, InitializingBean {
|
||||
|
||||
@@ -30,7 +31,7 @@ public class TwoPassSCSqlGeneration implements SqlGeneration, InitializingBean {
|
||||
private ChatLanguageModel chatLanguageModel;
|
||||
|
||||
@Autowired
|
||||
private SqlExampleLoader sqlExampleLoader;
|
||||
private SqlExamplarLoader sqlExamplarLoader;
|
||||
|
||||
@Autowired
|
||||
private OptimizationConfig optimizationConfig;
|
||||
@@ -39,11 +40,11 @@ public class TwoPassSCSqlGeneration implements SqlGeneration, InitializingBean {
|
||||
private SqlPromptGenerator sqlPromptGenerator;
|
||||
|
||||
@Override
|
||||
public LLMResp generation(LLMReq llmReq, String modelClusterKey) {
|
||||
public LLMResp generation(LLMReq llmReq, Long viewId) {
|
||||
//1.retriever sqlExamples and generate exampleListPool
|
||||
keyPipelineLog.info("modelClusterKey:{},llmReq:{}", modelClusterKey, llmReq);
|
||||
List<Map<String, String>> sqlExamples = sqlExampleLoader.retrieverSqlExamples(llmReq.getQueryText(),
|
||||
optimizationConfig.getText2sqlCollectionName(), optimizationConfig.getText2sqlExampleNum());
|
||||
keyPipelineLog.info("viewId:{},llmReq:{}", viewId, llmReq);
|
||||
List<Map<String, String>> sqlExamples = sqlExamplarLoader.retrieverSqlExamples(llmReq.getQueryText(),
|
||||
optimizationConfig.getText2sqlExampleNum());
|
||||
|
||||
List<List<Map<String, String>>> exampleListPool = sqlPromptGenerator.getExampleCombos(sqlExamples,
|
||||
optimizationConfig.getText2sqlFewShotsNum(), optimizationConfig.getText2sqlSelfConsistencyNum());
|
||||
|
||||
@@ -11,9 +11,6 @@ import dev.langchain4j.model.chat.ChatLanguageModel;
|
||||
import dev.langchain4j.model.input.Prompt;
|
||||
import dev.langchain4j.model.input.PromptTemplate;
|
||||
import dev.langchain4j.model.output.Response;
|
||||
import java.util.HashMap;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.slf4j.Logger;
|
||||
import org.slf4j.LoggerFactory;
|
||||
@@ -21,6 +18,10 @@ import org.springframework.beans.factory.InitializingBean;
|
||||
import org.springframework.beans.factory.annotation.Autowired;
|
||||
import org.springframework.stereotype.Service;
|
||||
|
||||
import java.util.HashMap;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
|
||||
@Service
|
||||
@Slf4j
|
||||
public class TwoPassSqlGeneration implements SqlGeneration, InitializingBean {
|
||||
@@ -30,7 +31,7 @@ public class TwoPassSqlGeneration implements SqlGeneration, InitializingBean {
|
||||
private ChatLanguageModel chatLanguageModel;
|
||||
|
||||
@Autowired
|
||||
private SqlExampleLoader sqlExampleLoader;
|
||||
private SqlExamplarLoader sqlExamplarLoader;
|
||||
|
||||
@Autowired
|
||||
private OptimizationConfig optimizationConfig;
|
||||
@@ -39,10 +40,10 @@ public class TwoPassSqlGeneration implements SqlGeneration, InitializingBean {
|
||||
private SqlPromptGenerator sqlPromptGenerator;
|
||||
|
||||
@Override
|
||||
public LLMResp generation(LLMReq llmReq, String modelClusterKey) {
|
||||
keyPipelineLog.info("modelClusterKey:{},llmReq:{}", modelClusterKey, llmReq);
|
||||
List<Map<String, String>> sqlExamples = sqlExampleLoader.retrieverSqlExamples(llmReq.getQueryText(),
|
||||
optimizationConfig.getText2sqlCollectionName(), optimizationConfig.getText2sqlExampleNum());
|
||||
public LLMResp generation(LLMReq llmReq, Long viewId) {
|
||||
keyPipelineLog.info("viewId:{},llmReq:{}", viewId, llmReq);
|
||||
List<Map<String, String>> sqlExamples = sqlExamplarLoader.retrieverSqlExamples(llmReq.getQueryText(),
|
||||
optimizationConfig.getText2sqlExampleNum());
|
||||
|
||||
String linkingPromptStr = sqlPromptGenerator.generateLinkingPrompt(llmReq, sqlExamples);
|
||||
|
||||
|
||||
@@ -3,7 +3,7 @@ package com.tencent.supersonic.chat.core.parser.sql.llm;
|
||||
import lombok.Data;
|
||||
|
||||
@Data
|
||||
public class ModelMatchResult {
|
||||
public class ViewMatchResult {
|
||||
private Integer count = 0;
|
||||
private double maxSimilarity;
|
||||
}
|
||||
@@ -0,0 +1,12 @@
|
||||
package com.tencent.supersonic.chat.core.parser.sql.llm;
|
||||
|
||||
|
||||
import com.tencent.supersonic.chat.core.pojo.QueryContext;
|
||||
|
||||
import java.util.Set;
|
||||
|
||||
public interface ViewResolver {
|
||||
|
||||
Long resolve(QueryContext queryContext, Set<Long> restrictiveModels);
|
||||
|
||||
}
|
||||
@@ -11,32 +11,34 @@ import com.tencent.supersonic.chat.core.pojo.QueryContext;
|
||||
import com.tencent.supersonic.chat.core.query.QueryManager;
|
||||
import com.tencent.supersonic.chat.core.query.SemanticQuery;
|
||||
import com.tencent.supersonic.common.pojo.enums.QueryType;
|
||||
import java.util.HashSet;
|
||||
import java.util.List;
|
||||
import java.util.stream.Collectors;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.apache.commons.collections.CollectionUtils;
|
||||
|
||||
import java.util.List;
|
||||
import java.util.stream.Collectors;
|
||||
|
||||
@Slf4j
|
||||
public class AgentCheckParser implements SemanticParser {
|
||||
|
||||
@Override
|
||||
public void parse(QueryContext queryContext, ChatContext chatContext) {
|
||||
List<SemanticQuery> queries = queryContext.getCandidateQueries();
|
||||
agentCanSupport(queryContext, queries);
|
||||
log.info("query size before agent filter:{}", queryContext.getCandidateQueries().size());
|
||||
filterQueries(queryContext, queries);
|
||||
log.info("query size after agent filter: {}", queryContext.getCandidateQueries().size());
|
||||
}
|
||||
|
||||
private void agentCanSupport(QueryContext queryContext, List<SemanticQuery> queries) {
|
||||
private void filterQueries(QueryContext queryContext, List<SemanticQuery> queries) {
|
||||
Agent agent = queryContext.getAgent();
|
||||
if (agent == null) {
|
||||
return;
|
||||
}
|
||||
List<RuleParserTool> queryTools = getRuleTools(agent);
|
||||
if (CollectionUtils.isEmpty(queryTools)) {
|
||||
queries.clear();
|
||||
queryContext.setCandidateQueries(Lists.newArrayList());
|
||||
return;
|
||||
}
|
||||
log.info("queries resolved:{} {}", agent.getName(),
|
||||
log.info("agent name :{}, queries resolved: {}", agent.getName(),
|
||||
queries.stream().map(SemanticQuery::getQueryMode).collect(Collectors.toList()));
|
||||
queries.removeIf(query -> {
|
||||
for (RuleParserTool tool : queryTools) {
|
||||
@@ -46,26 +48,28 @@ public class AgentCheckParser implements SemanticParser {
|
||||
}
|
||||
if (CollectionUtils.isNotEmpty(tool.getQueryTypes())) {
|
||||
if (QueryManager.isTagQuery(query.getQueryMode())) {
|
||||
return !tool.getQueryTypes().contains(QueryType.TAG.name());
|
||||
if (!tool.getQueryTypes().contains(QueryType.TAG.name())) {
|
||||
return true;
|
||||
}
|
||||
}
|
||||
if (QueryManager.isMetricQuery(query.getQueryMode())) {
|
||||
return !tool.getQueryTypes().contains(QueryType.METRIC.name());
|
||||
if (!tool.getQueryTypes().contains(QueryType.METRIC.name())) {
|
||||
return true;
|
||||
}
|
||||
}
|
||||
}
|
||||
if (CollectionUtils.isEmpty(tool.getModelIds())) {
|
||||
if (CollectionUtils.isEmpty(tool.getViewIds())) {
|
||||
return true;
|
||||
}
|
||||
if (tool.isContainsAllModel()) {
|
||||
return false;
|
||||
}
|
||||
if (new HashSet<>(tool.getModelIds())
|
||||
.containsAll(query.getParseInfo().getModel().getModelIds())) {
|
||||
return false;
|
||||
}
|
||||
return !tool.getViewIds().contains(query.getParseInfo().getViewId());
|
||||
}
|
||||
return true;
|
||||
});
|
||||
log.info("rule queries witch can be supported by agent :{} {}", agent.getName(),
|
||||
queryContext.setCandidateQueries(queries);
|
||||
log.info("agent name :{}, rule queries witch can be supported by agent :{}", agent.getName(),
|
||||
queries.stream().map(SemanticQuery::getQueryMode).collect(Collectors.toList()));
|
||||
}
|
||||
|
||||
|
||||
@@ -1,8 +1,7 @@
|
||||
package com.tencent.supersonic.chat.core.parser.sql.rule;
|
||||
|
||||
import com.tencent.supersonic.chat.api.pojo.SchemaElementMatch;
|
||||
import com.tencent.supersonic.chat.api.pojo.SchemaElementType;
|
||||
import com.tencent.supersonic.chat.api.pojo.SemanticSchema;
|
||||
import com.tencent.supersonic.headless.api.pojo.SchemaElementType;
|
||||
import com.tencent.supersonic.chat.core.parser.SemanticParser;
|
||||
import com.tencent.supersonic.chat.core.pojo.ChatContext;
|
||||
import com.tencent.supersonic.chat.core.pojo.QueryContext;
|
||||
@@ -12,8 +11,8 @@ import com.tencent.supersonic.chat.core.query.rule.RuleSemanticQuery;
|
||||
import com.tencent.supersonic.chat.core.query.rule.metric.MetricModelQuery;
|
||||
import com.tencent.supersonic.chat.core.query.rule.metric.MetricSemanticQuery;
|
||||
import com.tencent.supersonic.chat.core.query.rule.metric.MetricTagQuery;
|
||||
import com.tencent.supersonic.chat.core.utils.ModelClusterBuilder;
|
||||
import com.tencent.supersonic.common.pojo.ModelCluster;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
|
||||
import java.util.AbstractMap;
|
||||
import java.util.ArrayList;
|
||||
import java.util.Arrays;
|
||||
@@ -23,8 +22,6 @@ import java.util.Objects;
|
||||
import java.util.Set;
|
||||
import java.util.stream.Collectors;
|
||||
import java.util.stream.Stream;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.apache.commons.lang3.StringUtils;
|
||||
|
||||
/**
|
||||
* ContextInheritParser tries to inherit certain schema elements from context
|
||||
@@ -42,7 +39,7 @@ public class ContextInheritParser implements SemanticParser {
|
||||
SchemaElementType.VALUE, Arrays.asList(SchemaElementType.VALUE, SchemaElementType.DIMENSION)),
|
||||
new AbstractMap.SimpleEntry<>(SchemaElementType.ENTITY, Arrays.asList(SchemaElementType.ENTITY)),
|
||||
new AbstractMap.SimpleEntry<>(SchemaElementType.TAG, Arrays.asList(SchemaElementType.TAG)),
|
||||
new AbstractMap.SimpleEntry<>(SchemaElementType.MODEL, Arrays.asList(SchemaElementType.MODEL)),
|
||||
new AbstractMap.SimpleEntry<>(SchemaElementType.VIEW, Arrays.asList(SchemaElementType.VIEW)),
|
||||
new AbstractMap.SimpleEntry<>(SchemaElementType.ID, Arrays.asList(SchemaElementType.ID))
|
||||
).collect(Collectors.toMap(Map.Entry::getKey, Map.Entry::getValue));
|
||||
|
||||
@@ -51,12 +48,13 @@ public class ContextInheritParser implements SemanticParser {
|
||||
if (!shouldInherit(queryContext)) {
|
||||
return;
|
||||
}
|
||||
ModelCluster modelCluster = getMatchedModelCluster(queryContext, chatContext);
|
||||
if (modelCluster == null) {
|
||||
Long viewId = getMatchedView(queryContext, chatContext);
|
||||
if (viewId == null) {
|
||||
return;
|
||||
}
|
||||
List<SchemaElementMatch> elementMatches = queryContext.getModelClusterMapInfo()
|
||||
.getMatchedElements(modelCluster.getKey());
|
||||
|
||||
List<SchemaElementMatch> elementMatches = queryContext.getMapInfo().getMatchedElements(viewId);
|
||||
|
||||
List<SchemaElementMatch> matchesToInherit = new ArrayList<>();
|
||||
for (SchemaElementMatch match : chatContext.getParseInfo().getElementMatches()) {
|
||||
SchemaElementType matchType = match.getElement().getType();
|
||||
@@ -72,17 +70,17 @@ public class ContextInheritParser implements SemanticParser {
|
||||
List<RuleSemanticQuery> queries = RuleSemanticQuery.resolve(elementMatches, queryContext);
|
||||
for (RuleSemanticQuery query : queries) {
|
||||
query.fillParseInfo(queryContext, chatContext);
|
||||
if (existSameQuery(query.getParseInfo().getModelClusterKey(), query.getQueryMode(), queryContext)) {
|
||||
if (existSameQuery(query.getParseInfo().getViewId(), query.getQueryMode(), queryContext)) {
|
||||
continue;
|
||||
}
|
||||
queryContext.getCandidateQueries().add(query);
|
||||
}
|
||||
}
|
||||
|
||||
private boolean existSameQuery(String modelClusterKey, String queryMode, QueryContext queryContext) {
|
||||
private boolean existSameQuery(Long viewId, String queryMode, QueryContext queryContext) {
|
||||
for (SemanticQuery semanticQuery : queryContext.getCandidateQueries()) {
|
||||
if (semanticQuery.getQueryMode().equals(queryMode)
|
||||
&& semanticQuery.getParseInfo().getModelClusterKey().equals(modelClusterKey)) {
|
||||
&& semanticQuery.getParseInfo().getViewId().equals(viewId)) {
|
||||
return true;
|
||||
}
|
||||
}
|
||||
@@ -111,25 +109,16 @@ public class ContextInheritParser implements SemanticParser {
|
||||
return metricModelQueries.size() == queryContext.getCandidateQueries().size();
|
||||
}
|
||||
|
||||
protected ModelCluster getMatchedModelCluster(QueryContext queryContext, ChatContext chatContext) {
|
||||
String contextModelClusterKey = chatContext.getParseInfo().getModelClusterKey();
|
||||
if (StringUtils.isBlank(contextModelClusterKey)) {
|
||||
protected Long getMatchedView(QueryContext queryContext, ChatContext chatContext) {
|
||||
Long viewId = chatContext.getParseInfo().getViewId();
|
||||
if (viewId == null) {
|
||||
return null;
|
||||
}
|
||||
SemanticSchema semanticSchema = queryContext.getSemanticSchema();
|
||||
List<ModelCluster> allModelClusters = ModelClusterBuilder.buildModelClusters(semanticSchema);
|
||||
Set<String> queryModelClusters = queryContext.getModelClusterMapInfo().getMatchedModelClusters();
|
||||
ModelCluster contextModelCluster = ModelCluster.build(contextModelClusterKey);
|
||||
for (String cluster : queryModelClusters) {
|
||||
ModelCluster queryModelCluster = ModelCluster.build(cluster);
|
||||
for (ModelCluster modelCluster : allModelClusters) {
|
||||
if (modelCluster.getModelIds().containsAll(contextModelCluster.getModelIds())
|
||||
&& modelCluster.getModelIds().containsAll(queryModelCluster.getModelIds())) {
|
||||
return queryModelCluster;
|
||||
}
|
||||
}
|
||||
Set<Long> queryViews = queryContext.getMapInfo().getMatchedViewInfos();
|
||||
if (queryViews.contains(viewId)) {
|
||||
return viewId;
|
||||
}
|
||||
return null;
|
||||
return viewId;
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
@@ -1,10 +1,10 @@
|
||||
package com.tencent.supersonic.chat.core.parser.sql.rule;
|
||||
|
||||
import com.tencent.supersonic.chat.api.pojo.SchemaMapInfo;
|
||||
import com.tencent.supersonic.chat.core.parser.SemanticParser;
|
||||
import com.tencent.supersonic.chat.core.pojo.ChatContext;
|
||||
import com.tencent.supersonic.chat.core.pojo.QueryContext;
|
||||
import com.tencent.supersonic.chat.api.pojo.SchemaElementMatch;
|
||||
import com.tencent.supersonic.chat.api.pojo.SchemaModelClusterMapInfo;
|
||||
import com.tencent.supersonic.chat.core.query.rule.RuleSemanticQuery;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
|
||||
@@ -27,10 +27,10 @@ public class RuleSqlParser implements SemanticParser {
|
||||
|
||||
@Override
|
||||
public void parse(QueryContext queryContext, ChatContext chatContext) {
|
||||
SchemaModelClusterMapInfo modelClusterMapInfo = queryContext.getModelClusterMapInfo();
|
||||
SchemaMapInfo mapInfo = queryContext.getMapInfo();
|
||||
// iterate all schemaElementMatches to resolve query mode
|
||||
for (String modelClusterKey : modelClusterMapInfo.getMatchedModelClusters()) {
|
||||
List<SchemaElementMatch> elementMatches = modelClusterMapInfo.getMatchedElements(modelClusterKey);
|
||||
for (Long viewId : mapInfo.getMatchedViewInfos()) {
|
||||
List<SchemaElementMatch> elementMatches = mapInfo.getMatchedElements(viewId);
|
||||
List<RuleSemanticQuery> queries = RuleSemanticQuery.resolve(elementMatches, queryContext);
|
||||
for (RuleSemanticQuery query : queries) {
|
||||
query.fillParseInfo(queryContext, chatContext);
|
||||
|
||||
@@ -20,7 +20,7 @@ public class Plugin extends RecordInfo {
|
||||
*/
|
||||
private String type;
|
||||
|
||||
private List<Long> modelList = Lists.newArrayList();
|
||||
private List<Long> viewList = Lists.newArrayList();
|
||||
|
||||
/**
|
||||
* description, for parsing
|
||||
@@ -52,7 +52,7 @@ public class Plugin extends RecordInfo {
|
||||
}
|
||||
|
||||
public boolean isContainsAllModel() {
|
||||
return CollectionUtils.isNotEmpty(modelList) && modelList.contains(-1L);
|
||||
return CollectionUtils.isNotEmpty(viewList) && viewList.contains(-1L);
|
||||
}
|
||||
|
||||
public Long getDefaultMode() {
|
||||
|
||||
@@ -3,9 +3,9 @@ package com.tencent.supersonic.chat.core.plugin;
|
||||
import com.alibaba.fastjson.JSONObject;
|
||||
import com.google.common.collect.Lists;
|
||||
import com.google.common.collect.Sets;
|
||||
import com.tencent.supersonic.chat.api.pojo.SchemaElement;
|
||||
import com.tencent.supersonic.headless.api.pojo.SchemaElement;
|
||||
import com.tencent.supersonic.chat.api.pojo.SchemaElementMatch;
|
||||
import com.tencent.supersonic.chat.api.pojo.SchemaElementType;
|
||||
import com.tencent.supersonic.headless.api.pojo.SchemaElementType;
|
||||
import com.tencent.supersonic.chat.api.pojo.SchemaMapInfo;
|
||||
import com.tencent.supersonic.chat.core.agent.Agent;
|
||||
import com.tencent.supersonic.chat.core.agent.AgentToolType;
|
||||
@@ -23,6 +23,12 @@ import com.tencent.supersonic.common.util.embedding.Retrieval;
|
||||
import com.tencent.supersonic.common.util.embedding.RetrieveQuery;
|
||||
import com.tencent.supersonic.common.util.embedding.RetrieveQueryResult;
|
||||
import com.tencent.supersonic.common.util.embedding.S2EmbeddingStore;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.apache.commons.collections.CollectionUtils;
|
||||
import org.apache.commons.lang3.tuple.Pair;
|
||||
import org.springframework.context.event.EventListener;
|
||||
import org.springframework.stereotype.Component;
|
||||
|
||||
import java.util.ArrayList;
|
||||
import java.util.Collection;
|
||||
import java.util.Collections;
|
||||
@@ -32,11 +38,6 @@ import java.util.Map;
|
||||
import java.util.Objects;
|
||||
import java.util.Set;
|
||||
import java.util.stream.Collectors;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.apache.commons.collections.CollectionUtils;
|
||||
import org.apache.commons.lang3.tuple.Pair;
|
||||
import org.springframework.context.event.EventListener;
|
||||
import org.springframework.stereotype.Component;
|
||||
|
||||
@Slf4j
|
||||
@Component
|
||||
@@ -265,14 +266,14 @@ public class PluginManager {
|
||||
}
|
||||
|
||||
private static Set<Long> getPluginMatchedModel(Plugin plugin, QueryContext queryContext) {
|
||||
Set<Long> matchedModel = queryContext.getMapInfo().getMatchedModels();
|
||||
Set<Long> matchedViews = queryContext.getMapInfo().getMatchedViewInfos();
|
||||
if (plugin.isContainsAllModel()) {
|
||||
return Sets.newHashSet(plugin.getDefaultMode());
|
||||
}
|
||||
List<Long> modelIds = plugin.getModelList();
|
||||
List<Long> modelIds = plugin.getViewList();
|
||||
Set<Long> pluginMatchedModel = Sets.newHashSet();
|
||||
for (Long modelId : modelIds) {
|
||||
if (matchedModel.contains(modelId)) {
|
||||
if (matchedViews.contains(modelId)) {
|
||||
pluginMatchedModel.add(modelId);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -4,6 +4,7 @@ import lombok.AllArgsConstructor;
|
||||
import lombok.Builder;
|
||||
import lombok.Data;
|
||||
import lombok.NoArgsConstructor;
|
||||
|
||||
import java.util.Set;
|
||||
|
||||
@Data
|
||||
@@ -14,7 +15,7 @@ public class PluginRecallResult {
|
||||
|
||||
private Plugin plugin;
|
||||
|
||||
private Set<Long> modelIds;
|
||||
private Set<Long> viewIds;
|
||||
|
||||
private double score;
|
||||
|
||||
|
||||
@@ -1,8 +1,8 @@
|
||||
package com.tencent.supersonic.chat.core.pojo;
|
||||
|
||||
import com.fasterxml.jackson.annotation.JsonIgnore;
|
||||
import com.tencent.supersonic.auth.api.authentication.pojo.User;
|
||||
import com.tencent.supersonic.chat.api.pojo.SchemaMapInfo;
|
||||
import com.tencent.supersonic.chat.api.pojo.SchemaModelClusterMapInfo;
|
||||
import com.tencent.supersonic.chat.api.pojo.SemanticSchema;
|
||||
import com.tencent.supersonic.chat.api.pojo.request.QueryFilters;
|
||||
import com.tencent.supersonic.chat.api.pojo.response.ChatConfigRichResp;
|
||||
@@ -11,15 +11,16 @@ import com.tencent.supersonic.chat.core.config.OptimizationConfig;
|
||||
import com.tencent.supersonic.chat.core.plugin.Plugin;
|
||||
import com.tencent.supersonic.chat.core.query.SemanticQuery;
|
||||
import com.tencent.supersonic.common.util.ContextUtils;
|
||||
import lombok.AllArgsConstructor;
|
||||
import lombok.Builder;
|
||||
import lombok.Data;
|
||||
import lombok.NoArgsConstructor;
|
||||
|
||||
import java.util.ArrayList;
|
||||
import java.util.Comparator;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
import java.util.stream.Collectors;
|
||||
import lombok.AllArgsConstructor;
|
||||
import lombok.Builder;
|
||||
import lombok.Data;
|
||||
import lombok.NoArgsConstructor;
|
||||
|
||||
@Data
|
||||
@Builder
|
||||
@@ -29,18 +30,22 @@ public class QueryContext {
|
||||
|
||||
private String queryText;
|
||||
private Integer chatId;
|
||||
private Long modelId;
|
||||
private Long viewId;
|
||||
private User user;
|
||||
private boolean saveAnswer = true;
|
||||
private Integer agentId;
|
||||
private QueryFilters queryFilters;
|
||||
private List<SemanticQuery> candidateQueries = new ArrayList<>();
|
||||
private SchemaMapInfo mapInfo = new SchemaMapInfo();
|
||||
private SchemaModelClusterMapInfo modelClusterMapInfo = new SchemaModelClusterMapInfo();
|
||||
@JsonIgnore
|
||||
private SemanticSchema semanticSchema;
|
||||
@JsonIgnore
|
||||
private Agent agent;
|
||||
@JsonIgnore
|
||||
private Map<Long, ChatConfigRichResp> modelIdToChatRichConfig;
|
||||
@JsonIgnore
|
||||
private Map<String, Plugin> nameToPlugin;
|
||||
@JsonIgnore
|
||||
private List<Plugin> pluginList;
|
||||
|
||||
public List<SemanticQuery> getCandidateQueries() {
|
||||
|
||||
@@ -6,7 +6,7 @@ import com.tencent.supersonic.chat.api.pojo.SemanticParseInfo;
|
||||
import com.tencent.supersonic.chat.api.pojo.SemanticSchema;
|
||||
import com.tencent.supersonic.chat.api.pojo.response.SqlInfo;
|
||||
import com.tencent.supersonic.chat.core.config.OptimizationConfig;
|
||||
import com.tencent.supersonic.chat.core.knowledge.semantic.SemanticInterpreter;
|
||||
import com.tencent.supersonic.chat.core.query.semantic.SemanticInterpreter;
|
||||
import com.tencent.supersonic.chat.core.utils.ComponentFactory;
|
||||
import com.tencent.supersonic.chat.core.utils.QueryReqBuilder;
|
||||
import com.tencent.supersonic.common.pojo.Aggregator;
|
||||
@@ -14,20 +14,21 @@ import com.tencent.supersonic.common.pojo.Filter;
|
||||
import com.tencent.supersonic.common.pojo.Order;
|
||||
import com.tencent.supersonic.common.pojo.enums.TimeDimensionEnum;
|
||||
import com.tencent.supersonic.common.util.ContextUtils;
|
||||
import com.tencent.supersonic.headless.api.enums.QueryType;
|
||||
import com.tencent.supersonic.headless.api.request.ExplainSqlReq;
|
||||
import com.tencent.supersonic.headless.api.request.QuerySqlReq;
|
||||
import com.tencent.supersonic.headless.api.request.QueryStructReq;
|
||||
import com.tencent.supersonic.headless.api.response.ExplainResp;
|
||||
import com.tencent.supersonic.headless.api.pojo.enums.QueryType;
|
||||
import com.tencent.supersonic.headless.api.pojo.request.ExplainSqlReq;
|
||||
import com.tencent.supersonic.headless.api.pojo.request.QuerySqlReq;
|
||||
import com.tencent.supersonic.headless.api.pojo.request.QueryStructReq;
|
||||
import com.tencent.supersonic.headless.api.pojo.response.ExplainResp;
|
||||
import lombok.ToString;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.apache.commons.collections4.CollectionUtils;
|
||||
import org.apache.commons.lang3.StringUtils;
|
||||
|
||||
import java.io.Serializable;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
import java.util.Objects;
|
||||
import java.util.stream.Collectors;
|
||||
import lombok.ToString;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.apache.commons.collections4.CollectionUtils;
|
||||
import org.apache.commons.lang3.StringUtils;
|
||||
|
||||
@Slf4j
|
||||
@ToString
|
||||
@@ -48,7 +49,7 @@ public abstract class BaseSemanticQuery implements SemanticQuery, Serializable {
|
||||
explainSqlReq = ExplainSqlReq.builder()
|
||||
.queryTypeEnum(QueryType.SQL)
|
||||
.queryReq(QueryReqBuilder.buildS2SQLReq(
|
||||
sqlInfo.getCorrectS2SQL(), parseInfo.getModel().getModelIds()
|
||||
sqlInfo.getCorrectS2SQL(), parseInfo.getViewId()
|
||||
))
|
||||
.build();
|
||||
} else {
|
||||
@@ -83,7 +84,7 @@ public abstract class BaseSemanticQuery implements SemanticQuery, Serializable {
|
||||
}
|
||||
|
||||
protected void convertBizNameToName(SemanticSchema semanticSchema, QueryStructReq queryStructReq) {
|
||||
Map<String, String> bizNameToName = semanticSchema.getBizNameToName(queryStructReq.getModelIdSet());
|
||||
Map<String, String> bizNameToName = semanticSchema.getBizNameToName(queryStructReq.getViewId());
|
||||
bizNameToName.putAll(TimeDimensionEnum.getNameToNameMap());
|
||||
|
||||
List<Order> orders = queryStructReq.getOrders();
|
||||
@@ -100,18 +101,17 @@ public abstract class BaseSemanticQuery implements SemanticQuery, Serializable {
|
||||
}
|
||||
List<String> groups = queryStructReq.getGroups();
|
||||
if (CollectionUtils.isNotEmpty(groups)) {
|
||||
groups = groups.stream().map(group -> bizNameToName.get(group)).collect(Collectors.toList());
|
||||
groups = groups.stream().map(bizNameToName::get).collect(Collectors.toList());
|
||||
queryStructReq.setGroups(groups);
|
||||
}
|
||||
List<Filter> dimensionFilters = queryStructReq.getDimensionFilters();
|
||||
if (CollectionUtils.isNotEmpty(dimensionFilters)) {
|
||||
dimensionFilters.stream().forEach(filter -> filter.setName(bizNameToName.get(filter.getBizName())));
|
||||
dimensionFilters.forEach(filter -> filter.setName(bizNameToName.get(filter.getBizName())));
|
||||
}
|
||||
List<Filter> metricFilters = queryStructReq.getMetricFilters();
|
||||
if (CollectionUtils.isNotEmpty(dimensionFilters)) {
|
||||
metricFilters.stream().forEach(filter -> filter.setName(bizNameToName.get(filter.getBizName())));
|
||||
metricFilters.forEach(filter -> filter.setName(bizNameToName.get(filter.getBizName())));
|
||||
}
|
||||
queryStructReq.setModelName(parseInfo.getModelName());
|
||||
}
|
||||
|
||||
protected void initS2SqlByStruct(SemanticSchema semanticSchema) {
|
||||
@@ -121,7 +121,7 @@ public abstract class BaseSemanticQuery implements SemanticQuery, Serializable {
|
||||
}
|
||||
QueryStructReq queryStructReq = convertQueryStruct();
|
||||
convertBizNameToName(semanticSchema, queryStructReq);
|
||||
QuerySqlReq querySQLReq = queryStructReq.convert(queryStructReq);
|
||||
QuerySqlReq querySQLReq = queryStructReq.convert();
|
||||
parseInfo.getSqlInfo().setS2SQL(querySQLReq.getSql());
|
||||
parseInfo.getSqlInfo().setCorrectS2SQL(querySQLReq.getSql());
|
||||
}
|
||||
|
||||
@@ -2,14 +2,14 @@ package com.tencent.supersonic.chat.core.query.llm.analytics;
|
||||
|
||||
import com.google.common.collect.Lists;
|
||||
import com.tencent.supersonic.auth.api.authentication.pojo.User;
|
||||
import com.tencent.supersonic.chat.api.pojo.SchemaElement;
|
||||
import com.tencent.supersonic.headless.api.pojo.SchemaElement;
|
||||
import com.tencent.supersonic.chat.api.pojo.SchemaElementMatch;
|
||||
import com.tencent.supersonic.chat.api.pojo.SchemaElementType;
|
||||
import com.tencent.supersonic.headless.api.pojo.SchemaElementType;
|
||||
import com.tencent.supersonic.chat.api.pojo.SemanticSchema;
|
||||
import com.tencent.supersonic.chat.api.pojo.response.QueryResult;
|
||||
import com.tencent.supersonic.chat.api.pojo.response.QueryState;
|
||||
import com.tencent.supersonic.chat.core.config.OptimizationConfig;
|
||||
import com.tencent.supersonic.chat.core.knowledge.semantic.SemanticInterpreter;
|
||||
import com.tencent.supersonic.chat.core.query.semantic.SemanticInterpreter;
|
||||
import com.tencent.supersonic.chat.core.query.QueryManager;
|
||||
import com.tencent.supersonic.chat.core.query.llm.LLMSemanticQuery;
|
||||
import com.tencent.supersonic.chat.core.utils.ComponentFactory;
|
||||
@@ -19,8 +19,8 @@ import com.tencent.supersonic.common.pojo.QueryColumn;
|
||||
import com.tencent.supersonic.common.pojo.enums.AggOperatorEnum;
|
||||
import com.tencent.supersonic.common.pojo.enums.QueryType;
|
||||
import com.tencent.supersonic.common.util.ContextUtils;
|
||||
import com.tencent.supersonic.headless.api.request.QueryStructReq;
|
||||
import com.tencent.supersonic.headless.api.response.SemanticQueryResp;
|
||||
import com.tencent.supersonic.headless.api.pojo.request.QueryStructReq;
|
||||
import com.tencent.supersonic.headless.api.pojo.response.SemanticQueryResp;
|
||||
import java.util.HashMap;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user