mirror of
https://github.com/tencentmusic/supersonic.git
synced 2025-12-10 19:51:00 +00:00
Compare commits
176 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
4d095f9676 | ||
|
|
7c7ccadcfd | ||
|
|
02b9dc6947 | ||
|
|
4222d7e2b5 | ||
|
|
b5daf04c96 | ||
|
|
d26089a249 | ||
|
|
627683d437 | ||
|
|
87e222eecc | ||
|
|
667272b103 | ||
|
|
3f08d95aaa | ||
|
|
27ebda3439 | ||
|
|
52eca178d3 | ||
|
|
f5a064aaad | ||
|
|
41e585324d | ||
|
|
c36082476f | ||
|
|
4bc1378285 | ||
|
|
23d01c4f83 | ||
|
|
49ebb70cb3 | ||
|
|
27bb1b322e | ||
|
|
0534053ff9 | ||
|
|
fe2a424718 | ||
|
|
d79e30cd7a | ||
|
|
aa433baa06 | ||
|
|
30bb9a1dc0 | ||
|
|
c168925f03 | ||
|
|
42c0bea8fc | ||
|
|
291c00749a | ||
|
|
6763ea0f7b | ||
|
|
f917defea8 | ||
|
|
91718592d4 | ||
|
|
6d9a8095eb | ||
|
|
5b8fdbc6fd | ||
|
|
e15e44e4a2 | ||
|
|
980d317152 | ||
|
|
2c23c2f574 | ||
|
|
80ad75503b | ||
|
|
0143b0a1b2 | ||
|
|
dd115f9d37 | ||
|
|
f198ce1ef8 | ||
|
|
18211a215d | ||
|
|
d6a386ad03 | ||
|
|
8f19584ad7 | ||
|
|
d9eaf79ab8 | ||
|
|
05b1a7ec3b | ||
|
|
11cdcb29fa | ||
|
|
46a9e5b097 | ||
|
|
8c65ac80b5 | ||
|
|
5b3a9ffba8 | ||
|
|
8688c8c2b3 | ||
|
|
13d8b9cff5 | ||
|
|
aa448b1ba3 | ||
|
|
7ef3d92f2c | ||
|
|
9f09598ccd | ||
|
|
36c8938ff7 | ||
|
|
3271db4ca6 | ||
|
|
400d8b34fd | ||
|
|
d4374f7074 | ||
|
|
438ee539d6 | ||
|
|
5ccde0206c | ||
|
|
74ed269544 | ||
|
|
1ad2c5402b | ||
|
|
805abeb261 | ||
|
|
551a376b00 | ||
|
|
47be92d5f6 | ||
|
|
5feac0c14e | ||
|
|
0f02e21eaa | ||
|
|
cdb84716b7 | ||
|
|
731238de08 | ||
|
|
cb1ad94086 | ||
|
|
24b0be7566 | ||
|
|
3241ef87a3 | ||
|
|
bd541e1199 | ||
|
|
f998f27c6f | ||
|
|
cf788316c3 | ||
|
|
8ed7e91221 | ||
|
|
e537b738e4 | ||
|
|
bf3a111e55 | ||
|
|
63a526709d | ||
|
|
e0088e8f8f | ||
|
|
7d33c49db8 | ||
|
|
acee0a36da | ||
|
|
a528ba6070 | ||
|
|
ba224ac335 | ||
|
|
18aa14118c | ||
|
|
4e139c837a | ||
|
|
6ad74bb206 | ||
|
|
16c3de44e4 | ||
|
|
608a4f7a2f | ||
|
|
cd972d0850 | ||
|
|
2aeeb1a14e | ||
|
|
41aa6ff8e4 | ||
|
|
67f658ced2 | ||
|
|
94fa86629d | ||
|
|
2cb0640a7b | ||
|
|
772d5bd3ae | ||
|
|
d6681ead60 | ||
|
|
d94fd4714f | ||
|
|
0365886270 | ||
|
|
aa6c658a9a | ||
|
|
6e3f871015 | ||
|
|
6c9983164e | ||
|
|
e00b935c1f | ||
|
|
f5f9c0314a | ||
|
|
910384d17f | ||
|
|
2fe56e7462 | ||
|
|
b8989e204f | ||
|
|
84b7c2c062 | ||
|
|
14373309aa | ||
|
|
9cd3e22721 | ||
|
|
2f812372d7 | ||
|
|
9f813ca1c0 | ||
|
|
70784598e1 | ||
|
|
ad20380283 | ||
|
|
f4e3922f47 | ||
|
|
bfac71a7d0 | ||
|
|
435e789fa4 | ||
|
|
b9f5e0a354 | ||
|
|
372e4acc2c | ||
|
|
8f37c3175f | ||
|
|
438e8463f5 | ||
|
|
ae9aa1ba0f | ||
|
|
688d26c457 | ||
|
|
8b99b46787 | ||
|
|
80cce47f58 | ||
|
|
c92184d89f | ||
|
|
a8fe575999 | ||
|
|
38099c8cc7 | ||
|
|
32e51257f6 | ||
|
|
e44e7ca8d5 | ||
|
|
d533496b2a | ||
|
|
eb9db28352 | ||
|
|
836ee5f3ed | ||
|
|
0fa31f84a3 | ||
|
|
9a3c71df4a | ||
|
|
e4e39e0496 | ||
|
|
cd901fbc68 | ||
|
|
8fde378534 | ||
|
|
d8f81aca65 | ||
|
|
b9895d541b | ||
|
|
4fbc3c8533 | ||
|
|
62e2bf7de6 | ||
|
|
c6f9ea2b20 | ||
|
|
166a3cfe09 | ||
|
|
cbf84876de | ||
|
|
8bd43f113b | ||
|
|
d710986923 | ||
|
|
dd63b78937 | ||
|
|
8d1a07585b | ||
|
|
f4638b48d5 | ||
|
|
a1d56fc7e4 | ||
|
|
9879c99873 | ||
|
|
1016efc646 | ||
|
|
156aa6822b | ||
|
|
34eb94320e | ||
|
|
ba1d14f40a | ||
|
|
dc4fbb1a14 | ||
|
|
7d770d2a6d | ||
|
|
7b861f563c | ||
|
|
65614ed3ba | ||
|
|
7acdf9cb3d | ||
|
|
2e4954a4e8 | ||
|
|
36052cb4f2 | ||
|
|
8d81f63e08 | ||
|
|
bf5be11549 | ||
|
|
36907ccac1 | ||
|
|
0f3e9e8754 | ||
|
|
883cdbefbe | ||
|
|
968d50e071 | ||
|
|
a9bb1c1f68 | ||
|
|
207d6cba43 | ||
|
|
40ba179703 | ||
|
|
5b8fde70ca | ||
|
|
e3232f0198 | ||
|
|
c5536aa25d | ||
|
|
37bb9ff767 | ||
|
|
f2e8207245 |
@@ -27,7 +27,7 @@ move webapp ..\..\launchers\standalone\target\classes
|
||||
|
||||
rem 5. build backend python modules
|
||||
echo "start installing python modules with pip: ${pip_path}"
|
||||
set requirementPath="%baseDir%/../chat/core/src/main/python/requirements.txt"
|
||||
set requirementPath="%baseDir%/../chat/python/requirements.txt"
|
||||
%pip_path% install -r %requirementPath%
|
||||
echo "install python modules success"
|
||||
|
||||
|
||||
@@ -37,7 +37,7 @@ rm -fr ${buildDir}/webapp
|
||||
|
||||
#5. build backend python modules
|
||||
echo "start installing python modules with pip: ${pip_path}"
|
||||
requirementPath=$baseDir/../chat/core/src/main/python/requirements.txt
|
||||
requirementPath=$baseDir/../chat/python/requirements.txt
|
||||
${pip_path} install -r ${requirementPath}
|
||||
echo "install python modules success"
|
||||
|
||||
|
||||
@@ -96,10 +96,11 @@ function runPythonService {
|
||||
break
|
||||
else
|
||||
if [ "$i" -eq 10 ]; then
|
||||
echo "llmparser Health check failed after 10 attempts. Exiting."
|
||||
echo "llmparser Health check failed after 10 attempts."
|
||||
echo "May still downloading model files. Please check llmparser.log in runtime directory."
|
||||
fi
|
||||
echo "Retrying after 5 seconds..."
|
||||
sleep 5
|
||||
fi
|
||||
done
|
||||
}
|
||||
}
|
||||
|
||||
@@ -21,7 +21,7 @@
|
||||
</includes>
|
||||
</fileSet>
|
||||
<fileSet>
|
||||
<directory>${project.basedir}/../../chat/core/src/main/python</directory>
|
||||
<directory>${project.basedir}/../../chat/python</directory>
|
||||
<outputDirectory>llmparser</outputDirectory>
|
||||
<fileMode>0777</fileMode>
|
||||
<directoryMode>0755</directoryMode>
|
||||
|
||||
@@ -6,7 +6,7 @@ import lombok.Data;
|
||||
@Data
|
||||
public class AuthGroup {
|
||||
|
||||
private String modelId;
|
||||
private Long modelId;
|
||||
private String name;
|
||||
private Integer groupId;
|
||||
private List<AuthRule> authRules;
|
||||
|
||||
@@ -7,13 +7,13 @@ import lombok.ToString;
|
||||
@ToString
|
||||
public class AuthRes {
|
||||
|
||||
private String modelId;
|
||||
private Long modelId;
|
||||
private String name;
|
||||
|
||||
public AuthRes() {
|
||||
}
|
||||
|
||||
public AuthRes(String modelId, String name) {
|
||||
public AuthRes(Long modelId, String name) {
|
||||
this.modelId = modelId;
|
||||
this.name = name;
|
||||
}
|
||||
|
||||
@@ -1,11 +1,13 @@
|
||||
package com.tencent.supersonic.auth.api.authorization.request;
|
||||
|
||||
import com.google.common.collect.Lists;
|
||||
import com.tencent.supersonic.auth.api.authorization.pojo.AuthRes;
|
||||
import java.util.ArrayList;
|
||||
import java.util.List;
|
||||
|
||||
import lombok.Data;
|
||||
import lombok.ToString;
|
||||
import org.springframework.util.CollectionUtils;
|
||||
|
||||
import java.util.ArrayList;
|
||||
import java.util.List;
|
||||
|
||||
@Data
|
||||
@ToString
|
||||
@@ -15,5 +17,17 @@ public class QueryAuthResReq {
|
||||
|
||||
private List<AuthRes> resources;
|
||||
|
||||
private String modelId;
|
||||
private Long modelId;
|
||||
|
||||
private List<Long> modelIds;
|
||||
|
||||
public List<Long> getModelIds() {
|
||||
if (!CollectionUtils.isEmpty(modelIds)) {
|
||||
return modelIds;
|
||||
}
|
||||
if (modelId != null) {
|
||||
return Lists.newArrayList(modelId);
|
||||
}
|
||||
return Lists.newArrayList();
|
||||
}
|
||||
}
|
||||
|
||||
@@ -33,12 +33,6 @@
|
||||
<artifactId>spring-boot-starter-jdbc</artifactId>
|
||||
</dependency>
|
||||
|
||||
<dependency>
|
||||
<groupId>org.mybatis</groupId>
|
||||
<artifactId>mybatis</artifactId>
|
||||
</dependency>
|
||||
|
||||
|
||||
<dependency>
|
||||
<groupId>com.alibaba</groupId>
|
||||
<artifactId>druid</artifactId>
|
||||
@@ -52,12 +46,7 @@
|
||||
<groupId>org.springframework.boot</groupId>
|
||||
<artifactId>spring-boot-starter-web</artifactId>
|
||||
</dependency>
|
||||
<dependency>
|
||||
<groupId>org.mybatis</groupId>
|
||||
<artifactId>mybatis-spring</artifactId>
|
||||
<version>${mybatis-spring.version}</version>
|
||||
<scope>compile</scope>
|
||||
</dependency>
|
||||
|
||||
<dependency>
|
||||
<groupId>com.github.pagehelper</groupId>
|
||||
<artifactId>pagehelper</artifactId>
|
||||
|
||||
@@ -13,7 +13,6 @@ import com.tencent.supersonic.auth.api.authorization.service.AuthService;
|
||||
import com.tencent.supersonic.auth.api.authorization.pojo.AuthGroup;
|
||||
import com.tencent.supersonic.auth.api.authorization.pojo.AuthRule;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.apache.commons.lang3.StringUtils;
|
||||
import org.springframework.jdbc.core.JdbcTemplate;
|
||||
import org.springframework.stereotype.Service;
|
||||
import org.springframework.util.CollectionUtils;
|
||||
@@ -48,7 +47,7 @@ public class AuthServiceImpl implements AuthService {
|
||||
public List<AuthGroup> queryAuthGroups(String modelId, Integer groupId) {
|
||||
return load().stream()
|
||||
.filter(group -> (Objects.isNull(groupId) || groupId.equals(group.getGroupId()))
|
||||
&& modelId.equals(group.getModelId()))
|
||||
&& modelId.equals(group.getModelId().toString()))
|
||||
.collect(Collectors.toList());
|
||||
}
|
||||
|
||||
@@ -80,17 +79,14 @@ public class AuthServiceImpl implements AuthService {
|
||||
@Override
|
||||
public AuthorizedResourceResp queryAuthorizedResources(QueryAuthResReq req, User user) {
|
||||
Set<String> userOrgIds = userService.getUserAllOrgId(user.getName());
|
||||
if (!CollectionUtils.isEmpty(userOrgIds)) {
|
||||
req.setDepartmentIds(new ArrayList<>(userOrgIds));
|
||||
}
|
||||
List<AuthGroup> groups = getAuthGroups(req, user.getName());
|
||||
List<AuthGroup> groups = getAuthGroups(req.getModelIds(), user.getName(), new ArrayList<>(userOrgIds));
|
||||
AuthorizedResourceResp resource = new AuthorizedResourceResp();
|
||||
Map<String, List<AuthGroup>> authGroupsByModelId = groups.stream()
|
||||
Map<Long, List<AuthGroup>> authGroupsByModelId = groups.stream()
|
||||
.collect(Collectors.groupingBy(AuthGroup::getModelId));
|
||||
Map<String, List<AuthRes>> reqAuthRes = req.getResources().stream()
|
||||
Map<Long, List<AuthRes>> reqAuthRes = req.getResources().stream()
|
||||
.collect(Collectors.groupingBy(AuthRes::getModelId));
|
||||
|
||||
for (String modelId : reqAuthRes.keySet()) {
|
||||
for (Long modelId : reqAuthRes.keySet()) {
|
||||
List<AuthRes> reqResourcesList = reqAuthRes.get(modelId);
|
||||
AuthResGrp rg = new AuthResGrp();
|
||||
if (authGroupsByModelId.containsKey(modelId)) {
|
||||
@@ -113,7 +109,7 @@ public class AuthServiceImpl implements AuthService {
|
||||
}
|
||||
}
|
||||
|
||||
if (StringUtils.isNotEmpty(req.getModelId())) {
|
||||
if (req.getModelId() != null) {
|
||||
List<AuthGroup> authGroups = authGroupsByModelId.get(req.getModelId());
|
||||
if (!CollectionUtils.isEmpty(authGroups)) {
|
||||
for (AuthGroup group : authGroups) {
|
||||
@@ -130,17 +126,17 @@ public class AuthServiceImpl implements AuthService {
|
||||
return resource;
|
||||
}
|
||||
|
||||
private List<AuthGroup> getAuthGroups(QueryAuthResReq req, String userName) {
|
||||
private List<AuthGroup> getAuthGroups(List<Long> modelIds, String userName, List<String> departmentIds) {
|
||||
List<AuthGroup> groups = load().stream()
|
||||
.filter(group -> {
|
||||
if (!Objects.equals(group.getModelId(), req.getModelId())) {
|
||||
if (CollectionUtils.isEmpty(modelIds) || !modelIds.contains(group.getModelId())) {
|
||||
return false;
|
||||
}
|
||||
if (!CollectionUtils.isEmpty(group.getAuthorizedUsers()) && group.getAuthorizedUsers()
|
||||
.contains(userName)) {
|
||||
return true;
|
||||
}
|
||||
for (String departmentId : req.getDepartmentIds()) {
|
||||
for (String departmentId : departmentIds) {
|
||||
if (!CollectionUtils.isEmpty(group.getAuthorizedDepartmentIds())
|
||||
&& group.getAuthorizedDepartmentIds().contains(departmentId)) {
|
||||
return true;
|
||||
@@ -148,7 +144,7 @@ public class AuthServiceImpl implements AuthService {
|
||||
}
|
||||
return false;
|
||||
}).collect(Collectors.toList());
|
||||
log.info("user:{} department:{} authGroups:{}", userName, req.getDepartmentIds(), groups);
|
||||
log.info("user:{} department:{} authGroups:{}", userName, departmentIds, groups);
|
||||
return groups;
|
||||
}
|
||||
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
package com.tencent.supersonic.chat.api.component;
|
||||
|
||||
import com.tencent.supersonic.chat.api.pojo.SemanticCorrectInfo;
|
||||
import net.sf.jsqlparser.JSQLParserException;
|
||||
import com.tencent.supersonic.chat.api.pojo.SemanticParseInfo;
|
||||
import com.tencent.supersonic.chat.api.pojo.request.QueryReq;
|
||||
|
||||
/**
|
||||
* A semantic corrector checks validity of extracted semantic information and
|
||||
@@ -9,5 +9,5 @@ import net.sf.jsqlparser.JSQLParserException;
|
||||
*/
|
||||
public interface SemanticCorrector {
|
||||
|
||||
void correct(SemanticCorrectInfo semanticCorrectInfo) throws JSQLParserException;
|
||||
void correct(QueryReq queryReq, SemanticParseInfo semanticParseInfo);
|
||||
}
|
||||
|
||||
@@ -9,12 +9,13 @@ import com.tencent.supersonic.semantic.api.model.request.PageMetricReq;
|
||||
import com.tencent.supersonic.semantic.api.model.response.DomainResp;
|
||||
import com.tencent.supersonic.semantic.api.model.response.DimensionResp;
|
||||
import com.tencent.supersonic.semantic.api.model.response.ExplainResp;
|
||||
import com.tencent.supersonic.semantic.api.model.response.ModelResp;
|
||||
import com.tencent.supersonic.semantic.api.model.response.MetricResp;
|
||||
import com.tencent.supersonic.semantic.api.model.response.ModelResp;
|
||||
import com.tencent.supersonic.semantic.api.model.response.ModelSchemaResp;
|
||||
import com.tencent.supersonic.semantic.api.model.response.QueryResultWithSchemaResp;
|
||||
import com.tencent.supersonic.semantic.api.query.request.ExplainSqlReq;
|
||||
import com.tencent.supersonic.semantic.api.query.request.QueryDimValueReq;
|
||||
import com.tencent.supersonic.semantic.api.query.request.QueryDslReq;
|
||||
import com.tencent.supersonic.semantic.api.query.request.QueryS2SQLReq;
|
||||
import com.tencent.supersonic.semantic.api.query.request.QueryMultiStructReq;
|
||||
import com.tencent.supersonic.semantic.api.query.request.QueryStructReq;
|
||||
|
||||
@@ -37,7 +38,7 @@ public interface SemanticInterpreter {
|
||||
|
||||
QueryResultWithSchemaResp queryByMultiStruct(QueryMultiStructReq queryMultiStructReq, User user);
|
||||
|
||||
QueryResultWithSchemaResp queryByDsl(QueryDslReq queryDslReq, User user);
|
||||
QueryResultWithSchemaResp queryByS2SQL(QueryS2SQLReq queryS2SQLReq, User user);
|
||||
|
||||
QueryResultWithSchemaResp queryDimValue(QueryDimValueReq queryDimValueReq, User user);
|
||||
|
||||
@@ -47,9 +48,9 @@ public interface SemanticInterpreter {
|
||||
|
||||
ModelSchema getModelSchema(Long model, Boolean cacheEnable);
|
||||
|
||||
PageInfo<DimensionResp> getDimensionPage(PageDimensionReq pageDimensionCmd);
|
||||
PageInfo<DimensionResp> getDimensionPage(PageDimensionReq pageDimensionReq);
|
||||
|
||||
PageInfo<MetricResp> getMetricPage(PageMetricReq pageMetricCmd, User user);
|
||||
PageInfo<MetricResp> getMetricPage(PageMetricReq pageDimensionReq, User user);
|
||||
|
||||
List<DomainResp> getDomainList(User user);
|
||||
|
||||
@@ -57,4 +58,6 @@ public interface SemanticInterpreter {
|
||||
|
||||
<T> ExplainResp explain(ExplainSqlReq<T> explainSqlReq, User user) throws Exception;
|
||||
|
||||
List<ModelSchemaResp> fetchModelSchema(List<Long> ids, Boolean cacheEnable);
|
||||
|
||||
}
|
||||
|
||||
@@ -3,7 +3,6 @@ package com.tencent.supersonic.chat.api.component;
|
||||
import com.tencent.supersonic.auth.api.authentication.pojo.User;
|
||||
import com.tencent.supersonic.chat.api.pojo.SemanticParseInfo;
|
||||
import com.tencent.supersonic.chat.api.pojo.response.QueryResult;
|
||||
import com.tencent.supersonic.semantic.api.model.response.ExplainResp;
|
||||
import org.apache.calcite.sql.parser.SqlParseException;
|
||||
|
||||
/**
|
||||
@@ -15,7 +14,9 @@ public interface SemanticQuery {
|
||||
|
||||
QueryResult execute(User user) throws SqlParseException;
|
||||
|
||||
ExplainResp explain(User user);
|
||||
void initS2Sql(User user);
|
||||
|
||||
String explain(User user);
|
||||
|
||||
SemanticParseInfo getParseInfo();
|
||||
|
||||
|
||||
@@ -1,8 +1,13 @@
|
||||
package com.tencent.supersonic.chat.api.pojo;
|
||||
|
||||
import com.google.common.collect.Sets;
|
||||
import com.tencent.supersonic.common.pojo.ModelRela;
|
||||
import lombok.Data;
|
||||
import org.springframework.util.CollectionUtils;
|
||||
|
||||
import java.util.ArrayList;
|
||||
import java.util.HashSet;
|
||||
import java.util.List;
|
||||
import java.util.Optional;
|
||||
import java.util.Set;
|
||||
|
||||
@@ -13,7 +18,9 @@ public class ModelSchema {
|
||||
private Set<SchemaElement> metrics = new HashSet<>();
|
||||
private Set<SchemaElement> dimensions = new HashSet<>();
|
||||
private Set<SchemaElement> dimensionValues = new HashSet<>();
|
||||
private Set<SchemaElement> tags = new HashSet<>();
|
||||
private SchemaElement entity = new SchemaElement();
|
||||
private List<ModelRela> modelRelas = new ArrayList<>();
|
||||
|
||||
public SchemaElement getElement(SchemaElementType elementType, long elementID) {
|
||||
Optional<SchemaElement> element = Optional.empty();
|
||||
@@ -34,6 +41,9 @@ public class ModelSchema {
|
||||
case VALUE:
|
||||
element = dimensionValues.stream().filter(e -> e.getId() == elementID).findFirst();
|
||||
break;
|
||||
case TAG:
|
||||
element = tags.stream().filter(e -> e.getId() == elementID).findFirst();
|
||||
break;
|
||||
default:
|
||||
}
|
||||
|
||||
@@ -44,4 +54,45 @@ public class ModelSchema {
|
||||
}
|
||||
}
|
||||
|
||||
public SchemaElement getElement(SchemaElementType elementType, String name) {
|
||||
Optional<SchemaElement> element = Optional.empty();
|
||||
|
||||
switch (elementType) {
|
||||
case ENTITY:
|
||||
element = Optional.ofNullable(entity);
|
||||
break;
|
||||
case MODEL:
|
||||
element = Optional.of(model);
|
||||
break;
|
||||
case METRIC:
|
||||
element = metrics.stream().filter(e -> name.equals(e.getName())).findFirst();
|
||||
break;
|
||||
case DIMENSION:
|
||||
element = dimensions.stream().filter(e -> name.equals(e.getName())).findFirst();
|
||||
break;
|
||||
case VALUE:
|
||||
element = dimensionValues.stream().filter(e -> name.equals(e.getName())).findFirst();
|
||||
break;
|
||||
default:
|
||||
}
|
||||
|
||||
if (element.isPresent()) {
|
||||
return element.get();
|
||||
} else {
|
||||
return null;
|
||||
}
|
||||
}
|
||||
|
||||
public Set<Long> getModelClusterSet() {
|
||||
if (CollectionUtils.isEmpty(modelRelas)) {
|
||||
return Sets.newHashSet();
|
||||
}
|
||||
Set<Long> modelClusterSet = new HashSet<>();
|
||||
modelRelas.forEach(modelRela -> {
|
||||
modelClusterSet.add(modelRela.getToModelId());
|
||||
modelClusterSet.add(modelRela.getFromModelId());
|
||||
});
|
||||
return modelClusterSet;
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
@@ -15,6 +15,7 @@ public class QueryContext {
|
||||
private QueryReq request;
|
||||
private List<SemanticQuery> candidateQueries = new ArrayList<>();
|
||||
private SchemaMapInfo mapInfo = new SchemaMapInfo();
|
||||
private SchemaModelClusterMapInfo modelClusterMapInfo = new SchemaModelClusterMapInfo();
|
||||
|
||||
public QueryContext(QueryReq request) {
|
||||
this.request = request;
|
||||
|
||||
@@ -0,0 +1,19 @@
|
||||
package com.tencent.supersonic.chat.api.pojo;
|
||||
|
||||
|
||||
import lombok.AllArgsConstructor;
|
||||
import lombok.Builder;
|
||||
import lombok.Data;
|
||||
import lombok.NoArgsConstructor;
|
||||
|
||||
@Data
|
||||
@Builder
|
||||
@NoArgsConstructor
|
||||
@AllArgsConstructor
|
||||
public class RelatedSchemaElement {
|
||||
|
||||
private Long dimensionId;
|
||||
|
||||
private boolean isNecessary;
|
||||
|
||||
}
|
||||
@@ -1,14 +1,15 @@
|
||||
package com.tencent.supersonic.chat.api.pojo;
|
||||
|
||||
import com.google.common.base.Objects;
|
||||
import java.io.Serializable;
|
||||
import java.util.List;
|
||||
import lombok.AllArgsConstructor;
|
||||
import lombok.Builder;
|
||||
import lombok.Data;
|
||||
import lombok.Getter;
|
||||
import lombok.NoArgsConstructor;
|
||||
|
||||
import java.io.Serializable;
|
||||
import java.util.List;
|
||||
|
||||
@Data
|
||||
@Getter
|
||||
@Builder
|
||||
@@ -22,13 +23,14 @@ public class SchemaElement implements Serializable {
|
||||
private String bizName;
|
||||
private Long useCnt;
|
||||
private SchemaElementType type;
|
||||
|
||||
private List<String> alias;
|
||||
|
||||
private List<SchemaValueMap> schemaValueMaps;
|
||||
private List<RelatedSchemaElement> relatedSchemaElements;
|
||||
|
||||
private String defaultAgg;
|
||||
|
||||
private double order;
|
||||
|
||||
@Override
|
||||
public boolean equals(Object o) {
|
||||
if (this == o) {
|
||||
@@ -40,13 +42,13 @@ public class SchemaElement implements Serializable {
|
||||
SchemaElement schemaElement = (SchemaElement) o;
|
||||
return Objects.equal(model, schemaElement.model) && Objects.equal(id,
|
||||
schemaElement.id) && Objects.equal(name, schemaElement.name)
|
||||
&& Objects.equal(bizName, schemaElement.bizName) && Objects.equal(
|
||||
useCnt, schemaElement.useCnt) && Objects.equal(type, schemaElement.type);
|
||||
&& Objects.equal(bizName, schemaElement.bizName)
|
||||
&& Objects.equal(type, schemaElement.type);
|
||||
}
|
||||
|
||||
@Override
|
||||
public int hashCode() {
|
||||
return Objects.hashCode(model, id, name, bizName, useCnt, type);
|
||||
return Objects.hashCode(model, id, name, bizName, type);
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
@@ -6,6 +6,7 @@ public enum SchemaElementType {
|
||||
DIMENSION,
|
||||
VALUE,
|
||||
ENTITY,
|
||||
TAG,
|
||||
ID,
|
||||
DATE
|
||||
}
|
||||
|
||||
@@ -0,0 +1,61 @@
|
||||
package com.tencent.supersonic.chat.api.pojo;
|
||||
|
||||
import com.clickhouse.client.internal.apache.commons.compress.utils.Lists;
|
||||
import com.tencent.supersonic.common.pojo.ModelCluster;
|
||||
import lombok.Data;
|
||||
import org.springframework.util.CollectionUtils;
|
||||
|
||||
import java.util.HashMap;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
import java.util.Set;
|
||||
|
||||
@Data
|
||||
public class SchemaModelClusterMapInfo {
|
||||
|
||||
private Map<String, List<SchemaElementMatch>> modelElementMatches = new HashMap<>();
|
||||
|
||||
public Set<String> getMatchedModelClusters() {
|
||||
return modelElementMatches.keySet();
|
||||
}
|
||||
|
||||
public List<SchemaElementMatch> getMatchedElements(Long modelId) {
|
||||
for (String key : modelElementMatches.keySet()) {
|
||||
if (ModelCluster.getModelIdFromKey(key).contains(modelId)) {
|
||||
return modelElementMatches.get(key);
|
||||
}
|
||||
}
|
||||
return Lists.newArrayList();
|
||||
}
|
||||
|
||||
public List<SchemaElementMatch> getMatchedElements(String modelCluster) {
|
||||
return modelElementMatches.get(modelCluster);
|
||||
}
|
||||
|
||||
public Map<String, List<SchemaElementMatch>> getModelElementMatches() {
|
||||
return modelElementMatches;
|
||||
}
|
||||
|
||||
public Map<String, List<SchemaElementMatch>> getElementMatchesByModelIds(Set<Long> modelIds) {
|
||||
if (CollectionUtils.isEmpty(modelIds)) {
|
||||
return modelElementMatches;
|
||||
}
|
||||
Map<String, List<SchemaElementMatch>> modelElementMatchesFiltered = new HashMap<>();
|
||||
for (String key : modelElementMatches.keySet()) {
|
||||
for (Long modelId : modelIds) {
|
||||
if (ModelCluster.getModelIdFromKey(key).contains(modelId)) {
|
||||
modelElementMatchesFiltered.put(key, modelElementMatches.get(key));
|
||||
}
|
||||
}
|
||||
}
|
||||
return modelElementMatchesFiltered;
|
||||
}
|
||||
|
||||
public void setModelElementMatches(Map<String, List<SchemaElementMatch>> modelElementMatches) {
|
||||
this.modelElementMatches = modelElementMatches;
|
||||
}
|
||||
|
||||
public void setMatchedElements(String modelCluster, List<SchemaElementMatch> elementMatches) {
|
||||
modelElementMatches.put(modelCluster, elementMatches);
|
||||
}
|
||||
}
|
||||
@@ -5,8 +5,13 @@ import com.tencent.supersonic.chat.api.pojo.request.QueryFilter;
|
||||
import com.tencent.supersonic.chat.api.pojo.response.EntityInfo;
|
||||
import com.tencent.supersonic.chat.api.pojo.response.SqlInfo;
|
||||
import com.tencent.supersonic.common.pojo.DateConf;
|
||||
import com.tencent.supersonic.common.pojo.ModelCluster;
|
||||
import com.tencent.supersonic.common.pojo.Order;
|
||||
import com.tencent.supersonic.common.pojo.QueryType;
|
||||
import com.tencent.supersonic.common.pojo.enums.AggregateTypeEnum;
|
||||
import com.tencent.supersonic.common.pojo.enums.FilterType;
|
||||
import lombok.Data;
|
||||
|
||||
import java.util.ArrayList;
|
||||
import java.util.Comparator;
|
||||
import java.util.HashMap;
|
||||
@@ -15,15 +20,13 @@ import java.util.List;
|
||||
import java.util.Map;
|
||||
import java.util.Set;
|
||||
import java.util.TreeSet;
|
||||
import com.tencent.supersonic.common.pojo.enums.FilterType;
|
||||
import lombok.Data;
|
||||
|
||||
@Data
|
||||
public class SemanticParseInfo {
|
||||
|
||||
private Integer id;
|
||||
private String queryMode;
|
||||
private SchemaElement model;
|
||||
private ModelCluster model = new ModelCluster();
|
||||
private Set<SchemaElement> metrics = new TreeSet<>(new SchemaNameLengthComparator());
|
||||
private Set<SchemaElement> dimensions = new LinkedHashSet();
|
||||
private SchemaElement entity;
|
||||
@@ -34,25 +37,38 @@ public class SemanticParseInfo {
|
||||
private Set<Order> orders = new LinkedHashSet();
|
||||
private DateConf dateInfo;
|
||||
private Long limit;
|
||||
private Boolean nativeQuery = false;
|
||||
private double score;
|
||||
private List<SchemaElementMatch> elementMatches = new ArrayList<>();
|
||||
private Map<String, Object> properties = new HashMap<>();
|
||||
private EntityInfo entityInfo;
|
||||
private SqlInfo sqlInfo = new SqlInfo();
|
||||
private QueryType queryType = QueryType.OTHER;
|
||||
|
||||
public Long getModelId() {
|
||||
return model != null ? model.getId() : 0L;
|
||||
public String getModelClusterKey() {
|
||||
if (model == null) {
|
||||
return "";
|
||||
}
|
||||
return model.getKey();
|
||||
}
|
||||
|
||||
public String getModelName() {
|
||||
return model != null ? model.getName() : "null";
|
||||
if (model == null) {
|
||||
return "";
|
||||
}
|
||||
return model.getName();
|
||||
}
|
||||
|
||||
private static class SchemaNameLengthComparator implements Comparator<SchemaElement> {
|
||||
|
||||
@Override
|
||||
public int compare(SchemaElement o1, SchemaElement o2) {
|
||||
if (o1.getOrder() != o2.getOrder()) {
|
||||
if (o1.getOrder() < o2.getOrder()) {
|
||||
return -1;
|
||||
} else {
|
||||
return 1;
|
||||
}
|
||||
}
|
||||
int len1 = o1.getName().length();
|
||||
int len2 = o2.getName().length();
|
||||
if (len1 != len2) {
|
||||
@@ -70,4 +86,26 @@ public class SemanticParseInfo {
|
||||
return metrics;
|
||||
}
|
||||
|
||||
private Map<Long, Integer> getModelElementCountMap() {
|
||||
Map<Long, Integer> elementCountMap = new HashMap<>();
|
||||
elementMatches.forEach(element -> {
|
||||
int count = elementCountMap.getOrDefault(element.getElement().getModel(), 0);
|
||||
elementCountMap.put(element.getElement().getModel(), count + 1);
|
||||
});
|
||||
return elementCountMap;
|
||||
}
|
||||
|
||||
public Long getModelId() {
|
||||
Map<Long, Integer> elementCountMap = getModelElementCountMap();
|
||||
Long modelId = -1L;
|
||||
int maxCnt = 0;
|
||||
for (Long model : elementCountMap.keySet()) {
|
||||
if (elementCountMap.get(model) > maxCnt) {
|
||||
maxCnt = elementCountMap.get(model);
|
||||
modelId = model;
|
||||
}
|
||||
}
|
||||
return modelId;
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
@@ -1,9 +1,14 @@
|
||||
package com.tencent.supersonic.chat.api.pojo;
|
||||
|
||||
import org.springframework.util.CollectionUtils;
|
||||
|
||||
import java.io.Serializable;
|
||||
import java.util.ArrayList;
|
||||
import java.util.HashMap;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
import java.util.Optional;
|
||||
import java.util.Set;
|
||||
import java.util.stream.Collectors;
|
||||
|
||||
public class SemanticSchema implements Serializable {
|
||||
@@ -18,6 +23,64 @@ public class SemanticSchema implements Serializable {
|
||||
modelSchemaList.add(schema);
|
||||
}
|
||||
|
||||
public SchemaElement getElement(SchemaElementType elementType, long elementID) {
|
||||
Optional<SchemaElement> element = Optional.empty();
|
||||
|
||||
switch (elementType) {
|
||||
case ENTITY:
|
||||
element = getElementsById(elementID, getEntities());
|
||||
break;
|
||||
case MODEL:
|
||||
element = getElementsById(elementID, getModels());
|
||||
break;
|
||||
case METRIC:
|
||||
element = getElementsById(elementID, getMetrics());
|
||||
break;
|
||||
case DIMENSION:
|
||||
element = getElementsById(elementID, getDimensions());
|
||||
break;
|
||||
case VALUE:
|
||||
element = getElementsById(elementID, getDimensionValues());
|
||||
break;
|
||||
default:
|
||||
}
|
||||
|
||||
if (element.isPresent()) {
|
||||
return element.get();
|
||||
} else {
|
||||
return null;
|
||||
}
|
||||
}
|
||||
|
||||
public SchemaElement getElementByName(SchemaElementType elementType, String name) {
|
||||
Optional<SchemaElement> element = Optional.empty();
|
||||
|
||||
switch (elementType) {
|
||||
case ENTITY:
|
||||
element = getElementsByName(name, getEntities());
|
||||
break;
|
||||
case MODEL:
|
||||
element = getElementsByName(name, getModels());
|
||||
break;
|
||||
case METRIC:
|
||||
element = getElementsByName(name, getMetrics());
|
||||
break;
|
||||
case DIMENSION:
|
||||
element = getElementsByName(name, getDimensions());
|
||||
break;
|
||||
case VALUE:
|
||||
element = getElementsByName(name, getDimensionValues());
|
||||
break;
|
||||
default:
|
||||
}
|
||||
|
||||
if (element.isPresent()) {
|
||||
return element.get();
|
||||
} else {
|
||||
return null;
|
||||
}
|
||||
}
|
||||
|
||||
public Map<Long, String> getModelIdToName() {
|
||||
return modelSchemaList.stream()
|
||||
.collect(Collectors.toMap(a -> a.getModel().getId(), a -> a.getModel().getName(), (k1, k2) -> k1));
|
||||
@@ -35,9 +98,28 @@ public class SemanticSchema implements Serializable {
|
||||
return dimensions;
|
||||
}
|
||||
|
||||
public List<SchemaElement> getDimensions(Long modelId) {
|
||||
public List<SchemaElement> getDimensions(Set<Long> modelIds) {
|
||||
List<SchemaElement> dimensions = getDimensions();
|
||||
return getElementsByModelId(modelId, dimensions);
|
||||
return getElementsByModelId(modelIds, dimensions);
|
||||
}
|
||||
|
||||
public SchemaElement getDimensions(Long id) {
|
||||
List<SchemaElement> dimensions = getDimensions();
|
||||
Optional<SchemaElement> dimension = getElementsById(id, dimensions);
|
||||
return dimension.orElse(null);
|
||||
}
|
||||
|
||||
public List<SchemaElement> getTags() {
|
||||
List<SchemaElement> tags = new ArrayList<>();
|
||||
modelSchemaList.stream().forEach(d -> tags.addAll(d.getTags()));
|
||||
return tags;
|
||||
}
|
||||
|
||||
public List<SchemaElement> getTags(Set<Long> modelIds) {
|
||||
List<SchemaElement> tags = new ArrayList<>();
|
||||
modelSchemaList.stream().filter(schemaElement -> modelIds.contains(schemaElement.getModel()))
|
||||
.forEach(d -> tags.addAll(d.getTags()));
|
||||
return tags;
|
||||
}
|
||||
|
||||
public List<SchemaElement> getMetrics() {
|
||||
@@ -46,26 +128,54 @@ public class SemanticSchema implements Serializable {
|
||||
return metrics;
|
||||
}
|
||||
|
||||
public List<SchemaElement> getMetrics(Long modelId) {
|
||||
public List<SchemaElement> getMetrics(Set<Long> modelIds) {
|
||||
List<SchemaElement> metrics = getMetrics();
|
||||
return getElementsByModelId(modelId, metrics);
|
||||
return getElementsByModelId(modelIds, metrics);
|
||||
}
|
||||
|
||||
private List<SchemaElement> getElementsByModelId(Long modelId, List<SchemaElement> elements) {
|
||||
public List<SchemaElement> getEntities() {
|
||||
List<SchemaElement> entities = new ArrayList<>();
|
||||
modelSchemaList.stream().forEach(d -> entities.add(d.getEntity()));
|
||||
return entities;
|
||||
}
|
||||
|
||||
private List<SchemaElement> getElementsByModelId(Set<Long> modelIds, List<SchemaElement> elements) {
|
||||
return elements.stream()
|
||||
.filter(schemaElement -> modelId.equals(schemaElement.getModel()))
|
||||
.filter(schemaElement -> modelIds.contains(schemaElement.getModel()))
|
||||
.collect(Collectors.toList());
|
||||
}
|
||||
|
||||
private Optional<SchemaElement> getElementsById(Long id, List<SchemaElement> elements) {
|
||||
return elements.stream()
|
||||
.filter(schemaElement -> id.equals(schemaElement.getId()))
|
||||
.findFirst();
|
||||
}
|
||||
|
||||
private Optional<SchemaElement> getElementsByName(String name, List<SchemaElement> elements) {
|
||||
return elements.stream()
|
||||
.filter(schemaElement -> name.equals(schemaElement.getName()))
|
||||
.findFirst();
|
||||
}
|
||||
|
||||
public List<SchemaElement> getModels() {
|
||||
List<SchemaElement> models = new ArrayList<>();
|
||||
modelSchemaList.stream().forEach(d -> models.add(d.getModel()));
|
||||
return models;
|
||||
}
|
||||
|
||||
public List<SchemaElement> getEntities() {
|
||||
List<SchemaElement> entities = new ArrayList<>();
|
||||
modelSchemaList.stream().forEach(d -> entities.add(d.getEntity()));
|
||||
return entities;
|
||||
public Map<String, String> getBizNameToName(Set<Long> modelIds) {
|
||||
List<SchemaElement> allElements = new ArrayList<>();
|
||||
allElements.addAll(getDimensions(modelIds));
|
||||
allElements.addAll(getMetrics(modelIds));
|
||||
return allElements.stream()
|
||||
.collect(Collectors.toMap(SchemaElement::getBizName, SchemaElement::getName, (k1, k2) -> k1));
|
||||
}
|
||||
|
||||
public Map<Long, ModelSchema> getModelSchemaMap() {
|
||||
if (CollectionUtils.isEmpty(modelSchemaList)) {
|
||||
return new HashMap<>();
|
||||
}
|
||||
return modelSchemaList.stream().collect(Collectors.toMap(modelSchema
|
||||
-> modelSchema.getModel().getModel(), modelSchema -> modelSchema));
|
||||
}
|
||||
}
|
||||
|
||||
@@ -32,6 +32,11 @@ public class ChatConfigBaseReq {
|
||||
*/
|
||||
private List<RecommendedQuestionReq> recommendedQuestions;
|
||||
|
||||
/**
|
||||
* the llm examples about the model
|
||||
*/
|
||||
private String llmExamples;
|
||||
|
||||
/**
|
||||
* available status
|
||||
*/
|
||||
|
||||
@@ -1,16 +1,21 @@
|
||||
package com.tencent.supersonic.chat.api.pojo.request;
|
||||
|
||||
import javax.validation.constraints.NotNull;
|
||||
import lombok.Data;
|
||||
|
||||
@Data
|
||||
public class DimensionValueReq {
|
||||
|
||||
private Integer agentId;
|
||||
|
||||
@NotNull
|
||||
private Long elementID;
|
||||
|
||||
@NotNull
|
||||
private Long modelId;
|
||||
|
||||
private String bizName;
|
||||
|
||||
private Object value;
|
||||
@NotNull
|
||||
private String value;
|
||||
}
|
||||
|
||||
@@ -3,16 +3,18 @@ package com.tencent.supersonic.chat.api.pojo.request;
|
||||
|
||||
import com.tencent.supersonic.auth.api.authentication.pojo.User;
|
||||
import com.tencent.supersonic.chat.api.pojo.SemanticParseInfo;
|
||||
import lombok.Builder;
|
||||
import lombok.Data;
|
||||
|
||||
@Builder
|
||||
@Data
|
||||
public class ExecuteQueryReq {
|
||||
private User user;
|
||||
private Integer agentId;
|
||||
private Integer chatId;
|
||||
private String queryText;
|
||||
private Long queryId = 7L;
|
||||
private Integer parseId = 2;
|
||||
private Long queryId;
|
||||
private Integer parseId;
|
||||
private SemanticParseInfo parseInfo;
|
||||
private boolean saveAnswer = true;
|
||||
private boolean saveAnswer;
|
||||
}
|
||||
|
||||
@@ -13,4 +13,8 @@ public class PageQueryInfoReq {
|
||||
private String userName;
|
||||
|
||||
private List<Long> ids;
|
||||
|
||||
public Integer getLimitStart() {
|
||||
return this.pageSize * (this.current - 1);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -16,6 +16,6 @@ public class QueryDataReq {
|
||||
private Set<QueryFilter> dimensionFilters = new HashSet<>();
|
||||
private Set<QueryFilter> metricFilters = new HashSet<>();
|
||||
private DateConf dateInfo;
|
||||
private Long queryId = 7L;
|
||||
private Integer parseId = 2;
|
||||
private Long queryId;
|
||||
private Integer parseId;
|
||||
}
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
package com.tencent.supersonic.chat.api.pojo.request;
|
||||
|
||||
import com.google.common.base.Objects;
|
||||
import com.tencent.supersonic.semantic.api.query.enums.FilterOperatorEnum;
|
||||
import com.tencent.supersonic.common.pojo.enums.FilterOperatorEnum;
|
||||
import lombok.Data;
|
||||
import lombok.ToString;
|
||||
|
||||
|
||||
@@ -7,7 +7,7 @@ import lombok.Data;
|
||||
public class QueryReq {
|
||||
private String queryText;
|
||||
private Integer chatId;
|
||||
private Long modelId = 0L;
|
||||
private Long modelId;
|
||||
private User user;
|
||||
private QueryFilters queryFilters;
|
||||
private boolean saveAnswer = true;
|
||||
|
||||
@@ -0,0 +1,13 @@
|
||||
package com.tencent.supersonic.chat.api.pojo.request;
|
||||
|
||||
|
||||
import lombok.Data;
|
||||
|
||||
@Data
|
||||
public class RecommendReq {
|
||||
|
||||
private Long modelId;
|
||||
|
||||
private Long metricId;
|
||||
|
||||
}
|
||||
@@ -18,7 +18,7 @@ public class SolvedQueryReq {
|
||||
|
||||
private String queryText;
|
||||
|
||||
private Long modelId;
|
||||
private String modelId;
|
||||
|
||||
private Integer agentId;
|
||||
|
||||
|
||||
@@ -23,6 +23,8 @@ public class ChatConfigResp {
|
||||
|
||||
private List<RecommendedQuestionReq> recommendedQuestions;
|
||||
|
||||
private String llmExamples;
|
||||
|
||||
/**
|
||||
* available status
|
||||
*/
|
||||
|
||||
@@ -1,13 +1,13 @@
|
||||
package com.tencent.supersonic.chat.api.pojo.response;
|
||||
|
||||
import lombok.Data;
|
||||
|
||||
import java.io.Serializable;
|
||||
import java.util.List;
|
||||
import lombok.Data;
|
||||
|
||||
@Data
|
||||
public class ModelInfo extends DataInfo implements Serializable {
|
||||
|
||||
private List<String> words;
|
||||
private String primaryEntityName;
|
||||
private String primaryEntityBizName;
|
||||
private String primaryKey;
|
||||
}
|
||||
|
||||
@@ -1,31 +1,24 @@
|
||||
package com.tencent.supersonic.chat.api.pojo.response;
|
||||
|
||||
import com.google.common.collect.Lists;
|
||||
import com.tencent.supersonic.chat.api.pojo.SemanticParseInfo;
|
||||
import lombok.Data;
|
||||
import lombok.Getter;
|
||||
import lombok.Builder;
|
||||
import lombok.NoArgsConstructor;
|
||||
import lombok.AllArgsConstructor;
|
||||
|
||||
import java.util.List;
|
||||
|
||||
@Data
|
||||
@Getter
|
||||
@Builder
|
||||
@NoArgsConstructor
|
||||
@AllArgsConstructor
|
||||
public class ParseResp {
|
||||
private Integer chatId;
|
||||
private String queryText;
|
||||
private Long queryId;
|
||||
private ParseState state;
|
||||
private List<SemanticParseInfo> selectedParses;
|
||||
private List<SemanticParseInfo> candidateParses;
|
||||
private List<SolvedQueryRecallResp> similarSolvedQuery;
|
||||
private List<SemanticParseInfo> selectedParses = Lists.newArrayList();
|
||||
private List<SemanticParseInfo> candidateParses = Lists.newArrayList();
|
||||
private ParseTimeCostDO parseTimeCost = new ParseTimeCostDO();
|
||||
|
||||
public enum ParseState {
|
||||
COMPLETED,
|
||||
PENDING,
|
||||
FAILED
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
@@ -0,0 +1,15 @@
|
||||
package com.tencent.supersonic.chat.api.pojo.response;
|
||||
|
||||
import lombok.Data;
|
||||
|
||||
@Data
|
||||
public class ParseTimeCostDO {
|
||||
|
||||
private long parseStartTime;
|
||||
private long parseTime;
|
||||
private long sqlTime;
|
||||
|
||||
public ParseTimeCostDO() {
|
||||
this.parseStartTime = System.currentTimeMillis();
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,11 @@
|
||||
package com.tencent.supersonic.chat.api.pojo.response;
|
||||
|
||||
import lombok.Data;
|
||||
|
||||
import java.util.List;
|
||||
|
||||
@Data
|
||||
public class QueryRecallResp {
|
||||
private List<SolvedQueryRecallResp> solvedQueryRecallRespList;
|
||||
private Long queryTimeCost;
|
||||
}
|
||||
@@ -1,6 +1,8 @@
|
||||
package com.tencent.supersonic.chat.api.pojo.response;
|
||||
|
||||
import java.util.Date;
|
||||
import java.util.List;
|
||||
import com.tencent.supersonic.chat.api.pojo.SemanticParseInfo;
|
||||
import lombok.Data;
|
||||
|
||||
@Data
|
||||
@@ -13,4 +15,5 @@ public class QueryResp {
|
||||
private String feedback;
|
||||
private String queryText;
|
||||
private QueryResult queryResult;
|
||||
private List<SemanticParseInfo> parseInfos;
|
||||
}
|
||||
@@ -21,4 +21,5 @@ public class QueryResult {
|
||||
private SemanticParseInfo chatContext;
|
||||
private Object response;
|
||||
private List<Map<String, Object>> queryResults;
|
||||
private Long queryTimeCost;
|
||||
}
|
||||
|
||||
@@ -5,7 +5,7 @@ import lombok.Data;
|
||||
@Data
|
||||
public class SqlInfo {
|
||||
|
||||
private String llmParseSql;
|
||||
private String logicSql;
|
||||
private String querySql;
|
||||
private String s2SQL;
|
||||
private String correctS2SQL;
|
||||
private String querySQL;
|
||||
}
|
||||
|
||||
@@ -59,16 +59,7 @@
|
||||
<groupId>org.springframework.boot</groupId>
|
||||
<artifactId>spring-boot-starter-web</artifactId>
|
||||
</dependency>
|
||||
<dependency>
|
||||
<groupId>org.mybatis</groupId>
|
||||
<artifactId>mybatis</artifactId>
|
||||
</dependency>
|
||||
|
||||
<dependency>
|
||||
<groupId>org.mybatis.spring.boot</groupId>
|
||||
<artifactId>mybatis-spring-boot-starter</artifactId>
|
||||
<version>${mybatis-spring.version}</version>
|
||||
</dependency>
|
||||
<dependency>
|
||||
<groupId>com.alibaba</groupId>
|
||||
<artifactId>druid</artifactId>
|
||||
@@ -78,24 +69,6 @@
|
||||
<groupId>mysql</groupId>
|
||||
<artifactId>mysql-connector-java</artifactId>
|
||||
</dependency>
|
||||
<dependency>
|
||||
<groupId>org.mybatis</groupId>
|
||||
<artifactId>mybatis-spring</artifactId>
|
||||
<scope>test</scope>
|
||||
</dependency>
|
||||
|
||||
<dependency>
|
||||
<groupId>org.mybatis.spring.boot</groupId>
|
||||
<artifactId>mybatis-spring-boot-starter-test</artifactId>
|
||||
<version>${mybatis.test.version}</version>
|
||||
<scope>test</scope>
|
||||
</dependency>
|
||||
|
||||
<dependency>
|
||||
<groupId>com.github.pagehelper</groupId>
|
||||
<artifactId>pagehelper-spring-boot-starter</artifactId>
|
||||
<version>${pagehelper.spring.version}</version>
|
||||
</dependency>
|
||||
|
||||
<dependency>
|
||||
<groupId>com.h2database</groupId>
|
||||
|
||||
@@ -2,7 +2,7 @@ package com.tencent.supersonic.chat.agent.tool;
|
||||
|
||||
public enum AgentToolType {
|
||||
RULE,
|
||||
DSL,
|
||||
LLM_S2SQL,
|
||||
PLUGIN,
|
||||
INTERPRET
|
||||
}
|
||||
|
||||
@@ -0,0 +1,16 @@
|
||||
package com.tencent.supersonic.chat.agent.tool;
|
||||
|
||||
|
||||
import java.util.List;
|
||||
import lombok.AllArgsConstructor;
|
||||
import lombok.Data;
|
||||
import lombok.NoArgsConstructor;
|
||||
|
||||
@Data
|
||||
@NoArgsConstructor
|
||||
@AllArgsConstructor
|
||||
public class CommonAgentTool extends AgentTool {
|
||||
|
||||
protected List<Long> modelIds;
|
||||
|
||||
}
|
||||
@@ -5,9 +5,7 @@ import lombok.Data;
|
||||
import java.util.List;
|
||||
|
||||
@Data
|
||||
public class DslTool extends AgentTool {
|
||||
|
||||
private List<Long> modelIds;
|
||||
public class LLMParserTool extends CommonAgentTool {
|
||||
|
||||
private List<String> exampleQuestions;
|
||||
|
||||
@@ -7,12 +7,13 @@ import org.apache.commons.collections.CollectionUtils;
|
||||
import java.util.List;
|
||||
|
||||
@Data
|
||||
public class RuleQueryTool extends AgentTool {
|
||||
public class RuleQueryTool extends CommonAgentTool {
|
||||
|
||||
private List<Long> modelIds;
|
||||
|
||||
private List<String> queryModes;
|
||||
|
||||
private List<String> queryTypes;
|
||||
|
||||
public boolean isContainsAllModel() {
|
||||
return CollectionUtils.isNotEmpty(modelIds) && modelIds.contains(-1L);
|
||||
}
|
||||
|
||||
@@ -1,43 +1,161 @@
|
||||
package com.tencent.supersonic.chat.config;
|
||||
|
||||
import com.tencent.supersonic.common.service.SysParameterService;
|
||||
import lombok.Data;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.apache.commons.lang3.StringUtils;
|
||||
import org.springframework.beans.factory.annotation.Autowired;
|
||||
import org.springframework.beans.factory.annotation.Value;
|
||||
import org.springframework.context.annotation.Configuration;
|
||||
import org.springframework.context.annotation.PropertySource;
|
||||
|
||||
@Configuration
|
||||
@Data
|
||||
@PropertySource("classpath:optimization.properties")
|
||||
//@ComponentScan(basePackages = "com.tencent.supersonic.chat")
|
||||
@Slf4j
|
||||
public class OptimizationConfig {
|
||||
|
||||
@Value("${one.detection.size}")
|
||||
@Value("${one.detection.size:8}")
|
||||
private Integer oneDetectionSize;
|
||||
@Value("${one.detection.max.size}")
|
||||
|
||||
@Value("${one.detection.max.size:20}")
|
||||
private Integer oneDetectionMaxSize;
|
||||
|
||||
@Value("${metric.dimension.min.threshold}")
|
||||
@Value("${metric.dimension.min.threshold:0.3}")
|
||||
private Double metricDimensionMinThresholdConfig;
|
||||
|
||||
@Value("${metric.dimension.threshold}")
|
||||
@Value("${metric.dimension.threshold:0.3}")
|
||||
private Double metricDimensionThresholdConfig;
|
||||
|
||||
@Value("${dimension.value.threshold}")
|
||||
@Value("${dimension.value.threshold:0.5}")
|
||||
private Double dimensionValueThresholdConfig;
|
||||
|
||||
@Value("${function.bonus.threshold}")
|
||||
private Double functionBonusThreshold;
|
||||
|
||||
@Value("${long.text.threshold}")
|
||||
@Value("${long.text.threshold:0.8}")
|
||||
private Double longTextThreshold;
|
||||
|
||||
@Value("${short.text.threshold}")
|
||||
@Value("${short.text.threshold:0.5}")
|
||||
private Double shortTextThreshold;
|
||||
|
||||
@Value("${query.text.length.threshold}")
|
||||
@Value("${query.text.length.threshold:10}")
|
||||
private Integer queryTextLengthThreshold;
|
||||
@Value("${embedding.mapper.word.min:4}")
|
||||
private int embeddingMapperWordMin;
|
||||
|
||||
@Value("${candidate.threshold}")
|
||||
private Double candidateThreshold;
|
||||
@Value("${embedding.mapper.word.max:5}")
|
||||
private int embeddingMapperWordMax;
|
||||
|
||||
@Value("${embedding.mapper.batch:50}")
|
||||
private int embeddingMapperBatch;
|
||||
|
||||
@Value("${embedding.mapper.number:5}")
|
||||
private int embeddingMapperNumber;
|
||||
|
||||
@Value("${embedding.mapper.round.number:10}")
|
||||
private int embeddingMapperRoundNumber;
|
||||
|
||||
@Value("${embedding.mapper.distance.threshold:0.58}")
|
||||
private Double embeddingMapperDistanceThreshold;
|
||||
|
||||
@Value("${s2SQL.linking.value.switch:true}")
|
||||
private boolean useLinkingValueSwitch;
|
||||
|
||||
@Value("${s2SQL.use.switch:true}")
|
||||
private boolean useS2SqlSwitch;
|
||||
|
||||
@Value("${text2sql.example.num:10}")
|
||||
private int text2sqlExampleNum;
|
||||
|
||||
@Value("${text2sql.fewShots.num:10}")
|
||||
private int text2sqlFewShotsNum;
|
||||
|
||||
@Value("${text2sql.self.consistency.num:5}")
|
||||
private int text2sqlSelfConsistencyNum;
|
||||
|
||||
@Value("${text2sql.collection.name:text2dsl_agent_collection}")
|
||||
private String text2sqlCollectionName;
|
||||
|
||||
@Autowired
|
||||
private SysParameterService sysParameterService;
|
||||
|
||||
public Integer getOneDetectionSize() {
|
||||
return convertValue("one.detection.size", Integer.class, oneDetectionSize);
|
||||
}
|
||||
|
||||
public Integer getOneDetectionMaxSize() {
|
||||
return convertValue("one.detection.max.size", Integer.class, oneDetectionMaxSize);
|
||||
}
|
||||
|
||||
public Double getMetricDimensionMinThresholdConfig() {
|
||||
return convertValue("metric.dimension.min.threshold", Double.class, metricDimensionMinThresholdConfig);
|
||||
}
|
||||
|
||||
public Double getMetricDimensionThresholdConfig() {
|
||||
return convertValue("metric.dimension.threshold", Double.class, metricDimensionThresholdConfig);
|
||||
}
|
||||
|
||||
public Double getDimensionValueThresholdConfig() {
|
||||
return convertValue("dimension.value.threshold", Double.class, dimensionValueThresholdConfig);
|
||||
}
|
||||
|
||||
public Double getLongTextThreshold() {
|
||||
return convertValue("long.text.threshold", Double.class, longTextThreshold);
|
||||
}
|
||||
|
||||
public Double getShortTextThreshold() {
|
||||
return convertValue("short.text.threshold", Double.class, shortTextThreshold);
|
||||
}
|
||||
|
||||
public Integer getQueryTextLengthThreshold() {
|
||||
return convertValue("query.text.length.threshold", Integer.class, queryTextLengthThreshold);
|
||||
}
|
||||
|
||||
public boolean isUseS2SqlSwitch() {
|
||||
return convertValue("use.s2SQL.switch", Boolean.class, useS2SqlSwitch);
|
||||
}
|
||||
|
||||
public Integer getEmbeddingMapperWordMin() {
|
||||
return convertValue("embedding.mapper.word.min", Integer.class, embeddingMapperWordMin);
|
||||
}
|
||||
|
||||
public Integer getEmbeddingMapperWordMax() {
|
||||
return convertValue("embedding.mapper.word.max", Integer.class, embeddingMapperWordMax);
|
||||
}
|
||||
|
||||
public Integer getEmbeddingMapperBatch() {
|
||||
return convertValue("embedding.mapper.batch", Integer.class, embeddingMapperBatch);
|
||||
}
|
||||
|
||||
public Integer getEmbeddingMapperNumber() {
|
||||
return convertValue("embedding.mapper.number", Integer.class, embeddingMapperNumber);
|
||||
}
|
||||
|
||||
public Integer getEmbeddingMapperRoundNumber() {
|
||||
return convertValue("embedding.mapper.round.number", Integer.class, embeddingMapperRoundNumber);
|
||||
}
|
||||
|
||||
public Double getEmbeddingMapperDistanceThreshold() {
|
||||
return convertValue("embedding.mapper.distance.threshold", Double.class, embeddingMapperDistanceThreshold);
|
||||
}
|
||||
|
||||
public boolean isUseLinkingValueSwitch() {
|
||||
return convertValue("s2SQL.linking.value.switch", Boolean.class, useLinkingValueSwitch);
|
||||
}
|
||||
|
||||
public <T> T convertValue(String paramName, Class<T> targetType, T defaultValue) {
|
||||
try {
|
||||
String value = sysParameterService.getSysParameter().getParameterByName(paramName);
|
||||
if (StringUtils.isBlank(value)) {
|
||||
return defaultValue;
|
||||
}
|
||||
if (targetType == Double.class) {
|
||||
return targetType.cast(Double.parseDouble(value));
|
||||
} else if (targetType == Integer.class) {
|
||||
return targetType.cast(Integer.parseInt(value));
|
||||
} else if (targetType == Boolean.class) {
|
||||
return targetType.cast(Boolean.parseBoolean(value));
|
||||
}
|
||||
} catch (Exception e) {
|
||||
log.error("convertValue", e);
|
||||
}
|
||||
return defaultValue;
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
@@ -2,14 +2,20 @@ package com.tencent.supersonic.chat.corrector;
|
||||
|
||||
import com.tencent.supersonic.chat.api.component.SemanticCorrector;
|
||||
import com.tencent.supersonic.chat.api.pojo.SchemaElement;
|
||||
import com.tencent.supersonic.chat.api.pojo.SemanticCorrectInfo;
|
||||
import com.tencent.supersonic.chat.api.pojo.SemanticParseInfo;
|
||||
import com.tencent.supersonic.chat.api.pojo.SemanticSchema;
|
||||
import com.tencent.supersonic.chat.api.pojo.request.QueryReq;
|
||||
import com.tencent.supersonic.common.pojo.enums.AggregateTypeEnum;
|
||||
import com.tencent.supersonic.common.pojo.enums.TimeDimensionEnum;
|
||||
import com.tencent.supersonic.common.util.ContextUtils;
|
||||
import com.tencent.supersonic.common.util.DateUtils;
|
||||
import com.tencent.supersonic.common.util.jsqlparser.SqlParserAddHelper;
|
||||
import com.tencent.supersonic.common.util.jsqlparser.SqlParserSelectFunctionHelper;
|
||||
import com.tencent.supersonic.common.util.jsqlparser.SqlParserSelectHelper;
|
||||
import com.tencent.supersonic.knowledge.service.SchemaService;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.apache.commons.lang3.StringUtils;
|
||||
import org.springframework.util.CollectionUtils;
|
||||
|
||||
import java.util.ArrayList;
|
||||
import java.util.HashSet;
|
||||
import java.util.List;
|
||||
@@ -17,17 +23,30 @@ import java.util.Map;
|
||||
import java.util.Objects;
|
||||
import java.util.Set;
|
||||
import java.util.stream.Collectors;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.springframework.util.CollectionUtils;
|
||||
|
||||
/**
|
||||
* basic semantic correction functionality, offering common methods and an
|
||||
* abstract method called doCorrect
|
||||
*/
|
||||
@Slf4j
|
||||
public abstract class BaseSemanticCorrector implements SemanticCorrector {
|
||||
|
||||
public void correct(SemanticCorrectInfo semanticCorrectInfo) {
|
||||
semanticCorrectInfo.setPreSql(semanticCorrectInfo.getSql());
|
||||
public void correct(QueryReq queryReq, SemanticParseInfo semanticParseInfo) {
|
||||
try {
|
||||
if (StringUtils.isBlank(semanticParseInfo.getSqlInfo().getCorrectS2SQL())) {
|
||||
return;
|
||||
}
|
||||
doCorrect(queryReq, semanticParseInfo);
|
||||
log.info("sqlCorrection:{} sql:{}", this.getClass().getSimpleName(), semanticParseInfo.getSqlInfo());
|
||||
} catch (Exception e) {
|
||||
log.error(String.format("correct error,sqlInfo:%s", semanticParseInfo.getSqlInfo()), e);
|
||||
}
|
||||
}
|
||||
|
||||
protected Map<String, String> getFieldNameMap(Long modelId) {
|
||||
|
||||
public abstract void doCorrect(QueryReq queryReq, SemanticParseInfo semanticParseInfo);
|
||||
|
||||
protected Map<String, String> getFieldNameMap(Set<Long> modelIds) {
|
||||
|
||||
SemanticSchema semanticSchema = ContextUtils.getBean(SchemaService.class).getSemanticSchema();
|
||||
|
||||
@@ -35,34 +54,59 @@ public abstract class BaseSemanticCorrector implements SemanticCorrector {
|
||||
dbAllFields.addAll(semanticSchema.getMetrics());
|
||||
dbAllFields.addAll(semanticSchema.getDimensions());
|
||||
|
||||
// support fieldName and field alias
|
||||
Map<String, String> result = dbAllFields.stream()
|
||||
.filter(entry -> entry.getModel().equals(modelId))
|
||||
.collect(Collectors.toMap(SchemaElement::getName, a -> a.getName(), (k1, k2) -> k1));
|
||||
result.put(DateUtils.DATE_FIELD, DateUtils.DATE_FIELD);
|
||||
.filter(entry -> modelIds.contains(entry.getModel()))
|
||||
.flatMap(schemaElement -> {
|
||||
Set<String> elements = new HashSet<>();
|
||||
elements.add(schemaElement.getName());
|
||||
if (!CollectionUtils.isEmpty(schemaElement.getAlias())) {
|
||||
elements.addAll(schemaElement.getAlias());
|
||||
}
|
||||
return elements.stream();
|
||||
})
|
||||
.collect(Collectors.toMap(a -> a, a -> a, (k1, k2) -> k1));
|
||||
result.put(TimeDimensionEnum.DAY.getChName(), TimeDimensionEnum.DAY.getChName());
|
||||
result.put(TimeDimensionEnum.MONTH.getChName(), TimeDimensionEnum.MONTH.getChName());
|
||||
result.put(TimeDimensionEnum.WEEK.getChName(), TimeDimensionEnum.WEEK.getChName());
|
||||
|
||||
result.put(TimeDimensionEnum.DAY.getName(), TimeDimensionEnum.DAY.getChName());
|
||||
result.put(TimeDimensionEnum.MONTH.getName(), TimeDimensionEnum.MONTH.getChName());
|
||||
result.put(TimeDimensionEnum.WEEK.getName(), TimeDimensionEnum.WEEK.getChName());
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
protected void addFieldsToSelect(SemanticCorrectInfo semanticCorrectInfo, String sql) {
|
||||
Set<String> selectFields = new HashSet<>(SqlParserSelectHelper.getSelectFields(sql));
|
||||
Set<String> needAddFields = new HashSet<>(SqlParserSelectHelper.getGroupByFields(sql));
|
||||
needAddFields.addAll(SqlParserSelectHelper.getOrderByFields(sql));
|
||||
protected void addFieldsToSelect(SemanticParseInfo semanticParseInfo, String correctS2SQL) {
|
||||
Set<String> selectFields = new HashSet<>(SqlParserSelectHelper.getSelectFields(correctS2SQL));
|
||||
Set<String> needAddFields = new HashSet<>(SqlParserSelectHelper.getGroupByFields(correctS2SQL));
|
||||
needAddFields.addAll(SqlParserSelectHelper.getOrderByFields(correctS2SQL));
|
||||
|
||||
// If there is no aggregate function in the S2SQL statement and
|
||||
// there is a data field in 'WHERE' statement, add the field to the 'SELECT' statement.
|
||||
if (!SqlParserSelectFunctionHelper.hasAggregateFunction(correctS2SQL)) {
|
||||
List<String> whereFields = SqlParserSelectHelper.getWhereFields(correctS2SQL);
|
||||
List<String> timeChNameList = TimeDimensionEnum.getChNameList();
|
||||
Set<String> timeFields = whereFields.stream().filter(field -> timeChNameList.contains(field))
|
||||
.collect(Collectors.toSet());
|
||||
needAddFields.addAll(timeFields);
|
||||
}
|
||||
|
||||
if (CollectionUtils.isEmpty(selectFields) || CollectionUtils.isEmpty(needAddFields)) {
|
||||
return;
|
||||
}
|
||||
|
||||
needAddFields.removeAll(selectFields);
|
||||
needAddFields.remove(DateUtils.DATE_FIELD);
|
||||
String replaceFields = SqlParserAddHelper.addFieldsToSelect(sql, new ArrayList<>(needAddFields));
|
||||
semanticCorrectInfo.setSql(replaceFields);
|
||||
String replaceFields = SqlParserAddHelper.addFieldsToSelect(correctS2SQL, new ArrayList<>(needAddFields));
|
||||
semanticParseInfo.getSqlInfo().setCorrectS2SQL(replaceFields);
|
||||
}
|
||||
|
||||
protected void addAggregateToMetric(SemanticCorrectInfo semanticCorrectInfo) {
|
||||
protected void addAggregateToMetric(SemanticParseInfo semanticParseInfo) {
|
||||
//add aggregate to all metric
|
||||
String sql = semanticCorrectInfo.getSql();
|
||||
Long modelId = semanticCorrectInfo.getParseInfo().getModel().getModel();
|
||||
String correctS2SQL = semanticParseInfo.getSqlInfo().getCorrectS2SQL();
|
||||
Set<Long> modelIds = semanticParseInfo.getModel().getModelIds();
|
||||
|
||||
List<SchemaElement> metrics = getMetricElements(modelId);
|
||||
List<SchemaElement> metrics = getMetricElements(modelIds);
|
||||
|
||||
Map<String, String> metricToAggregate = metrics.stream()
|
||||
.map(schemaElement -> {
|
||||
@@ -75,18 +119,13 @@ public abstract class BaseSemanticCorrector implements SemanticCorrector {
|
||||
if (CollectionUtils.isEmpty(metricToAggregate)) {
|
||||
return;
|
||||
}
|
||||
|
||||
String aggregateSql = SqlParserAddHelper.addAggregateToField(sql, metricToAggregate);
|
||||
semanticCorrectInfo.setSql(aggregateSql);
|
||||
String aggregateSql = SqlParserAddHelper.addAggregateToField(correctS2SQL, metricToAggregate);
|
||||
semanticParseInfo.getSqlInfo().setCorrectS2SQL(aggregateSql);
|
||||
}
|
||||
|
||||
protected List<SchemaElement> getMetricElements(Long modelId) {
|
||||
protected List<SchemaElement> getMetricElements(Set<Long> modelIds) {
|
||||
SemanticSchema semanticSchema = ContextUtils.getBean(SchemaService.class).getSemanticSchema();
|
||||
return semanticSchema.getMetrics(modelId);
|
||||
return semanticSchema.getMetrics(modelIds);
|
||||
}
|
||||
|
||||
protected List<SchemaElement> getDimensionElements(Long modelId) {
|
||||
SemanticSchema semanticSchema = ContextUtils.getBean(SchemaService.class).getSemanticSchema();
|
||||
return semanticSchema.getDimensions(modelId);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -0,0 +1,17 @@
|
||||
package com.tencent.supersonic.chat.corrector;
|
||||
|
||||
import com.tencent.supersonic.chat.api.pojo.SemanticParseInfo;
|
||||
import com.tencent.supersonic.chat.api.pojo.request.QueryReq;
|
||||
import com.tencent.supersonic.common.util.jsqlparser.SqlParserReplaceHelper;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
|
||||
@Slf4j
|
||||
public class FromCorrector extends BaseSemanticCorrector {
|
||||
|
||||
@Override
|
||||
public void doCorrect(QueryReq queryReq, SemanticParseInfo semanticParseInfo) {
|
||||
String modelName = semanticParseInfo.getModel().getName();
|
||||
SqlParserReplaceHelper.replaceTable(semanticParseInfo.getSqlInfo().getCorrectS2SQL(), modelName);
|
||||
}
|
||||
|
||||
}
|
||||
@@ -1,30 +0,0 @@
|
||||
package com.tencent.supersonic.chat.corrector;
|
||||
|
||||
import com.tencent.supersonic.chat.api.pojo.SemanticCorrectInfo;
|
||||
import com.tencent.supersonic.common.util.jsqlparser.SqlParserAddHelper;
|
||||
import com.tencent.supersonic.common.util.jsqlparser.SqlParserSelectFunctionHelper;
|
||||
import com.tencent.supersonic.common.util.jsqlparser.SqlParserSelectHelper;
|
||||
import java.util.Objects;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import net.sf.jsqlparser.expression.Expression;
|
||||
|
||||
@Slf4j
|
||||
public class GlobalAfterCorrector extends BaseSemanticCorrector {
|
||||
|
||||
@Override
|
||||
public void correct(SemanticCorrectInfo semanticCorrectInfo) {
|
||||
|
||||
super.correct(semanticCorrectInfo);
|
||||
String sql = semanticCorrectInfo.getSql();
|
||||
if (!SqlParserSelectFunctionHelper.hasAggregateFunction(sql)) {
|
||||
return;
|
||||
}
|
||||
Expression havingExpression = SqlParserSelectHelper.getHavingExpression(sql);
|
||||
if (Objects.nonNull(havingExpression)) {
|
||||
String replaceSql = SqlParserAddHelper.addFunctionToSelect(sql, havingExpression);
|
||||
semanticCorrectInfo.setSql(replaceSql);
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
}
|
||||
@@ -1,96 +0,0 @@
|
||||
package com.tencent.supersonic.chat.corrector;
|
||||
|
||||
import com.tencent.supersonic.chat.api.pojo.SemanticCorrectInfo;
|
||||
import com.tencent.supersonic.chat.parser.llm.dsl.DSLParseResult;
|
||||
import com.tencent.supersonic.chat.query.llm.dsl.LLMReq;
|
||||
import com.tencent.supersonic.chat.query.llm.dsl.LLMReq.ElementValue;
|
||||
import com.tencent.supersonic.common.pojo.Constants;
|
||||
import com.tencent.supersonic.common.util.JsonUtil;
|
||||
import com.tencent.supersonic.common.util.jsqlparser.SqlParserReplaceHelper;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
import java.util.Objects;
|
||||
import java.util.Set;
|
||||
import java.util.stream.Collectors;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.springframework.util.CollectionUtils;
|
||||
|
||||
@Slf4j
|
||||
public class GlobalBeforeCorrector extends BaseSemanticCorrector {
|
||||
|
||||
@Override
|
||||
public void correct(SemanticCorrectInfo semanticCorrectInfo) {
|
||||
|
||||
super.correct(semanticCorrectInfo);
|
||||
|
||||
replaceAlias(semanticCorrectInfo);
|
||||
|
||||
updateFieldNameByLinkingValue(semanticCorrectInfo);
|
||||
|
||||
updateFieldValueByLinkingValue(semanticCorrectInfo);
|
||||
|
||||
correctFieldName(semanticCorrectInfo);
|
||||
}
|
||||
|
||||
private void replaceAlias(SemanticCorrectInfo semanticCorrectInfo) {
|
||||
String replaceAlias = SqlParserReplaceHelper.replaceAlias(semanticCorrectInfo.getSql());
|
||||
semanticCorrectInfo.setSql(replaceAlias);
|
||||
}
|
||||
|
||||
private void correctFieldName(SemanticCorrectInfo semanticCorrectInfo) {
|
||||
|
||||
Map<String, String> fieldNameMap = getFieldNameMap(semanticCorrectInfo.getParseInfo().getModelId());
|
||||
|
||||
String sql = SqlParserReplaceHelper.replaceFields(semanticCorrectInfo.getSql(), fieldNameMap);
|
||||
|
||||
semanticCorrectInfo.setSql(sql);
|
||||
}
|
||||
|
||||
private void updateFieldNameByLinkingValue(SemanticCorrectInfo semanticCorrectInfo) {
|
||||
List<ElementValue> linking = getLinkingValues(semanticCorrectInfo);
|
||||
if (CollectionUtils.isEmpty(linking)) {
|
||||
return;
|
||||
}
|
||||
|
||||
Map<String, Set<String>> fieldValueToFieldNames = linking.stream().collect(
|
||||
Collectors.groupingBy(ElementValue::getFieldValue,
|
||||
Collectors.mapping(ElementValue::getFieldName, Collectors.toSet())));
|
||||
|
||||
String sql = SqlParserReplaceHelper.replaceFieldNameByValue(semanticCorrectInfo.getSql(),
|
||||
fieldValueToFieldNames);
|
||||
semanticCorrectInfo.setSql(sql);
|
||||
}
|
||||
|
||||
private List<ElementValue> getLinkingValues(SemanticCorrectInfo semanticCorrectInfo) {
|
||||
Object context = semanticCorrectInfo.getParseInfo().getProperties().get(Constants.CONTEXT);
|
||||
if (Objects.isNull(context)) {
|
||||
return null;
|
||||
}
|
||||
|
||||
DSLParseResult dslParseResult = JsonUtil.toObject(JsonUtil.toString(context), DSLParseResult.class);
|
||||
if (Objects.isNull(dslParseResult) || Objects.isNull(dslParseResult.getLlmReq())) {
|
||||
return null;
|
||||
}
|
||||
LLMReq llmReq = dslParseResult.getLlmReq();
|
||||
return llmReq.getLinking();
|
||||
}
|
||||
|
||||
|
||||
private void updateFieldValueByLinkingValue(SemanticCorrectInfo semanticCorrectInfo) {
|
||||
List<ElementValue> linking = getLinkingValues(semanticCorrectInfo);
|
||||
if (CollectionUtils.isEmpty(linking)) {
|
||||
return;
|
||||
}
|
||||
|
||||
Map<String, Map<String, String>> filedNameToValueMap = linking.stream().collect(
|
||||
Collectors.groupingBy(ElementValue::getFieldName,
|
||||
Collectors.mapping(ElementValue::getFieldValue, Collectors.toMap(
|
||||
oldValue -> oldValue,
|
||||
newValue -> newValue,
|
||||
(existingValue, newValue) -> newValue)
|
||||
)));
|
||||
|
||||
String sql = SqlParserReplaceHelper.replaceValue(semanticCorrectInfo.getSql(), filedNameToValueMap, false);
|
||||
semanticCorrectInfo.setSql(sql);
|
||||
}
|
||||
}
|
||||
@@ -1,46 +1,71 @@
|
||||
package com.tencent.supersonic.chat.corrector;
|
||||
|
||||
import com.tencent.supersonic.chat.api.pojo.SemanticCorrectInfo;
|
||||
import com.tencent.supersonic.chat.api.pojo.SemanticParseInfo;
|
||||
import com.tencent.supersonic.chat.api.pojo.SemanticSchema;
|
||||
import com.tencent.supersonic.chat.api.pojo.request.QueryReq;
|
||||
import com.tencent.supersonic.chat.api.pojo.response.SqlInfo;
|
||||
import com.tencent.supersonic.common.pojo.enums.TimeDimensionEnum;
|
||||
import com.tencent.supersonic.common.util.ContextUtils;
|
||||
import com.tencent.supersonic.common.util.DateUtils;
|
||||
import com.tencent.supersonic.common.util.jsqlparser.SqlParserAddHelper;
|
||||
import com.tencent.supersonic.common.util.jsqlparser.SqlParserSelectHelper;
|
||||
import com.tencent.supersonic.knowledge.service.SchemaService;
|
||||
import java.util.List;
|
||||
import java.util.Set;
|
||||
import java.util.stream.Collectors;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.springframework.util.CollectionUtils;
|
||||
|
||||
import java.util.HashSet;
|
||||
import java.util.List;
|
||||
import java.util.Set;
|
||||
import java.util.stream.Collectors;
|
||||
|
||||
/**
|
||||
* Perform SQL corrections on the "group by" section in S2SQL.
|
||||
*/
|
||||
@Slf4j
|
||||
public class GroupByCorrector extends BaseSemanticCorrector {
|
||||
|
||||
@Override
|
||||
public void correct(SemanticCorrectInfo semanticCorrectInfo) {
|
||||
public void doCorrect(QueryReq queryReq, SemanticParseInfo semanticParseInfo) {
|
||||
|
||||
super.correct(semanticCorrectInfo);
|
||||
addGroupByFields(semanticParseInfo);
|
||||
|
||||
addGroupByFields(semanticCorrectInfo);
|
||||
|
||||
addAggregate(semanticCorrectInfo);
|
||||
}
|
||||
|
||||
private void addGroupByFields(SemanticCorrectInfo semanticCorrectInfo) {
|
||||
Long modelId = semanticCorrectInfo.getParseInfo().getModel().getModel();
|
||||
private void addGroupByFields(SemanticParseInfo semanticParseInfo) {
|
||||
Set<Long> modelIds = semanticParseInfo.getModel().getModelIds();
|
||||
|
||||
//add dimension group by
|
||||
String sql = semanticCorrectInfo.getSql();
|
||||
SqlInfo sqlInfo = semanticParseInfo.getSqlInfo();
|
||||
String correctS2SQL = sqlInfo.getCorrectS2SQL();
|
||||
SemanticSchema semanticSchema = ContextUtils.getBean(SchemaService.class).getSemanticSchema();
|
||||
Set<String> dimensions = semanticSchema.getDimensions(modelId).stream()
|
||||
.map(schemaElement -> schemaElement.getName()).collect(Collectors.toSet());
|
||||
dimensions.add(DateUtils.DATE_FIELD);
|
||||
List<String> selectFields = SqlParserSelectHelper.getSelectFields(sql);
|
||||
//add alias field name
|
||||
Set<String> dimensions = semanticSchema.getDimensions(modelIds).stream()
|
||||
.flatMap(
|
||||
schemaElement -> {
|
||||
Set<String> elements = new HashSet<>();
|
||||
elements.add(schemaElement.getName());
|
||||
if (!CollectionUtils.isEmpty(schemaElement.getAlias())) {
|
||||
elements.addAll(schemaElement.getAlias());
|
||||
}
|
||||
return elements.stream();
|
||||
}
|
||||
).collect(Collectors.toSet());
|
||||
dimensions.add(TimeDimensionEnum.DAY.getChName());
|
||||
|
||||
List<String> selectFields = SqlParserSelectHelper.getSelectFields(correctS2SQL);
|
||||
|
||||
if (CollectionUtils.isEmpty(selectFields) || CollectionUtils.isEmpty(dimensions)) {
|
||||
return;
|
||||
}
|
||||
List<String> aggregateFields = SqlParserSelectHelper.getAggregateFields(sql);
|
||||
// if only date in select not add group by.
|
||||
if (selectFields.size() == 1 && selectFields.contains(TimeDimensionEnum.DAY.getChName())) {
|
||||
return;
|
||||
}
|
||||
if (SqlParserSelectHelper.hasGroupBy(correctS2SQL)) {
|
||||
log.info("not add group by ,exist group by in correctS2SQL:{}", correctS2SQL);
|
||||
return;
|
||||
}
|
||||
|
||||
List<String> aggregateFields = SqlParserSelectHelper.getAggregateFields(correctS2SQL);
|
||||
Set<String> groupByFields = selectFields.stream()
|
||||
.filter(field -> dimensions.contains(field))
|
||||
.filter(field -> {
|
||||
@@ -50,14 +75,17 @@ public class GroupByCorrector extends BaseSemanticCorrector {
|
||||
return true;
|
||||
})
|
||||
.collect(Collectors.toSet());
|
||||
semanticCorrectInfo.setSql(SqlParserAddHelper.addGroupBy(sql, groupByFields));
|
||||
semanticParseInfo.getSqlInfo().setCorrectS2SQL(SqlParserAddHelper.addGroupBy(correctS2SQL, groupByFields));
|
||||
|
||||
addAggregate(semanticParseInfo);
|
||||
}
|
||||
|
||||
private void addAggregate(SemanticCorrectInfo semanticCorrectInfo) {
|
||||
List<String> sqlGroupByFields = SqlParserSelectHelper.getGroupByFields(semanticCorrectInfo.getSql());
|
||||
private void addAggregate(SemanticParseInfo semanticParseInfo) {
|
||||
List<String> sqlGroupByFields = SqlParserSelectHelper.getGroupByFields(
|
||||
semanticParseInfo.getSqlInfo().getCorrectS2SQL());
|
||||
if (CollectionUtils.isEmpty(sqlGroupByFields)) {
|
||||
return;
|
||||
}
|
||||
addAggregateToMetric(semanticCorrectInfo);
|
||||
addAggregateToMetric(semanticParseInfo);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,37 +1,66 @@
|
||||
package com.tencent.supersonic.chat.corrector;
|
||||
|
||||
import com.tencent.supersonic.chat.api.pojo.SemanticCorrectInfo;
|
||||
import com.tencent.supersonic.chat.api.pojo.SemanticParseInfo;
|
||||
import com.tencent.supersonic.chat.api.pojo.SemanticSchema;
|
||||
import com.tencent.supersonic.chat.api.pojo.request.QueryReq;
|
||||
import com.tencent.supersonic.common.util.ContextUtils;
|
||||
import com.tencent.supersonic.common.util.jsqlparser.SqlParserAddHelper;
|
||||
import com.tencent.supersonic.common.util.jsqlparser.SqlParserSelectFunctionHelper;
|
||||
import com.tencent.supersonic.common.util.jsqlparser.SqlParserSelectHelper;
|
||||
import com.tencent.supersonic.knowledge.service.SchemaService;
|
||||
|
||||
import java.util.Set;
|
||||
import java.util.List;
|
||||
import java.util.stream.Collectors;
|
||||
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import net.sf.jsqlparser.expression.Expression;
|
||||
import org.springframework.util.CollectionUtils;
|
||||
|
||||
/**
|
||||
* Perform SQL corrections on the "Having" section in S2SQL.
|
||||
*/
|
||||
@Slf4j
|
||||
public class HavingCorrector extends BaseSemanticCorrector {
|
||||
|
||||
@Override
|
||||
public void correct(SemanticCorrectInfo semanticCorrectInfo) {
|
||||
|
||||
super.correct(semanticCorrectInfo);
|
||||
public void doCorrect(QueryReq queryReq, SemanticParseInfo semanticParseInfo) {
|
||||
|
||||
//add aggregate to all metric
|
||||
semanticCorrectInfo.setPreSql(semanticCorrectInfo.getSql());
|
||||
Long modelId = semanticCorrectInfo.getParseInfo().getModel().getModel();
|
||||
addHaving(semanticParseInfo);
|
||||
|
||||
//add having expression filed to select
|
||||
addHavingToSelect(semanticParseInfo);
|
||||
|
||||
}
|
||||
|
||||
private void addHaving(SemanticParseInfo semanticParseInfo) {
|
||||
Set<Long> modelIds = semanticParseInfo.getModel().getModelIds();
|
||||
|
||||
SemanticSchema semanticSchema = ContextUtils.getBean(SchemaService.class).getSemanticSchema();
|
||||
|
||||
Set<String> metrics = semanticSchema.getMetrics(modelId).stream()
|
||||
Set<String> metrics = semanticSchema.getMetrics(modelIds).stream()
|
||||
.map(schemaElement -> schemaElement.getName()).collect(Collectors.toSet());
|
||||
|
||||
if (CollectionUtils.isEmpty(metrics)) {
|
||||
return;
|
||||
}
|
||||
String havingSql = SqlParserAddHelper.addHaving(semanticCorrectInfo.getSql(), metrics);
|
||||
semanticCorrectInfo.setSql(havingSql);
|
||||
String havingSql = SqlParserAddHelper.addHaving(semanticParseInfo.getSqlInfo().getCorrectS2SQL(), metrics);
|
||||
semanticParseInfo.getSqlInfo().setCorrectS2SQL(havingSql);
|
||||
}
|
||||
|
||||
private void addHavingToSelect(SemanticParseInfo semanticParseInfo) {
|
||||
String correctS2SQL = semanticParseInfo.getSqlInfo().getCorrectS2SQL();
|
||||
if (!SqlParserSelectFunctionHelper.hasAggregateFunction(correctS2SQL)) {
|
||||
return;
|
||||
}
|
||||
List<Expression> havingExpressionList = SqlParserSelectHelper.getHavingExpression(correctS2SQL);
|
||||
if (!CollectionUtils.isEmpty(havingExpressionList)) {
|
||||
String replaceSql = SqlParserAddHelper.addFunctionToSelect(correctS2SQL, havingExpressionList);
|
||||
semanticParseInfo.getSqlInfo().setCorrectS2SQL(replaceSql);
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
|
||||
}
|
||||
|
||||
@@ -0,0 +1,108 @@
|
||||
package com.tencent.supersonic.chat.corrector;
|
||||
|
||||
import com.tencent.supersonic.chat.api.pojo.SemanticParseInfo;
|
||||
import com.tencent.supersonic.chat.api.pojo.request.QueryReq;
|
||||
import com.tencent.supersonic.chat.api.pojo.response.SqlInfo;
|
||||
import com.tencent.supersonic.chat.parser.llm.s2sql.ParseResult;
|
||||
import com.tencent.supersonic.chat.query.llm.s2sql.LLMReq.ElementValue;
|
||||
import com.tencent.supersonic.common.pojo.Constants;
|
||||
import com.tencent.supersonic.common.util.JsonUtil;
|
||||
import com.tencent.supersonic.common.util.jsqlparser.AggregateEnum;
|
||||
import com.tencent.supersonic.common.util.jsqlparser.SqlParserReplaceHelper;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
import java.util.Objects;
|
||||
import java.util.Set;
|
||||
import java.util.stream.Collectors;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.springframework.util.CollectionUtils;
|
||||
|
||||
/**
|
||||
* Perform schema corrections on the Schema information in S2QL.
|
||||
*/
|
||||
@Slf4j
|
||||
public class SchemaCorrector extends BaseSemanticCorrector {
|
||||
|
||||
@Override
|
||||
public void doCorrect(QueryReq queryReq, SemanticParseInfo semanticParseInfo) {
|
||||
|
||||
correctAggFunction(semanticParseInfo);
|
||||
|
||||
replaceAlias(semanticParseInfo);
|
||||
|
||||
updateFieldNameByLinkingValue(semanticParseInfo);
|
||||
|
||||
updateFieldValueByLinkingValue(semanticParseInfo);
|
||||
|
||||
correctFieldName(semanticParseInfo);
|
||||
}
|
||||
|
||||
private void correctAggFunction(SemanticParseInfo semanticParseInfo) {
|
||||
Map<String, String> aggregateEnum = AggregateEnum.getAggregateEnum();
|
||||
SqlInfo sqlInfo = semanticParseInfo.getSqlInfo();
|
||||
String sql = SqlParserReplaceHelper.replaceFunction(sqlInfo.getCorrectS2SQL(), aggregateEnum);
|
||||
sqlInfo.setCorrectS2SQL(sql);
|
||||
}
|
||||
|
||||
private void replaceAlias(SemanticParseInfo semanticParseInfo) {
|
||||
SqlInfo sqlInfo = semanticParseInfo.getSqlInfo();
|
||||
String replaceAlias = SqlParserReplaceHelper.replaceAlias(sqlInfo.getCorrectS2SQL());
|
||||
sqlInfo.setCorrectS2SQL(replaceAlias);
|
||||
}
|
||||
|
||||
private void correctFieldName(SemanticParseInfo semanticParseInfo) {
|
||||
Map<String, String> fieldNameMap = getFieldNameMap(semanticParseInfo.getModel().getModelIds());
|
||||
SqlInfo sqlInfo = semanticParseInfo.getSqlInfo();
|
||||
String sql = SqlParserReplaceHelper.replaceFields(sqlInfo.getCorrectS2SQL(), fieldNameMap);
|
||||
sqlInfo.setCorrectS2SQL(sql);
|
||||
}
|
||||
|
||||
private void updateFieldNameByLinkingValue(SemanticParseInfo semanticParseInfo) {
|
||||
List<ElementValue> linking = getLinkingValues(semanticParseInfo);
|
||||
if (CollectionUtils.isEmpty(linking)) {
|
||||
return;
|
||||
}
|
||||
|
||||
Map<String, Set<String>> fieldValueToFieldNames = linking.stream().collect(
|
||||
Collectors.groupingBy(ElementValue::getFieldValue,
|
||||
Collectors.mapping(ElementValue::getFieldName, Collectors.toSet())));
|
||||
|
||||
SqlInfo sqlInfo = semanticParseInfo.getSqlInfo();
|
||||
|
||||
String sql = SqlParserReplaceHelper.replaceFieldNameByValue(sqlInfo.getCorrectS2SQL(), fieldValueToFieldNames);
|
||||
sqlInfo.setCorrectS2SQL(sql);
|
||||
}
|
||||
|
||||
private List<ElementValue> getLinkingValues(SemanticParseInfo semanticParseInfo) {
|
||||
Object context = semanticParseInfo.getProperties().get(Constants.CONTEXT);
|
||||
if (Objects.isNull(context)) {
|
||||
return null;
|
||||
}
|
||||
|
||||
ParseResult parseResult = JsonUtil.toObject(JsonUtil.toString(context), ParseResult.class);
|
||||
if (Objects.isNull(parseResult) || Objects.isNull(parseResult.getLlmReq())) {
|
||||
return null;
|
||||
}
|
||||
return parseResult.getLinkingValues();
|
||||
}
|
||||
|
||||
|
||||
private void updateFieldValueByLinkingValue(SemanticParseInfo semanticParseInfo) {
|
||||
List<ElementValue> linking = getLinkingValues(semanticParseInfo);
|
||||
if (CollectionUtils.isEmpty(linking)) {
|
||||
return;
|
||||
}
|
||||
|
||||
Map<String, Map<String, String>> filedNameToValueMap = linking.stream().collect(
|
||||
Collectors.groupingBy(ElementValue::getFieldName,
|
||||
Collectors.mapping(ElementValue::getFieldValue, Collectors.toMap(
|
||||
oldValue -> oldValue,
|
||||
newValue -> newValue,
|
||||
(existingValue, newValue) -> newValue)
|
||||
)));
|
||||
|
||||
SqlInfo sqlInfo = semanticParseInfo.getSqlInfo();
|
||||
String sql = SqlParserReplaceHelper.replaceValue(sqlInfo.getCorrectS2SQL(), filedNameToValueMap, false);
|
||||
sqlInfo.setCorrectS2SQL(sql);
|
||||
}
|
||||
}
|
||||
@@ -1,26 +1,29 @@
|
||||
package com.tencent.supersonic.chat.corrector;
|
||||
|
||||
import com.tencent.supersonic.chat.api.pojo.SemanticCorrectInfo;
|
||||
import com.tencent.supersonic.chat.api.pojo.SemanticParseInfo;
|
||||
import com.tencent.supersonic.chat.api.pojo.request.QueryReq;
|
||||
import com.tencent.supersonic.common.util.jsqlparser.SqlParserSelectHelper;
|
||||
import java.util.List;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.springframework.util.CollectionUtils;
|
||||
|
||||
/**
|
||||
* Perform SQL corrections on the "Select" section in S2SQL.
|
||||
*/
|
||||
@Slf4j
|
||||
public class SelectCorrector extends BaseSemanticCorrector {
|
||||
|
||||
@Override
|
||||
public void correct(SemanticCorrectInfo semanticCorrectInfo) {
|
||||
super.correct(semanticCorrectInfo);
|
||||
String sql = semanticCorrectInfo.getSql();
|
||||
List<String> aggregateFields = SqlParserSelectHelper.getAggregateFields(sql);
|
||||
List<String> selectFields = SqlParserSelectHelper.getSelectFields(sql);
|
||||
public void doCorrect(QueryReq queryReq, SemanticParseInfo semanticParseInfo) {
|
||||
String correctS2SQL = semanticParseInfo.getSqlInfo().getCorrectS2SQL();
|
||||
List<String> aggregateFields = SqlParserSelectHelper.getAggregateFields(correctS2SQL);
|
||||
List<String> selectFields = SqlParserSelectHelper.getSelectFields(correctS2SQL);
|
||||
// If the number of aggregated fields is equal to the number of queried fields, do not add fields to select.
|
||||
if (!CollectionUtils.isEmpty(aggregateFields)
|
||||
&& !CollectionUtils.isEmpty(selectFields)
|
||||
&& aggregateFields.size() == selectFields.size()) {
|
||||
return;
|
||||
}
|
||||
addFieldsToSelect(semanticCorrectInfo, sql);
|
||||
addFieldsToSelect(semanticParseInfo, correctS2SQL);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -2,23 +2,19 @@ package com.tencent.supersonic.chat.corrector;
|
||||
|
||||
import com.tencent.supersonic.chat.api.pojo.SchemaElement;
|
||||
import com.tencent.supersonic.chat.api.pojo.SchemaValueMap;
|
||||
import com.tencent.supersonic.chat.api.pojo.SemanticCorrectInfo;
|
||||
import com.tencent.supersonic.chat.api.pojo.SemanticParseInfo;
|
||||
import com.tencent.supersonic.chat.api.pojo.SemanticSchema;
|
||||
import com.tencent.supersonic.chat.api.pojo.request.QueryFilters;
|
||||
import com.tencent.supersonic.chat.parser.llm.dsl.DSLDateHelper;
|
||||
import com.tencent.supersonic.chat.api.pojo.request.QueryReq;
|
||||
import com.tencent.supersonic.chat.parser.llm.s2sql.S2SQLDateHelper;
|
||||
import com.tencent.supersonic.common.pojo.Constants;
|
||||
import com.tencent.supersonic.common.pojo.enums.TimeDimensionEnum;
|
||||
import com.tencent.supersonic.common.util.ContextUtils;
|
||||
import com.tencent.supersonic.common.util.DateUtils;
|
||||
import com.tencent.supersonic.common.util.StringUtil;
|
||||
import com.tencent.supersonic.common.util.jsqlparser.SqlParserAddHelper;
|
||||
import com.tencent.supersonic.common.util.jsqlparser.SqlParserReplaceHelper;
|
||||
import com.tencent.supersonic.common.util.jsqlparser.SqlParserSelectHelper;
|
||||
import com.tencent.supersonic.knowledge.service.SchemaService;
|
||||
import java.util.HashMap;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
import java.util.Objects;
|
||||
import java.util.stream.Collectors;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import net.sf.jsqlparser.JSQLParserException;
|
||||
import net.sf.jsqlparser.expression.Expression;
|
||||
@@ -27,56 +23,67 @@ import org.apache.commons.lang3.StringUtils;
|
||||
import org.apache.logging.log4j.util.Strings;
|
||||
import org.springframework.util.CollectionUtils;
|
||||
|
||||
import java.util.HashMap;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
import java.util.Objects;
|
||||
import java.util.Set;
|
||||
import java.util.stream.Collectors;
|
||||
|
||||
/**
|
||||
* Perform SQL corrections on the "Where" section in S2SQL.
|
||||
*/
|
||||
@Slf4j
|
||||
public class WhereCorrector extends BaseSemanticCorrector {
|
||||
|
||||
@Override
|
||||
public void correct(SemanticCorrectInfo semanticCorrectInfo) {
|
||||
public void doCorrect(QueryReq queryReq, SemanticParseInfo semanticParseInfo) {
|
||||
|
||||
super.correct(semanticCorrectInfo);
|
||||
addDateIfNotExist(semanticParseInfo);
|
||||
|
||||
addDateIfNotExist(semanticCorrectInfo);
|
||||
parserDateDiffFunction(semanticParseInfo);
|
||||
|
||||
parserDateDiffFunction(semanticCorrectInfo);
|
||||
addQueryFilter(queryReq, semanticParseInfo);
|
||||
|
||||
addQueryFilter(semanticCorrectInfo);
|
||||
|
||||
updateFieldValueByTechName(semanticCorrectInfo);
|
||||
updateFieldValueByTechName(semanticParseInfo);
|
||||
}
|
||||
|
||||
private void addQueryFilter(SemanticCorrectInfo semanticCorrectInfo) {
|
||||
String queryFilter = getQueryFilter(semanticCorrectInfo.getQueryFilters());
|
||||
private void addQueryFilter(QueryReq queryReq, SemanticParseInfo semanticParseInfo) {
|
||||
String queryFilter = getQueryFilter(queryReq.getQueryFilters());
|
||||
|
||||
String preSql = semanticCorrectInfo.getSql();
|
||||
String correctS2SQL = semanticParseInfo.getSqlInfo().getCorrectS2SQL();
|
||||
|
||||
if (StringUtils.isNotEmpty(queryFilter)) {
|
||||
log.info("add queryFilter to preSql :{}", queryFilter);
|
||||
log.info("add queryFilter to correctS2SQL :{}", queryFilter);
|
||||
Expression expression = null;
|
||||
try {
|
||||
expression = CCJSqlParserUtil.parseCondExpression(queryFilter);
|
||||
} catch (JSQLParserException e) {
|
||||
log.error("parseCondExpression", e);
|
||||
}
|
||||
String sql = SqlParserAddHelper.addWhere(preSql, expression);
|
||||
semanticCorrectInfo.setSql(sql);
|
||||
correctS2SQL = SqlParserAddHelper.addWhere(correctS2SQL, expression);
|
||||
semanticParseInfo.getSqlInfo().setCorrectS2SQL(correctS2SQL);
|
||||
}
|
||||
}
|
||||
|
||||
private void parserDateDiffFunction(SemanticCorrectInfo semanticCorrectInfo) {
|
||||
String sql = semanticCorrectInfo.getSql();
|
||||
sql = SqlParserReplaceHelper.replaceFunction(sql);
|
||||
semanticCorrectInfo.setSql(sql);
|
||||
private void parserDateDiffFunction(SemanticParseInfo semanticParseInfo) {
|
||||
String correctS2SQL = semanticParseInfo.getSqlInfo().getCorrectS2SQL();
|
||||
correctS2SQL = SqlParserReplaceHelper.replaceFunction(correctS2SQL);
|
||||
semanticParseInfo.getSqlInfo().setCorrectS2SQL(correctS2SQL);
|
||||
}
|
||||
|
||||
private void addDateIfNotExist(SemanticCorrectInfo semanticCorrectInfo) {
|
||||
String sql = semanticCorrectInfo.getSql();
|
||||
List<String> whereFields = SqlParserSelectHelper.getWhereFields(sql);
|
||||
if (CollectionUtils.isEmpty(whereFields) || !whereFields.contains(DateUtils.DATE_FIELD)) {
|
||||
String currentDate = DSLDateHelper.getReferenceDate(semanticCorrectInfo.getParseInfo().getModelId());
|
||||
sql = SqlParserAddHelper.addParenthesisToWhere(sql);
|
||||
sql = SqlParserAddHelper.addWhere(sql, DateUtils.DATE_FIELD, currentDate);
|
||||
private void addDateIfNotExist(SemanticParseInfo semanticParseInfo) {
|
||||
String correctS2SQL = semanticParseInfo.getSqlInfo().getCorrectS2SQL();
|
||||
List<String> whereFields = SqlParserSelectHelper.getWhereFields(correctS2SQL);
|
||||
if (CollectionUtils.isEmpty(whereFields) || !TimeDimensionEnum.containsZhTimeDimension(whereFields)) {
|
||||
String currentDate = S2SQLDateHelper.getReferenceDate(semanticParseInfo.getModelId());
|
||||
if (StringUtils.isNotBlank(currentDate)) {
|
||||
correctS2SQL = SqlParserAddHelper.addParenthesisToWhere(correctS2SQL);
|
||||
correctS2SQL = SqlParserAddHelper.addWhere(
|
||||
correctS2SQL, TimeDimensionEnum.DAY.getChName(), currentDate);
|
||||
}
|
||||
}
|
||||
semanticCorrectInfo.setSql(sql);
|
||||
semanticParseInfo.getSqlInfo().setCorrectS2SQL(correctS2SQL);
|
||||
}
|
||||
|
||||
private String getQueryFilter(QueryFilters queryFilters) {
|
||||
@@ -93,21 +100,19 @@ public class WhereCorrector extends BaseSemanticCorrector {
|
||||
.collect(Collectors.joining(Constants.AND_UPPER));
|
||||
}
|
||||
|
||||
private void updateFieldValueByTechName(SemanticCorrectInfo semanticCorrectInfo) {
|
||||
private void updateFieldValueByTechName(SemanticParseInfo semanticParseInfo) {
|
||||
SemanticSchema semanticSchema = ContextUtils.getBean(SchemaService.class).getSemanticSchema();
|
||||
Long modelId = semanticCorrectInfo.getParseInfo().getModel().getId();
|
||||
List<SchemaElement> dimensions = semanticSchema.getDimensions().stream()
|
||||
.filter(schemaElement -> modelId.equals(schemaElement.getModel()))
|
||||
.collect(Collectors.toList());
|
||||
Set<Long> modelIds = semanticParseInfo.getModel().getModelIds();
|
||||
List<SchemaElement> dimensions = semanticSchema.getDimensions(modelIds);
|
||||
|
||||
if (CollectionUtils.isEmpty(dimensions)) {
|
||||
return;
|
||||
}
|
||||
|
||||
Map<String, Map<String, String>> aliasAndBizNameToTechName = getAliasAndBizNameToTechName(dimensions);
|
||||
String sql = SqlParserReplaceHelper.replaceValue(semanticCorrectInfo.getSql(), aliasAndBizNameToTechName);
|
||||
semanticCorrectInfo.setSql(sql);
|
||||
return;
|
||||
String correctS2SQL = SqlParserReplaceHelper.replaceValue(semanticParseInfo.getSqlInfo().getCorrectS2SQL(),
|
||||
aliasAndBizNameToTechName);
|
||||
semanticParseInfo.getSqlInfo().setCorrectS2SQL(correctS2SQL);
|
||||
}
|
||||
|
||||
private Map<String, Map<String, String>> getAliasAndBizNameToTechName(List<SchemaElement> dimensions) {
|
||||
|
||||
@@ -0,0 +1,104 @@
|
||||
package com.tencent.supersonic.chat.mapper;
|
||||
|
||||
import com.tencent.supersonic.chat.api.component.SchemaMapper;
|
||||
import com.tencent.supersonic.chat.api.pojo.ModelSchema;
|
||||
import com.tencent.supersonic.chat.api.pojo.QueryContext;
|
||||
import com.tencent.supersonic.chat.api.pojo.SchemaElement;
|
||||
import com.tencent.supersonic.chat.api.pojo.SchemaElementMatch;
|
||||
import com.tencent.supersonic.chat.api.pojo.SchemaElementType;
|
||||
import com.tencent.supersonic.chat.api.pojo.SchemaMapInfo;
|
||||
import com.tencent.supersonic.chat.service.SemanticService;
|
||||
import com.tencent.supersonic.common.util.ContextUtils;
|
||||
import java.util.ArrayList;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
import java.util.Objects;
|
||||
import java.util.concurrent.atomic.AtomicBoolean;
|
||||
import java.util.stream.Collectors;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.apache.commons.lang3.StringUtils;
|
||||
import org.springframework.beans.BeanUtils;
|
||||
|
||||
/**
|
||||
* base Mapper
|
||||
*/
|
||||
@Slf4j
|
||||
public abstract class BaseMapper implements SchemaMapper {
|
||||
|
||||
@Override
|
||||
public void map(QueryContext queryContext) {
|
||||
|
||||
String simpleName = this.getClass().getSimpleName();
|
||||
long startTime = System.currentTimeMillis();
|
||||
log.info("before {},mapInfo:{}", simpleName, queryContext.getMapInfo().getModelElementMatches());
|
||||
|
||||
try {
|
||||
doMap(queryContext);
|
||||
} catch (Exception e) {
|
||||
log.error("work error", e);
|
||||
}
|
||||
|
||||
long cost = System.currentTimeMillis() - startTime;
|
||||
log.info("after {},cost:{},mapInfo:{}", simpleName, cost, queryContext.getMapInfo().getModelElementMatches());
|
||||
}
|
||||
|
||||
public abstract void doMap(QueryContext queryContext);
|
||||
|
||||
|
||||
public void addToSchemaMap(SchemaMapInfo schemaMap, Long modelId, SchemaElementMatch newElementMatch) {
|
||||
Map<Long, List<SchemaElementMatch>> modelElementMatches = schemaMap.getModelElementMatches();
|
||||
List<SchemaElementMatch> schemaElementMatches = modelElementMatches.putIfAbsent(modelId, new ArrayList<>());
|
||||
if (schemaElementMatches == null) {
|
||||
schemaElementMatches = modelElementMatches.get(modelId);
|
||||
}
|
||||
//remove duplication
|
||||
AtomicBoolean needAddNew = new AtomicBoolean(true);
|
||||
schemaElementMatches.removeIf(
|
||||
existElementMatch -> {
|
||||
SchemaElement existElement = existElementMatch.getElement();
|
||||
SchemaElement newElement = newElementMatch.getElement();
|
||||
if (existElement.equals(newElement)) {
|
||||
if (newElementMatch.getSimilarity() > existElementMatch.getSimilarity()) {
|
||||
return true;
|
||||
} else {
|
||||
needAddNew.set(false);
|
||||
}
|
||||
}
|
||||
return false;
|
||||
}
|
||||
);
|
||||
if (needAddNew.get()) {
|
||||
schemaElementMatches.add(newElementMatch);
|
||||
}
|
||||
}
|
||||
|
||||
public SchemaElement getSchemaElement(Long modelId, SchemaElementType elementType, Long elementID) {
|
||||
SchemaElement element = new SchemaElement();
|
||||
SemanticService schemaService = ContextUtils.getBean(SemanticService.class);
|
||||
ModelSchema modelSchema = schemaService.getModelSchema(modelId);
|
||||
if (Objects.isNull(modelSchema)) {
|
||||
return null;
|
||||
}
|
||||
SchemaElement elementDb = modelSchema.getElement(elementType, elementID);
|
||||
if (Objects.isNull(elementDb)) {
|
||||
log.info("element is null, elementType:{},elementID:{}", elementType, elementID);
|
||||
return null;
|
||||
}
|
||||
BeanUtils.copyProperties(elementDb, element);
|
||||
element.setAlias(getAlias(elementDb));
|
||||
return element;
|
||||
}
|
||||
|
||||
public List<String> getAlias(SchemaElement element) {
|
||||
if (!SchemaElementType.VALUE.equals(element.getType())) {
|
||||
return element.getAlias();
|
||||
}
|
||||
if (org.apache.commons.collections.CollectionUtils.isNotEmpty(element.getAlias()) && StringUtils.isNotEmpty(
|
||||
element.getName())) {
|
||||
return element.getAlias().stream()
|
||||
.filter(aliasItem -> aliasItem.contains(element.getName()))
|
||||
.collect(Collectors.toList());
|
||||
}
|
||||
return element.getAlias();
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,158 @@
|
||||
package com.tencent.supersonic.chat.mapper;
|
||||
|
||||
import com.hankcs.hanlp.seg.common.Term;
|
||||
import com.tencent.supersonic.chat.api.pojo.QueryContext;
|
||||
import com.tencent.supersonic.knowledge.utils.NatureHelper;
|
||||
import java.util.ArrayList;
|
||||
import java.util.Comparator;
|
||||
import java.util.HashMap;
|
||||
import java.util.HashSet;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
import java.util.Objects;
|
||||
import java.util.Optional;
|
||||
import java.util.Set;
|
||||
import java.util.stream.Collectors;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.apache.commons.collections4.CollectionUtils;
|
||||
import org.apache.commons.lang3.StringUtils;
|
||||
import org.springframework.beans.factory.annotation.Autowired;
|
||||
import org.springframework.stereotype.Service;
|
||||
|
||||
/**
|
||||
* Base Match Strategy
|
||||
*/
|
||||
@Service
|
||||
@Slf4j
|
||||
public abstract class BaseMatchStrategy<T> implements MatchStrategy<T> {
|
||||
|
||||
@Autowired
|
||||
private MapperHelper mapperHelper;
|
||||
|
||||
@Override
|
||||
public Map<MatchText, List<T>> match(QueryContext queryContext, List<Term> terms, Set<Long> detectModelIds) {
|
||||
String text = queryContext.getRequest().getQueryText();
|
||||
if (Objects.isNull(terms) || StringUtils.isEmpty(text)) {
|
||||
return null;
|
||||
}
|
||||
|
||||
log.debug("terms:{},,detectModelIds:{}", terms, detectModelIds);
|
||||
|
||||
List<T> detects = detect(queryContext, terms, detectModelIds);
|
||||
Map<MatchText, List<T>> result = new HashMap<>();
|
||||
|
||||
result.put(MatchText.builder().regText(text).detectSegment(text).build(), detects);
|
||||
return result;
|
||||
}
|
||||
|
||||
public List<T> detect(QueryContext queryContext, List<Term> terms, Set<Long> detectModelIds) {
|
||||
Map<Integer, Integer> regOffsetToLength = getRegOffsetToLength(terms);
|
||||
String text = queryContext.getRequest().getQueryText();
|
||||
Set<T> results = new HashSet<>();
|
||||
|
||||
Set<String> detectSegments = new HashSet<>();
|
||||
|
||||
for (Integer startIndex = 0; startIndex <= text.length() - 1; ) {
|
||||
|
||||
for (Integer index = startIndex; index <= text.length(); ) {
|
||||
int offset = mapperHelper.getStepOffset(terms, startIndex);
|
||||
index = mapperHelper.getStepIndex(regOffsetToLength, index);
|
||||
if (index <= text.length()) {
|
||||
String detectSegment = text.substring(startIndex, index);
|
||||
detectSegments.add(detectSegment);
|
||||
detectByStep(queryContext, results, detectModelIds, startIndex, index, offset);
|
||||
}
|
||||
}
|
||||
startIndex = mapperHelper.getStepIndex(regOffsetToLength, startIndex);
|
||||
}
|
||||
detectByBatch(queryContext, results, detectModelIds, detectSegments);
|
||||
return new ArrayList<>(results);
|
||||
}
|
||||
|
||||
protected void detectByBatch(QueryContext queryContext, Set<T> results, Set<Long> detectModelIds,
|
||||
Set<String> detectSegments) {
|
||||
return;
|
||||
}
|
||||
|
||||
public Map<Integer, Integer> getRegOffsetToLength(List<Term> terms) {
|
||||
return terms.stream().sorted(Comparator.comparing(Term::length))
|
||||
.collect(Collectors.toMap(Term::getOffset, term -> term.word.length(), (value1, value2) -> value2));
|
||||
}
|
||||
|
||||
public void selectResultInOneRound(Set<T> existResults, List<T> oneRoundResults) {
|
||||
if (CollectionUtils.isEmpty(oneRoundResults)) {
|
||||
return;
|
||||
}
|
||||
for (T oneRoundResult : oneRoundResults) {
|
||||
if (existResults.contains(oneRoundResult)) {
|
||||
boolean isDeleted = existResults.removeIf(
|
||||
existResult -> {
|
||||
boolean delete = needDelete(oneRoundResult, existResult);
|
||||
if (delete) {
|
||||
log.info("deleted existResult:{}", existResult);
|
||||
}
|
||||
return delete;
|
||||
}
|
||||
);
|
||||
if (isDeleted) {
|
||||
log.info("deleted, add oneRoundResult:{}", oneRoundResult);
|
||||
existResults.add(oneRoundResult);
|
||||
}
|
||||
} else {
|
||||
existResults.add(oneRoundResult);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
public List<T> getMatches(QueryContext queryContext, List<Term> terms) {
|
||||
Set<Long> detectModelIds = mapperHelper.getModelIds(queryContext.getRequest());
|
||||
terms = filterByModelIds(terms, detectModelIds);
|
||||
Map<MatchText, List<T>> matchResult = match(queryContext, terms, detectModelIds);
|
||||
List<T> matches = new ArrayList<>();
|
||||
if (Objects.isNull(matchResult)) {
|
||||
return matches;
|
||||
}
|
||||
Optional<List<T>> first = matchResult.entrySet().stream()
|
||||
.filter(entry -> CollectionUtils.isNotEmpty(entry.getValue()))
|
||||
.map(entry -> entry.getValue()).findFirst();
|
||||
|
||||
if (first.isPresent()) {
|
||||
matches = first.get();
|
||||
}
|
||||
return matches;
|
||||
}
|
||||
|
||||
public List<Term> filterByModelIds(List<Term> terms, Set<Long> detectModelIds) {
|
||||
logTerms(terms);
|
||||
if (CollectionUtils.isNotEmpty(detectModelIds)) {
|
||||
terms = terms.stream().filter(term -> {
|
||||
Long modelId = NatureHelper.getModelId(term.getNature().toString());
|
||||
if (Objects.nonNull(modelId)) {
|
||||
return detectModelIds.contains(modelId);
|
||||
}
|
||||
return false;
|
||||
}).collect(Collectors.toList());
|
||||
log.info("terms filter by modelIds:{}", detectModelIds);
|
||||
logTerms(terms);
|
||||
}
|
||||
return terms;
|
||||
}
|
||||
|
||||
public void logTerms(List<Term> terms) {
|
||||
if (CollectionUtils.isEmpty(terms)) {
|
||||
return;
|
||||
}
|
||||
for (Term term : terms) {
|
||||
log.debug("word:{},nature:{},frequency:{}", term.word, term.nature.toString(), term.getFrequency());
|
||||
}
|
||||
}
|
||||
|
||||
public abstract boolean needDelete(T oneRoundResult, T existResult);
|
||||
|
||||
public abstract String getMapKey(T a);
|
||||
|
||||
public abstract void detectByStep(QueryContext queryContext, Set<T> results,
|
||||
Set<Long> detectModelIds, Integer startIndex, Integer index, int offset);
|
||||
|
||||
|
||||
}
|
||||
@@ -0,0 +1,62 @@
|
||||
package com.tencent.supersonic.chat.mapper;
|
||||
|
||||
import com.alibaba.fastjson.JSONObject;
|
||||
import com.hankcs.hanlp.seg.common.Term;
|
||||
import com.tencent.supersonic.chat.api.pojo.QueryContext;
|
||||
import com.tencent.supersonic.chat.api.pojo.SchemaElement;
|
||||
import com.tencent.supersonic.chat.api.pojo.SchemaElementMatch;
|
||||
import com.tencent.supersonic.common.util.ContextUtils;
|
||||
import com.tencent.supersonic.common.util.embedding.Retrieval;
|
||||
import com.tencent.supersonic.knowledge.dictionary.EmbeddingResult;
|
||||
import com.tencent.supersonic.knowledge.dictionary.builder.BaseWordBuilder;
|
||||
import com.tencent.supersonic.knowledge.utils.HanlpHelper;
|
||||
import java.util.List;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.apache.commons.lang3.StringUtils;
|
||||
|
||||
/***
|
||||
* A mapper that is capable of semantic understanding of text.
|
||||
*/
|
||||
@Slf4j
|
||||
public class EmbeddingMapper extends BaseMapper {
|
||||
|
||||
@Override
|
||||
public void doMap(QueryContext queryContext) {
|
||||
//1. query from embedding by queryText
|
||||
|
||||
String queryText = queryContext.getRequest().getQueryText();
|
||||
List<Term> terms = HanlpHelper.getTerms(queryText);
|
||||
|
||||
EmbeddingMatchStrategy matchStrategy = ContextUtils.getBean(EmbeddingMatchStrategy.class);
|
||||
List<EmbeddingResult> matchResults = matchStrategy.getMatches(queryContext, terms);
|
||||
|
||||
HanlpHelper.transLetterOriginal(matchResults);
|
||||
|
||||
//2. build SchemaElementMatch by info
|
||||
for (EmbeddingResult matchResult : matchResults) {
|
||||
Long elementId = Retrieval.getLongId(matchResult.getId());
|
||||
|
||||
SchemaElement schemaElement = JSONObject.parseObject(JSONObject.toJSONString(matchResult.getMetadata()),
|
||||
SchemaElement.class);
|
||||
|
||||
if (StringUtils.isBlank(matchResult.getMetadata().get("modelId"))) {
|
||||
continue;
|
||||
}
|
||||
long modelId = Long.parseLong(matchResult.getMetadata().get("modelId"));
|
||||
|
||||
schemaElement = getSchemaElement(modelId, schemaElement.getType(), elementId);
|
||||
if (schemaElement == null) {
|
||||
continue;
|
||||
}
|
||||
SchemaElementMatch schemaElementMatch = SchemaElementMatch.builder()
|
||||
.element(schemaElement)
|
||||
.frequency(BaseWordBuilder.DEFAULT_FREQUENCY)
|
||||
.word(matchResult.getName())
|
||||
.similarity(1 - matchResult.getDistance())
|
||||
.detectWord(matchResult.getDetectWord())
|
||||
.build();
|
||||
//3. add to mapInfo
|
||||
addToSchemaMap(queryContext.getMapInfo(), modelId, schemaElementMatch);
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,136 @@
|
||||
package com.tencent.supersonic.chat.mapper;
|
||||
|
||||
import com.google.common.collect.Lists;
|
||||
import com.tencent.supersonic.chat.api.pojo.QueryContext;
|
||||
import com.tencent.supersonic.chat.config.OptimizationConfig;
|
||||
import com.tencent.supersonic.common.pojo.Constants;
|
||||
import com.tencent.supersonic.common.util.embedding.EmbeddingUtils;
|
||||
import com.tencent.supersonic.common.util.embedding.Retrieval;
|
||||
import com.tencent.supersonic.common.util.embedding.RetrieveQuery;
|
||||
import com.tencent.supersonic.common.util.embedding.RetrieveQueryResult;
|
||||
import com.tencent.supersonic.knowledge.dictionary.EmbeddingResult;
|
||||
import com.tencent.supersonic.semantic.model.domain.listener.MetaEmbeddingListener;
|
||||
import java.util.Comparator;
|
||||
import java.util.HashMap;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
import java.util.Set;
|
||||
import java.util.stream.Collectors;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.apache.commons.collections.CollectionUtils;
|
||||
import org.apache.commons.lang3.StringUtils;
|
||||
import org.springframework.beans.BeanUtils;
|
||||
import org.springframework.beans.factory.annotation.Autowired;
|
||||
import org.springframework.stereotype.Service;
|
||||
|
||||
/**
|
||||
* match strategy implement
|
||||
*/
|
||||
@Service
|
||||
@Slf4j
|
||||
public class EmbeddingMatchStrategy extends BaseMatchStrategy<EmbeddingResult> {
|
||||
|
||||
@Autowired
|
||||
private OptimizationConfig optimizationConfig;
|
||||
@Autowired
|
||||
private EmbeddingUtils embeddingUtils;
|
||||
|
||||
@Override
|
||||
public boolean needDelete(EmbeddingResult oneRoundResult, EmbeddingResult existResult) {
|
||||
return getMapKey(oneRoundResult).equals(getMapKey(existResult))
|
||||
&& existResult.getDistance() > oneRoundResult.getDistance();
|
||||
}
|
||||
|
||||
@Override
|
||||
public String getMapKey(EmbeddingResult a) {
|
||||
return a.getName() + Constants.UNDERLINE + a.getId();
|
||||
}
|
||||
|
||||
|
||||
@Override
|
||||
protected void detectByBatch(QueryContext queryContext, Set<EmbeddingResult> results, Set<Long> detectModelIds,
|
||||
Set<String> detectSegments) {
|
||||
|
||||
List<String> queryTextsList = detectSegments.stream()
|
||||
.map(detectSegment -> detectSegment.trim())
|
||||
.filter(detectSegment -> StringUtils.isNotBlank(detectSegment)
|
||||
&& detectSegment.length() >= optimizationConfig.getEmbeddingMapperWordMin()
|
||||
&& detectSegment.length() <= optimizationConfig.getEmbeddingMapperWordMax())
|
||||
.collect(Collectors.toList());
|
||||
|
||||
List<List<String>> queryTextsSubList = Lists.partition(queryTextsList,
|
||||
optimizationConfig.getEmbeddingMapperBatch());
|
||||
|
||||
for (List<String> queryTextsSub : queryTextsSubList) {
|
||||
detectByQueryTextsSub(results, detectModelIds, queryTextsSub);
|
||||
}
|
||||
}
|
||||
|
||||
private void detectByQueryTextsSub(Set<EmbeddingResult> results, Set<Long> detectModelIds,
|
||||
List<String> queryTextsSub) {
|
||||
int embeddingNumber = optimizationConfig.getEmbeddingMapperNumber();
|
||||
Double distance = optimizationConfig.getEmbeddingMapperDistanceThreshold();
|
||||
Map<String, String> filterCondition = null;
|
||||
// step1. build query params
|
||||
// if only one modelId, add to filterCondition
|
||||
if (CollectionUtils.isNotEmpty(detectModelIds) && detectModelIds.size() == 1) {
|
||||
filterCondition = new HashMap<>();
|
||||
filterCondition.put("modelId", detectModelIds.stream().findFirst().get().toString());
|
||||
}
|
||||
|
||||
RetrieveQuery retrieveQuery = RetrieveQuery.builder()
|
||||
.queryTextsList(queryTextsSub)
|
||||
.filterCondition(filterCondition)
|
||||
.queryEmbeddings(null)
|
||||
.build();
|
||||
// step2. retrieveQuery by detectSegment
|
||||
List<RetrieveQueryResult> retrieveQueryResults = embeddingUtils.retrieveQuery(
|
||||
MetaEmbeddingListener.COLLECTION_NAME, retrieveQuery, embeddingNumber);
|
||||
|
||||
if (CollectionUtils.isEmpty(retrieveQueryResults)) {
|
||||
return;
|
||||
}
|
||||
// step3. build EmbeddingResults. filter by modelId
|
||||
List<EmbeddingResult> collect = retrieveQueryResults.stream()
|
||||
.map(retrieveQueryResult -> {
|
||||
List<Retrieval> retrievals = retrieveQueryResult.getRetrieval();
|
||||
if (CollectionUtils.isNotEmpty(retrievals)) {
|
||||
retrievals.removeIf(retrieval -> retrieval.getDistance() > distance.doubleValue());
|
||||
if (CollectionUtils.isNotEmpty(detectModelIds)) {
|
||||
retrievals.removeIf(retrieval -> {
|
||||
String modelIdStr = retrieval.getMetadata().get("modelId");
|
||||
if (StringUtils.isBlank(modelIdStr)) {
|
||||
return true;
|
||||
}
|
||||
return detectModelIds.contains(Long.parseLong(modelIdStr));
|
||||
});
|
||||
}
|
||||
}
|
||||
return retrieveQueryResult;
|
||||
})
|
||||
.filter(retrieveQueryResult -> CollectionUtils.isNotEmpty(retrieveQueryResult.getRetrieval()))
|
||||
.flatMap(retrieveQueryResult -> retrieveQueryResult.getRetrieval().stream()
|
||||
.map(retrieval -> {
|
||||
EmbeddingResult embeddingResult = new EmbeddingResult();
|
||||
BeanUtils.copyProperties(retrieval, embeddingResult);
|
||||
embeddingResult.setDetectWord(retrieveQueryResult.getQuery());
|
||||
embeddingResult.setName(retrieval.getQuery());
|
||||
return embeddingResult;
|
||||
}))
|
||||
.collect(Collectors.toList());
|
||||
|
||||
// step4. select mapResul in one round
|
||||
int roundNumber = optimizationConfig.getEmbeddingMapperRoundNumber() * queryTextsSub.size();
|
||||
List<EmbeddingResult> oneRoundResults = collect.stream()
|
||||
.sorted(Comparator.comparingDouble(EmbeddingResult::getDistance))
|
||||
.limit(roundNumber)
|
||||
.collect(Collectors.toList());
|
||||
selectResultInOneRound(results, oneRoundResults);
|
||||
}
|
||||
|
||||
@Override
|
||||
public void detectByStep(QueryContext queryContext, Set<EmbeddingResult> existResults, Set<Long> detectModelIds,
|
||||
Integer startIndex, Integer index, int offset) {
|
||||
return;
|
||||
}
|
||||
}
|
||||
@@ -1,27 +1,29 @@
|
||||
package com.tencent.supersonic.chat.mapper;
|
||||
|
||||
|
||||
import com.tencent.supersonic.chat.api.component.SchemaMapper;
|
||||
import com.tencent.supersonic.chat.api.pojo.QueryContext;
|
||||
import com.tencent.supersonic.chat.api.pojo.SchemaMapInfo;
|
||||
import com.tencent.supersonic.chat.api.pojo.SchemaElementMatch;
|
||||
import com.tencent.supersonic.chat.api.pojo.SchemaElement;
|
||||
import com.tencent.supersonic.chat.api.pojo.SchemaElementType;
|
||||
import com.tencent.supersonic.chat.api.pojo.ModelSchema;
|
||||
import com.tencent.supersonic.chat.api.pojo.QueryContext;
|
||||
import com.tencent.supersonic.chat.api.pojo.SchemaElement;
|
||||
import com.tencent.supersonic.chat.api.pojo.SchemaElementMatch;
|
||||
import com.tencent.supersonic.chat.api.pojo.SchemaElementType;
|
||||
import com.tencent.supersonic.chat.api.pojo.SchemaMapInfo;
|
||||
import com.tencent.supersonic.chat.service.SemanticService;
|
||||
import com.tencent.supersonic.common.util.ContextUtils;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.springframework.beans.BeanUtils;
|
||||
import org.springframework.util.CollectionUtils;
|
||||
|
||||
import java.util.List;
|
||||
import java.util.stream.Collectors;
|
||||
|
||||
|
||||
/**
|
||||
* A mapper capable of converting the VALUE of entity dimension values into ID types.
|
||||
*/
|
||||
@Slf4j
|
||||
public class EntityMapper implements SchemaMapper {
|
||||
public class EntityMapper extends BaseMapper {
|
||||
|
||||
@Override
|
||||
public void map(QueryContext queryContext) {
|
||||
public void doMap(QueryContext queryContext) {
|
||||
SchemaMapInfo schemaMapInfo = queryContext.getMapInfo();
|
||||
for (Long modelId : schemaMapInfo.getMatchedModels()) {
|
||||
List<SchemaElementMatch> schemaElementMatchList = schemaMapInfo.getMatchedElements(modelId);
|
||||
@@ -32,8 +34,9 @@ public class EntityMapper implements SchemaMapper {
|
||||
if (entity == null || entity.getId() == null) {
|
||||
continue;
|
||||
}
|
||||
List<SchemaElementMatch> valueSchemaElements = schemaElementMatchList.stream().filter(schemaElementMatch ->
|
||||
SchemaElementType.VALUE.equals(schemaElementMatch.getElement().getType()))
|
||||
List<SchemaElementMatch> valueSchemaElements = schemaElementMatchList.stream()
|
||||
.filter(schemaElementMatch ->
|
||||
SchemaElementType.VALUE.equals(schemaElementMatch.getElement().getType()))
|
||||
.collect(Collectors.toList());
|
||||
for (SchemaElementMatch schemaElementMatch : valueSchemaElements) {
|
||||
if (!entity.getId().equals(schemaElementMatch.getElement().getId())) {
|
||||
@@ -51,7 +54,7 @@ public class EntityMapper implements SchemaMapper {
|
||||
}
|
||||
|
||||
private boolean checkExistSameEntitySchemaElements(SchemaElementMatch valueSchemaElementMatch,
|
||||
List<SchemaElementMatch> schemaElementMatchList) {
|
||||
List<SchemaElementMatch> schemaElementMatchList) {
|
||||
List<SchemaElementMatch> entitySchemaElements = schemaElementMatchList.stream().filter(schemaElementMatch ->
|
||||
SchemaElementType.ENTITY.equals(schemaElementMatch.getElement().getType()))
|
||||
.collect(Collectors.toList());
|
||||
|
||||
@@ -1,179 +1,67 @@
|
||||
package com.tencent.supersonic.chat.mapper;
|
||||
|
||||
import com.hankcs.hanlp.seg.common.Term;
|
||||
import com.tencent.supersonic.chat.api.component.SchemaMapper;
|
||||
import com.tencent.supersonic.chat.api.pojo.SchemaMapInfo;
|
||||
import com.tencent.supersonic.chat.api.pojo.SchemaElement;
|
||||
import com.tencent.supersonic.chat.api.pojo.SchemaElementType;
|
||||
import com.tencent.supersonic.chat.api.pojo.SchemaElementMatch;
|
||||
import com.tencent.supersonic.chat.api.pojo.SemanticSchema;
|
||||
import com.tencent.supersonic.chat.api.pojo.QueryContext;
|
||||
import com.tencent.supersonic.chat.config.OptimizationConfig;
|
||||
import com.tencent.supersonic.chat.api.pojo.SchemaElement;
|
||||
import com.tencent.supersonic.chat.api.pojo.SchemaElementMatch;
|
||||
import com.tencent.supersonic.chat.api.pojo.SchemaElementType;
|
||||
import com.tencent.supersonic.chat.api.pojo.SchemaMapInfo;
|
||||
import com.tencent.supersonic.common.util.ContextUtils;
|
||||
import com.tencent.supersonic.knowledge.service.SchemaService;
|
||||
import com.tencent.supersonic.knowledge.dictionary.FuzzyResult;
|
||||
import com.tencent.supersonic.knowledge.utils.HanlpHelper;
|
||||
import java.util.ArrayList;
|
||||
import java.util.Comparator;
|
||||
import java.util.HashMap;
|
||||
import java.util.HashSet;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
import java.util.Map.Entry;
|
||||
import java.util.Objects;
|
||||
import java.util.Set;
|
||||
import java.util.stream.Collectors;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.springframework.util.CollectionUtils;
|
||||
|
||||
/***
|
||||
* A mapper capable of fuzzy parsing of metric names and dimension names.
|
||||
*/
|
||||
@Slf4j
|
||||
public class FuzzyNameMapper implements SchemaMapper {
|
||||
public class FuzzyNameMapper extends BaseMapper {
|
||||
|
||||
@Override
|
||||
public void map(QueryContext queryContext) {
|
||||
|
||||
log.debug("before db mapper,mapInfo:{}", queryContext.getMapInfo());
|
||||
public void doMap(QueryContext queryContext) {
|
||||
|
||||
List<Term> terms = HanlpHelper.getTerms(queryContext.getRequest().getQueryText());
|
||||
|
||||
SemanticSchema semanticSchema = ContextUtils.getBean(SchemaService.class).getSemanticSchema();
|
||||
|
||||
detectAndAddToSchema(queryContext, terms, semanticSchema.getDimensions(), SchemaElementType.DIMENSION);
|
||||
|
||||
detectAndAddToSchema(queryContext, terms, semanticSchema.getMetrics(), SchemaElementType.METRIC);
|
||||
|
||||
log.debug("after db mapper,mapInfo:{}", queryContext.getMapInfo());
|
||||
}
|
||||
|
||||
private void detectAndAddToSchema(QueryContext queryContext, List<Term> terms, List<SchemaElement> models,
|
||||
SchemaElementType schemaElementType) {
|
||||
try {
|
||||
|
||||
Map<String, Set<SchemaElement>> modelResultSet = getResultSet(queryContext, terms, models);
|
||||
|
||||
addToSchemaMapInfo(modelResultSet, queryContext.getMapInfo(), schemaElementType);
|
||||
|
||||
} catch (Exception e) {
|
||||
log.error("detectAndAddToSchema error", e);
|
||||
}
|
||||
}
|
||||
|
||||
private Map<String, Set<SchemaElement>> getResultSet(QueryContext queryContext, List<Term> terms,
|
||||
List<SchemaElement> models) {
|
||||
|
||||
String queryText = queryContext.getRequest().getQueryText();
|
||||
FuzzyNameMatchStrategy fuzzyNameMatchStrategy = ContextUtils.getBean(FuzzyNameMatchStrategy.class);
|
||||
|
||||
MapperHelper mapperHelper = ContextUtils.getBean(MapperHelper.class);
|
||||
Set<Long> modelIds = mapperHelper.getModelIds(queryContext.getRequest());
|
||||
|
||||
Double metricDimensionThresholdConfig = getThreshold(queryContext, mapperHelper);
|
||||
List<FuzzyResult> matches = fuzzyNameMatchStrategy.getMatches(queryContext, terms);
|
||||
|
||||
Map<String, Set<SchemaElement>> nameToItems = getNameToItems(models);
|
||||
|
||||
Map<Integer, Integer> regOffsetToLength = terms.stream().sorted(Comparator.comparing(Term::length))
|
||||
.collect(Collectors.toMap(Term::getOffset, term -> term.word.length(), (value1, value2) -> value2));
|
||||
|
||||
Map<String, Set<SchemaElement>> modelResultSet = new HashMap<>();
|
||||
for (Integer startIndex = 0; startIndex <= queryText.length() - 1; ) {
|
||||
for (Integer endIndex = startIndex; endIndex <= queryText.length(); ) {
|
||||
endIndex = mapperHelper.getStepIndex(regOffsetToLength, endIndex);
|
||||
if (endIndex > queryText.length()) {
|
||||
continue;
|
||||
}
|
||||
String detectSegment = queryText.substring(startIndex, endIndex);
|
||||
|
||||
for (Entry<String, Set<SchemaElement>> entry : nameToItems.entrySet()) {
|
||||
String name = entry.getKey();
|
||||
Set<SchemaElement> schemaElements = entry.getValue();
|
||||
if (!name.contains(detectSegment)
|
||||
|| mapperHelper.getSimilarity(detectSegment, name) < metricDimensionThresholdConfig) {
|
||||
continue;
|
||||
}
|
||||
if (!CollectionUtils.isEmpty(modelIds)) {
|
||||
schemaElements = schemaElements.stream()
|
||||
.filter(schemaElement -> modelIds.contains(schemaElement.getModel()))
|
||||
.collect(Collectors.toSet());
|
||||
}
|
||||
Set<SchemaElement> preSchemaElements = modelResultSet.putIfAbsent(detectSegment, schemaElements);
|
||||
if (Objects.nonNull(preSchemaElements)) {
|
||||
preSchemaElements.addAll(schemaElements);
|
||||
}
|
||||
}
|
||||
for (FuzzyResult match : matches) {
|
||||
SchemaElement schemaElement = match.getSchemaElement();
|
||||
Set<Long> regElementSet = getRegElementSet(queryContext.getMapInfo(), schemaElement);
|
||||
if (regElementSet.contains(schemaElement.getId())) {
|
||||
continue;
|
||||
}
|
||||
startIndex = mapperHelper.getStepIndex(regOffsetToLength, startIndex);
|
||||
SchemaElementMatch schemaElementMatch = SchemaElementMatch.builder()
|
||||
.element(schemaElement)
|
||||
.word(schemaElement.getName())
|
||||
.detectWord(match.getDetectWord())
|
||||
.frequency(10000L)
|
||||
.similarity(mapperHelper.getSimilarity(match.getDetectWord(), schemaElement.getName()))
|
||||
.build();
|
||||
log.info("add to schema, elementMatch {}", schemaElementMatch);
|
||||
addToSchemaMap(queryContext.getMapInfo(), schemaElement.getModel(), schemaElementMatch);
|
||||
}
|
||||
return modelResultSet;
|
||||
}
|
||||
|
||||
private Double getThreshold(QueryContext queryContext, MapperHelper mapperHelper) {
|
||||
|
||||
OptimizationConfig optimizationConfig = ContextUtils.getBean(OptimizationConfig.class);
|
||||
Double metricDimensionThresholdConfig = optimizationConfig.getMetricDimensionThresholdConfig();
|
||||
Double metricDimensionMinThresholdConfig = optimizationConfig.getMetricDimensionMinThresholdConfig();
|
||||
|
||||
Map<Long, List<SchemaElementMatch>> modelElementMatches = queryContext.getMapInfo()
|
||||
.getModelElementMatches();
|
||||
boolean existElement = modelElementMatches.entrySet().stream()
|
||||
.anyMatch(entry -> entry.getValue().size() >= 1);
|
||||
|
||||
if (!existElement) {
|
||||
double halfThreshold = metricDimensionThresholdConfig / 2;
|
||||
|
||||
metricDimensionThresholdConfig = halfThreshold >= metricDimensionMinThresholdConfig ? halfThreshold
|
||||
: metricDimensionMinThresholdConfig;
|
||||
log.info("ModelElementMatches:{} , not exist Element metricDimensionThresholdConfig reduce by half:{}",
|
||||
modelElementMatches, metricDimensionThresholdConfig);
|
||||
}
|
||||
return metricDimensionThresholdConfig;
|
||||
}
|
||||
|
||||
private Map<String, Set<SchemaElement>> getNameToItems(List<SchemaElement> models) {
|
||||
return models.stream().collect(
|
||||
Collectors.toMap(SchemaElement::getName, a -> {
|
||||
Set<SchemaElement> result = new HashSet<>();
|
||||
result.add(a);
|
||||
return result;
|
||||
}, (k1, k2) -> {
|
||||
k1.addAll(k2);
|
||||
return k1;
|
||||
}));
|
||||
}
|
||||
|
||||
private void addToSchemaMapInfo(Map<String, Set<SchemaElement>> mapResultRowSet, SchemaMapInfo schemaMap,
|
||||
SchemaElementType schemaElementType) {
|
||||
if (Objects.isNull(mapResultRowSet) || mapResultRowSet.size() <= 0) {
|
||||
return;
|
||||
}
|
||||
MapperHelper mapperHelper = ContextUtils.getBean(MapperHelper.class);
|
||||
|
||||
for (Map.Entry<String, Set<SchemaElement>> entry : mapResultRowSet.entrySet()) {
|
||||
String detectWord = entry.getKey();
|
||||
Set<SchemaElement> schemaElements = entry.getValue();
|
||||
for (SchemaElement schemaElement : schemaElements) {
|
||||
|
||||
List<SchemaElementMatch> elements = schemaMap.getMatchedElements(schemaElement.getModel());
|
||||
if (CollectionUtils.isEmpty(elements)) {
|
||||
elements = new ArrayList<>();
|
||||
schemaMap.setMatchedElements(schemaElement.getModel(), elements);
|
||||
}
|
||||
Set<Long> regElementSet = elements.stream()
|
||||
.filter(elementMatch -> schemaElementType.equals(elementMatch.getElement().getType()))
|
||||
.map(elementMatch -> elementMatch.getElement().getId())
|
||||
.collect(Collectors.toSet());
|
||||
|
||||
if (regElementSet.contains(schemaElement.getId())) {
|
||||
continue;
|
||||
}
|
||||
SchemaElementMatch schemaElementMatch = SchemaElementMatch.builder()
|
||||
.element(schemaElement)
|
||||
.word(schemaElement.getName())
|
||||
.detectWord(detectWord)
|
||||
.frequency(10000L)
|
||||
.similarity(mapperHelper.getSimilarity(detectWord, schemaElement.getName()))
|
||||
.build();
|
||||
log.info("schemaElementType:{},add to schema, elementMatch {}", schemaElementType, schemaElementMatch);
|
||||
elements.add(schemaElementMatch);
|
||||
}
|
||||
private Set<Long> getRegElementSet(SchemaMapInfo schemaMap, SchemaElement schemaElement) {
|
||||
List<SchemaElementMatch> elements = schemaMap.getMatchedElements(schemaElement.getModel());
|
||||
if (CollectionUtils.isEmpty(elements)) {
|
||||
return new HashSet<>();
|
||||
}
|
||||
return elements.stream()
|
||||
.filter(elementMatch ->
|
||||
SchemaElementType.METRIC.equals(elementMatch.getElement().getType())
|
||||
|| SchemaElementType.DIMENSION.equals(elementMatch.getElement().getType()))
|
||||
.map(elementMatch -> elementMatch.getElement().getId())
|
||||
.collect(Collectors.toSet());
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
@@ -0,0 +1,131 @@
|
||||
package com.tencent.supersonic.chat.mapper;
|
||||
|
||||
import com.hankcs.hanlp.seg.common.Term;
|
||||
import com.tencent.supersonic.chat.api.pojo.QueryContext;
|
||||
import com.tencent.supersonic.chat.api.pojo.SchemaElement;
|
||||
import com.tencent.supersonic.chat.api.pojo.SchemaElementMatch;
|
||||
import com.tencent.supersonic.chat.config.OptimizationConfig;
|
||||
import com.tencent.supersonic.common.pojo.Constants;
|
||||
import com.tencent.supersonic.knowledge.dictionary.FuzzyResult;
|
||||
import com.tencent.supersonic.knowledge.service.SchemaService;
|
||||
import java.util.ArrayList;
|
||||
import java.util.HashSet;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
import java.util.Map.Entry;
|
||||
import java.util.Set;
|
||||
import java.util.stream.Collectors;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.apache.commons.lang3.StringUtils;
|
||||
import org.springframework.beans.factory.annotation.Autowired;
|
||||
import org.springframework.stereotype.Service;
|
||||
import org.springframework.util.CollectionUtils;
|
||||
|
||||
/**
|
||||
* Fuzzy Name Match Strategy
|
||||
*/
|
||||
@Service
|
||||
@Slf4j
|
||||
public class FuzzyNameMatchStrategy extends BaseMatchStrategy<FuzzyResult> {
|
||||
|
||||
@Autowired
|
||||
private OptimizationConfig optimizationConfig;
|
||||
@Autowired
|
||||
private MapperHelper mapperHelper;
|
||||
@Autowired
|
||||
private SchemaService schemaService;
|
||||
private List<SchemaElement> allElements;
|
||||
|
||||
|
||||
@Override
|
||||
public Map<MatchText, List<FuzzyResult>> match(QueryContext queryContext, List<Term> terms,
|
||||
Set<Long> detectModelIds) {
|
||||
this.allElements = getSchemaElements();
|
||||
return super.match(queryContext, terms, detectModelIds);
|
||||
}
|
||||
|
||||
@Override
|
||||
public boolean needDelete(FuzzyResult oneRoundResult, FuzzyResult existResult) {
|
||||
return getMapKey(oneRoundResult).equals(getMapKey(existResult))
|
||||
&& existResult.getDetectWord().length() < oneRoundResult.getDetectWord().length();
|
||||
}
|
||||
|
||||
@Override
|
||||
public String getMapKey(FuzzyResult a) {
|
||||
return a.getName() + Constants.UNDERLINE + a.getSchemaElement().getId()
|
||||
+ Constants.UNDERLINE + a.getSchemaElement().getName();
|
||||
}
|
||||
|
||||
public void detectByStep(QueryContext queryContext, Set<FuzzyResult> existResults, Set<Long> detectModelIds,
|
||||
Integer startIndex, Integer index, int offset) {
|
||||
String detectSegment = queryContext.getRequest().getQueryText().substring(startIndex, index);
|
||||
if (StringUtils.isBlank(detectSegment)) {
|
||||
return;
|
||||
}
|
||||
Set<Long> modelIds = mapperHelper.getModelIds(queryContext.getRequest());
|
||||
|
||||
Double metricDimensionThresholdConfig = getThreshold(queryContext);
|
||||
|
||||
Map<String, Set<SchemaElement>> nameToItems = getNameToItems(allElements);
|
||||
|
||||
for (Entry<String, Set<SchemaElement>> entry : nameToItems.entrySet()) {
|
||||
String name = entry.getKey();
|
||||
if (!name.contains(detectSegment)
|
||||
|| mapperHelper.getSimilarity(detectSegment, name) < metricDimensionThresholdConfig) {
|
||||
continue;
|
||||
}
|
||||
Set<SchemaElement> schemaElements = entry.getValue();
|
||||
if (!CollectionUtils.isEmpty(modelIds)) {
|
||||
schemaElements = schemaElements.stream()
|
||||
.filter(schemaElement -> modelIds.contains(schemaElement.getModel()))
|
||||
.collect(Collectors.toSet());
|
||||
}
|
||||
for (SchemaElement schemaElement : schemaElements) {
|
||||
FuzzyResult fuzzyResult = new FuzzyResult();
|
||||
fuzzyResult.setDetectWord(detectSegment);
|
||||
fuzzyResult.setName(schemaElement.getName());
|
||||
fuzzyResult.setSchemaElement(schemaElement);
|
||||
existResults.add(fuzzyResult);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
private List<SchemaElement> getSchemaElements() {
|
||||
List<SchemaElement> allElements = new ArrayList<>();
|
||||
allElements.addAll(schemaService.getSemanticSchema().getDimensions());
|
||||
allElements.addAll(schemaService.getSemanticSchema().getMetrics());
|
||||
return allElements;
|
||||
}
|
||||
|
||||
|
||||
private Double getThreshold(QueryContext queryContext) {
|
||||
Double metricDimensionThresholdConfig = optimizationConfig.getMetricDimensionThresholdConfig();
|
||||
Double metricDimensionMinThresholdConfig = optimizationConfig.getMetricDimensionMinThresholdConfig();
|
||||
|
||||
Map<Long, List<SchemaElementMatch>> modelElementMatches = queryContext.getMapInfo().getModelElementMatches();
|
||||
|
||||
boolean existElement = modelElementMatches.entrySet().stream().anyMatch(entry -> entry.getValue().size() >= 1);
|
||||
|
||||
if (!existElement) {
|
||||
double halfThreshold = metricDimensionThresholdConfig / 2;
|
||||
|
||||
metricDimensionThresholdConfig = halfThreshold >= metricDimensionMinThresholdConfig ? halfThreshold
|
||||
: metricDimensionMinThresholdConfig;
|
||||
log.info("ModelElementMatches:{} , not exist Element metricDimensionThresholdConfig reduce by half:{}",
|
||||
modelElementMatches, metricDimensionThresholdConfig);
|
||||
}
|
||||
return metricDimensionThresholdConfig;
|
||||
}
|
||||
|
||||
private Map<String, Set<SchemaElement>> getNameToItems(List<SchemaElement> models) {
|
||||
return models.stream().collect(
|
||||
Collectors.toMap(SchemaElement::getName, a -> {
|
||||
Set<SchemaElement> result = new HashSet<>();
|
||||
result.add(a);
|
||||
return result;
|
||||
}, (k1, k2) -> {
|
||||
k1.addAll(k2);
|
||||
return k1;
|
||||
}));
|
||||
}
|
||||
}
|
||||
@@ -1,83 +1,48 @@
|
||||
package com.tencent.supersonic.chat.mapper;
|
||||
|
||||
import com.hankcs.hanlp.seg.common.Term;
|
||||
import com.tencent.supersonic.chat.api.component.SchemaMapper;
|
||||
import com.tencent.supersonic.chat.api.pojo.ModelSchema;
|
||||
import com.tencent.supersonic.chat.api.pojo.QueryContext;
|
||||
import com.tencent.supersonic.chat.api.pojo.SchemaElement;
|
||||
import com.tencent.supersonic.chat.api.pojo.SchemaElementMatch;
|
||||
import com.tencent.supersonic.chat.api.pojo.SchemaElementType;
|
||||
import com.tencent.supersonic.chat.api.pojo.SchemaMapInfo;
|
||||
import com.tencent.supersonic.chat.service.SemanticService;
|
||||
import com.tencent.supersonic.common.util.ContextUtils;
|
||||
import com.tencent.supersonic.common.pojo.enums.DictWordType;
|
||||
import com.tencent.supersonic.knowledge.dictionary.MapResult;
|
||||
import com.tencent.supersonic.knowledge.dictionary.builder.BaseWordBuilder;
|
||||
import com.tencent.supersonic.knowledge.dictionary.builder.WordBuilderFactory;
|
||||
import com.tencent.supersonic.knowledge.dictionary.HanlpMapResult;
|
||||
import com.tencent.supersonic.knowledge.utils.HanlpHelper;
|
||||
import com.tencent.supersonic.knowledge.utils.NatureHelper;
|
||||
import java.util.ArrayList;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
import java.util.Objects;
|
||||
import java.util.Optional;
|
||||
import java.util.Set;
|
||||
import java.util.stream.Collectors;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.apache.commons.collections.CollectionUtils;
|
||||
import org.apache.commons.lang3.StringUtils;
|
||||
import org.springframework.beans.BeanUtils;
|
||||
|
||||
/***
|
||||
* A mapper capable of prefix and suffix similarity parsing for
|
||||
* domain names, dimension values, metric names, and dimension names.
|
||||
*/
|
||||
@Slf4j
|
||||
public class HanlpDictMapper implements SchemaMapper {
|
||||
public class HanlpDictMapper extends BaseMapper {
|
||||
|
||||
@Override
|
||||
public void map(QueryContext queryContext) {
|
||||
public void doMap(QueryContext queryContext) {
|
||||
|
||||
String queryText = queryContext.getRequest().getQueryText();
|
||||
List<Term> terms = HanlpHelper.getTerms(queryText);
|
||||
|
||||
QueryMatchStrategy matchStrategy = ContextUtils.getBean(QueryMatchStrategy.class);
|
||||
MapperHelper mapperHelper = ContextUtils.getBean(MapperHelper.class);
|
||||
Set<Long> detectModelIds = mapperHelper.getModelIds(queryContext.getRequest());
|
||||
HanlpDictMatchStrategy matchStrategy = ContextUtils.getBean(HanlpDictMatchStrategy.class);
|
||||
|
||||
terms = filterByModelIds(terms, detectModelIds);
|
||||
|
||||
Map<MatchText, List<MapResult>> matchResult = matchStrategy.match(queryContext.getRequest(), terms,
|
||||
detectModelIds);
|
||||
|
||||
List<MapResult> matches = getMatches(matchResult);
|
||||
List<HanlpMapResult> matches = matchStrategy.getMatches(queryContext, terms);
|
||||
|
||||
HanlpHelper.transLetterOriginal(matches);
|
||||
|
||||
log.info("queryContext:{},matches:{}", queryContext, matches);
|
||||
|
||||
convertTermsToSchemaMapInfo(matches, queryContext.getMapInfo(), terms);
|
||||
}
|
||||
|
||||
private List<Term> filterByModelIds(List<Term> terms, Set<Long> detectModelIds) {
|
||||
for (Term term : terms) {
|
||||
log.info("before word:{},nature:{},frequency:{}", term.word, term.nature.toString(), term.getFrequency());
|
||||
}
|
||||
if (CollectionUtils.isNotEmpty(detectModelIds)) {
|
||||
terms = terms.stream().filter(term -> {
|
||||
Long modelId = NatureHelper.getModelId(term.getNature().toString());
|
||||
if (Objects.nonNull(modelId)) {
|
||||
return detectModelIds.contains(modelId);
|
||||
}
|
||||
return false;
|
||||
}).collect(Collectors.toList());
|
||||
}
|
||||
for (Term term : terms) {
|
||||
log.info("after filter word:{},nature:{},frequency:{}", term.word, term.nature.toString(),
|
||||
term.getFrequency());
|
||||
}
|
||||
return terms;
|
||||
}
|
||||
|
||||
|
||||
private void convertTermsToSchemaMapInfo(List<MapResult> mapResults, SchemaMapInfo schemaMap, List<Term> terms) {
|
||||
if (CollectionUtils.isEmpty(mapResults)) {
|
||||
private void convertTermsToSchemaMapInfo(List<HanlpMapResult> hanlpMapResults, SchemaMapInfo schemaMap,
|
||||
List<Term> terms) {
|
||||
if (CollectionUtils.isEmpty(hanlpMapResults)) {
|
||||
return;
|
||||
}
|
||||
|
||||
@@ -85,8 +50,8 @@ public class HanlpDictMapper implements SchemaMapper {
|
||||
Collectors.toMap(entry -> entry.getWord() + entry.getNature(),
|
||||
term -> Long.valueOf(term.getFrequency()), (value1, value2) -> value2));
|
||||
|
||||
for (MapResult mapResult : mapResults) {
|
||||
for (String nature : mapResult.getNatures()) {
|
||||
for (HanlpMapResult hanlpMapResult : hanlpMapResults) {
|
||||
for (String nature : hanlpMapResult.getNatures()) {
|
||||
Long modelId = NatureHelper.getModelId(nature);
|
||||
if (Objects.isNull(modelId)) {
|
||||
continue;
|
||||
@@ -95,68 +60,27 @@ public class HanlpDictMapper implements SchemaMapper {
|
||||
if (Objects.isNull(elementType)) {
|
||||
continue;
|
||||
}
|
||||
|
||||
SemanticService schemaService = ContextUtils.getBean(SemanticService.class);
|
||||
ModelSchema modelSchema = schemaService.getModelSchema(modelId);
|
||||
|
||||
BaseWordBuilder baseWordBuilder = WordBuilderFactory.get(DictWordType.getNatureType(nature));
|
||||
Long elementID = baseWordBuilder.getElementID(nature);
|
||||
Long frequency = wordNatureToFrequency.get(mapResult.getName() + nature);
|
||||
|
||||
SchemaElement elementDb = modelSchema.getElement(elementType, elementID);
|
||||
if (Objects.isNull(elementDb)) {
|
||||
log.info("element is null, elementType:{},elementID:{}", elementType, elementID);
|
||||
Long elementID = NatureHelper.getElementID(nature);
|
||||
SchemaElement element = getSchemaElement(modelId, elementType, elementID);
|
||||
if (element == null) {
|
||||
continue;
|
||||
}
|
||||
SchemaElement element = new SchemaElement();
|
||||
BeanUtils.copyProperties(elementDb, element);
|
||||
element.setAlias(getAlias(elementDb));
|
||||
if (element.getType().equals(SchemaElementType.VALUE)) {
|
||||
element.setName(mapResult.getName());
|
||||
element.setName(hanlpMapResult.getName());
|
||||
}
|
||||
Long frequency = wordNatureToFrequency.get(hanlpMapResult.getName() + nature);
|
||||
SchemaElementMatch schemaElementMatch = SchemaElementMatch.builder()
|
||||
.element(element)
|
||||
.frequency(frequency)
|
||||
.word(mapResult.getName())
|
||||
.similarity(mapResult.getSimilarity())
|
||||
.detectWord(mapResult.getDetectWord())
|
||||
.word(hanlpMapResult.getName())
|
||||
.similarity(hanlpMapResult.getSimilarity())
|
||||
.detectWord(hanlpMapResult.getDetectWord())
|
||||
.build();
|
||||
|
||||
Map<Long, List<SchemaElementMatch>> modelElementMatches = schemaMap.getModelElementMatches();
|
||||
List<SchemaElementMatch> schemaElementMatches = modelElementMatches.putIfAbsent(modelId,
|
||||
new ArrayList<>());
|
||||
if (schemaElementMatches == null) {
|
||||
schemaElementMatches = modelElementMatches.get(modelId);
|
||||
}
|
||||
schemaElementMatches.add(schemaElementMatch);
|
||||
addToSchemaMap(schemaMap, modelId, schemaElementMatch);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
private List<MapResult> getMatches(Map<MatchText, List<MapResult>> matchResult) {
|
||||
List<MapResult> matches = new ArrayList<>();
|
||||
if (Objects.isNull(matchResult)) {
|
||||
return matches;
|
||||
}
|
||||
Optional<List<MapResult>> first = matchResult.entrySet().stream()
|
||||
.filter(entry -> CollectionUtils.isNotEmpty(entry.getValue()))
|
||||
.map(entry -> entry.getValue()).findFirst();
|
||||
|
||||
if (first.isPresent()) {
|
||||
matches = first.get();
|
||||
}
|
||||
return matches;
|
||||
}
|
||||
|
||||
public List<String> getAlias(SchemaElement element) {
|
||||
if (!SchemaElementType.VALUE.equals(element.getType())) {
|
||||
return element.getAlias();
|
||||
}
|
||||
if (CollectionUtils.isNotEmpty(element.getAlias()) && StringUtils.isNotEmpty(element.getName())) {
|
||||
return element.getAlias().stream()
|
||||
.filter(aliasItem -> aliasItem.contains(element.getName()))
|
||||
.collect(Collectors.toList());
|
||||
}
|
||||
return element.getAlias();
|
||||
}
|
||||
}
|
||||
|
||||
@@ -0,0 +1,123 @@
|
||||
package com.tencent.supersonic.chat.mapper;
|
||||
|
||||
import com.hankcs.hanlp.seg.common.Term;
|
||||
import com.tencent.supersonic.chat.api.pojo.QueryContext;
|
||||
import com.tencent.supersonic.chat.api.pojo.request.QueryReq;
|
||||
import com.tencent.supersonic.chat.config.OptimizationConfig;
|
||||
import com.tencent.supersonic.common.pojo.Constants;
|
||||
import com.tencent.supersonic.knowledge.dictionary.HanlpMapResult;
|
||||
import com.tencent.supersonic.knowledge.service.SearchService;
|
||||
import java.util.HashMap;
|
||||
import java.util.LinkedHashSet;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
import java.util.Objects;
|
||||
import java.util.Set;
|
||||
import java.util.stream.Collectors;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.apache.commons.collections.CollectionUtils;
|
||||
import org.apache.commons.lang3.StringUtils;
|
||||
import org.springframework.beans.factory.annotation.Autowired;
|
||||
import org.springframework.stereotype.Service;
|
||||
|
||||
/**
|
||||
* match strategy implement
|
||||
*/
|
||||
@Service
|
||||
@Slf4j
|
||||
public class HanlpDictMatchStrategy extends BaseMatchStrategy<HanlpMapResult> {
|
||||
|
||||
@Autowired
|
||||
private MapperHelper mapperHelper;
|
||||
|
||||
@Autowired
|
||||
private OptimizationConfig optimizationConfig;
|
||||
|
||||
@Override
|
||||
public Map<MatchText, List<HanlpMapResult>> match(QueryContext queryContext, List<Term> terms,
|
||||
Set<Long> detectModelIds) {
|
||||
QueryReq queryReq = queryContext.getRequest();
|
||||
String text = queryReq.getQueryText();
|
||||
if (Objects.isNull(terms) || StringUtils.isEmpty(text)) {
|
||||
return null;
|
||||
}
|
||||
|
||||
log.debug("retryCount:{},terms:{},,detectModelIds:{}", terms, detectModelIds);
|
||||
|
||||
List<HanlpMapResult> detects = detect(queryContext, terms, detectModelIds);
|
||||
Map<MatchText, List<HanlpMapResult>> result = new HashMap<>();
|
||||
|
||||
result.put(MatchText.builder().regText(text).detectSegment(text).build(), detects);
|
||||
return result;
|
||||
}
|
||||
|
||||
@Override
|
||||
public boolean needDelete(HanlpMapResult oneRoundResult, HanlpMapResult existResult) {
|
||||
return getMapKey(oneRoundResult).equals(getMapKey(existResult))
|
||||
&& existResult.getDetectWord().length() < oneRoundResult.getDetectWord().length();
|
||||
}
|
||||
|
||||
public void detectByStep(QueryContext queryContext, Set<HanlpMapResult> existResults, Set<Long> detectModelIds,
|
||||
Integer startIndex, Integer index, int offset) {
|
||||
QueryReq queryReq = queryContext.getRequest();
|
||||
String text = queryReq.getQueryText();
|
||||
Integer agentId = queryReq.getAgentId();
|
||||
String detectSegment = text.substring(startIndex, index);
|
||||
|
||||
// step1. pre search
|
||||
Integer oneDetectionMaxSize = optimizationConfig.getOneDetectionMaxSize();
|
||||
LinkedHashSet<HanlpMapResult> hanlpMapResults = SearchService.prefixSearch(detectSegment, oneDetectionMaxSize,
|
||||
agentId,
|
||||
detectModelIds).stream().collect(Collectors.toCollection(LinkedHashSet::new));
|
||||
// step2. suffix search
|
||||
LinkedHashSet<HanlpMapResult> suffixHanlpMapResults = SearchService.suffixSearch(detectSegment,
|
||||
oneDetectionMaxSize, agentId, detectModelIds).stream()
|
||||
.collect(Collectors.toCollection(LinkedHashSet::new));
|
||||
|
||||
hanlpMapResults.addAll(suffixHanlpMapResults);
|
||||
|
||||
if (CollectionUtils.isEmpty(hanlpMapResults)) {
|
||||
return;
|
||||
}
|
||||
// step3. merge pre/suffix result
|
||||
hanlpMapResults = hanlpMapResults.stream().sorted((a, b) -> -(b.getName().length() - a.getName().length()))
|
||||
.collect(Collectors.toCollection(LinkedHashSet::new));
|
||||
|
||||
// step4. filter by similarity
|
||||
hanlpMapResults = hanlpMapResults.stream()
|
||||
.filter(term -> mapperHelper.getSimilarity(detectSegment, term.getName())
|
||||
>= mapperHelper.getThresholdMatch(term.getNatures()))
|
||||
.filter(term -> CollectionUtils.isNotEmpty(term.getNatures()))
|
||||
.collect(Collectors.toCollection(LinkedHashSet::new));
|
||||
|
||||
log.info("after isSimilarity parseResults:{}", hanlpMapResults);
|
||||
|
||||
hanlpMapResults = hanlpMapResults.stream().map(parseResult -> {
|
||||
parseResult.setOffset(offset);
|
||||
parseResult.setSimilarity(mapperHelper.getSimilarity(detectSegment, parseResult.getName()));
|
||||
return parseResult;
|
||||
}).collect(Collectors.toCollection(LinkedHashSet::new));
|
||||
|
||||
// step5. take only one dimension or 10 metric/dimension value per rond.
|
||||
List<HanlpMapResult> dimensionMetrics = hanlpMapResults.stream()
|
||||
.filter(entry -> mapperHelper.existDimensionValues(entry.getNatures()))
|
||||
.collect(Collectors.toList())
|
||||
.stream()
|
||||
.limit(1)
|
||||
.collect(Collectors.toList());
|
||||
|
||||
Integer oneDetectionSize = optimizationConfig.getOneDetectionSize();
|
||||
List<HanlpMapResult> oneRoundResults = hanlpMapResults.stream().limit(oneDetectionSize)
|
||||
.collect(Collectors.toList());
|
||||
|
||||
if (CollectionUtils.isNotEmpty(dimensionMetrics)) {
|
||||
oneRoundResults = dimensionMetrics;
|
||||
}
|
||||
// step6. select mapResul in one round
|
||||
selectResultInOneRound(existResults, oneRoundResults);
|
||||
}
|
||||
|
||||
public String getMapKey(HanlpMapResult a) {
|
||||
return a.getName() + Constants.UNDERLINE + String.join(Constants.UNDERLINE, a.getNatures());
|
||||
}
|
||||
}
|
||||
@@ -1,11 +1,13 @@
|
||||
package com.tencent.supersonic.chat.mapper;
|
||||
|
||||
import com.hankcs.hanlp.algorithm.EditDistance;
|
||||
import com.hankcs.hanlp.seg.common.Term;
|
||||
import com.tencent.supersonic.chat.api.pojo.request.QueryReq;
|
||||
import com.tencent.supersonic.chat.config.OptimizationConfig;
|
||||
import com.tencent.supersonic.chat.service.AgentService;
|
||||
import com.tencent.supersonic.common.util.ContextUtils;
|
||||
import com.tencent.supersonic.knowledge.utils.NatureHelper;
|
||||
import java.util.Comparator;
|
||||
import java.util.HashSet;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
@@ -39,10 +41,14 @@ public class MapperHelper {
|
||||
return index;
|
||||
}
|
||||
|
||||
public Integer getStepOffset(List<Integer> termList, Integer index) {
|
||||
|
||||
public Integer getStepOffset(List<Term> termList, Integer index) {
|
||||
List<Integer> offsetList = termList.stream().sorted(Comparator.comparing(Term::getOffset))
|
||||
.map(term -> term.getOffset()).collect(Collectors.toList());
|
||||
|
||||
for (int j = 0; j < termList.size() - 1; j++) {
|
||||
if (termList.get(j) <= index && termList.get(j + 1) > index) {
|
||||
return termList.get(j);
|
||||
if (offsetList.get(j) <= index && offsetList.get(j + 1) > index) {
|
||||
return offsetList.get(j);
|
||||
}
|
||||
}
|
||||
return index;
|
||||
@@ -88,7 +94,7 @@ public class MapperHelper {
|
||||
|
||||
AgentService agentService = ContextUtils.getBean(AgentService.class);
|
||||
|
||||
Set<Long> detectModelIds = agentService.getDslToolsModelIds(request.getAgentId(), null);
|
||||
Set<Long> detectModelIds = agentService.getModelIds(request.getAgentId(), null);
|
||||
//contains all
|
||||
if (agentService.containsAllModel(detectModelIds)) {
|
||||
if (Objects.nonNull(modelId) && modelId > 0) {
|
||||
|
||||
@@ -1,8 +1,7 @@
|
||||
package com.tencent.supersonic.chat.mapper;
|
||||
|
||||
import com.hankcs.hanlp.seg.common.Term;
|
||||
import com.tencent.supersonic.chat.api.pojo.request.QueryReq;
|
||||
import com.tencent.supersonic.knowledge.dictionary.MapResult;
|
||||
import com.tencent.supersonic.chat.api.pojo.QueryContext;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
import java.util.Set;
|
||||
@@ -10,8 +9,8 @@ import java.util.Set;
|
||||
/**
|
||||
* match strategy
|
||||
*/
|
||||
public interface MatchStrategy {
|
||||
public interface MatchStrategy<T> {
|
||||
|
||||
Map<MatchText, List<MapResult>> match(QueryReq queryReq, List<Term> terms, Set<Long> detectModelId);
|
||||
Map<MatchText, List<T>> match(QueryContext queryContext, List<Term> terms, Set<Long> detectModelId);
|
||||
|
||||
}
|
||||
@@ -0,0 +1,52 @@
|
||||
package com.tencent.supersonic.chat.mapper;
|
||||
|
||||
import com.tencent.supersonic.chat.api.component.SchemaMapper;
|
||||
import com.tencent.supersonic.chat.api.pojo.QueryContext;
|
||||
import com.tencent.supersonic.chat.api.pojo.SchemaElementMatch;
|
||||
import com.tencent.supersonic.chat.api.pojo.SchemaMapInfo;
|
||||
import com.tencent.supersonic.chat.api.pojo.SchemaModelClusterMapInfo;
|
||||
import com.tencent.supersonic.chat.api.pojo.SemanticSchema;
|
||||
import com.tencent.supersonic.chat.utils.ModelClusterBuilder;
|
||||
import com.tencent.supersonic.common.pojo.ModelCluster;
|
||||
import com.tencent.supersonic.common.util.ContextUtils;
|
||||
import com.tencent.supersonic.knowledge.service.SchemaService;
|
||||
|
||||
import java.util.ArrayList;
|
||||
import java.util.HashMap;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
import java.util.Set;
|
||||
import java.util.stream.Collectors;
|
||||
|
||||
public class ModelClusterMapper implements SchemaMapper {
|
||||
|
||||
@Override
|
||||
public void map(QueryContext queryContext) {
|
||||
SchemaService schemaService = ContextUtils.getBean(SchemaService.class);
|
||||
SemanticSchema semanticSchema = schemaService.getSemanticSchema();
|
||||
SchemaMapInfo schemaMapInfo = queryContext.getMapInfo();
|
||||
List<ModelCluster> modelClusters = buildModelClusterMatched(schemaMapInfo, semanticSchema);
|
||||
Map<String, List<SchemaElementMatch>> modelClusterElementMatches = new HashMap<>();
|
||||
for (ModelCluster modelCluster : modelClusters) {
|
||||
for (Long modelId : schemaMapInfo.getMatchedModels()) {
|
||||
if (modelCluster.getModelIds().contains(modelId)) {
|
||||
modelClusterElementMatches.computeIfAbsent(modelCluster.getKey(), k -> new ArrayList<>())
|
||||
.addAll(schemaMapInfo.getMatchedElements(modelId));
|
||||
}
|
||||
}
|
||||
}
|
||||
SchemaModelClusterMapInfo modelClusterMapInfo = new SchemaModelClusterMapInfo();
|
||||
modelClusterMapInfo.setModelElementMatches(modelClusterElementMatches);
|
||||
queryContext.setModelClusterMapInfo(modelClusterMapInfo);
|
||||
}
|
||||
|
||||
private List<ModelCluster> buildModelClusterMatched(SchemaMapInfo schemaMapInfo,
|
||||
SemanticSchema semanticSchema) {
|
||||
Set<Long> matchedModels = schemaMapInfo.getMatchedModels();
|
||||
List<ModelCluster> modelClusters = ModelClusterBuilder.buildModelClusters(semanticSchema);
|
||||
return modelClusters.stream().map(ModelCluster::getModelIds).peek(modelCluster -> {
|
||||
modelCluster.removeIf(model -> !matchedModels.contains(model));
|
||||
}).filter(modelCluster -> modelCluster.size() > 0).map(ModelCluster::build).collect(Collectors.toList());
|
||||
}
|
||||
|
||||
}
|
||||
@@ -1,163 +0,0 @@
|
||||
package com.tencent.supersonic.chat.mapper;
|
||||
|
||||
import com.hankcs.hanlp.seg.common.Term;
|
||||
import com.tencent.supersonic.chat.api.pojo.request.QueryReq;
|
||||
import com.tencent.supersonic.chat.config.OptimizationConfig;
|
||||
import com.tencent.supersonic.common.pojo.Constants;
|
||||
import com.tencent.supersonic.knowledge.dictionary.MapResult;
|
||||
import com.tencent.supersonic.knowledge.service.SearchService;
|
||||
import java.util.ArrayList;
|
||||
import java.util.Comparator;
|
||||
import java.util.HashMap;
|
||||
import java.util.LinkedHashSet;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
import java.util.Objects;
|
||||
import java.util.Set;
|
||||
import java.util.stream.Collectors;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.apache.commons.collections.CollectionUtils;
|
||||
import org.apache.commons.compress.utils.Lists;
|
||||
import org.apache.commons.lang3.StringUtils;
|
||||
import org.springframework.beans.factory.annotation.Autowired;
|
||||
import org.springframework.stereotype.Service;
|
||||
|
||||
/**
|
||||
* match strategy implement
|
||||
*/
|
||||
@Service
|
||||
@Slf4j
|
||||
public class QueryMatchStrategy implements MatchStrategy {
|
||||
|
||||
@Autowired
|
||||
private MapperHelper mapperHelper;
|
||||
|
||||
@Autowired
|
||||
private OptimizationConfig optimizationConfig;
|
||||
|
||||
@Override
|
||||
public Map<MatchText, List<MapResult>> match(QueryReq queryReq, List<Term> terms, Set<Long> detectModelIds) {
|
||||
String text = queryReq.getQueryText();
|
||||
if (Objects.isNull(terms) || StringUtils.isEmpty(text)) {
|
||||
return null;
|
||||
}
|
||||
|
||||
Map<Integer, Integer> regOffsetToLength = terms.stream().sorted(Comparator.comparing(Term::length))
|
||||
.collect(Collectors.toMap(Term::getOffset, term -> term.word.length(), (value1, value2) -> value2));
|
||||
|
||||
List<Integer> offsetList = terms.stream().sorted(Comparator.comparing(Term::getOffset))
|
||||
.map(term -> term.getOffset()).collect(Collectors.toList());
|
||||
|
||||
log.debug("retryCount:{},terms:{},regOffsetToLength:{},offsetList:{},detectModelIds:{}", terms,
|
||||
regOffsetToLength, offsetList, detectModelIds);
|
||||
|
||||
List<MapResult> detects = detect(queryReq, regOffsetToLength, offsetList, detectModelIds);
|
||||
Map<MatchText, List<MapResult>> result = new HashMap<>();
|
||||
|
||||
result.put(MatchText.builder().regText(text).detectSegment(text).build(), detects);
|
||||
return result;
|
||||
}
|
||||
|
||||
private List<MapResult> detect(QueryReq queryReq, Map<Integer, Integer> regOffsetToLength, List<Integer> offsetList,
|
||||
Set<Long> detectModelIds) {
|
||||
String text = queryReq.getQueryText();
|
||||
List<MapResult> results = Lists.newArrayList();
|
||||
|
||||
for (Integer index = 0; index <= text.length() - 1; ) {
|
||||
|
||||
Set<MapResult> mapResultRowSet = new LinkedHashSet();
|
||||
|
||||
for (Integer i = index; i <= text.length(); ) {
|
||||
int offset = mapperHelper.getStepOffset(offsetList, index);
|
||||
i = mapperHelper.getStepIndex(regOffsetToLength, i);
|
||||
if (i <= text.length()) {
|
||||
List<MapResult> mapResults = detectByStep(queryReq, detectModelIds, index, i, offset);
|
||||
selectMapResultInOneRound(mapResultRowSet, mapResults);
|
||||
}
|
||||
}
|
||||
index = mapperHelper.getStepIndex(regOffsetToLength, index);
|
||||
results.addAll(mapResultRowSet);
|
||||
}
|
||||
return results;
|
||||
}
|
||||
|
||||
private void selectMapResultInOneRound(Set<MapResult> mapResultRowSet, List<MapResult> mapResults) {
|
||||
for (MapResult mapResult : mapResults) {
|
||||
if (mapResultRowSet.contains(mapResult)) {
|
||||
boolean isDeleted = mapResultRowSet.removeIf(
|
||||
entry -> {
|
||||
boolean deleted = getMapKey(mapResult).equals(getMapKey(entry))
|
||||
&& entry.getDetectWord().length() < mapResult.getDetectWord().length();
|
||||
if (deleted) {
|
||||
log.info("deleted entry:{}", entry);
|
||||
}
|
||||
return deleted;
|
||||
}
|
||||
);
|
||||
if (isDeleted) {
|
||||
log.info("deleted, add mapResult:{}", mapResult);
|
||||
mapResultRowSet.add(mapResult);
|
||||
}
|
||||
} else {
|
||||
mapResultRowSet.add(mapResult);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
private String getMapKey(MapResult a) {
|
||||
return a.getName() + Constants.UNDERLINE + String.join(Constants.UNDERLINE, a.getNatures());
|
||||
}
|
||||
|
||||
private List<MapResult> detectByStep(QueryReq queryReq, Set<Long> detectModelIds, Integer index, Integer i,
|
||||
int offset) {
|
||||
String text = queryReq.getQueryText();
|
||||
Integer agentId = queryReq.getAgentId();
|
||||
String detectSegment = text.substring(index, i);
|
||||
|
||||
// step1. pre search
|
||||
Integer oneDetectionMaxSize = optimizationConfig.getOneDetectionMaxSize();
|
||||
LinkedHashSet<MapResult> mapResults = SearchService.prefixSearch(detectSegment, oneDetectionMaxSize, agentId,
|
||||
detectModelIds).stream().collect(Collectors.toCollection(LinkedHashSet::new));
|
||||
// step2. suffix search
|
||||
LinkedHashSet<MapResult> suffixMapResults = SearchService.suffixSearch(detectSegment, oneDetectionMaxSize,
|
||||
agentId, detectModelIds).stream().collect(Collectors.toCollection(LinkedHashSet::new));
|
||||
|
||||
mapResults.addAll(suffixMapResults);
|
||||
|
||||
if (CollectionUtils.isEmpty(mapResults)) {
|
||||
return new ArrayList<>();
|
||||
}
|
||||
// step3. merge pre/suffix result
|
||||
mapResults = mapResults.stream().sorted((a, b) -> -(b.getName().length() - a.getName().length()))
|
||||
.collect(Collectors.toCollection(LinkedHashSet::new));
|
||||
|
||||
// step4. filter by similarity
|
||||
mapResults = mapResults.stream()
|
||||
.filter(term -> mapperHelper.getSimilarity(detectSegment, term.getName())
|
||||
>= mapperHelper.getThresholdMatch(term.getNatures()))
|
||||
.filter(term -> CollectionUtils.isNotEmpty(term.getNatures()))
|
||||
.collect(Collectors.toCollection(LinkedHashSet::new));
|
||||
|
||||
log.info("after isSimilarity parseResults:{}", mapResults);
|
||||
|
||||
mapResults = mapResults.stream().map(parseResult -> {
|
||||
parseResult.setOffset(offset);
|
||||
parseResult.setSimilarity(mapperHelper.getSimilarity(detectSegment, parseResult.getName()));
|
||||
return parseResult;
|
||||
}).collect(Collectors.toCollection(LinkedHashSet::new));
|
||||
|
||||
// step5. take only one dimension or 10 metric/dimension value per rond.
|
||||
List<MapResult> dimensionMetrics = mapResults.stream()
|
||||
.filter(entry -> mapperHelper.existDimensionValues(entry.getNatures()))
|
||||
.collect(Collectors.toList())
|
||||
.stream()
|
||||
.limit(1)
|
||||
.collect(Collectors.toList());
|
||||
|
||||
if (CollectionUtils.isNotEmpty(dimensionMetrics)) {
|
||||
return dimensionMetrics;
|
||||
} else {
|
||||
return mapResults.stream().limit(optimizationConfig.getOneDetectionSize()).collect(Collectors.toList());
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -2,9 +2,10 @@ package com.tencent.supersonic.chat.mapper;
|
||||
|
||||
import com.google.common.collect.Lists;
|
||||
import com.hankcs.hanlp.seg.common.Term;
|
||||
import com.tencent.supersonic.chat.api.pojo.QueryContext;
|
||||
import com.tencent.supersonic.chat.api.pojo.request.QueryReq;
|
||||
import com.tencent.supersonic.common.pojo.enums.DictWordType;
|
||||
import com.tencent.supersonic.knowledge.dictionary.MapResult;
|
||||
import com.tencent.supersonic.knowledge.dictionary.HanlpMapResult;
|
||||
import com.tencent.supersonic.knowledge.service.SearchService;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
@@ -20,17 +21,16 @@ import org.springframework.stereotype.Service;
|
||||
* match strategy implement
|
||||
*/
|
||||
@Service
|
||||
public class SearchMatchStrategy implements MatchStrategy {
|
||||
public class SearchMatchStrategy extends BaseMatchStrategy<HanlpMapResult> {
|
||||
|
||||
private static final int SEARCH_SIZE = 3;
|
||||
|
||||
@Override
|
||||
public Map<MatchText, List<MapResult>> match(QueryReq queryReq, List<Term> originals, Set<Long> detectModelIds) {
|
||||
public Map<MatchText, List<HanlpMapResult>> match(QueryContext queryContext, List<Term> originals,
|
||||
Set<Long> detectModelIds) {
|
||||
QueryReq queryReq = queryContext.getRequest();
|
||||
String text = queryReq.getQueryText();
|
||||
Map<Integer, Integer> regOffsetToLength = originals.stream()
|
||||
.filter(entry -> !entry.nature.toString().startsWith(DictWordType.NATURE_SPILT))
|
||||
.collect(Collectors.toMap(Term::getOffset, value -> value.word.length(),
|
||||
(value1, value2) -> value2));
|
||||
Map<Integer, Integer> regOffsetToLength = getRegOffsetToLength(originals);
|
||||
|
||||
List<Integer> detectIndexList = Lists.newArrayList();
|
||||
|
||||
@@ -46,19 +46,19 @@ public class SearchMatchStrategy implements MatchStrategy {
|
||||
index++;
|
||||
}
|
||||
}
|
||||
Map<MatchText, List<MapResult>> regTextMap = new ConcurrentHashMap<>();
|
||||
Map<MatchText, List<HanlpMapResult>> regTextMap = new ConcurrentHashMap<>();
|
||||
detectIndexList.stream().parallel().forEach(detectIndex -> {
|
||||
String regText = text.substring(0, detectIndex);
|
||||
String detectSegment = text.substring(detectIndex);
|
||||
|
||||
if (StringUtils.isNotEmpty(detectSegment)) {
|
||||
List<MapResult> mapResults = SearchService.prefixSearch(detectSegment,
|
||||
List<HanlpMapResult> hanlpMapResults = SearchService.prefixSearch(detectSegment,
|
||||
SearchService.SEARCH_SIZE, queryReq.getAgentId(), detectModelIds);
|
||||
List<MapResult> suffixMapResults = SearchService.suffixSearch(detectSegment, SEARCH_SIZE,
|
||||
queryReq.getAgentId(), detectModelIds);
|
||||
mapResults.addAll(suffixMapResults);
|
||||
List<HanlpMapResult> suffixHanlpMapResults = SearchService.suffixSearch(
|
||||
detectSegment, SEARCH_SIZE, queryReq.getAgentId(), detectModelIds);
|
||||
hanlpMapResults.addAll(suffixHanlpMapResults);
|
||||
// remove entity name where search
|
||||
mapResults = mapResults.stream().filter(entry -> {
|
||||
hanlpMapResults = hanlpMapResults.stream().filter(entry -> {
|
||||
List<String> natures = entry.getNatures().stream()
|
||||
.filter(nature -> !nature.endsWith(DictWordType.ENTITY.getType()))
|
||||
.collect(Collectors.toList());
|
||||
@@ -71,10 +71,27 @@ public class SearchMatchStrategy implements MatchStrategy {
|
||||
.regText(regText)
|
||||
.detectSegment(detectSegment)
|
||||
.build();
|
||||
regTextMap.put(matchText, mapResults);
|
||||
regTextMap.put(matchText, hanlpMapResults);
|
||||
}
|
||||
}
|
||||
);
|
||||
return regTextMap;
|
||||
}
|
||||
|
||||
@Override
|
||||
public boolean needDelete(HanlpMapResult oneRoundResult, HanlpMapResult existResult) {
|
||||
return false;
|
||||
}
|
||||
|
||||
@Override
|
||||
public String getMapKey(HanlpMapResult a) {
|
||||
return null;
|
||||
}
|
||||
|
||||
@Override
|
||||
public void detectByStep(QueryContext queryContext, Set<HanlpMapResult> results, Set<Long> detectModelIds,
|
||||
Integer startIndex,
|
||||
Integer i, int offset) {
|
||||
|
||||
}
|
||||
}
|
||||
|
||||
@@ -0,0 +1,71 @@
|
||||
package com.tencent.supersonic.chat.parser;
|
||||
|
||||
import com.alibaba.fastjson.JSON;
|
||||
import com.tencent.supersonic.chat.config.LLMParserConfig;
|
||||
import com.tencent.supersonic.chat.parser.plugin.function.FunctionCallConfig;
|
||||
import com.tencent.supersonic.chat.parser.plugin.function.FunctionReq;
|
||||
import com.tencent.supersonic.chat.parser.plugin.function.FunctionResp;
|
||||
import com.tencent.supersonic.chat.query.llm.s2sql.LLMReq;
|
||||
import com.tencent.supersonic.chat.query.llm.s2sql.LLMResp;
|
||||
import com.tencent.supersonic.common.util.ContextUtils;
|
||||
import com.tencent.supersonic.common.util.JsonUtil;
|
||||
import java.net.URI;
|
||||
import java.net.URL;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.springframework.http.HttpEntity;
|
||||
import org.springframework.http.HttpHeaders;
|
||||
import org.springframework.http.HttpMethod;
|
||||
import org.springframework.http.MediaType;
|
||||
import org.springframework.http.ResponseEntity;
|
||||
import org.springframework.web.client.RestTemplate;
|
||||
import org.springframework.web.util.UriComponentsBuilder;
|
||||
|
||||
@Slf4j
|
||||
public class HttpLLMInterpreter implements LLMInterpreter {
|
||||
|
||||
public LLMResp query2sql(LLMReq llmReq, String modelClusterKey) {
|
||||
|
||||
long startTime = System.currentTimeMillis();
|
||||
log.info("requestLLM request, modelId:{},llmReq:{}", modelClusterKey, llmReq);
|
||||
try {
|
||||
LLMParserConfig llmParserConfig = ContextUtils.getBean(LLMParserConfig.class);
|
||||
|
||||
URL url = new URL(new URL(llmParserConfig.getUrl()), llmParserConfig.getQueryToSqlPath());
|
||||
HttpHeaders headers = new HttpHeaders();
|
||||
headers.setContentType(MediaType.APPLICATION_JSON);
|
||||
HttpEntity<String> entity = new HttpEntity<>(JsonUtil.toString(llmReq), headers);
|
||||
RestTemplate restTemplate = ContextUtils.getBean(RestTemplate.class);
|
||||
ResponseEntity<LLMResp> responseEntity = restTemplate.exchange(url.toString(), HttpMethod.POST, entity,
|
||||
LLMResp.class);
|
||||
|
||||
log.info("requestLLM response,cost:{}, questUrl:{} \n entity:{} \n body:{}",
|
||||
System.currentTimeMillis() - startTime, url, entity, responseEntity.getBody());
|
||||
return responseEntity.getBody();
|
||||
} catch (Exception e) {
|
||||
log.error("requestLLM error", e);
|
||||
}
|
||||
return null;
|
||||
}
|
||||
|
||||
public FunctionResp requestFunction(FunctionReq functionReq) {
|
||||
FunctionCallConfig functionCallInfoConfig = ContextUtils.getBean(FunctionCallConfig.class);
|
||||
String url = functionCallInfoConfig.getUrl() + functionCallInfoConfig.getPluginSelectPath();
|
||||
HttpHeaders headers = new HttpHeaders();
|
||||
long startTime = System.currentTimeMillis();
|
||||
headers.setContentType(MediaType.APPLICATION_JSON);
|
||||
HttpEntity<String> entity = new HttpEntity<>(JSON.toJSONString(functionReq), headers);
|
||||
URI requestUrl = UriComponentsBuilder.fromHttpUrl(url).build().encode().toUri();
|
||||
RestTemplate restTemplate = ContextUtils.getBean(RestTemplate.class);
|
||||
try {
|
||||
log.info("requestFunction functionReq:{}", JsonUtil.toString(functionReq));
|
||||
ResponseEntity<FunctionResp> responseEntity = restTemplate.exchange(requestUrl, HttpMethod.POST, entity,
|
||||
FunctionResp.class);
|
||||
log.info("requestFunction responseEntity:{},cost:{}", responseEntity,
|
||||
System.currentTimeMillis() - startTime);
|
||||
return responseEntity.getBody();
|
||||
} catch (Exception e) {
|
||||
log.error("requestFunction error", e);
|
||||
}
|
||||
return null;
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,18 @@
|
||||
package com.tencent.supersonic.chat.parser;
|
||||
|
||||
import com.tencent.supersonic.chat.parser.plugin.function.FunctionReq;
|
||||
import com.tencent.supersonic.chat.parser.plugin.function.FunctionResp;
|
||||
import com.tencent.supersonic.chat.query.llm.s2sql.LLMReq;
|
||||
import com.tencent.supersonic.chat.query.llm.s2sql.LLMResp;
|
||||
|
||||
/**
|
||||
* Unified interpreter for invoking the llm layer.
|
||||
*/
|
||||
public interface LLMInterpreter {
|
||||
|
||||
|
||||
LLMResp query2sql(LLMReq llmReq, String modelClusterKey);
|
||||
|
||||
FunctionResp requestFunction(FunctionReq functionReq);
|
||||
|
||||
}
|
||||
@@ -0,0 +1,85 @@
|
||||
package com.tencent.supersonic.chat.parser;
|
||||
|
||||
import com.tencent.supersonic.auth.api.authentication.pojo.User;
|
||||
import com.tencent.supersonic.chat.api.component.SemanticParser;
|
||||
import com.tencent.supersonic.chat.api.component.SemanticQuery;
|
||||
import com.tencent.supersonic.chat.api.pojo.ChatContext;
|
||||
import com.tencent.supersonic.chat.api.pojo.QueryContext;
|
||||
import com.tencent.supersonic.chat.api.pojo.SchemaElement;
|
||||
import com.tencent.supersonic.chat.api.pojo.SemanticParseInfo;
|
||||
import com.tencent.supersonic.chat.api.pojo.SemanticSchema;
|
||||
import com.tencent.supersonic.chat.api.pojo.response.SqlInfo;
|
||||
import com.tencent.supersonic.chat.query.llm.s2sql.S2SQLQuery;
|
||||
import com.tencent.supersonic.chat.query.rule.RuleSemanticQuery;
|
||||
import com.tencent.supersonic.chat.service.SemanticService;
|
||||
import com.tencent.supersonic.common.pojo.QueryType;
|
||||
import com.tencent.supersonic.common.util.ContextUtils;
|
||||
import com.tencent.supersonic.common.util.jsqlparser.SqlParserSelectHelper;
|
||||
import com.tencent.supersonic.knowledge.service.SchemaService;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.apache.commons.collections4.CollectionUtils;
|
||||
import org.apache.commons.lang3.StringUtils;
|
||||
|
||||
import java.util.List;
|
||||
import java.util.Objects;
|
||||
import java.util.Set;
|
||||
import java.util.stream.Collectors;
|
||||
|
||||
/**
|
||||
* Query type parser, determine if the query is a metric query, an entity query,
|
||||
* or another type of query.
|
||||
*/
|
||||
@Slf4j
|
||||
public class QueryTypeParser implements SemanticParser {
|
||||
|
||||
@Override
|
||||
public void parse(QueryContext queryContext, ChatContext chatContext) {
|
||||
|
||||
List<SemanticQuery> candidateQueries = queryContext.getCandidateQueries();
|
||||
User user = queryContext.getRequest().getUser();
|
||||
|
||||
for (SemanticQuery semanticQuery : candidateQueries) {
|
||||
// 1.init S2SQL
|
||||
semanticQuery.initS2Sql(user);
|
||||
// 2.set queryType
|
||||
QueryType queryType = getQueryType(semanticQuery);
|
||||
semanticQuery.getParseInfo().setQueryType(queryType);
|
||||
}
|
||||
}
|
||||
|
||||
private QueryType getQueryType(SemanticQuery semanticQuery) {
|
||||
SemanticParseInfo parseInfo = semanticQuery.getParseInfo();
|
||||
SqlInfo sqlInfo = parseInfo.getSqlInfo();
|
||||
if (Objects.isNull(sqlInfo) || StringUtils.isBlank(sqlInfo.getS2SQL())) {
|
||||
return QueryType.OTHER;
|
||||
}
|
||||
//1. entity queryType
|
||||
Set<Long> modelIds = parseInfo.getModel().getModelIds();
|
||||
if (semanticQuery instanceof RuleSemanticQuery || semanticQuery instanceof S2SQLQuery) {
|
||||
//If all the fields in the SELECT statement are of tag type.
|
||||
List<String> selectFields = SqlParserSelectHelper.getSelectFields(sqlInfo.getS2SQL());
|
||||
SemanticService semanticService = ContextUtils.getBean(SemanticService.class);
|
||||
SemanticSchema semanticSchema = semanticService.getSemanticSchema();
|
||||
if (CollectionUtils.isNotEmpty(selectFields)) {
|
||||
Set<String> tags = semanticSchema.getTags(modelIds).stream().map(SchemaElement::getName)
|
||||
.collect(Collectors.toSet());
|
||||
if (CollectionUtils.isNotEmpty(tags) && tags.containsAll(selectFields)) {
|
||||
return QueryType.TAG;
|
||||
}
|
||||
}
|
||||
}
|
||||
//2. metric queryType
|
||||
List<String> selectFields = SqlParserSelectHelper.getSelectFields(sqlInfo.getS2SQL());
|
||||
SemanticSchema semanticSchema = ContextUtils.getBean(SchemaService.class).getSemanticSchema();
|
||||
List<SchemaElement> metrics = semanticSchema.getMetrics(modelIds);
|
||||
if (CollectionUtils.isNotEmpty(metrics)) {
|
||||
Set<String> metricNameSet = metrics.stream().map(SchemaElement::getName).collect(Collectors.toSet());
|
||||
boolean containMetric = selectFields.stream().anyMatch(metricNameSet::contains);
|
||||
if (containMetric) {
|
||||
return QueryType.METRIC;
|
||||
}
|
||||
}
|
||||
return QueryType.OTHER;
|
||||
}
|
||||
|
||||
}
|
||||
@@ -5,7 +5,7 @@ import com.tencent.supersonic.chat.api.component.SemanticQuery;
|
||||
import com.tencent.supersonic.chat.api.pojo.QueryContext;
|
||||
import com.tencent.supersonic.chat.api.pojo.SemanticParseInfo;
|
||||
import com.tencent.supersonic.chat.config.OptimizationConfig;
|
||||
import com.tencent.supersonic.chat.query.llm.dsl.DslQuery;
|
||||
import com.tencent.supersonic.chat.query.llm.s2sql.S2SQLQuery;
|
||||
import com.tencent.supersonic.common.util.ContextUtils;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
|
||||
@@ -20,7 +20,7 @@ public class SatisfactionChecker {
|
||||
// check all the parse info in candidate
|
||||
public static boolean check(QueryContext queryContext) {
|
||||
for (SemanticQuery query : queryContext.getCandidateQueries()) {
|
||||
if (query.getQueryMode().equals(DslQuery.QUERY_MODE)) {
|
||||
if (query.getQueryMode().equals(S2SQLQuery.QUERY_MODE)) {
|
||||
continue;
|
||||
}
|
||||
if (checkThreshold(queryContext.getRequest().getQueryText(), query.getParseInfo())) {
|
||||
|
||||
@@ -1,25 +0,0 @@
|
||||
package com.tencent.supersonic.chat.parser.llm.dsl;
|
||||
|
||||
import com.tencent.supersonic.chat.agent.tool.DslTool;
|
||||
import com.tencent.supersonic.chat.api.pojo.request.QueryReq;
|
||||
import com.tencent.supersonic.chat.query.llm.dsl.LLMReq;
|
||||
import com.tencent.supersonic.chat.query.llm.dsl.LLMResp;
|
||||
import lombok.AllArgsConstructor;
|
||||
import lombok.Builder;
|
||||
import lombok.Data;
|
||||
import lombok.NoArgsConstructor;
|
||||
|
||||
@Data
|
||||
@Builder
|
||||
@AllArgsConstructor
|
||||
@NoArgsConstructor
|
||||
public class DSLParseResult {
|
||||
|
||||
private LLMReq llmReq;
|
||||
|
||||
private LLMResp llmResp;
|
||||
|
||||
private QueryReq request;
|
||||
|
||||
private DslTool dslTool;
|
||||
}
|
||||
@@ -1,231 +0,0 @@
|
||||
package com.tencent.supersonic.chat.parser.llm.dsl;
|
||||
|
||||
import com.tencent.supersonic.chat.api.component.SemanticQuery;
|
||||
import com.tencent.supersonic.chat.api.pojo.ChatContext;
|
||||
import com.tencent.supersonic.chat.api.pojo.QueryContext;
|
||||
import com.tencent.supersonic.chat.api.pojo.SchemaElementMatch;
|
||||
import com.tencent.supersonic.chat.api.pojo.SchemaElementType;
|
||||
import com.tencent.supersonic.chat.api.pojo.SchemaMapInfo;
|
||||
import com.tencent.supersonic.chat.api.pojo.SemanticParseInfo;
|
||||
import com.tencent.supersonic.chat.api.pojo.request.QueryReq;
|
||||
import java.util.Comparator;
|
||||
import java.util.HashMap;
|
||||
import java.util.HashSet;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
import java.util.Map.Entry;
|
||||
import java.util.Objects;
|
||||
import java.util.Set;
|
||||
import java.util.stream.Collectors;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.apache.commons.collections.CollectionUtils;
|
||||
|
||||
@Slf4j
|
||||
public class HeuristicModelResolver implements ModelResolver {
|
||||
|
||||
protected static Long selectModelBySchemaElementCount(Map<Long, SemanticQuery> modelQueryModes,
|
||||
SchemaMapInfo schemaMap) {
|
||||
//model count priority
|
||||
Long modelIdByModelCount = getModelIdByModelCount(schemaMap);
|
||||
if (Objects.nonNull(modelIdByModelCount)) {
|
||||
log.info("selectModel by model count:{}", modelIdByModelCount);
|
||||
return modelIdByModelCount;
|
||||
}
|
||||
|
||||
Map<Long, ModelMatchResult> modelTypeMap = getModelTypeMap(schemaMap);
|
||||
if (modelTypeMap.size() == 1) {
|
||||
Long modelSelect = modelTypeMap.entrySet().stream().collect(Collectors.toList()).get(0).getKey();
|
||||
if (modelQueryModes.containsKey(modelSelect)) {
|
||||
log.info("selectModel with only one Model [{}]", modelSelect);
|
||||
return modelSelect;
|
||||
}
|
||||
} else {
|
||||
|
||||
Map.Entry<Long, ModelMatchResult> maxModel = modelTypeMap.entrySet().stream()
|
||||
.filter(entry -> modelQueryModes.containsKey(entry.getKey()))
|
||||
.sorted((o1, o2) -> {
|
||||
int difference = o2.getValue().getCount() - o1.getValue().getCount();
|
||||
if (difference == 0) {
|
||||
return (int) ((o2.getValue().getMaxSimilarity()
|
||||
- o1.getValue().getMaxSimilarity()) * 100);
|
||||
}
|
||||
return difference;
|
||||
}).findFirst().orElse(null);
|
||||
if (maxModel != null) {
|
||||
log.info("selectModel with multiple Models [{}]", maxModel.getKey());
|
||||
return maxModel.getKey();
|
||||
}
|
||||
}
|
||||
return 0L;
|
||||
}
|
||||
|
||||
private static Long getModelIdByModelCount(SchemaMapInfo schemaMap) {
|
||||
Map<Long, List<SchemaElementMatch>> modelElementMatches = schemaMap.getModelElementMatches();
|
||||
Map<Long, Integer> modelIdToModelCount = new HashMap<>();
|
||||
if (Objects.nonNull(modelElementMatches)) {
|
||||
for (Entry<Long, List<SchemaElementMatch>> modelElementMatch : modelElementMatches.entrySet()) {
|
||||
Long modelId = modelElementMatch.getKey();
|
||||
List<SchemaElementMatch> modelMatches = modelElementMatch.getValue().stream().filter(
|
||||
elementMatch -> SchemaElementType.MODEL.equals(elementMatch.getElement().getType())
|
||||
).collect(Collectors.toList());
|
||||
|
||||
if (!CollectionUtils.isEmpty(modelMatches)) {
|
||||
Integer count = modelMatches.size();
|
||||
modelIdToModelCount.put(modelId, count);
|
||||
}
|
||||
}
|
||||
Entry<Long, Integer> maxModelCount = modelIdToModelCount.entrySet().stream()
|
||||
.max(Comparator.comparingInt(o -> o.getValue())).orElse(null);
|
||||
log.info("maxModelCount:{},modelIdToModelCount:{}", maxModelCount, modelIdToModelCount);
|
||||
if (Objects.nonNull(maxModelCount)) {
|
||||
return maxModelCount.getKey();
|
||||
}
|
||||
}
|
||||
return null;
|
||||
}
|
||||
|
||||
/**
|
||||
* to check can switch Model if context exit Model
|
||||
*
|
||||
* @return false will use context Model, true will use other Model , maybe include context Model
|
||||
*/
|
||||
protected static boolean isAllowSwitch(Map<Long, SemanticQuery> modelQueryModes, SchemaMapInfo schemaMap,
|
||||
ChatContext chatCtx, QueryReq searchCtx,
|
||||
Long modelId, Set<Long> restrictiveModels) {
|
||||
if (!Objects.nonNull(modelId) || modelId <= 0) {
|
||||
return true;
|
||||
}
|
||||
// except content Model, calculate the number of types for each Model, if numbers<=1 will not switch
|
||||
Map<Long, ModelMatchResult> modelTypeMap = getModelTypeMap(schemaMap);
|
||||
log.info("isAllowSwitch ModelTypeMap [{}]", modelTypeMap);
|
||||
long otherModelTypeNumBigOneCount = modelTypeMap.entrySet().stream()
|
||||
.filter(entry -> modelQueryModes.containsKey(entry.getKey()) && !entry.getKey().equals(modelId))
|
||||
.filter(entry -> entry.getValue().getCount() > 1).count();
|
||||
if (otherModelTypeNumBigOneCount >= 1) {
|
||||
return true;
|
||||
}
|
||||
// if query text only contain time , will not switch
|
||||
if (!CollectionUtils.isEmpty(modelQueryModes.values())) {
|
||||
for (SemanticQuery semanticQuery : modelQueryModes.values()) {
|
||||
if (semanticQuery == null) {
|
||||
continue;
|
||||
}
|
||||
SemanticParseInfo semanticParseInfo = semanticQuery.getParseInfo();
|
||||
if (semanticParseInfo == null) {
|
||||
continue;
|
||||
}
|
||||
if (searchCtx.getQueryText() != null && semanticParseInfo.getDateInfo() != null) {
|
||||
if (semanticParseInfo.getDateInfo().getDetectWord() != null) {
|
||||
if (semanticParseInfo.getDateInfo().getDetectWord()
|
||||
.equalsIgnoreCase(searchCtx.getQueryText())) {
|
||||
log.info("timeParseResults is not null , can not switch context , timeParseResults:{},",
|
||||
semanticParseInfo.getDateInfo());
|
||||
return false;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if (CollectionUtils.isNotEmpty(restrictiveModels) && !restrictiveModels.contains(modelId)) {
|
||||
return true;
|
||||
}
|
||||
// if context Model not in schemaMap , will switch
|
||||
if (schemaMap.getMatchedElements(modelId) == null || schemaMap.getMatchedElements(modelId).size() <= 0) {
|
||||
log.info("modelId not in schemaMap ");
|
||||
return true;
|
||||
}
|
||||
// other will not switch
|
||||
return false;
|
||||
}
|
||||
|
||||
public static Map<Long, ModelMatchResult> getModelTypeMap(SchemaMapInfo schemaMap) {
|
||||
Map<Long, ModelMatchResult> modelCount = new HashMap<>();
|
||||
for (Map.Entry<Long, List<SchemaElementMatch>> entry : schemaMap.getModelElementMatches().entrySet()) {
|
||||
List<SchemaElementMatch> schemaElementMatches = schemaMap.getMatchedElements(entry.getKey());
|
||||
if (schemaElementMatches != null && schemaElementMatches.size() > 0) {
|
||||
if (!modelCount.containsKey(entry.getKey())) {
|
||||
modelCount.put(entry.getKey(), new ModelMatchResult());
|
||||
}
|
||||
ModelMatchResult modelMatchResult = modelCount.get(entry.getKey());
|
||||
Set<SchemaElementType> schemaElementTypes = new HashSet<>();
|
||||
schemaElementMatches.stream()
|
||||
.forEach(schemaElementMatch -> schemaElementTypes.add(
|
||||
schemaElementMatch.getElement().getType()));
|
||||
SchemaElementMatch schemaElementMatchMax = schemaElementMatches.stream()
|
||||
.sorted((o1, o2) ->
|
||||
((int) ((o2.getSimilarity() - o1.getSimilarity()) * 100))
|
||||
).findFirst().orElse(null);
|
||||
if (schemaElementMatchMax != null) {
|
||||
modelMatchResult.setMaxSimilarity(schemaElementMatchMax.getSimilarity());
|
||||
}
|
||||
modelMatchResult.setCount(schemaElementTypes.size());
|
||||
|
||||
}
|
||||
}
|
||||
return modelCount;
|
||||
}
|
||||
|
||||
|
||||
public Long resolve(QueryContext queryContext, ChatContext chatCtx, Set<Long> restrictiveModels) {
|
||||
Long modelId = queryContext.getRequest().getModelId();
|
||||
if (Objects.nonNull(modelId) && modelId > 0) {
|
||||
if (CollectionUtils.isEmpty(restrictiveModels)) {
|
||||
return modelId;
|
||||
}
|
||||
if (restrictiveModels.contains(modelId)) {
|
||||
return modelId;
|
||||
} else {
|
||||
return null;
|
||||
}
|
||||
}
|
||||
SchemaMapInfo mapInfo = queryContext.getMapInfo();
|
||||
Set<Long> matchedModels = mapInfo.getMatchedModels();
|
||||
if (CollectionUtils.isNotEmpty(restrictiveModels)) {
|
||||
matchedModels = matchedModels.stream()
|
||||
.filter(restrictiveModels::contains)
|
||||
.collect(Collectors.toSet());
|
||||
}
|
||||
Map<Long, SemanticQuery> modelQueryModes = new HashMap<>();
|
||||
for (Long matchedModel : matchedModels) {
|
||||
modelQueryModes.put(matchedModel, null);
|
||||
}
|
||||
if (modelQueryModes.size() == 1) {
|
||||
return modelQueryModes.keySet().stream().findFirst().get();
|
||||
}
|
||||
return resolve(modelQueryModes, queryContext, chatCtx,
|
||||
queryContext.getMapInfo(), restrictiveModels);
|
||||
}
|
||||
|
||||
public Long resolve(Map<Long, SemanticQuery> modelQueryModes, QueryContext queryContext,
|
||||
ChatContext chatCtx, SchemaMapInfo schemaMap, Set<Long> restrictiveModels) {
|
||||
Long selectModel = selectModel(modelQueryModes, queryContext.getRequest(),
|
||||
chatCtx, schemaMap, restrictiveModels);
|
||||
if (selectModel > 0) {
|
||||
log.info("selectModel {} ", selectModel);
|
||||
return selectModel;
|
||||
}
|
||||
// get the max SchemaElementType number
|
||||
return selectModelBySchemaElementCount(modelQueryModes, schemaMap);
|
||||
}
|
||||
|
||||
public Long selectModel(Map<Long, SemanticQuery> modelQueryModes, QueryReq queryContext,
|
||||
ChatContext chatCtx,
|
||||
SchemaMapInfo schemaMap, Set<Long> restrictiveModels) {
|
||||
// if QueryContext has modelId and in ModelQueryModes
|
||||
if (modelQueryModes.containsKey(queryContext.getModelId())) {
|
||||
log.info("selectModel from QueryContext [{}]", queryContext.getModelId());
|
||||
return queryContext.getModelId();
|
||||
}
|
||||
// if ChatContext has modelId and in ModelQueryModes
|
||||
if (chatCtx.getParseInfo().getModelId() > 0) {
|
||||
Long modelId = chatCtx.getParseInfo().getModelId();
|
||||
if (!isAllowSwitch(modelQueryModes, schemaMap, chatCtx, queryContext, modelId, restrictiveModels)) {
|
||||
log.info("selectModel from ChatContext [{}]", modelId);
|
||||
return modelId;
|
||||
}
|
||||
}
|
||||
// default 0
|
||||
return 0L;
|
||||
}
|
||||
}
|
||||
@@ -1,460 +0,0 @@
|
||||
package com.tencent.supersonic.chat.parser.llm.dsl;
|
||||
|
||||
import com.google.common.collect.Lists;
|
||||
import com.tencent.supersonic.chat.agent.tool.AgentToolType;
|
||||
import com.tencent.supersonic.chat.agent.tool.DslTool;
|
||||
import com.tencent.supersonic.chat.api.component.SemanticCorrector;
|
||||
import com.tencent.supersonic.chat.api.component.SemanticParser;
|
||||
import com.tencent.supersonic.chat.api.pojo.ChatContext;
|
||||
import com.tencent.supersonic.chat.api.pojo.QueryContext;
|
||||
import com.tencent.supersonic.chat.api.pojo.SchemaElement;
|
||||
import com.tencent.supersonic.chat.api.pojo.SchemaElementMatch;
|
||||
import com.tencent.supersonic.chat.api.pojo.SchemaElementType;
|
||||
import com.tencent.supersonic.chat.api.pojo.SemanticCorrectInfo;
|
||||
import com.tencent.supersonic.chat.api.pojo.SemanticParseInfo;
|
||||
import com.tencent.supersonic.chat.api.pojo.SemanticSchema;
|
||||
import com.tencent.supersonic.chat.api.pojo.request.QueryFilter;
|
||||
import com.tencent.supersonic.chat.api.pojo.request.QueryReq;
|
||||
import com.tencent.supersonic.chat.config.LLMParserConfig;
|
||||
import com.tencent.supersonic.chat.parser.SatisfactionChecker;
|
||||
import com.tencent.supersonic.chat.query.QueryManager;
|
||||
import com.tencent.supersonic.chat.query.llm.dsl.DslQuery;
|
||||
import com.tencent.supersonic.chat.query.llm.dsl.LLMReq;
|
||||
import com.tencent.supersonic.chat.query.llm.dsl.LLMReq.ElementValue;
|
||||
import com.tencent.supersonic.chat.query.llm.dsl.LLMResp;
|
||||
import com.tencent.supersonic.chat.query.plugin.PluginSemanticQuery;
|
||||
import com.tencent.supersonic.chat.service.AgentService;
|
||||
import com.tencent.supersonic.chat.utils.ComponentFactory;
|
||||
import com.tencent.supersonic.common.pojo.Constants;
|
||||
import com.tencent.supersonic.common.pojo.DateConf;
|
||||
import com.tencent.supersonic.common.pojo.DateConf.DateMode;
|
||||
import com.tencent.supersonic.common.util.ContextUtils;
|
||||
import com.tencent.supersonic.common.util.DateUtils;
|
||||
import com.tencent.supersonic.common.util.JsonUtil;
|
||||
import com.tencent.supersonic.common.util.jsqlparser.FilterExpression;
|
||||
import com.tencent.supersonic.common.util.jsqlparser.SqlParserSelectFunctionHelper;
|
||||
import com.tencent.supersonic.common.util.jsqlparser.SqlParserSelectHelper;
|
||||
import com.tencent.supersonic.knowledge.service.SchemaService;
|
||||
import com.tencent.supersonic.semantic.api.query.enums.FilterOperatorEnum;
|
||||
import java.util.ArrayList;
|
||||
import java.util.Arrays;
|
||||
import java.util.Comparator;
|
||||
import java.util.HashMap;
|
||||
import java.util.HashSet;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
import java.util.Objects;
|
||||
import java.util.Optional;
|
||||
import java.util.Set;
|
||||
import java.util.function.Function;
|
||||
import java.util.stream.Collectors;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.apache.commons.lang3.StringUtils;
|
||||
import org.springframework.http.HttpEntity;
|
||||
import org.springframework.http.HttpHeaders;
|
||||
import org.springframework.http.HttpMethod;
|
||||
import org.springframework.http.MediaType;
|
||||
import org.springframework.http.ResponseEntity;
|
||||
import org.springframework.util.CollectionUtils;
|
||||
import org.springframework.web.client.RestTemplate;
|
||||
|
||||
@Slf4j
|
||||
public class LLMDslParser implements SemanticParser {
|
||||
|
||||
@Override
|
||||
public void parse(QueryContext queryCtx, ChatContext chatCtx) {
|
||||
QueryReq request = queryCtx.getRequest();
|
||||
LLMParserConfig llmParserConfig = ContextUtils.getBean(LLMParserConfig.class);
|
||||
if (StringUtils.isEmpty(llmParserConfig.getUrl())) {
|
||||
log.info("llm parser url is empty, skip dsl parser, llmParserConfig:{}", llmParserConfig);
|
||||
return;
|
||||
}
|
||||
if (SatisfactionChecker.check(queryCtx)) {
|
||||
log.info("skip dsl parser, queryText:{}", request.getQueryText());
|
||||
return;
|
||||
}
|
||||
try {
|
||||
Long modelId = getModelId(queryCtx, chatCtx, request.getAgentId());
|
||||
if (Objects.isNull(modelId) || modelId <= 0) {
|
||||
return;
|
||||
}
|
||||
|
||||
DslTool dslTool = getDslTool(request, modelId);
|
||||
if (Objects.isNull(dslTool)) {
|
||||
log.info("no dsl tool in this agent, skip dsl parser");
|
||||
return;
|
||||
}
|
||||
|
||||
LLMReq llmReq = getLlmReq(queryCtx, modelId, llmParserConfig);
|
||||
LLMResp llmResp = requestLLM(llmReq, modelId, llmParserConfig);
|
||||
|
||||
if (Objects.isNull(llmResp)) {
|
||||
return;
|
||||
}
|
||||
DSLParseResult dslParseResult = DSLParseResult.builder().request(request)
|
||||
.dslTool(dslTool).llmReq(llmReq).llmResp(llmResp).build();
|
||||
|
||||
SemanticParseInfo parseInfo = getParseInfo(queryCtx, modelId, dslTool, dslParseResult);
|
||||
|
||||
SemanticCorrectInfo semanticCorrectInfo = getCorrectorSql(queryCtx, parseInfo, llmResp.getSqlOutput());
|
||||
|
||||
llmResp.setCorrectorSql(semanticCorrectInfo.getSql());
|
||||
|
||||
updateParseInfo(semanticCorrectInfo, modelId, parseInfo);
|
||||
|
||||
} catch (Exception e) {
|
||||
log.error("LLMDSLParser error", e);
|
||||
}
|
||||
}
|
||||
|
||||
private Set<SchemaElement> getElements(Long modelId, List<String> allFields, List<SchemaElement> elements) {
|
||||
return elements.stream()
|
||||
.filter(schemaElement -> modelId.equals(schemaElement.getModel())
|
||||
&& allFields.contains(schemaElement.getName())
|
||||
).collect(Collectors.toSet());
|
||||
}
|
||||
|
||||
private List<String> getFieldsExceptDate(List<String> allFields) {
|
||||
if (CollectionUtils.isEmpty(allFields)) {
|
||||
return new ArrayList<>();
|
||||
}
|
||||
return allFields.stream()
|
||||
.filter(entry -> !DateUtils.DATE_FIELD.equalsIgnoreCase(entry))
|
||||
.collect(Collectors.toList());
|
||||
}
|
||||
|
||||
public void updateParseInfo(SemanticCorrectInfo semanticCorrectInfo, Long modelId, SemanticParseInfo parseInfo) {
|
||||
|
||||
String correctorSql = semanticCorrectInfo.getSql();
|
||||
parseInfo.getSqlInfo().setLogicSql(correctorSql);
|
||||
|
||||
List<FilterExpression> expressions = SqlParserSelectHelper.getFilterExpression(correctorSql);
|
||||
//set dataInfo
|
||||
try {
|
||||
if (!CollectionUtils.isEmpty(expressions)) {
|
||||
DateConf dateInfo = getDateInfo(expressions);
|
||||
parseInfo.setDateInfo(dateInfo);
|
||||
}
|
||||
} catch (Exception e) {
|
||||
log.error("set dateInfo error :", e);
|
||||
}
|
||||
|
||||
//set filter
|
||||
try {
|
||||
Map<String, SchemaElement> fieldNameToElement = getNameToElement(modelId);
|
||||
List<QueryFilter> result = getDimensionFilter(fieldNameToElement, expressions);
|
||||
parseInfo.getDimensionFilters().addAll(result);
|
||||
} catch (Exception e) {
|
||||
log.error("set dimensionFilter error :", e);
|
||||
}
|
||||
|
||||
SemanticSchema semanticSchema = ContextUtils.getBean(SchemaService.class).getSemanticSchema();
|
||||
|
||||
if (Objects.isNull(semanticSchema)) {
|
||||
return;
|
||||
}
|
||||
List<String> allFields = getFieldsExceptDate(SqlParserSelectHelper.getAllFields(semanticCorrectInfo.getSql()));
|
||||
|
||||
Set<SchemaElement> metrics = getElements(modelId, allFields, semanticSchema.getMetrics());
|
||||
parseInfo.setMetrics(metrics);
|
||||
|
||||
if (SqlParserSelectFunctionHelper.hasAggregateFunction(semanticCorrectInfo.getSql())) {
|
||||
parseInfo.setNativeQuery(false);
|
||||
List<String> groupByFields = SqlParserSelectHelper.getGroupByFields(semanticCorrectInfo.getSql());
|
||||
List<String> groupByDimensions = getFieldsExceptDate(groupByFields);
|
||||
parseInfo.setDimensions(getElements(modelId, groupByDimensions, semanticSchema.getDimensions()));
|
||||
} else {
|
||||
parseInfo.setNativeQuery(true);
|
||||
List<String> selectFields = SqlParserSelectHelper.getSelectFields(semanticCorrectInfo.getSql());
|
||||
List<String> selectDimensions = getFieldsExceptDate(selectFields);
|
||||
parseInfo.setDimensions(getElements(modelId, selectDimensions, semanticSchema.getDimensions()));
|
||||
}
|
||||
}
|
||||
|
||||
private List<QueryFilter> getDimensionFilter(Map<String, SchemaElement> fieldNameToElement,
|
||||
List<FilterExpression> filterExpressions) {
|
||||
List<QueryFilter> result = Lists.newArrayList();
|
||||
for (FilterExpression expression : filterExpressions) {
|
||||
QueryFilter dimensionFilter = new QueryFilter();
|
||||
dimensionFilter.setValue(expression.getFieldValue());
|
||||
SchemaElement schemaElement = fieldNameToElement.get(expression.getFieldName());
|
||||
if (Objects.isNull(schemaElement)) {
|
||||
continue;
|
||||
}
|
||||
dimensionFilter.setName(schemaElement.getName());
|
||||
dimensionFilter.setBizName(schemaElement.getBizName());
|
||||
dimensionFilter.setElementID(schemaElement.getId());
|
||||
|
||||
FilterOperatorEnum operatorEnum = FilterOperatorEnum.getSqlOperator(expression.getOperator());
|
||||
dimensionFilter.setOperator(operatorEnum);
|
||||
dimensionFilter.setFunction(expression.getFunction());
|
||||
result.add(dimensionFilter);
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
private DateConf getDateInfo(List<FilterExpression> filterExpressions) {
|
||||
List<FilterExpression> dateExpressions = filterExpressions.stream()
|
||||
.filter(expression -> DateUtils.DATE_FIELD.equalsIgnoreCase(expression.getFieldName()))
|
||||
.collect(Collectors.toList());
|
||||
if (CollectionUtils.isEmpty(dateExpressions)) {
|
||||
return new DateConf();
|
||||
}
|
||||
DateConf dateInfo = new DateConf();
|
||||
dateInfo.setDateMode(DateMode.BETWEEN);
|
||||
FilterExpression firstExpression = dateExpressions.get(0);
|
||||
|
||||
FilterOperatorEnum firstOperator = FilterOperatorEnum.getSqlOperator(firstExpression.getOperator());
|
||||
if (FilterOperatorEnum.EQUALS.equals(firstOperator) && Objects.nonNull(firstExpression.getFieldValue())) {
|
||||
dateInfo.setStartDate(firstExpression.getFieldValue().toString());
|
||||
dateInfo.setEndDate(firstExpression.getFieldValue().toString());
|
||||
dateInfo.setDateMode(DateMode.BETWEEN);
|
||||
return dateInfo;
|
||||
}
|
||||
if (containOperators(firstExpression, firstOperator, FilterOperatorEnum.GREATER_THAN,
|
||||
FilterOperatorEnum.GREATER_THAN_EQUALS)) {
|
||||
dateInfo.setStartDate(firstExpression.getFieldValue().toString());
|
||||
if (hasSecondDate(dateExpressions)) {
|
||||
dateInfo.setEndDate(dateExpressions.get(1).getFieldValue().toString());
|
||||
}
|
||||
}
|
||||
if (containOperators(firstExpression, firstOperator, FilterOperatorEnum.MINOR_THAN,
|
||||
FilterOperatorEnum.MINOR_THAN_EQUALS)) {
|
||||
dateInfo.setEndDate(firstExpression.getFieldValue().toString());
|
||||
if (hasSecondDate(dateExpressions)) {
|
||||
dateInfo.setStartDate(dateExpressions.get(1).getFieldValue().toString());
|
||||
}
|
||||
}
|
||||
return dateInfo;
|
||||
}
|
||||
|
||||
private boolean containOperators(FilterExpression expression, FilterOperatorEnum firstOperator,
|
||||
FilterOperatorEnum... operatorEnums) {
|
||||
return (Arrays.asList(operatorEnums).contains(firstOperator) && Objects.nonNull(expression.getFieldValue()));
|
||||
}
|
||||
|
||||
private boolean hasSecondDate(List<FilterExpression> dateExpressions) {
|
||||
return dateExpressions.size() > 1 && Objects.nonNull(dateExpressions.get(1).getFieldValue());
|
||||
}
|
||||
|
||||
private SemanticCorrectInfo getCorrectorSql(QueryContext queryCtx, SemanticParseInfo parseInfo, String sql) {
|
||||
|
||||
SemanticCorrectInfo correctInfo = SemanticCorrectInfo.builder()
|
||||
.queryFilters(queryCtx.getRequest().getQueryFilters()).sql(sql)
|
||||
.parseInfo(parseInfo).build();
|
||||
|
||||
List<SemanticCorrector> dslCorrections = ComponentFactory.getSqlCorrections();
|
||||
|
||||
dslCorrections.forEach(dslCorrection -> {
|
||||
try {
|
||||
dslCorrection.correct(correctInfo);
|
||||
log.info("sqlCorrection:{} sql:{}", dslCorrection.getClass().getSimpleName(), correctInfo.getSql());
|
||||
} catch (Exception e) {
|
||||
log.error(String.format("correct error,correctInfo:%s", correctInfo), e);
|
||||
}
|
||||
});
|
||||
return correctInfo;
|
||||
}
|
||||
|
||||
private SemanticParseInfo getParseInfo(QueryContext queryCtx, Long modelId, DslTool dslTool,
|
||||
DSLParseResult dslParseResult) {
|
||||
PluginSemanticQuery semanticQuery = QueryManager.createPluginQuery(DslQuery.QUERY_MODE);
|
||||
SemanticParseInfo parseInfo = semanticQuery.getParseInfo();
|
||||
parseInfo.getElementMatches().addAll(queryCtx.getMapInfo().getMatchedElements(modelId));
|
||||
|
||||
Map<String, Object> properties = new HashMap<>();
|
||||
properties.put(Constants.CONTEXT, dslParseResult);
|
||||
properties.put("type", "internal");
|
||||
properties.put("name", dslTool.getName());
|
||||
|
||||
parseInfo.setProperties(properties);
|
||||
parseInfo.setScore(queryCtx.getRequest().getQueryText().length());
|
||||
parseInfo.setQueryMode(semanticQuery.getQueryMode());
|
||||
parseInfo.getSqlInfo().setLlmParseSql(dslParseResult.getLlmResp().getSqlOutput());
|
||||
|
||||
SemanticSchema semanticSchema = ContextUtils.getBean(SchemaService.class).getSemanticSchema();
|
||||
Map<Long, String> modelIdToName = semanticSchema.getModelIdToName();
|
||||
|
||||
SchemaElement model = new SchemaElement();
|
||||
model.setModel(modelId);
|
||||
model.setId(modelId);
|
||||
model.setName(modelIdToName.get(modelId));
|
||||
parseInfo.setModel(model);
|
||||
queryCtx.getCandidateQueries().add(semanticQuery);
|
||||
return parseInfo;
|
||||
}
|
||||
|
||||
private DslTool getDslTool(QueryReq request, Long modelId) {
|
||||
AgentService agentService = ContextUtils.getBean(AgentService.class);
|
||||
List<DslTool> dslTools = agentService.getDslTools(request.getAgentId(), AgentToolType.DSL);
|
||||
Optional<DslTool> dslToolOptional = dslTools.stream()
|
||||
.filter(tool -> {
|
||||
List<Long> modelIds = tool.getModelIds();
|
||||
if (agentService.containsAllModel(new HashSet<>(modelIds))) {
|
||||
return true;
|
||||
}
|
||||
return modelIds.contains(modelId);
|
||||
})
|
||||
.findFirst();
|
||||
return dslToolOptional.orElse(null);
|
||||
}
|
||||
|
||||
private Long getModelId(QueryContext queryCtx, ChatContext chatCtx, Integer agentId) {
|
||||
AgentService agentService = ContextUtils.getBean(AgentService.class);
|
||||
Set<Long> distinctModelIds = agentService.getDslToolsModelIds(agentId, AgentToolType.DSL);
|
||||
if (agentService.containsAllModel(distinctModelIds)) {
|
||||
distinctModelIds = new HashSet<>();
|
||||
}
|
||||
ModelResolver modelResolver = ComponentFactory.getModelResolver();
|
||||
Long modelId = modelResolver.resolve(queryCtx, chatCtx, distinctModelIds);
|
||||
log.info("resolve modelId:{},dslModels:{}", modelId, distinctModelIds);
|
||||
return modelId;
|
||||
}
|
||||
|
||||
private LLMResp requestLLM(LLMReq llmReq, Long modelId, LLMParserConfig llmParserConfig) {
|
||||
String questUrl = llmParserConfig.getUrl() + llmParserConfig.getQueryToSqlPath();
|
||||
long startTime = System.currentTimeMillis();
|
||||
log.info("requestLLM request, modelId:{},llmReq:{}", modelId, llmReq);
|
||||
RestTemplate restTemplate = ContextUtils.getBean(RestTemplate.class);
|
||||
try {
|
||||
HttpHeaders headers = new HttpHeaders();
|
||||
headers.setContentType(MediaType.APPLICATION_JSON);
|
||||
HttpEntity<String> entity = new HttpEntity<>(JsonUtil.toString(llmReq), headers);
|
||||
ResponseEntity<LLMResp> responseEntity = restTemplate.exchange(questUrl, HttpMethod.POST, entity,
|
||||
LLMResp.class);
|
||||
|
||||
log.info("requestLLM response,cost:{}, questUrl:{} \n entity:{} \n body:{}",
|
||||
System.currentTimeMillis() - startTime, questUrl, entity, responseEntity.getBody());
|
||||
return responseEntity.getBody();
|
||||
} catch (Exception e) {
|
||||
log.error("requestLLM error", e);
|
||||
}
|
||||
return null;
|
||||
}
|
||||
|
||||
private LLMReq getLlmReq(QueryContext queryCtx, Long modelId, LLMParserConfig llmParserConfig) {
|
||||
SemanticSchema semanticSchema = ContextUtils.getBean(SchemaService.class).getSemanticSchema();
|
||||
Map<Long, String> modelIdToName = semanticSchema.getModelIdToName();
|
||||
String queryText = queryCtx.getRequest().getQueryText();
|
||||
|
||||
LLMReq llmReq = new LLMReq();
|
||||
llmReq.setQueryText(queryText);
|
||||
|
||||
LLMReq.LLMSchema llmSchema = new LLMReq.LLMSchema();
|
||||
llmSchema.setModelName(modelIdToName.get(modelId));
|
||||
llmSchema.setDomainName(modelIdToName.get(modelId));
|
||||
|
||||
List<String> fieldNameList = getFieldNameList(queryCtx, modelId, semanticSchema, llmParserConfig);
|
||||
|
||||
fieldNameList.add(DateUtils.DATE_FIELD);
|
||||
llmSchema.setFieldNameList(fieldNameList);
|
||||
llmReq.setSchema(llmSchema);
|
||||
|
||||
List<ElementValue> linking = new ArrayList<>();
|
||||
linking.addAll(getValueList(queryCtx, modelId, semanticSchema));
|
||||
llmReq.setLinking(linking);
|
||||
|
||||
String currentDate = DSLDateHelper.getReferenceDate(modelId);
|
||||
llmReq.setCurrentDate(currentDate);
|
||||
return llmReq;
|
||||
}
|
||||
|
||||
protected List<ElementValue> getValueList(QueryContext queryCtx, Long modelId, SemanticSchema semanticSchema) {
|
||||
Map<Long, String> itemIdToName = getItemIdToName(modelId, semanticSchema);
|
||||
|
||||
List<SchemaElementMatch> matchedElements = queryCtx.getMapInfo().getMatchedElements(modelId);
|
||||
if (CollectionUtils.isEmpty(matchedElements)) {
|
||||
return new ArrayList<>();
|
||||
}
|
||||
Set<ElementValue> valueMatches = matchedElements
|
||||
.stream()
|
||||
.filter(elementMatch -> !elementMatch.isInherited())
|
||||
.filter(schemaElementMatch -> {
|
||||
SchemaElementType type = schemaElementMatch.getElement().getType();
|
||||
return SchemaElementType.VALUE.equals(type) || SchemaElementType.ID.equals(type);
|
||||
})
|
||||
.map(elementMatch -> {
|
||||
ElementValue elementValue = new ElementValue();
|
||||
elementValue.setFieldName(itemIdToName.get(elementMatch.getElement().getId()));
|
||||
elementValue.setFieldValue(elementMatch.getWord());
|
||||
return elementValue;
|
||||
}).collect(Collectors.toSet());
|
||||
return new ArrayList<>(valueMatches);
|
||||
}
|
||||
|
||||
|
||||
protected Map<String, SchemaElement> getNameToElement(Long modelId) {
|
||||
SemanticSchema semanticSchema = ContextUtils.getBean(SchemaService.class).getSemanticSchema();
|
||||
List<SchemaElement> dimensions = semanticSchema.getDimensions();
|
||||
List<SchemaElement> metrics = semanticSchema.getMetrics();
|
||||
|
||||
List<SchemaElement> allElements = Lists.newArrayList();
|
||||
allElements.addAll(dimensions);
|
||||
allElements.addAll(metrics);
|
||||
return allElements.stream()
|
||||
.filter(schemaElement -> schemaElement.getModel().equals(modelId))
|
||||
.collect(Collectors.toMap(SchemaElement::getName, Function.identity(), (value1, value2) -> value2));
|
||||
}
|
||||
|
||||
|
||||
protected List<String> getFieldNameList(QueryContext queryCtx, Long modelId, SemanticSchema semanticSchema,
|
||||
LLMParserConfig llmParserConfig) {
|
||||
|
||||
Set<String> results = getTopNFieldNames(modelId, semanticSchema, llmParserConfig);
|
||||
|
||||
Set<String> fieldNameList = getMatchedFieldNames(queryCtx, modelId, semanticSchema);
|
||||
|
||||
results.addAll(fieldNameList);
|
||||
return new ArrayList<>(results);
|
||||
}
|
||||
|
||||
protected Set<String> getMatchedFieldNames(QueryContext queryCtx, Long modelId, SemanticSchema semanticSchema) {
|
||||
Map<Long, String> itemIdToName = getItemIdToName(modelId, semanticSchema);
|
||||
List<SchemaElementMatch> matchedElements = queryCtx.getMapInfo().getMatchedElements(modelId);
|
||||
if (CollectionUtils.isEmpty(matchedElements)) {
|
||||
return new HashSet<>();
|
||||
}
|
||||
Set<String> fieldNameList = matchedElements.stream()
|
||||
.filter(schemaElementMatch -> {
|
||||
SchemaElementType elementType = schemaElementMatch.getElement().getType();
|
||||
return SchemaElementType.METRIC.equals(elementType)
|
||||
|| SchemaElementType.DIMENSION.equals(elementType)
|
||||
|| SchemaElementType.VALUE.equals(elementType);
|
||||
})
|
||||
.map(schemaElementMatch -> {
|
||||
SchemaElementType elementType = schemaElementMatch.getElement().getType();
|
||||
|
||||
if (!SchemaElementType.VALUE.equals(elementType)) {
|
||||
return schemaElementMatch.getWord();
|
||||
}
|
||||
return itemIdToName.get(schemaElementMatch.getElement().getId());
|
||||
})
|
||||
.filter(name -> StringUtils.isNotEmpty(name) && !name.contains("%"))
|
||||
.collect(Collectors.toSet());
|
||||
return fieldNameList;
|
||||
}
|
||||
|
||||
private Set<String> getTopNFieldNames(Long modelId, SemanticSchema semanticSchema,
|
||||
LLMParserConfig llmParserConfig) {
|
||||
Set<String> results = semanticSchema.getDimensions(modelId).stream()
|
||||
.sorted(Comparator.comparing(SchemaElement::getUseCnt).reversed())
|
||||
.limit(llmParserConfig.getDimensionTopN())
|
||||
.map(entry -> entry.getName())
|
||||
.collect(Collectors.toSet());
|
||||
|
||||
Set<String> metrics = semanticSchema.getMetrics(modelId).stream()
|
||||
.sorted(Comparator.comparing(SchemaElement::getUseCnt).reversed())
|
||||
.limit(llmParserConfig.getMetricTopN())
|
||||
.map(entry -> entry.getName())
|
||||
.collect(Collectors.toSet());
|
||||
|
||||
results.addAll(metrics);
|
||||
return results;
|
||||
}
|
||||
|
||||
protected Map<Long, String> getItemIdToName(Long modelId, SemanticSchema semanticSchema) {
|
||||
return semanticSchema.getDimensions(modelId).stream()
|
||||
.collect(Collectors.toMap(SchemaElement::getId, SchemaElement::getName, (value1, value2) -> value2));
|
||||
}
|
||||
|
||||
}
|
||||
@@ -5,34 +5,32 @@ import com.google.common.collect.Sets;
|
||||
import com.tencent.supersonic.chat.agent.Agent;
|
||||
import com.tencent.supersonic.chat.agent.tool.AgentToolType;
|
||||
import com.tencent.supersonic.chat.agent.tool.MetricInterpretTool;
|
||||
import com.tencent.supersonic.chat.api.component.SemanticInterpreter;
|
||||
import com.tencent.supersonic.chat.api.component.SemanticParser;
|
||||
import com.tencent.supersonic.chat.api.pojo.QueryContext;
|
||||
import com.tencent.supersonic.chat.api.pojo.ChatContext;
|
||||
import com.tencent.supersonic.chat.api.pojo.SchemaElementMatch;
|
||||
import com.tencent.supersonic.chat.api.pojo.QueryContext;
|
||||
import com.tencent.supersonic.chat.api.pojo.SchemaElement;
|
||||
import com.tencent.supersonic.chat.api.pojo.SchemaElementMatch;
|
||||
import com.tencent.supersonic.chat.api.pojo.SchemaElementType;
|
||||
import com.tencent.supersonic.chat.api.pojo.SemanticParseInfo;
|
||||
import com.tencent.supersonic.chat.api.pojo.ModelSchema;
|
||||
import com.tencent.supersonic.chat.api.pojo.request.QueryFilter;
|
||||
import com.tencent.supersonic.chat.api.pojo.request.QueryReq;
|
||||
import com.tencent.supersonic.chat.parser.SatisfactionChecker;
|
||||
import com.tencent.supersonic.chat.query.llm.interpret.MetricInterpretQuery;
|
||||
import com.tencent.supersonic.chat.query.QueryManager;
|
||||
import com.tencent.supersonic.chat.query.plugin.PluginSemanticQuery;
|
||||
import com.tencent.supersonic.chat.query.llm.LLMSemanticQuery;
|
||||
import com.tencent.supersonic.chat.query.llm.interpret.MetricInterpretQuery;
|
||||
import com.tencent.supersonic.chat.service.AgentService;
|
||||
import com.tencent.supersonic.chat.utils.ComponentFactory;
|
||||
import com.tencent.supersonic.chat.service.SemanticService;
|
||||
import com.tencent.supersonic.common.pojo.DateConf;
|
||||
import com.tencent.supersonic.common.pojo.ModelCluster;
|
||||
import com.tencent.supersonic.common.pojo.enums.FilterOperatorEnum;
|
||||
import com.tencent.supersonic.common.pojo.enums.TimeDimensionEnum;
|
||||
import com.tencent.supersonic.common.util.ContextUtils;
|
||||
import com.tencent.supersonic.semantic.api.model.enums.TimeDimensionEnum;
|
||||
import com.tencent.supersonic.semantic.api.query.enums.FilterOperatorEnum;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.springframework.util.CollectionUtils;
|
||||
|
||||
import java.util.HashMap;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
import java.util.Set;
|
||||
import java.util.HashMap;
|
||||
import java.util.stream.Collectors;
|
||||
|
||||
@Slf4j
|
||||
@@ -71,7 +69,7 @@ public class MetricInterpretParser implements SemanticParser {
|
||||
|
||||
private void buildQuery(Long modelId, QueryContext queryContext,
|
||||
List<Long> metricIds, List<SchemaElementMatch> schemaElementMatches, String toolName) {
|
||||
PluginSemanticQuery metricInterpretQuery = QueryManager.createPluginQuery(MetricInterpretQuery.QUERY_MODE);
|
||||
LLMSemanticQuery metricInterpretQuery = QueryManager.createLLMQuery(MetricInterpretQuery.QUERY_MODE);
|
||||
Set<SchemaElement> metrics = getMetrics(metricIds, modelId);
|
||||
SemanticParseInfo semanticParseInfo = buildSemanticParseInfo(modelId, queryContext.getRequest(),
|
||||
metrics, schemaElementMatches, toolName);
|
||||
@@ -82,9 +80,8 @@ public class MetricInterpretParser implements SemanticParser {
|
||||
}
|
||||
|
||||
public Set<SchemaElement> getMetrics(List<Long> metricIds, Long modelId) {
|
||||
SemanticInterpreter semanticInterpreter = ComponentFactory.getSemanticLayer();
|
||||
ModelSchema modelSchema = semanticInterpreter.getModelSchema(modelId, true);
|
||||
Set<SchemaElement> metrics = modelSchema.getMetrics();
|
||||
SemanticService semanticService = ContextUtils.getBean(SemanticService.class);
|
||||
List<SchemaElement> metrics = semanticService.getSemanticSchema().getMetrics();
|
||||
return metrics.stream().filter(schemaElement -> metricIds.contains(schemaElement.getId()))
|
||||
.collect(Collectors.toSet());
|
||||
}
|
||||
@@ -113,16 +110,13 @@ public class MetricInterpretParser implements SemanticParser {
|
||||
|
||||
private SemanticParseInfo buildSemanticParseInfo(Long modelId, QueryReq queryReq, Set<SchemaElement> metrics,
|
||||
List<SchemaElementMatch> schemaElementMatches, String toolName) {
|
||||
SchemaElement model = new SchemaElement();
|
||||
model.setModel(modelId);
|
||||
model.setId(modelId);
|
||||
SemanticParseInfo semanticParseInfo = new SemanticParseInfo();
|
||||
semanticParseInfo.setMetrics(metrics);
|
||||
SchemaElement dimension = new SchemaElement();
|
||||
dimension.setBizName(TimeDimensionEnum.DAY.getName());
|
||||
semanticParseInfo.setDimensions(Sets.newHashSet(dimension));
|
||||
semanticParseInfo.setElementMatches(schemaElementMatches);
|
||||
semanticParseInfo.setModel(model);
|
||||
semanticParseInfo.setModel(ModelCluster.build(Sets.newHashSet(modelId)));
|
||||
semanticParseInfo.setScore(queryReq.getQueryText().length());
|
||||
DateConf dateConf = new DateConf();
|
||||
dateConf.setDateMode(DateConf.DateMode.RECENT);
|
||||
|
||||
@@ -0,0 +1,148 @@
|
||||
package com.tencent.supersonic.chat.parser.llm.s2sql;
|
||||
|
||||
import com.tencent.supersonic.chat.api.component.SemanticQuery;
|
||||
import com.tencent.supersonic.chat.api.pojo.ChatContext;
|
||||
import com.tencent.supersonic.chat.api.pojo.QueryContext;
|
||||
import com.tencent.supersonic.chat.api.pojo.SchemaElementMatch;
|
||||
import com.tencent.supersonic.chat.api.pojo.SchemaElementType;
|
||||
import com.tencent.supersonic.chat.api.pojo.SchemaModelClusterMapInfo;
|
||||
import com.tencent.supersonic.common.pojo.ModelCluster;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.apache.commons.collections.CollectionUtils;
|
||||
|
||||
import java.util.Comparator;
|
||||
import java.util.HashMap;
|
||||
import java.util.HashSet;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
import java.util.Map.Entry;
|
||||
import java.util.Objects;
|
||||
import java.util.Set;
|
||||
import java.util.stream.Collectors;
|
||||
|
||||
@Slf4j
|
||||
public class HeuristicModelResolver implements ModelResolver {
|
||||
|
||||
protected static String selectModelBySchemaElementMatchScore(Map<String, SemanticQuery> modelQueryModes,
|
||||
SchemaModelClusterMapInfo schemaMap) {
|
||||
//model count priority
|
||||
String modelIdByModelCount = getModelIdByMatchModelScore(schemaMap);
|
||||
if (Objects.nonNull(modelIdByModelCount)) {
|
||||
log.info("selectModel by model count:{}", modelIdByModelCount);
|
||||
return modelIdByModelCount;
|
||||
}
|
||||
|
||||
Map<String, ModelMatchResult> modelTypeMap = getModelTypeMap(schemaMap);
|
||||
if (modelTypeMap.size() == 1) {
|
||||
String modelSelect = modelTypeMap.entrySet().stream().collect(Collectors.toList()).get(0).getKey();
|
||||
if (modelQueryModes.containsKey(modelSelect)) {
|
||||
log.info("selectModel with only one Model [{}]", modelSelect);
|
||||
return modelSelect;
|
||||
}
|
||||
} else {
|
||||
Map.Entry<String, ModelMatchResult> maxModel = modelTypeMap.entrySet().stream()
|
||||
.filter(entry -> modelQueryModes.containsKey(entry.getKey()))
|
||||
.sorted((o1, o2) -> {
|
||||
int difference = o2.getValue().getCount() - o1.getValue().getCount();
|
||||
if (difference == 0) {
|
||||
return (int) ((o2.getValue().getMaxSimilarity()
|
||||
- o1.getValue().getMaxSimilarity()) * 100);
|
||||
}
|
||||
return difference;
|
||||
}).findFirst().orElse(null);
|
||||
if (maxModel != null) {
|
||||
log.info("selectModel with multiple Models [{}]", maxModel.getKey());
|
||||
return maxModel.getKey();
|
||||
}
|
||||
}
|
||||
return null;
|
||||
}
|
||||
|
||||
private static String getModelIdByMatchModelScore(SchemaModelClusterMapInfo schemaMap) {
|
||||
Map<String, List<SchemaElementMatch>> modelElementMatches = schemaMap.getModelElementMatches();
|
||||
// calculate model match score, matched element gets 1.0 point, and inherit element gets 0.5 point
|
||||
Map<String, Double> modelIdToModelScore = new HashMap<>();
|
||||
if (Objects.nonNull(modelElementMatches)) {
|
||||
for (Entry<String, List<SchemaElementMatch>> modelElementMatch : modelElementMatches.entrySet()) {
|
||||
String modelId = modelElementMatch.getKey();
|
||||
List<Double> modelMatchesScore = modelElementMatch.getValue().stream()
|
||||
.filter(elementMatch -> elementMatch.getSimilarity() >= 1)
|
||||
.filter(elementMatch -> SchemaElementType.MODEL.equals(elementMatch.getElement().getType()))
|
||||
.map(elementMatch -> elementMatch.isInherited() ? 0.5 : 1.0).collect(Collectors.toList());
|
||||
|
||||
if (!CollectionUtils.isEmpty(modelMatchesScore)) {
|
||||
// get sum of model match score
|
||||
double score = modelMatchesScore.stream().mapToDouble(Double::doubleValue).sum();
|
||||
modelIdToModelScore.put(modelId, score);
|
||||
}
|
||||
}
|
||||
Entry<String, Double> maxModelScore = modelIdToModelScore.entrySet().stream()
|
||||
.max(Comparator.comparingDouble(o -> o.getValue())).orElse(null);
|
||||
log.info("maxModelCount:{},modelIdToModelCount:{}", maxModelScore, modelIdToModelScore);
|
||||
if (Objects.nonNull(maxModelScore)) {
|
||||
return maxModelScore.getKey();
|
||||
}
|
||||
}
|
||||
return null;
|
||||
}
|
||||
|
||||
|
||||
public static Map<String, ModelMatchResult> getModelTypeMap(SchemaModelClusterMapInfo schemaMap) {
|
||||
Map<String, ModelMatchResult> modelCount = new HashMap<>();
|
||||
for (Map.Entry<String, List<SchemaElementMatch>> entry : schemaMap.getModelElementMatches().entrySet()) {
|
||||
List<SchemaElementMatch> schemaElementMatches = schemaMap.getMatchedElements(entry.getKey());
|
||||
if (schemaElementMatches != null && schemaElementMatches.size() > 0) {
|
||||
if (!modelCount.containsKey(entry.getKey())) {
|
||||
modelCount.put(entry.getKey(), new ModelMatchResult());
|
||||
}
|
||||
ModelMatchResult modelMatchResult = modelCount.get(entry.getKey());
|
||||
Set<SchemaElementType> schemaElementTypes = new HashSet<>();
|
||||
schemaElementMatches.stream()
|
||||
.forEach(schemaElementMatch -> schemaElementTypes.add(
|
||||
schemaElementMatch.getElement().getType()));
|
||||
SchemaElementMatch schemaElementMatchMax = schemaElementMatches.stream()
|
||||
.sorted((o1, o2) ->
|
||||
((int) ((o2.getSimilarity() - o1.getSimilarity()) * 100))
|
||||
).findFirst().orElse(null);
|
||||
if (schemaElementMatchMax != null) {
|
||||
modelMatchResult.setMaxSimilarity(schemaElementMatchMax.getSimilarity());
|
||||
}
|
||||
modelMatchResult.setCount(schemaElementTypes.size());
|
||||
|
||||
}
|
||||
}
|
||||
return modelCount;
|
||||
}
|
||||
|
||||
|
||||
public String resolve(QueryContext queryContext, ChatContext chatCtx, Set<Long> restrictiveModels) {
|
||||
SchemaModelClusterMapInfo mapInfo = queryContext.getModelClusterMapInfo();
|
||||
Set<String> matchedModelClusters = mapInfo.getElementMatchesByModelIds(restrictiveModels).keySet();
|
||||
Long modelId = queryContext.getRequest().getModelId();
|
||||
if (Objects.nonNull(modelId) && modelId > 0) {
|
||||
if (CollectionUtils.isEmpty(restrictiveModels) || restrictiveModels.contains(modelId)) {
|
||||
return getModelClusterByModelId(modelId, matchedModelClusters);
|
||||
}
|
||||
return null;
|
||||
}
|
||||
|
||||
Map<String, SemanticQuery> modelQueryModes = new HashMap<>();
|
||||
for (String matchedModel : matchedModelClusters) {
|
||||
modelQueryModes.put(matchedModel, null);
|
||||
}
|
||||
if (modelQueryModes.size() == 1) {
|
||||
return modelQueryModes.keySet().stream().findFirst().get();
|
||||
}
|
||||
return selectModelBySchemaElementMatchScore(modelQueryModes, mapInfo);
|
||||
}
|
||||
|
||||
private String getModelClusterByModelId(Long modelId, Set<String> modelClusterKeySet) {
|
||||
for (String modelClusterKey : modelClusterKeySet) {
|
||||
if (ModelCluster.build(modelClusterKey).getModelIds().contains(modelId)) {
|
||||
return modelClusterKey;
|
||||
}
|
||||
}
|
||||
return null;
|
||||
}
|
||||
|
||||
}
|
||||
@@ -0,0 +1,271 @@
|
||||
package com.tencent.supersonic.chat.parser.llm.s2sql;
|
||||
|
||||
import com.tencent.supersonic.chat.agent.tool.AgentToolType;
|
||||
import com.tencent.supersonic.chat.agent.tool.CommonAgentTool;
|
||||
import com.tencent.supersonic.chat.api.component.SemanticInterpreter;
|
||||
import com.tencent.supersonic.chat.api.pojo.ChatContext;
|
||||
import com.tencent.supersonic.chat.api.pojo.QueryContext;
|
||||
import com.tencent.supersonic.chat.api.pojo.SchemaElement;
|
||||
import com.tencent.supersonic.chat.api.pojo.SchemaElementMatch;
|
||||
import com.tencent.supersonic.chat.api.pojo.SchemaElementType;
|
||||
import com.tencent.supersonic.chat.api.pojo.SemanticSchema;
|
||||
import com.tencent.supersonic.chat.api.pojo.request.QueryReq;
|
||||
import com.tencent.supersonic.chat.config.LLMParserConfig;
|
||||
import com.tencent.supersonic.chat.config.OptimizationConfig;
|
||||
import com.tencent.supersonic.chat.parser.LLMInterpreter;
|
||||
import com.tencent.supersonic.chat.parser.SatisfactionChecker;
|
||||
import com.tencent.supersonic.chat.query.llm.s2sql.LLMReq;
|
||||
import com.tencent.supersonic.chat.query.llm.s2sql.LLMReq.ElementValue;
|
||||
import com.tencent.supersonic.chat.query.llm.s2sql.LLMResp;
|
||||
import com.tencent.supersonic.chat.service.AgentService;
|
||||
import com.tencent.supersonic.chat.utils.ComponentFactory;
|
||||
import com.tencent.supersonic.common.pojo.ModelCluster;
|
||||
import com.tencent.supersonic.common.pojo.enums.DataFormatTypeEnum;
|
||||
import com.tencent.supersonic.common.pojo.enums.TimeDimensionEnum;
|
||||
import com.tencent.supersonic.common.util.DateUtils;
|
||||
import com.tencent.supersonic.knowledge.service.SchemaService;
|
||||
import com.tencent.supersonic.semantic.api.model.pojo.SchemaItem;
|
||||
import com.tencent.supersonic.semantic.api.model.response.ModelSchemaResp;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.apache.commons.lang3.StringUtils;
|
||||
import org.apache.commons.lang3.tuple.Pair;
|
||||
import org.springframework.beans.factory.annotation.Autowired;
|
||||
import org.springframework.stereotype.Service;
|
||||
import org.springframework.util.CollectionUtils;
|
||||
|
||||
import java.util.ArrayList;
|
||||
import java.util.Comparator;
|
||||
import java.util.HashSet;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
import java.util.Objects;
|
||||
import java.util.Optional;
|
||||
import java.util.Set;
|
||||
import java.util.stream.Collectors;
|
||||
|
||||
@Slf4j
|
||||
@Service
|
||||
public class LLMRequestService {
|
||||
|
||||
protected LLMInterpreter llmInterpreter = ComponentFactory.getLLMInterpreter();
|
||||
|
||||
protected SemanticInterpreter semanticInterpreter = ComponentFactory.getSemanticLayer();
|
||||
@Autowired
|
||||
private LLMParserConfig llmParserConfig;
|
||||
@Autowired
|
||||
private AgentService agentService;
|
||||
@Autowired
|
||||
private SchemaService schemaService;
|
||||
@Autowired
|
||||
private OptimizationConfig optimizationConfig;
|
||||
|
||||
|
||||
public boolean check(QueryContext queryCtx) {
|
||||
QueryReq request = queryCtx.getRequest();
|
||||
if (StringUtils.isEmpty(llmParserConfig.getUrl())) {
|
||||
log.info("llm parser url is empty, skip {} , llmParserConfig:{}", LLMS2SQLParser.class, llmParserConfig);
|
||||
return true;
|
||||
}
|
||||
if (SatisfactionChecker.check(queryCtx)) {
|
||||
log.info("skip {}, queryText:{}", LLMS2SQLParser.class, request.getQueryText());
|
||||
return true;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
public ModelCluster getModelCluster(QueryContext queryCtx, ChatContext chatCtx, Integer agentId) {
|
||||
Set<Long> distinctModelIds = agentService.getModelIds(agentId, AgentToolType.LLM_S2SQL);
|
||||
if (agentService.containsAllModel(distinctModelIds)) {
|
||||
distinctModelIds = new HashSet<>();
|
||||
}
|
||||
ModelResolver modelResolver = ComponentFactory.getModelResolver();
|
||||
String modelCluster = modelResolver.resolve(queryCtx, chatCtx, distinctModelIds);
|
||||
log.info("resolve modelId:{},llmParser Models:{}", modelCluster, distinctModelIds);
|
||||
return ModelCluster.build(modelCluster);
|
||||
}
|
||||
|
||||
public CommonAgentTool getParserTool(QueryReq request, Set<Long> modelIdSet) {
|
||||
List<CommonAgentTool> commonAgentTools = agentService.getParserTools(request.getAgentId(),
|
||||
AgentToolType.LLM_S2SQL);
|
||||
Optional<CommonAgentTool> llmParserTool = commonAgentTools.stream()
|
||||
.filter(tool -> {
|
||||
List<Long> modelIds = tool.getModelIds();
|
||||
if (agentService.containsAllModel(new HashSet<>(modelIds))) {
|
||||
return true;
|
||||
}
|
||||
for (Long modelId : modelIdSet) {
|
||||
if (modelIds.contains(modelId)) {
|
||||
return true;
|
||||
}
|
||||
}
|
||||
return false;
|
||||
})
|
||||
.findFirst();
|
||||
return llmParserTool.orElse(null);
|
||||
}
|
||||
|
||||
public LLMReq getLlmReq(QueryContext queryCtx, SemanticSchema semanticSchema,
|
||||
ModelCluster modelCluster, List<ElementValue> linkingValues) {
|
||||
Map<Long, String> modelIdToName = semanticSchema.getModelIdToName();
|
||||
String queryText = queryCtx.getRequest().getQueryText();
|
||||
|
||||
LLMReq llmReq = new LLMReq();
|
||||
llmReq.setQueryText(queryText);
|
||||
Long firstModelId = modelCluster.getFirstModel();
|
||||
LLMReq.FilterCondition filterCondition = new LLMReq.FilterCondition();
|
||||
llmReq.setFilterCondition(filterCondition);
|
||||
|
||||
LLMReq.LLMSchema llmSchema = new LLMReq.LLMSchema();
|
||||
llmSchema.setModelName(modelIdToName.get(firstModelId));
|
||||
llmSchema.setDomainName(modelIdToName.get(firstModelId));
|
||||
|
||||
List<String> fieldNameList = getFieldNameList(queryCtx, modelCluster, llmParserConfig);
|
||||
|
||||
String priorExts = getPriorExts(modelCluster.getModelIds(), fieldNameList);
|
||||
llmReq.setPriorExts(priorExts);
|
||||
|
||||
fieldNameList.add(TimeDimensionEnum.DAY.getChName());
|
||||
llmSchema.setFieldNameList(fieldNameList);
|
||||
llmReq.setSchema(llmSchema);
|
||||
|
||||
List<ElementValue> linking = new ArrayList<>();
|
||||
if (optimizationConfig.isUseLinkingValueSwitch()) {
|
||||
linking.addAll(linkingValues);
|
||||
}
|
||||
llmReq.setLinking(linking);
|
||||
|
||||
String currentDate = S2SQLDateHelper.getReferenceDate(firstModelId);
|
||||
if (StringUtils.isEmpty(currentDate)) {
|
||||
currentDate = DateUtils.getBeforeDate(0);
|
||||
}
|
||||
llmReq.setCurrentDate(currentDate);
|
||||
return llmReq;
|
||||
}
|
||||
|
||||
public LLMResp requestLLM(LLMReq llmReq, String modelClusterKey) {
|
||||
return llmInterpreter.query2sql(llmReq, modelClusterKey);
|
||||
}
|
||||
|
||||
protected List<String> getFieldNameList(QueryContext queryCtx, ModelCluster modelCluster,
|
||||
LLMParserConfig llmParserConfig) {
|
||||
|
||||
Set<String> results = getTopNFieldNames(modelCluster, llmParserConfig);
|
||||
|
||||
Set<String> fieldNameList = getMatchedFieldNames(queryCtx, modelCluster);
|
||||
|
||||
results.addAll(fieldNameList);
|
||||
return new ArrayList<>(results);
|
||||
}
|
||||
|
||||
private String getPriorExts(Set<Long> modelIds, List<String> fieldNameList) {
|
||||
StringBuilder extraInfoSb = new StringBuilder();
|
||||
List<ModelSchemaResp> modelSchemaResps = semanticInterpreter.fetchModelSchema(
|
||||
new ArrayList<>(modelIds), true);
|
||||
if (!CollectionUtils.isEmpty(modelSchemaResps)) {
|
||||
|
||||
ModelSchemaResp modelSchemaResp = modelSchemaResps.get(0);
|
||||
Map<String, String> fieldNameToDataFormatType = modelSchemaResp.getMetrics()
|
||||
.stream().filter(metricSchemaResp -> Objects.nonNull(metricSchemaResp.getDataFormatType()))
|
||||
.flatMap(metricSchemaResp -> {
|
||||
Set<Pair<String, String>> result = new HashSet<>();
|
||||
String dataFormatType = metricSchemaResp.getDataFormatType();
|
||||
result.add(Pair.of(metricSchemaResp.getName(), dataFormatType));
|
||||
List<String> aliasList = SchemaItem.getAliasList(metricSchemaResp.getAlias());
|
||||
if (!CollectionUtils.isEmpty(aliasList)) {
|
||||
for (String alias : aliasList) {
|
||||
result.add(Pair.of(alias, dataFormatType));
|
||||
}
|
||||
}
|
||||
return result.stream();
|
||||
})
|
||||
.collect(Collectors.toMap(a -> a.getLeft(), a -> a.getRight(), (k1, k2) -> k1));
|
||||
|
||||
for (String fieldName : fieldNameList) {
|
||||
String dataFormatType = fieldNameToDataFormatType.get(fieldName);
|
||||
if (DataFormatTypeEnum.DECIMAL.getName().equalsIgnoreCase(dataFormatType)
|
||||
|| DataFormatTypeEnum.PERCENT.getName().equalsIgnoreCase(dataFormatType)) {
|
||||
String format = String.format("%s的计量单位是%s", fieldName, "小数; ");
|
||||
extraInfoSb.append(format);
|
||||
}
|
||||
}
|
||||
}
|
||||
return extraInfoSb.toString();
|
||||
}
|
||||
|
||||
|
||||
protected List<ElementValue> getValueList(QueryContext queryCtx, ModelCluster modelCluster) {
|
||||
Map<Long, String> itemIdToName = getItemIdToName(modelCluster);
|
||||
|
||||
List<SchemaElementMatch> matchedElements = queryCtx.getModelClusterMapInfo()
|
||||
.getMatchedElements(modelCluster.getKey());
|
||||
if (CollectionUtils.isEmpty(matchedElements)) {
|
||||
return new ArrayList<>();
|
||||
}
|
||||
Set<ElementValue> valueMatches = matchedElements
|
||||
.stream()
|
||||
.filter(elementMatch -> !elementMatch.isInherited())
|
||||
.filter(schemaElementMatch -> {
|
||||
SchemaElementType type = schemaElementMatch.getElement().getType();
|
||||
return SchemaElementType.VALUE.equals(type) || SchemaElementType.ID.equals(type);
|
||||
})
|
||||
.map(elementMatch -> {
|
||||
ElementValue elementValue = new ElementValue();
|
||||
elementValue.setFieldName(itemIdToName.get(elementMatch.getElement().getId()));
|
||||
elementValue.setFieldValue(elementMatch.getWord());
|
||||
return elementValue;
|
||||
}).collect(Collectors.toSet());
|
||||
return new ArrayList<>(valueMatches);
|
||||
}
|
||||
|
||||
protected Map<Long, String> getItemIdToName(ModelCluster modelCluster) {
|
||||
SemanticSchema semanticSchema = schemaService.getSemanticSchema();
|
||||
return semanticSchema.getDimensions(modelCluster.getModelIds()).stream()
|
||||
.collect(Collectors.toMap(SchemaElement::getId, SchemaElement::getName, (value1, value2) -> value2));
|
||||
}
|
||||
|
||||
|
||||
private Set<String> getTopNFieldNames(ModelCluster modelCluster, LLMParserConfig llmParserConfig) {
|
||||
SemanticSchema semanticSchema = schemaService.getSemanticSchema();
|
||||
Set<String> results = semanticSchema.getDimensions(modelCluster.getModelIds()).stream()
|
||||
.sorted(Comparator.comparing(SchemaElement::getUseCnt).reversed())
|
||||
.limit(llmParserConfig.getDimensionTopN())
|
||||
.map(entry -> entry.getName())
|
||||
.collect(Collectors.toSet());
|
||||
|
||||
Set<String> metrics = semanticSchema.getMetrics(modelCluster.getModelIds()).stream()
|
||||
.sorted(Comparator.comparing(SchemaElement::getUseCnt).reversed())
|
||||
.limit(llmParserConfig.getMetricTopN())
|
||||
.map(entry -> entry.getName())
|
||||
.collect(Collectors.toSet());
|
||||
|
||||
results.addAll(metrics);
|
||||
return results;
|
||||
}
|
||||
|
||||
|
||||
protected Set<String> getMatchedFieldNames(QueryContext queryCtx, ModelCluster modelCluster) {
|
||||
Map<Long, String> itemIdToName = getItemIdToName(modelCluster);
|
||||
List<SchemaElementMatch> matchedElements = queryCtx.getModelClusterMapInfo()
|
||||
.getMatchedElements(modelCluster.getKey());
|
||||
if (CollectionUtils.isEmpty(matchedElements)) {
|
||||
return new HashSet<>();
|
||||
}
|
||||
Set<String> fieldNameList = matchedElements.stream()
|
||||
.filter(schemaElementMatch -> {
|
||||
SchemaElementType elementType = schemaElementMatch.getElement().getType();
|
||||
return SchemaElementType.METRIC.equals(elementType)
|
||||
|| SchemaElementType.DIMENSION.equals(elementType)
|
||||
|| SchemaElementType.VALUE.equals(elementType);
|
||||
})
|
||||
.map(schemaElementMatch -> {
|
||||
SchemaElement element = schemaElementMatch.getElement();
|
||||
SchemaElementType elementType = element.getType();
|
||||
if (SchemaElementType.VALUE.equals(elementType)) {
|
||||
return itemIdToName.get(element.getId());
|
||||
}
|
||||
return schemaElementMatch.getWord();
|
||||
})
|
||||
.collect(Collectors.toSet());
|
||||
return fieldNameList;
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,63 @@
|
||||
package com.tencent.supersonic.chat.parser.llm.s2sql;
|
||||
|
||||
import com.tencent.supersonic.chat.agent.tool.CommonAgentTool;
|
||||
import com.tencent.supersonic.chat.api.pojo.QueryContext;
|
||||
import com.tencent.supersonic.chat.api.pojo.SemanticParseInfo;
|
||||
import com.tencent.supersonic.chat.query.QueryManager;
|
||||
import com.tencent.supersonic.chat.query.llm.LLMSemanticQuery;
|
||||
import com.tencent.supersonic.chat.query.llm.s2sql.LLMResp;
|
||||
import com.tencent.supersonic.chat.query.llm.s2sql.S2SQLQuery;
|
||||
import com.tencent.supersonic.common.pojo.Constants;
|
||||
import com.tencent.supersonic.common.util.jsqlparser.SqlParserEqualHelper;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.apache.commons.collections4.MapUtils;
|
||||
import org.springframework.stereotype.Service;
|
||||
|
||||
import java.util.HashMap;
|
||||
import java.util.Map;
|
||||
import java.util.Objects;
|
||||
|
||||
@Slf4j
|
||||
@Service
|
||||
public class LLMResponseService {
|
||||
|
||||
public SemanticParseInfo addParseInfo(QueryContext queryCtx, ParseResult parseResult, String s2SQL, Double weight) {
|
||||
if (Objects.isNull(weight)) {
|
||||
weight = 0D;
|
||||
}
|
||||
LLMSemanticQuery semanticQuery = QueryManager.createLLMQuery(S2SQLQuery.QUERY_MODE);
|
||||
SemanticParseInfo parseInfo = semanticQuery.getParseInfo();
|
||||
parseInfo.setModel(parseResult.getModelCluster());
|
||||
CommonAgentTool commonAgentTool = parseResult.getCommonAgentTool();
|
||||
parseInfo.getElementMatches().addAll(queryCtx.getModelClusterMapInfo()
|
||||
.getMatchedElements(parseInfo.getModelClusterKey()));
|
||||
|
||||
Map<String, Object> properties = new HashMap<>();
|
||||
properties.put(Constants.CONTEXT, parseResult);
|
||||
properties.put("type", "internal");
|
||||
properties.put("name", commonAgentTool.getName());
|
||||
|
||||
parseInfo.setProperties(properties);
|
||||
parseInfo.setScore(queryCtx.getRequest().getQueryText().length() * (1 + weight));
|
||||
parseInfo.setQueryMode(semanticQuery.getQueryMode());
|
||||
parseInfo.getSqlInfo().setS2SQL(s2SQL);
|
||||
parseInfo.setModel(parseResult.getModelCluster());
|
||||
queryCtx.getCandidateQueries().add(semanticQuery);
|
||||
return parseInfo;
|
||||
}
|
||||
|
||||
public Map<String, Double> getDeduplicationSqlWeight(LLMResp llmResp) {
|
||||
if (MapUtils.isEmpty(llmResp.getSqlWeight())) {
|
||||
return llmResp.getSqlWeight();
|
||||
}
|
||||
Map<String, Double> result = new HashMap<>();
|
||||
for (Map.Entry<String, Double> entry : llmResp.getSqlWeight().entrySet()) {
|
||||
String key = entry.getKey();
|
||||
if (result.keySet().stream().anyMatch(existKey -> SqlParserEqualHelper.equals(existKey, key))) {
|
||||
continue;
|
||||
}
|
||||
result.put(key, entry.getValue());
|
||||
}
|
||||
return result;
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,83 @@
|
||||
package com.tencent.supersonic.chat.parser.llm.s2sql;
|
||||
|
||||
import com.tencent.supersonic.chat.agent.tool.CommonAgentTool;
|
||||
import com.tencent.supersonic.chat.api.component.SemanticParser;
|
||||
import com.tencent.supersonic.chat.api.pojo.ChatContext;
|
||||
import com.tencent.supersonic.chat.api.pojo.QueryContext;
|
||||
import com.tencent.supersonic.chat.api.pojo.SemanticSchema;
|
||||
import com.tencent.supersonic.chat.api.pojo.request.QueryReq;
|
||||
import com.tencent.supersonic.chat.query.llm.s2sql.LLMReq;
|
||||
import com.tencent.supersonic.chat.query.llm.s2sql.LLMReq.ElementValue;
|
||||
import com.tencent.supersonic.chat.query.llm.s2sql.LLMResp;
|
||||
import com.tencent.supersonic.chat.service.SemanticService;
|
||||
import com.tencent.supersonic.common.pojo.ModelCluster;
|
||||
import com.tencent.supersonic.common.util.ContextUtils;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.apache.commons.collections4.MapUtils;
|
||||
import org.apache.commons.lang3.StringUtils;
|
||||
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
import java.util.Objects;
|
||||
|
||||
@Slf4j
|
||||
public class LLMS2SQLParser implements SemanticParser {
|
||||
|
||||
@Override
|
||||
public void parse(QueryContext queryCtx, ChatContext chatCtx) {
|
||||
QueryReq request = queryCtx.getRequest();
|
||||
LLMRequestService requestService = ContextUtils.getBean(LLMRequestService.class);
|
||||
SemanticService semanticService = ContextUtils.getBean(SemanticService.class);
|
||||
//1.determine whether to skip this parser.
|
||||
if (requestService.check(queryCtx)) {
|
||||
return;
|
||||
}
|
||||
try {
|
||||
//2.get modelId from queryCtx and chatCtx.
|
||||
ModelCluster modelCluster = requestService.getModelCluster(queryCtx, chatCtx, request.getAgentId());
|
||||
if (StringUtils.isBlank(modelCluster.getKey())) {
|
||||
return;
|
||||
}
|
||||
//3.get agent tool and determine whether to skip this parser.
|
||||
CommonAgentTool commonAgentTool = requestService.getParserTool(request, modelCluster.getModelIds());
|
||||
if (Objects.isNull(commonAgentTool)) {
|
||||
log.info("no tool in this agent, skip {}", LLMS2SQLParser.class);
|
||||
return;
|
||||
}
|
||||
//4.construct a request, call the API for the large model, and retrieve the results.
|
||||
List<ElementValue> linkingValues = requestService.getValueList(queryCtx, modelCluster);
|
||||
SemanticSchema semanticSchema = semanticService.getSemanticSchema();
|
||||
LLMReq llmReq = requestService.getLlmReq(queryCtx, semanticSchema, modelCluster, linkingValues);
|
||||
LLMResp llmResp = requestService.requestLLM(llmReq, modelCluster.getKey());
|
||||
|
||||
if (Objects.isNull(llmResp)) {
|
||||
return;
|
||||
}
|
||||
//5. deduplicate the SQL result list and build parserInfo
|
||||
modelCluster.buildName(semanticSchema.getModelIdToName());
|
||||
LLMResponseService responseService = ContextUtils.getBean(LLMResponseService.class);
|
||||
Map<String, Double> deduplicationSqlWeight = responseService.getDeduplicationSqlWeight(llmResp);
|
||||
ParseResult parseResult = ParseResult.builder()
|
||||
.request(request)
|
||||
.modelCluster(modelCluster)
|
||||
.commonAgentTool(commonAgentTool)
|
||||
.llmReq(llmReq)
|
||||
.llmResp(llmResp)
|
||||
.linkingValues(linkingValues)
|
||||
.build();
|
||||
|
||||
if (MapUtils.isEmpty(deduplicationSqlWeight)) {
|
||||
responseService.addParseInfo(queryCtx, parseResult, llmResp.getSqlOutput(), 1D);
|
||||
} else {
|
||||
deduplicationSqlWeight.forEach((sql, weight) -> {
|
||||
responseService.addParseInfo(queryCtx, parseResult, sql, weight);
|
||||
});
|
||||
}
|
||||
|
||||
} catch (Exception e) {
|
||||
log.error("parse", e);
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
}
|
||||
@@ -1,4 +1,4 @@
|
||||
package com.tencent.supersonic.chat.parser.llm.dsl;
|
||||
package com.tencent.supersonic.chat.parser.llm.s2sql;
|
||||
|
||||
import lombok.Data;
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
package com.tencent.supersonic.chat.parser.llm.dsl;
|
||||
package com.tencent.supersonic.chat.parser.llm.s2sql;
|
||||
|
||||
|
||||
import com.tencent.supersonic.chat.api.pojo.ChatContext;
|
||||
@@ -7,6 +7,6 @@ import java.util.Set;
|
||||
|
||||
public interface ModelResolver {
|
||||
|
||||
Long resolve(QueryContext queryContext, ChatContext chatCtx, Set<Long> restrictiveModels);
|
||||
String resolve(QueryContext queryContext, ChatContext chatCtx, Set<Long> restrictiveModels);
|
||||
|
||||
}
|
||||
@@ -0,0 +1,33 @@
|
||||
package com.tencent.supersonic.chat.parser.llm.s2sql;
|
||||
|
||||
import com.tencent.supersonic.chat.agent.tool.CommonAgentTool;
|
||||
import com.tencent.supersonic.chat.api.pojo.request.QueryReq;
|
||||
import com.tencent.supersonic.chat.query.llm.s2sql.LLMReq;
|
||||
import com.tencent.supersonic.chat.query.llm.s2sql.LLMReq.ElementValue;
|
||||
import com.tencent.supersonic.chat.query.llm.s2sql.LLMResp;
|
||||
import com.tencent.supersonic.common.pojo.ModelCluster;
|
||||
import lombok.AllArgsConstructor;
|
||||
import lombok.Builder;
|
||||
import lombok.Data;
|
||||
import lombok.NoArgsConstructor;
|
||||
|
||||
import java.util.List;
|
||||
|
||||
@Data
|
||||
@Builder
|
||||
@AllArgsConstructor
|
||||
@NoArgsConstructor
|
||||
public class ParseResult {
|
||||
|
||||
private ModelCluster modelCluster;
|
||||
|
||||
private LLMReq llmReq;
|
||||
|
||||
private LLMResp llmResp;
|
||||
|
||||
private QueryReq request;
|
||||
|
||||
private CommonAgentTool commonAgentTool;
|
||||
|
||||
private List<ElementValue> linkingValues;
|
||||
}
|
||||
@@ -1,4 +1,4 @@
|
||||
package com.tencent.supersonic.chat.parser.llm.dsl;
|
||||
package com.tencent.supersonic.chat.parser.llm.s2sql;
|
||||
|
||||
import com.tencent.supersonic.chat.api.pojo.request.ChatConfigFilter;
|
||||
import com.tencent.supersonic.chat.api.pojo.request.ChatDefaultConfigReq;
|
||||
@@ -10,39 +10,35 @@ import com.tencent.supersonic.common.util.DateUtils;
|
||||
import java.util.List;
|
||||
import java.util.Objects;
|
||||
import org.apache.commons.collections.CollectionUtils;
|
||||
import org.apache.commons.lang3.StringUtils;
|
||||
|
||||
public class DSLDateHelper {
|
||||
public class S2SQLDateHelper {
|
||||
|
||||
public static String getReferenceDate(Long modelId) {
|
||||
String chatDetailDate = getChatDetailDate(modelId);
|
||||
if (StringUtils.isNotBlank(chatDetailDate)) {
|
||||
return chatDetailDate;
|
||||
}
|
||||
return DateUtils.getBeforeDate(0);
|
||||
}
|
||||
|
||||
private static String getChatDetailDate(Long modelId) {
|
||||
String defaultDate = DateUtils.getBeforeDate(0);
|
||||
if (Objects.isNull(modelId)) {
|
||||
return null;
|
||||
return defaultDate;
|
||||
}
|
||||
ChatConfigFilter filter = new ChatConfigFilter();
|
||||
filter.setModelId(modelId);
|
||||
|
||||
List<ChatConfigResp> configResps = ContextUtils.getBean(ConfigService.class).search(filter, null);
|
||||
if (CollectionUtils.isEmpty(configResps)) {
|
||||
return null;
|
||||
return defaultDate;
|
||||
}
|
||||
ChatConfigResp chatConfigResp = configResps.get(0);
|
||||
if (Objects.isNull(chatConfigResp.getChatDetailConfig()) || Objects.isNull(
|
||||
chatConfigResp.getChatDetailConfig().getChatDefaultConfig())) {
|
||||
return null;
|
||||
return defaultDate;
|
||||
}
|
||||
|
||||
ChatDefaultConfigReq chatDefaultConfig = chatConfigResp.getChatDetailConfig().getChatDefaultConfig();
|
||||
Integer unit = chatDefaultConfig.getUnit();
|
||||
String period = chatDefaultConfig.getPeriod();
|
||||
if (Objects.nonNull(unit)) {
|
||||
// If the unit is set to less than 0, then do not add relative date.
|
||||
if (unit < 0) {
|
||||
return null;
|
||||
}
|
||||
DatePeriodEnum datePeriodEnum = DatePeriodEnum.get(period);
|
||||
if (Objects.isNull(datePeriodEnum)) {
|
||||
return DateUtils.getBeforeDate(unit);
|
||||
@@ -50,6 +46,7 @@ public class DSLDateHelper {
|
||||
return DateUtils.getBeforeDate(unit, datePeriodEnum);
|
||||
}
|
||||
}
|
||||
return null;
|
||||
return defaultDate;
|
||||
}
|
||||
|
||||
}
|
||||
@@ -1,51 +0,0 @@
|
||||
package com.tencent.supersonic.chat.parser.llm.time;
|
||||
|
||||
import com.alibaba.fastjson.JSON;
|
||||
import com.alibaba.fastjson.JSONObject;
|
||||
import com.tencent.supersonic.chat.api.component.SemanticParser;
|
||||
import com.tencent.supersonic.chat.api.component.SemanticQuery;
|
||||
import com.tencent.supersonic.chat.api.pojo.ChatContext;
|
||||
import com.tencent.supersonic.chat.api.pojo.QueryContext;
|
||||
import com.tencent.supersonic.common.pojo.DateConf;
|
||||
import com.tencent.supersonic.common.util.ChatGptHelper;
|
||||
import com.tencent.supersonic.common.util.ContextUtils;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
|
||||
@Slf4j
|
||||
public class LLMTimeEnhancementParse implements SemanticParser {
|
||||
|
||||
|
||||
@Override
|
||||
public void parse(QueryContext queryContext, ChatContext chatContext) {
|
||||
log.info("before queryContext:{},chatContext:{}", queryContext, chatContext);
|
||||
ChatGptHelper chatGptHelper = ContextUtils.getBean(ChatGptHelper.class);
|
||||
try {
|
||||
String inferredTime = chatGptHelper.inferredTime(queryContext.getRequest().getQueryText());
|
||||
if (!queryContext.getCandidateQueries().isEmpty()) {
|
||||
for (SemanticQuery query : queryContext.getCandidateQueries()) {
|
||||
DateConf dateInfo = query.getParseInfo().getDateInfo();
|
||||
JSONObject jsonObject = JSON.parseObject(inferredTime);
|
||||
if (jsonObject.containsKey("date")) {
|
||||
dateInfo.setDateMode(DateConf.DateMode.BETWEEN);
|
||||
dateInfo.setStartDate(jsonObject.getString("date"));
|
||||
dateInfo.setEndDate(jsonObject.getString("date"));
|
||||
query.getParseInfo().setDateInfo(dateInfo);
|
||||
} else if (jsonObject.containsKey("start")) {
|
||||
dateInfo.setDateMode(DateConf.DateMode.BETWEEN);
|
||||
dateInfo.setStartDate(jsonObject.getString("start"));
|
||||
dateInfo.setEndDate(jsonObject.getString("end"));
|
||||
query.getParseInfo().setDateInfo(dateInfo);
|
||||
}
|
||||
}
|
||||
}
|
||||
} catch (Exception exception) {
|
||||
log.error("{} parse error,this reason is:{}", LLMTimeEnhancementParse.class.getSimpleName(),
|
||||
(Object) exception.getStackTrace());
|
||||
}
|
||||
|
||||
log.info("{} after queryContext:{},chatContext:{}",
|
||||
LLMTimeEnhancementParse.class.getSimpleName(), queryContext, chatContext);
|
||||
}
|
||||
|
||||
|
||||
}
|
||||
@@ -3,14 +3,14 @@ package com.tencent.supersonic.chat.parser.plugin;
|
||||
import com.google.common.collect.Lists;
|
||||
import com.google.common.collect.Sets;
|
||||
import com.tencent.supersonic.chat.api.component.SemanticParser;
|
||||
import com.tencent.supersonic.chat.api.component.SemanticQuery;
|
||||
import com.tencent.supersonic.chat.api.pojo.ChatContext;
|
||||
import com.tencent.supersonic.chat.api.pojo.QueryContext;
|
||||
import com.tencent.supersonic.chat.api.pojo.SchemaElementMatch;
|
||||
import com.tencent.supersonic.chat.api.pojo.SchemaElementType;
|
||||
import com.tencent.supersonic.chat.api.pojo.SemanticParseInfo;
|
||||
import com.tencent.supersonic.chat.api.pojo.request.QueryFilter;
|
||||
import com.tencent.supersonic.chat.api.pojo.request.QueryReq;
|
||||
import com.tencent.supersonic.chat.api.pojo.SchemaElement;
|
||||
import com.tencent.supersonic.chat.api.pojo.SchemaElementType;
|
||||
import com.tencent.supersonic.chat.plugin.Plugin;
|
||||
import com.tencent.supersonic.chat.plugin.PluginManager;
|
||||
import com.tencent.supersonic.chat.plugin.PluginParseResult;
|
||||
@@ -18,8 +18,10 @@ import com.tencent.supersonic.chat.plugin.PluginRecallResult;
|
||||
import com.tencent.supersonic.chat.query.QueryManager;
|
||||
import com.tencent.supersonic.chat.query.plugin.PluginSemanticQuery;
|
||||
import com.tencent.supersonic.common.pojo.Constants;
|
||||
import com.tencent.supersonic.semantic.api.query.enums.FilterOperatorEnum;
|
||||
import com.tencent.supersonic.common.pojo.ModelCluster;
|
||||
import com.tencent.supersonic.common.pojo.enums.FilterOperatorEnum;
|
||||
import org.springframework.util.CollectionUtils;
|
||||
|
||||
import java.util.HashMap;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
@@ -29,6 +31,12 @@ public abstract class PluginParser implements SemanticParser {
|
||||
|
||||
@Override
|
||||
public void parse(QueryContext queryContext, ChatContext chatContext) {
|
||||
for (SemanticQuery semanticQuery : queryContext.getCandidateQueries()) {
|
||||
if (queryContext.getRequest().getQueryText().length() <= semanticQuery.getParseInfo().getScore()
|
||||
&& (QueryManager.getPluginQueryModes().contains(semanticQuery.getQueryMode()))) {
|
||||
return;
|
||||
}
|
||||
}
|
||||
if (!checkPreCondition(queryContext)) {
|
||||
return;
|
||||
}
|
||||
@@ -51,8 +59,10 @@ public abstract class PluginParser implements SemanticParser {
|
||||
}
|
||||
for (Long modelId : modelIds) {
|
||||
PluginSemanticQuery pluginQuery = QueryManager.createPluginQuery(plugin.getType());
|
||||
SemanticParseInfo semanticParseInfo = buildSemanticParseInfo(modelId, plugin, queryContext.getRequest(),
|
||||
queryContext.getMapInfo().getMatchedElements(modelId), pluginRecallResult.getDistance());
|
||||
SemanticParseInfo semanticParseInfo = buildSemanticParseInfo(modelId, plugin,
|
||||
queryContext.getRequest(),
|
||||
queryContext.getModelClusterMapInfo().getMatchedElements(modelId),
|
||||
pluginRecallResult.getDistance());
|
||||
semanticParseInfo.setQueryMode(pluginQuery.getQueryMode());
|
||||
semanticParseInfo.setScore(pluginRecallResult.getScore());
|
||||
pluginQuery.setParseInfo(semanticParseInfo);
|
||||
@@ -72,12 +82,9 @@ public abstract class PluginParser implements SemanticParser {
|
||||
if (schemaElementMatches == null) {
|
||||
schemaElementMatches = Lists.newArrayList();
|
||||
}
|
||||
SchemaElement model = new SchemaElement();
|
||||
model.setModel(modelId);
|
||||
model.setId(modelId);
|
||||
SemanticParseInfo semanticParseInfo = new SemanticParseInfo();
|
||||
semanticParseInfo.setElementMatches(schemaElementMatches);
|
||||
semanticParseInfo.setModel(model);
|
||||
semanticParseInfo.setModel(ModelCluster.build(Sets.newHashSet(modelId)));
|
||||
Map<String, Object> properties = new HashMap<>();
|
||||
PluginParseResult pluginParseResult = new PluginParseResult();
|
||||
pluginParseResult.setPlugin(plugin);
|
||||
|
||||
@@ -1,13 +1,16 @@
|
||||
package com.tencent.supersonic.chat.parser.plugin.embedding;
|
||||
|
||||
import com.google.common.collect.Lists;
|
||||
import com.tencent.supersonic.chat.api.component.SemanticQuery;
|
||||
import com.tencent.supersonic.chat.api.pojo.QueryContext;
|
||||
import com.tencent.supersonic.chat.parser.HttpLLMInterpreter;
|
||||
import com.tencent.supersonic.chat.parser.LLMInterpreter;
|
||||
import com.tencent.supersonic.chat.parser.ParseMode;
|
||||
import com.tencent.supersonic.chat.parser.plugin.PluginParser;
|
||||
import com.tencent.supersonic.chat.plugin.Plugin;
|
||||
import com.tencent.supersonic.chat.plugin.PluginManager;
|
||||
import com.tencent.supersonic.chat.plugin.PluginRecallResult;
|
||||
import com.tencent.supersonic.chat.utils.ComponentFactory;
|
||||
import com.tencent.supersonic.common.config.EmbeddingConfig;
|
||||
import com.tencent.supersonic.common.util.ContextUtils;
|
||||
import java.util.Comparator;
|
||||
import java.util.List;
|
||||
@@ -22,23 +25,16 @@ import org.springframework.util.CollectionUtils;
|
||||
@Slf4j
|
||||
public class EmbeddingBasedParser extends PluginParser {
|
||||
|
||||
protected LLMInterpreter llmInterpreter = ComponentFactory.getLLMInterpreter();
|
||||
|
||||
@Override
|
||||
public boolean checkPreCondition(QueryContext queryContext) {
|
||||
EmbeddingConfig embeddingConfig = ContextUtils.getBean(EmbeddingConfig.class);
|
||||
if (StringUtils.isBlank(embeddingConfig.getUrl())) {
|
||||
if (StringUtils.isBlank(embeddingConfig.getUrl()) && llmInterpreter instanceof HttpLLMInterpreter) {
|
||||
return false;
|
||||
}
|
||||
List<Plugin> plugins = getPluginList(queryContext);
|
||||
if (CollectionUtils.isEmpty(plugins)) {
|
||||
return false;
|
||||
}
|
||||
List<SemanticQuery> semanticQueries = queryContext.getCandidateQueries();
|
||||
for (SemanticQuery semanticQuery : semanticQueries) {
|
||||
if (queryContext.getRequest().getQueryText().length() <= semanticQuery.getParseInfo().getScore()) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
return true;
|
||||
return !CollectionUtils.isEmpty(plugins);
|
||||
}
|
||||
|
||||
@Override
|
||||
|
||||
@@ -1,44 +1,39 @@
|
||||
package com.tencent.supersonic.chat.parser.plugin.function;
|
||||
|
||||
import com.alibaba.fastjson.JSON;
|
||||
import com.tencent.supersonic.chat.api.pojo.QueryContext;
|
||||
import com.tencent.supersonic.chat.parser.HttpLLMInterpreter;
|
||||
import com.tencent.supersonic.chat.parser.LLMInterpreter;
|
||||
import com.tencent.supersonic.chat.parser.ParseMode;
|
||||
import com.tencent.supersonic.chat.parser.SatisfactionChecker;
|
||||
import com.tencent.supersonic.chat.parser.plugin.PluginParser;
|
||||
import com.tencent.supersonic.chat.plugin.Plugin;
|
||||
import com.tencent.supersonic.chat.plugin.PluginManager;
|
||||
import com.tencent.supersonic.chat.plugin.PluginParseConfig;
|
||||
import com.tencent.supersonic.chat.plugin.PluginRecallResult;
|
||||
import com.tencent.supersonic.chat.query.llm.dsl.DslQuery;
|
||||
import com.tencent.supersonic.chat.query.llm.s2sql.S2SQLQuery;
|
||||
import com.tencent.supersonic.chat.service.PluginService;
|
||||
import com.tencent.supersonic.chat.utils.ComponentFactory;
|
||||
import com.tencent.supersonic.common.util.ContextUtils;
|
||||
import java.net.URI;
|
||||
import com.tencent.supersonic.common.util.JsonUtil;
|
||||
import java.util.List;
|
||||
import java.util.Objects;
|
||||
import java.util.Optional;
|
||||
import java.util.Set;
|
||||
import java.util.Objects;
|
||||
import java.util.stream.Collectors;
|
||||
import com.tencent.supersonic.common.util.JsonUtil;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.apache.commons.lang3.StringUtils;
|
||||
import org.apache.commons.lang3.tuple.Pair;
|
||||
import org.springframework.http.HttpEntity;
|
||||
import org.springframework.http.HttpHeaders;
|
||||
import org.springframework.http.HttpMethod;
|
||||
import org.springframework.http.MediaType;
|
||||
import org.springframework.http.ResponseEntity;
|
||||
import org.springframework.util.CollectionUtils;
|
||||
import org.springframework.web.client.RestTemplate;
|
||||
import org.springframework.web.util.UriComponentsBuilder;
|
||||
|
||||
@Slf4j
|
||||
public class FunctionBasedParser extends PluginParser {
|
||||
|
||||
protected LLMInterpreter llmInterpreter = ComponentFactory.getLLMInterpreter();
|
||||
|
||||
@Override
|
||||
public boolean checkPreCondition(QueryContext queryContext) {
|
||||
FunctionCallConfig functionCallConfig = ContextUtils.getBean(FunctionCallConfig.class);
|
||||
String functionUrl = functionCallConfig.getUrl();
|
||||
if (StringUtils.isBlank(functionUrl) || SatisfactionChecker.check(queryContext)) {
|
||||
if (StringUtils.isBlank(functionUrl) && llmInterpreter instanceof HttpLLMInterpreter) {
|
||||
log.info("functionUrl:{}, skip function parser, queryText:{}", functionUrl,
|
||||
queryContext.getRequest().getQueryText());
|
||||
return false;
|
||||
@@ -89,7 +84,7 @@ public class FunctionBasedParser extends PluginParser {
|
||||
FunctionReq functionReq = FunctionReq.builder()
|
||||
.queryText(queryContext.getRequest().getQueryText())
|
||||
.pluginConfigs(pluginToFunctionCall).build();
|
||||
functionResp = requestFunction(functionReq);
|
||||
functionResp = llmInterpreter.requestFunction(functionReq);
|
||||
}
|
||||
return functionResp;
|
||||
}
|
||||
@@ -102,7 +97,7 @@ public class FunctionBasedParser extends PluginParser {
|
||||
log.info("user decide Model:{}", modelId);
|
||||
List<Plugin> plugins = getPluginList(queryContext);
|
||||
List<PluginParseConfig> functionDOList = plugins.stream().filter(plugin -> {
|
||||
if (DslQuery.QUERY_MODE.equalsIgnoreCase(plugin.getType())) {
|
||||
if (S2SQLQuery.QUERY_MODE.equalsIgnoreCase(plugin.getType())) {
|
||||
return false;
|
||||
}
|
||||
if (plugin.getParseModeConfig() == null) {
|
||||
@@ -132,25 +127,4 @@ public class FunctionBasedParser extends PluginParser {
|
||||
return functionDOList;
|
||||
}
|
||||
|
||||
public FunctionResp requestFunction(FunctionReq functionReq) {
|
||||
FunctionCallConfig functionCallInfoConfig = ContextUtils.getBean(FunctionCallConfig.class);
|
||||
String url = functionCallInfoConfig.getUrl() + functionCallInfoConfig.getPluginSelectPath();
|
||||
HttpHeaders headers = new HttpHeaders();
|
||||
long startTime = System.currentTimeMillis();
|
||||
headers.setContentType(MediaType.APPLICATION_JSON);
|
||||
HttpEntity<String> entity = new HttpEntity<>(JSON.toJSONString(functionReq), headers);
|
||||
URI requestUrl = UriComponentsBuilder.fromHttpUrl(url).build().encode().toUri();
|
||||
RestTemplate restTemplate = ContextUtils.getBean(RestTemplate.class);
|
||||
try {
|
||||
log.info("requestFunction functionReq:{}", JsonUtil.toString(functionReq));
|
||||
ResponseEntity<FunctionResp> responseEntity = restTemplate.exchange(requestUrl, HttpMethod.POST, entity,
|
||||
FunctionResp.class);
|
||||
log.info("requestFunction responseEntity:{},cost:{}", responseEntity,
|
||||
System.currentTimeMillis() - startTime);
|
||||
return responseEntity.getBody();
|
||||
} catch (Exception e) {
|
||||
log.error("requestFunction error", e);
|
||||
}
|
||||
return null;
|
||||
}
|
||||
}
|
||||
|
||||
@@ -9,10 +9,14 @@ import com.tencent.supersonic.chat.api.component.SemanticParser;
|
||||
import com.tencent.supersonic.chat.api.component.SemanticQuery;
|
||||
import com.tencent.supersonic.chat.api.pojo.ChatContext;
|
||||
import com.tencent.supersonic.chat.api.pojo.QueryContext;
|
||||
import com.tencent.supersonic.chat.query.QueryManager;
|
||||
import com.tencent.supersonic.chat.service.AgentService;
|
||||
import com.tencent.supersonic.common.pojo.QueryType;
|
||||
import com.tencent.supersonic.common.util.ContextUtils;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.apache.commons.collections.CollectionUtils;
|
||||
|
||||
import java.util.HashSet;
|
||||
import java.util.List;
|
||||
import java.util.stream.Collectors;
|
||||
|
||||
@@ -40,13 +44,26 @@ public class AgentCheckParser implements SemanticParser {
|
||||
queries.stream().map(SemanticQuery::getQueryMode).collect(Collectors.toList()));
|
||||
queries.removeIf(query -> {
|
||||
for (RuleQueryTool tool : queryTools) {
|
||||
if (!tool.getQueryModes().contains(query.getQueryMode())) {
|
||||
if (CollectionUtils.isNotEmpty(tool.getQueryModes())
|
||||
&& !tool.getQueryModes().contains(query.getQueryMode())) {
|
||||
return true;
|
||||
}
|
||||
if (CollectionUtils.isNotEmpty(tool.getQueryTypes())) {
|
||||
if (QueryManager.isTagQuery(query.getQueryMode())) {
|
||||
return !tool.getQueryTypes().contains(QueryType.TAG.name());
|
||||
}
|
||||
if (QueryManager.isMetricQuery(query.getQueryMode())) {
|
||||
return !tool.getQueryTypes().contains(QueryType.METRIC.name());
|
||||
}
|
||||
}
|
||||
if (CollectionUtils.isEmpty(tool.getModelIds())) {
|
||||
return true;
|
||||
}
|
||||
if (tool.isContainsAllModel() || tool.getModelIds().contains(query.getParseInfo().getModelId())) {
|
||||
if (tool.isContainsAllModel()) {
|
||||
return false;
|
||||
}
|
||||
if (new HashSet<>(tool.getModelIds())
|
||||
.containsAll(query.getParseInfo().getModel().getModelIds())) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
@@ -27,6 +27,12 @@ import lombok.AllArgsConstructor;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.apache.commons.lang3.StringUtils;
|
||||
|
||||
/**
|
||||
* AggregateTypeParser extracts aggregation type specified in the user query
|
||||
* based on keyword matching.
|
||||
* Currently, it supports 7 types of aggregation: max, min, sum, avg, topN,
|
||||
* distinct count, count.
|
||||
*/
|
||||
@Slf4j
|
||||
public class AggregateTypeParser implements SemanticParser {
|
||||
|
||||
|
||||
@@ -1,33 +1,47 @@
|
||||
package com.tencent.supersonic.chat.parser.rule;
|
||||
|
||||
import static com.tencent.supersonic.chat.api.pojo.SchemaElementType.DIMENSION;
|
||||
import static com.tencent.supersonic.chat.api.pojo.SchemaElementType.ENTITY;
|
||||
import static com.tencent.supersonic.chat.api.pojo.SchemaElementType.ID;
|
||||
import static com.tencent.supersonic.chat.api.pojo.SchemaElementType.METRIC;
|
||||
import static com.tencent.supersonic.chat.api.pojo.SchemaElementType.MODEL;
|
||||
import static com.tencent.supersonic.chat.api.pojo.SchemaElementType.VALUE;
|
||||
|
||||
import com.tencent.supersonic.chat.api.component.SemanticParser;
|
||||
import com.tencent.supersonic.chat.api.component.SemanticQuery;
|
||||
import com.tencent.supersonic.chat.api.pojo.ChatContext;
|
||||
import com.tencent.supersonic.chat.api.pojo.QueryContext;
|
||||
import com.tencent.supersonic.chat.api.pojo.SchemaElementMatch;
|
||||
import com.tencent.supersonic.chat.api.pojo.SchemaElementType;
|
||||
import com.tencent.supersonic.chat.api.pojo.SemanticSchema;
|
||||
import com.tencent.supersonic.chat.query.QueryManager;
|
||||
import com.tencent.supersonic.chat.query.rule.RuleSemanticQuery;
|
||||
import com.tencent.supersonic.chat.query.rule.metric.MetricEntityQuery;
|
||||
import com.tencent.supersonic.chat.query.rule.metric.MetricModelQuery;
|
||||
import com.tencent.supersonic.chat.query.rule.metric.MetricSemanticQuery;
|
||||
import com.tencent.supersonic.chat.query.rule.metric.MetricTagQuery;
|
||||
import com.tencent.supersonic.chat.service.SemanticService;
|
||||
import com.tencent.supersonic.chat.utils.ModelClusterBuilder;
|
||||
import com.tencent.supersonic.common.pojo.ModelCluster;
|
||||
import com.tencent.supersonic.common.util.ContextUtils;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.apache.commons.lang3.StringUtils;
|
||||
|
||||
import java.util.AbstractMap;
|
||||
import java.util.ArrayList;
|
||||
import java.util.Arrays;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
import java.util.Objects;
|
||||
import java.util.Set;
|
||||
import java.util.stream.Collectors;
|
||||
import java.util.stream.Stream;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
|
||||
import static com.tencent.supersonic.chat.api.pojo.SchemaElementType.DIMENSION;
|
||||
import static com.tencent.supersonic.chat.api.pojo.SchemaElementType.ENTITY;
|
||||
import static com.tencent.supersonic.chat.api.pojo.SchemaElementType.ID;
|
||||
import static com.tencent.supersonic.chat.api.pojo.SchemaElementType.METRIC;
|
||||
import static com.tencent.supersonic.chat.api.pojo.SchemaElementType.MODEL;
|
||||
import static com.tencent.supersonic.chat.api.pojo.SchemaElementType.TAG;
|
||||
import static com.tencent.supersonic.chat.api.pojo.SchemaElementType.VALUE;
|
||||
|
||||
/**
|
||||
* ContextInheritParser tries to inherit certain schema elements from context
|
||||
* so that in multi-turn conversations users don't need to mention some keyword
|
||||
* repeatedly.
|
||||
*/
|
||||
@Slf4j
|
||||
public class ContextInheritParser implements SemanticParser {
|
||||
|
||||
@@ -36,20 +50,22 @@ public class ContextInheritParser implements SemanticParser {
|
||||
new AbstractMap.SimpleEntry<>(DIMENSION, Arrays.asList(DIMENSION, VALUE)),
|
||||
new AbstractMap.SimpleEntry<>(VALUE, Arrays.asList(VALUE, DIMENSION)),
|
||||
new AbstractMap.SimpleEntry<>(ENTITY, Arrays.asList(ENTITY)),
|
||||
new AbstractMap.SimpleEntry<>(TAG, Arrays.asList(TAG)),
|
||||
new AbstractMap.SimpleEntry<>(MODEL, Arrays.asList(MODEL)),
|
||||
new AbstractMap.SimpleEntry<>(ID, Arrays.asList(ID))
|
||||
).collect(Collectors.toMap(Map.Entry::getKey, Map.Entry::getValue));
|
||||
|
||||
@Override
|
||||
public void parse(QueryContext queryContext, ChatContext chatContext) {
|
||||
if (!shouldInherit(queryContext, chatContext)) {
|
||||
if (!shouldInherit(queryContext)) {
|
||||
return;
|
||||
}
|
||||
|
||||
Long modelId = chatContext.getParseInfo().getModelId();
|
||||
List<SchemaElementMatch> elementMatches = queryContext.getMapInfo()
|
||||
.getMatchedElements(modelId);
|
||||
|
||||
ModelCluster modelCluster = getMatchedModelCluster(queryContext, chatContext);
|
||||
if (modelCluster == null) {
|
||||
return;
|
||||
}
|
||||
List<SchemaElementMatch> elementMatches = queryContext.getModelClusterMapInfo()
|
||||
.getMatchedElements(modelCluster.getKey());
|
||||
List<SchemaElementMatch> matchesToInherit = new ArrayList<>();
|
||||
for (SchemaElementMatch match : chatContext.getParseInfo().getElementMatches()) {
|
||||
SchemaElementType matchType = match.getElement().getType();
|
||||
@@ -64,18 +80,18 @@ public class ContextInheritParser implements SemanticParser {
|
||||
|
||||
List<RuleSemanticQuery> queries = RuleSemanticQuery.resolve(elementMatches, queryContext);
|
||||
for (RuleSemanticQuery query : queries) {
|
||||
query.fillParseInfo(modelId, queryContext, chatContext);
|
||||
if (existSameQuery(query.getParseInfo().getModelId(), query.getQueryMode(), queryContext)) {
|
||||
query.fillParseInfo(chatContext);
|
||||
if (existSameQuery(query.getParseInfo().getModelClusterKey(), query.getQueryMode(), queryContext)) {
|
||||
continue;
|
||||
}
|
||||
queryContext.getCandidateQueries().add(query);
|
||||
}
|
||||
}
|
||||
|
||||
private boolean existSameQuery(Long modelId, String queryMode, QueryContext queryContext) {
|
||||
private boolean existSameQuery(String modelClusterKey, String queryMode, QueryContext queryContext) {
|
||||
for (SemanticQuery semanticQuery : queryContext.getCandidateQueries()) {
|
||||
if (semanticQuery.getQueryMode().equals(queryMode)
|
||||
&& semanticQuery.getParseInfo().getModelId().equals(modelId)) {
|
||||
&& semanticQuery.getParseInfo().getModelClusterKey().equals(modelClusterKey)) {
|
||||
return true;
|
||||
}
|
||||
}
|
||||
@@ -89,30 +105,41 @@ public class ContextInheritParser implements SemanticParser {
|
||||
return matches.stream().anyMatch(m -> {
|
||||
SchemaElementType type = m.getElement().getType();
|
||||
if (Objects.nonNull(ruleQuery) && ruleQuery instanceof MetricSemanticQuery
|
||||
&& !(ruleQuery instanceof MetricEntityQuery)) {
|
||||
&& !(ruleQuery instanceof MetricTagQuery)) {
|
||||
return types.contains(type);
|
||||
}
|
||||
return type.equals(matchType);
|
||||
});
|
||||
}
|
||||
|
||||
protected boolean shouldInherit(QueryContext queryContext, ChatContext chatContext) {
|
||||
Long contextModelId = chatContext.getParseInfo().getModelId();
|
||||
// if map info doesn't contain the same Model of the context,
|
||||
// no inheritance could be done
|
||||
if (queryContext.getMapInfo().getMatchedElements(contextModelId) == null) {
|
||||
return false;
|
||||
}
|
||||
|
||||
protected boolean shouldInherit(QueryContext queryContext) {
|
||||
// if candidates only have MetricModel mode, count in context
|
||||
List<SemanticQuery> metricModelQueries = queryContext.getCandidateQueries().stream()
|
||||
.filter(query -> query instanceof MetricModelQuery).collect(
|
||||
Collectors.toList());
|
||||
if (metricModelQueries.size() == queryContext.getCandidateQueries().size()) {
|
||||
return true;
|
||||
} else {
|
||||
return queryContext.getCandidateQueries().size() == 0;
|
||||
return metricModelQueries.size() == queryContext.getCandidateQueries().size();
|
||||
}
|
||||
|
||||
protected ModelCluster getMatchedModelCluster(QueryContext queryContext, ChatContext chatContext) {
|
||||
String contextModelClusterKey = chatContext.getParseInfo().getModelClusterKey();
|
||||
if (StringUtils.isBlank(contextModelClusterKey)) {
|
||||
return null;
|
||||
}
|
||||
SemanticService semanticService = ContextUtils.getBean(SemanticService.class);
|
||||
SemanticSchema semanticSchema = semanticService.getSemanticSchema();
|
||||
List<ModelCluster> allModelClusters = ModelClusterBuilder.buildModelClusters(semanticSchema);
|
||||
Set<String> queryModelClusters = queryContext.getModelClusterMapInfo().getMatchedModelClusters();
|
||||
ModelCluster contextModelCluster = ModelCluster.build(contextModelClusterKey);
|
||||
for (String cluster : queryModelClusters) {
|
||||
ModelCluster queryModelCluster = ModelCluster.build(cluster);
|
||||
for (ModelCluster modelCluster : allModelClusters) {
|
||||
if (modelCluster.getModelIds().containsAll(contextModelCluster.getModelIds())
|
||||
&& modelCluster.getModelIds().containsAll(queryModelCluster.getModelIds())) {
|
||||
return queryModelCluster;
|
||||
}
|
||||
}
|
||||
}
|
||||
return null;
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
@@ -1,42 +1,34 @@
|
||||
package com.tencent.supersonic.chat.parser.rule;
|
||||
|
||||
import com.tencent.supersonic.chat.api.component.SemanticParser;
|
||||
import com.tencent.supersonic.chat.api.pojo.QueryContext;
|
||||
import com.tencent.supersonic.chat.api.pojo.ChatContext;
|
||||
import com.tencent.supersonic.chat.api.pojo.SchemaMapInfo;
|
||||
import com.tencent.supersonic.chat.api.pojo.QueryContext;
|
||||
import com.tencent.supersonic.chat.api.pojo.SchemaElementMatch;
|
||||
import com.tencent.supersonic.chat.api.pojo.SchemaModelClusterMapInfo;
|
||||
import com.tencent.supersonic.chat.query.rule.RuleSemanticQuery;
|
||||
|
||||
import java.util.HashMap;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.apache.commons.collections.CollectionUtils;
|
||||
|
||||
import java.util.List;
|
||||
|
||||
/**
|
||||
* QueryModeParser resolves a specific query mode according to co-appearance
|
||||
* of certain schema element types.
|
||||
*/
|
||||
@Slf4j
|
||||
public class QueryModeParser implements SemanticParser {
|
||||
|
||||
@Override
|
||||
public void parse(QueryContext queryContext, ChatContext chatContext) {
|
||||
SchemaMapInfo mapInfo = queryContext.getMapInfo();
|
||||
// iterate all schemaElementMatches to resolve semantic query
|
||||
for (Long modelId : mapInfo.getMatchedModels()) {
|
||||
List<SchemaElementMatch> elementMatches = mapInfo.getMatchedElements(modelId);
|
||||
SchemaModelClusterMapInfo modelClusterMapInfo = queryContext.getModelClusterMapInfo();
|
||||
// iterate all schemaElementMatches to resolve query mode
|
||||
for (String modelClusterKey : modelClusterMapInfo.getMatchedModelClusters()) {
|
||||
List<SchemaElementMatch> elementMatches = modelClusterMapInfo.getMatchedElements(modelClusterKey);
|
||||
List<RuleSemanticQuery> queries = RuleSemanticQuery.resolve(elementMatches, queryContext);
|
||||
for (RuleSemanticQuery query : queries) {
|
||||
query.fillParseInfo(modelId, queryContext, chatContext);
|
||||
query.fillParseInfo(chatContext);
|
||||
queryContext.getCandidateQueries().add(query);
|
||||
}
|
||||
}
|
||||
// if modelElementMatches id empty,so remove it.
|
||||
Map<Long, List<SchemaElementMatch>> filterModelElementMatches = new HashMap<>();
|
||||
for (Long modelId : queryContext.getMapInfo().getModelElementMatches().keySet()) {
|
||||
if (!CollectionUtils.isEmpty(queryContext.getMapInfo().getModelElementMatches().get(modelId))) {
|
||||
filterModelElementMatches.put(modelId, queryContext.getMapInfo().getModelElementMatches().get(modelId));
|
||||
}
|
||||
}
|
||||
queryContext.getMapInfo().setModelElementMatches(filterModelElementMatches);
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
@@ -0,0 +1,30 @@
|
||||
package com.tencent.supersonic.chat.parser.rule;
|
||||
|
||||
import com.tencent.supersonic.chat.api.component.SemanticParser;
|
||||
import com.tencent.supersonic.chat.api.pojo.ChatContext;
|
||||
import com.tencent.supersonic.chat.api.pojo.QueryContext;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
|
||||
import java.util.Arrays;
|
||||
import java.util.List;
|
||||
|
||||
/**
|
||||
* RuleBasedParser acts as a container that incorporates a group of
|
||||
* rule-based semantic parsers.
|
||||
*/
|
||||
@Slf4j
|
||||
public class RuleBasedParser implements SemanticParser {
|
||||
|
||||
private static List<SemanticParser> ruleParsers = Arrays.asList(
|
||||
new QueryModeParser(),
|
||||
new ContextInheritParser(),
|
||||
new AgentCheckParser(),
|
||||
new TimeRangeParser(),
|
||||
new AggregateTypeParser()
|
||||
);
|
||||
|
||||
@Override
|
||||
public void parse(QueryContext queryContext, ChatContext chatContext) {
|
||||
ruleParsers.stream().forEach(p -> p.parse(queryContext, chatContext));
|
||||
}
|
||||
}
|
||||
@@ -24,6 +24,13 @@ import com.xkzhangsan.time.nlp.TimeNLPUtil;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.apache.logging.log4j.util.Strings;
|
||||
|
||||
/**
|
||||
* TimeRangeParser extracts time range specified in the user query
|
||||
* based on keyword matching.
|
||||
* Currently, it supports two kinds of expression:
|
||||
* 1. Recent unit: 近N天/周/月/年、过去N天/周/月/年
|
||||
* 2. Concrete date: 2023年11月15日、20231115
|
||||
*/
|
||||
@Slf4j
|
||||
public class TimeRangeParser implements SemanticParser {
|
||||
|
||||
|
||||
@@ -25,6 +25,8 @@ public class ChatConfigDO {
|
||||
|
||||
private Integer status;
|
||||
|
||||
private String llmExamples;
|
||||
|
||||
/**
|
||||
* record info
|
||||
*/
|
||||
|
||||
@@ -3,7 +3,9 @@ package com.tencent.supersonic.chat.persistence.dataobject;
|
||||
public enum CostType {
|
||||
MAPPER(1, "mapper"),
|
||||
PARSER(2, "parser"),
|
||||
QUERY(3, "query");
|
||||
QUERY(3, "query"),
|
||||
PARSERRESPONDER(4, "responder"),
|
||||
POSTPROCESSOR(5, "postprocessor");
|
||||
|
||||
private Integer type;
|
||||
private String name;
|
||||
|
||||
@@ -12,6 +12,10 @@ public interface ChatParseMapper {
|
||||
|
||||
boolean batchSaveParseInfo(@Param("list") List<ChatParseDO> list);
|
||||
|
||||
ChatParseDO getParseInfo(Long questionId, String userName, int parseId);
|
||||
boolean updateParseInfo(ChatParseDO chatParseDO);
|
||||
|
||||
ChatParseDO getParseInfo(Long questionId, int parseId);
|
||||
|
||||
List<ChatParseDO> getParseInfoList(List<Long> questionIds);
|
||||
|
||||
}
|
||||
|
||||
@@ -7,6 +7,6 @@ import java.util.List;
|
||||
@Mapper
|
||||
public interface ShowCaseCustomMapper {
|
||||
|
||||
List<ChatQueryDO> queryShowCase(int start, int limit, int agentId);
|
||||
List<ChatQueryDO> queryShowCase(int start, int limit, int agentId, String userName);
|
||||
|
||||
}
|
||||
|
||||
@@ -21,18 +21,21 @@ public interface ChatQueryRepository {
|
||||
|
||||
void createChatQuery(QueryResult queryResult, ChatContext chatCtx);
|
||||
|
||||
void updateChatParseInfo(List<ChatParseDO> chatParseDOS);
|
||||
|
||||
ChatQueryDO getLastChatQuery(long chatId);
|
||||
|
||||
int updateChatQuery(ChatQueryDO chatQueryDO);
|
||||
|
||||
Long createChatParse(ParseResp parseResult, ChatContext chatCtx, QueryReq queryReq);
|
||||
|
||||
Boolean batchSaveParseInfo(ChatContext chatCtx, QueryReq queryReq,
|
||||
List<ChatParseDO> batchSaveParseInfo(ChatContext chatCtx, QueryReq queryReq,
|
||||
ParseResp parseResult,
|
||||
List<SemanticParseInfo> candidateParses,
|
||||
List<SemanticParseInfo> selectedParses);
|
||||
List<SemanticParseInfo> candidateParses);
|
||||
|
||||
public ChatParseDO getParseInfo(Long questionId, String userName, int parseId);
|
||||
public ChatParseDO getParseInfo(Long questionId, int parseId);
|
||||
|
||||
List<ChatParseDO> getParseInfoList(List<Long> questionIds);
|
||||
|
||||
Boolean deleteChatQuery(Long questionId);
|
||||
}
|
||||
|
||||
@@ -7,5 +7,5 @@ import java.util.List;
|
||||
|
||||
public interface StatisticsRepository {
|
||||
|
||||
boolean batchSaveStatistics(List<StatisticsDO> list);
|
||||
void batchSaveStatistics(List<StatisticsDO> list);
|
||||
}
|
||||
|
||||
@@ -4,33 +4,31 @@ import com.github.pagehelper.PageHelper;
|
||||
import com.github.pagehelper.PageInfo;
|
||||
import com.tencent.supersonic.chat.api.pojo.ChatContext;
|
||||
import com.tencent.supersonic.chat.api.pojo.SemanticParseInfo;
|
||||
import com.tencent.supersonic.chat.api.pojo.request.PageQueryInfoReq;
|
||||
import com.tencent.supersonic.chat.api.pojo.request.QueryReq;
|
||||
import com.tencent.supersonic.chat.api.pojo.response.ParseResp;
|
||||
import com.tencent.supersonic.chat.api.pojo.response.QueryResp;
|
||||
import com.tencent.supersonic.chat.api.pojo.response.QueryResult;
|
||||
import com.tencent.supersonic.chat.persistence.dataobject.ChatParseDO;
|
||||
import com.tencent.supersonic.chat.persistence.dataobject.ChatQueryDO;
|
||||
import com.tencent.supersonic.chat.persistence.dataobject.ChatQueryDOExample;
|
||||
import com.tencent.supersonic.chat.persistence.dataobject.ChatQueryDOExample.Criteria;
|
||||
import com.tencent.supersonic.chat.api.pojo.response.QueryResp;
|
||||
import com.tencent.supersonic.chat.api.pojo.request.PageQueryInfoReq;
|
||||
import com.tencent.supersonic.chat.persistence.mapper.ChatParseMapper;
|
||||
import com.tencent.supersonic.chat.persistence.mapper.ChatQueryDOMapper;
|
||||
import com.tencent.supersonic.chat.persistence.mapper.custom.ShowCaseCustomMapper;
|
||||
import com.tencent.supersonic.chat.persistence.repository.ChatQueryRepository;
|
||||
import com.tencent.supersonic.common.util.JsonUtil;
|
||||
import com.tencent.supersonic.common.util.PageUtils;
|
||||
|
||||
import java.util.ArrayList;
|
||||
import java.util.Comparator;
|
||||
import java.util.List;
|
||||
import java.util.stream.Collectors;
|
||||
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.apache.commons.lang3.StringUtils;
|
||||
import org.springframework.beans.BeanUtils;
|
||||
import org.springframework.context.annotation.Primary;
|
||||
import org.springframework.stereotype.Repository;
|
||||
import org.springframework.util.CollectionUtils;
|
||||
import java.util.ArrayList;
|
||||
import java.util.Comparator;
|
||||
import java.util.List;
|
||||
import java.util.stream.Collectors;
|
||||
|
||||
@Repository
|
||||
@Primary
|
||||
@@ -78,9 +76,10 @@ public class ChatQueryRepositoryImpl implements ChatQueryRepository {
|
||||
}
|
||||
|
||||
@Override
|
||||
public List<QueryResp> queryShowCase(PageQueryInfoReq pageQueryInfoCommend, int agentId) {
|
||||
return showCaseCustomMapper.queryShowCase(pageQueryInfoCommend.getCurrent(),
|
||||
pageQueryInfoCommend.getPageSize(), agentId).stream().map(this::convertTo)
|
||||
public List<QueryResp> queryShowCase(PageQueryInfoReq pageQueryInfoReq, int agentId) {
|
||||
return showCaseCustomMapper.queryShowCase(pageQueryInfoReq.getLimitStart(),
|
||||
pageQueryInfoReq.getPageSize(), agentId, pageQueryInfoReq.getUserName())
|
||||
.stream().map(this::convertTo)
|
||||
.collect(Collectors.toList());
|
||||
}
|
||||
|
||||
@@ -129,30 +128,37 @@ public class ChatQueryRepositoryImpl implements ChatQueryRepository {
|
||||
return queryId;
|
||||
}
|
||||
|
||||
public Boolean batchSaveParseInfo(ChatContext chatCtx, QueryReq queryReq,
|
||||
@Override
|
||||
public List<ChatParseDO> batchSaveParseInfo(ChatContext chatCtx, QueryReq queryReq,
|
||||
ParseResp parseResult,
|
||||
List<SemanticParseInfo> candidateParses,
|
||||
List<SemanticParseInfo> selectedParses) {
|
||||
List<SemanticParseInfo> candidateParses) {
|
||||
Long queryId = createChatParse(parseResult, chatCtx, queryReq);
|
||||
List<ChatParseDO> chatParseDOList = new ArrayList<>();
|
||||
log.info("candidateParses size:{},selectedParses size:{}", candidateParses.size(), selectedParses.size());
|
||||
getChatParseDO(chatCtx, queryReq, queryId, 0, 1, candidateParses, chatParseDOList);
|
||||
getChatParseDO(chatCtx, queryReq, queryId, candidateParses.size(), 0, selectedParses, chatParseDOList);
|
||||
Boolean save = chatParseMapper.batchSaveParseInfo(chatParseDOList);
|
||||
return save;
|
||||
getChatParseDO(chatCtx, queryReq, queryId, candidateParses, chatParseDOList);
|
||||
chatParseMapper.batchSaveParseInfo(chatParseDOList);
|
||||
return chatParseDOList;
|
||||
}
|
||||
|
||||
public void getChatParseDO(ChatContext chatCtx, QueryReq queryReq, Long queryId, int base, int isCandidate,
|
||||
@Override
|
||||
public void updateChatParseInfo(List<ChatParseDO> chatParseDOS) {
|
||||
for (ChatParseDO chatParseDO : chatParseDOS) {
|
||||
chatParseMapper.updateParseInfo(chatParseDO);
|
||||
}
|
||||
}
|
||||
|
||||
public void getChatParseDO(ChatContext chatCtx, QueryReq queryReq, Long queryId,
|
||||
List<SemanticParseInfo> parses, List<ChatParseDO> chatParseDOList) {
|
||||
for (int i = 0; i < parses.size(); i++) {
|
||||
ChatParseDO chatParseDO = new ChatParseDO();
|
||||
parses.get(i).setId(base + i + 1);
|
||||
chatParseDO.setChatId(Long.valueOf(chatCtx.getChatId()));
|
||||
chatParseDO.setQuestionId(queryId);
|
||||
chatParseDO.setQueryText(queryReq.getQueryText());
|
||||
chatParseDO.setParseInfo(JsonUtil.toString(parses.get(i)));
|
||||
chatParseDO.setIsCandidate(isCandidate);
|
||||
chatParseDO.setParseId(base + i + 1);
|
||||
chatParseDO.setIsCandidate(1);
|
||||
if (i == 0) {
|
||||
chatParseDO.setIsCandidate(0);
|
||||
}
|
||||
chatParseDO.setParseId(parses.get(i).getId());
|
||||
chatParseDO.setCreateTime(new java.util.Date());
|
||||
chatParseDO.setUserName(queryReq.getUser().getName());
|
||||
chatParseDOList.add(chatParseDO);
|
||||
@@ -179,8 +185,14 @@ public class ChatQueryRepositoryImpl implements ChatQueryRepository {
|
||||
return chatQueryDOMapper.updateByPrimaryKeyWithBLOBs(chatQueryDO);
|
||||
}
|
||||
|
||||
public ChatParseDO getParseInfo(Long questionId, String userName, int parseId) {
|
||||
return chatParseMapper.getParseInfo(questionId, userName, parseId);
|
||||
|
||||
public ChatParseDO getParseInfo(Long questionId, int parseId) {
|
||||
return chatParseMapper.getParseInfo(questionId, parseId);
|
||||
}
|
||||
|
||||
@Override
|
||||
public List<ChatParseDO> getParseInfoList(List<Long> questionIds) {
|
||||
return chatParseMapper.getParseInfoList(questionIds);
|
||||
}
|
||||
|
||||
@Override
|
||||
|
||||
@@ -20,10 +20,9 @@ public class StatisticsRepositoryImpl implements StatisticsRepository {
|
||||
this.statisticsMapper = statisticsMapper;
|
||||
}
|
||||
|
||||
public boolean batchSaveStatistics(List<StatisticsDO> list) {
|
||||
return statisticsMapper.batchSaveStatistics(list);
|
||||
public void batchSaveStatistics(List<StatisticsDO> list) {
|
||||
statisticsMapper.batchSaveStatistics(list);
|
||||
}
|
||||
|
||||
;
|
||||
|
||||
}
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user