mirror of
https://github.com/tencentmusic/supersonic.git
synced 2025-12-11 12:07:42 +00:00
Compare commits
176 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
4d095f9676 | ||
|
|
7c7ccadcfd | ||
|
|
02b9dc6947 | ||
|
|
4222d7e2b5 | ||
|
|
b5daf04c96 | ||
|
|
d26089a249 | ||
|
|
627683d437 | ||
|
|
87e222eecc | ||
|
|
667272b103 | ||
|
|
3f08d95aaa | ||
|
|
27ebda3439 | ||
|
|
52eca178d3 | ||
|
|
f5a064aaad | ||
|
|
41e585324d | ||
|
|
c36082476f | ||
|
|
4bc1378285 | ||
|
|
23d01c4f83 | ||
|
|
49ebb70cb3 | ||
|
|
27bb1b322e | ||
|
|
0534053ff9 | ||
|
|
fe2a424718 | ||
|
|
d79e30cd7a | ||
|
|
aa433baa06 | ||
|
|
30bb9a1dc0 | ||
|
|
c168925f03 | ||
|
|
42c0bea8fc | ||
|
|
291c00749a | ||
|
|
6763ea0f7b | ||
|
|
f917defea8 | ||
|
|
91718592d4 | ||
|
|
6d9a8095eb | ||
|
|
5b8fdbc6fd | ||
|
|
e15e44e4a2 | ||
|
|
980d317152 | ||
|
|
2c23c2f574 | ||
|
|
80ad75503b | ||
|
|
0143b0a1b2 | ||
|
|
dd115f9d37 | ||
|
|
f198ce1ef8 | ||
|
|
18211a215d | ||
|
|
d6a386ad03 | ||
|
|
8f19584ad7 | ||
|
|
d9eaf79ab8 | ||
|
|
05b1a7ec3b | ||
|
|
11cdcb29fa | ||
|
|
46a9e5b097 | ||
|
|
8c65ac80b5 | ||
|
|
5b3a9ffba8 | ||
|
|
8688c8c2b3 | ||
|
|
13d8b9cff5 | ||
|
|
aa448b1ba3 | ||
|
|
7ef3d92f2c | ||
|
|
9f09598ccd | ||
|
|
36c8938ff7 | ||
|
|
3271db4ca6 | ||
|
|
400d8b34fd | ||
|
|
d4374f7074 | ||
|
|
438ee539d6 | ||
|
|
5ccde0206c | ||
|
|
74ed269544 | ||
|
|
1ad2c5402b | ||
|
|
805abeb261 | ||
|
|
551a376b00 | ||
|
|
47be92d5f6 | ||
|
|
5feac0c14e | ||
|
|
0f02e21eaa | ||
|
|
cdb84716b7 | ||
|
|
731238de08 | ||
|
|
cb1ad94086 | ||
|
|
24b0be7566 | ||
|
|
3241ef87a3 | ||
|
|
bd541e1199 | ||
|
|
f998f27c6f | ||
|
|
cf788316c3 | ||
|
|
8ed7e91221 | ||
|
|
e537b738e4 | ||
|
|
bf3a111e55 | ||
|
|
63a526709d | ||
|
|
e0088e8f8f | ||
|
|
7d33c49db8 | ||
|
|
acee0a36da | ||
|
|
a528ba6070 | ||
|
|
ba224ac335 | ||
|
|
18aa14118c | ||
|
|
4e139c837a | ||
|
|
6ad74bb206 | ||
|
|
16c3de44e4 | ||
|
|
608a4f7a2f | ||
|
|
cd972d0850 | ||
|
|
2aeeb1a14e | ||
|
|
41aa6ff8e4 | ||
|
|
67f658ced2 | ||
|
|
94fa86629d | ||
|
|
2cb0640a7b | ||
|
|
772d5bd3ae | ||
|
|
d6681ead60 | ||
|
|
d94fd4714f | ||
|
|
0365886270 | ||
|
|
aa6c658a9a | ||
|
|
6e3f871015 | ||
|
|
6c9983164e | ||
|
|
e00b935c1f | ||
|
|
f5f9c0314a | ||
|
|
910384d17f | ||
|
|
2fe56e7462 | ||
|
|
b8989e204f | ||
|
|
84b7c2c062 | ||
|
|
14373309aa | ||
|
|
9cd3e22721 | ||
|
|
2f812372d7 | ||
|
|
9f813ca1c0 | ||
|
|
70784598e1 | ||
|
|
ad20380283 | ||
|
|
f4e3922f47 | ||
|
|
bfac71a7d0 | ||
|
|
435e789fa4 | ||
|
|
b9f5e0a354 | ||
|
|
372e4acc2c | ||
|
|
8f37c3175f | ||
|
|
438e8463f5 | ||
|
|
ae9aa1ba0f | ||
|
|
688d26c457 | ||
|
|
8b99b46787 | ||
|
|
80cce47f58 | ||
|
|
c92184d89f | ||
|
|
a8fe575999 | ||
|
|
38099c8cc7 | ||
|
|
32e51257f6 | ||
|
|
e44e7ca8d5 | ||
|
|
d533496b2a | ||
|
|
eb9db28352 | ||
|
|
836ee5f3ed | ||
|
|
0fa31f84a3 | ||
|
|
9a3c71df4a | ||
|
|
e4e39e0496 | ||
|
|
cd901fbc68 | ||
|
|
8fde378534 | ||
|
|
d8f81aca65 | ||
|
|
b9895d541b | ||
|
|
4fbc3c8533 | ||
|
|
62e2bf7de6 | ||
|
|
c6f9ea2b20 | ||
|
|
166a3cfe09 | ||
|
|
cbf84876de | ||
|
|
8bd43f113b | ||
|
|
d710986923 | ||
|
|
dd63b78937 | ||
|
|
8d1a07585b | ||
|
|
f4638b48d5 | ||
|
|
a1d56fc7e4 | ||
|
|
9879c99873 | ||
|
|
1016efc646 | ||
|
|
156aa6822b | ||
|
|
34eb94320e | ||
|
|
ba1d14f40a | ||
|
|
dc4fbb1a14 | ||
|
|
7d770d2a6d | ||
|
|
7b861f563c | ||
|
|
65614ed3ba | ||
|
|
7acdf9cb3d | ||
|
|
2e4954a4e8 | ||
|
|
36052cb4f2 | ||
|
|
8d81f63e08 | ||
|
|
bf5be11549 | ||
|
|
36907ccac1 | ||
|
|
0f3e9e8754 | ||
|
|
883cdbefbe | ||
|
|
968d50e071 | ||
|
|
a9bb1c1f68 | ||
|
|
207d6cba43 | ||
|
|
40ba179703 | ||
|
|
5b8fde70ca | ||
|
|
e3232f0198 | ||
|
|
c5536aa25d | ||
|
|
37bb9ff767 | ||
|
|
f2e8207245 |
@@ -27,7 +27,7 @@ move webapp ..\..\launchers\standalone\target\classes
|
|||||||
|
|
||||||
rem 5. build backend python modules
|
rem 5. build backend python modules
|
||||||
echo "start installing python modules with pip: ${pip_path}"
|
echo "start installing python modules with pip: ${pip_path}"
|
||||||
set requirementPath="%baseDir%/../chat/core/src/main/python/requirements.txt"
|
set requirementPath="%baseDir%/../chat/python/requirements.txt"
|
||||||
%pip_path% install -r %requirementPath%
|
%pip_path% install -r %requirementPath%
|
||||||
echo "install python modules success"
|
echo "install python modules success"
|
||||||
|
|
||||||
|
|||||||
@@ -37,7 +37,7 @@ rm -fr ${buildDir}/webapp
|
|||||||
|
|
||||||
#5. build backend python modules
|
#5. build backend python modules
|
||||||
echo "start installing python modules with pip: ${pip_path}"
|
echo "start installing python modules with pip: ${pip_path}"
|
||||||
requirementPath=$baseDir/../chat/core/src/main/python/requirements.txt
|
requirementPath=$baseDir/../chat/python/requirements.txt
|
||||||
${pip_path} install -r ${requirementPath}
|
${pip_path} install -r ${requirementPath}
|
||||||
echo "install python modules success"
|
echo "install python modules success"
|
||||||
|
|
||||||
|
|||||||
@@ -96,10 +96,11 @@ function runPythonService {
|
|||||||
break
|
break
|
||||||
else
|
else
|
||||||
if [ "$i" -eq 10 ]; then
|
if [ "$i" -eq 10 ]; then
|
||||||
echo "llmparser Health check failed after 10 attempts. Exiting."
|
echo "llmparser Health check failed after 10 attempts."
|
||||||
|
echo "May still downloading model files. Please check llmparser.log in runtime directory."
|
||||||
fi
|
fi
|
||||||
echo "Retrying after 5 seconds..."
|
echo "Retrying after 5 seconds..."
|
||||||
sleep 5
|
sleep 5
|
||||||
fi
|
fi
|
||||||
done
|
done
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -21,7 +21,7 @@
|
|||||||
</includes>
|
</includes>
|
||||||
</fileSet>
|
</fileSet>
|
||||||
<fileSet>
|
<fileSet>
|
||||||
<directory>${project.basedir}/../../chat/core/src/main/python</directory>
|
<directory>${project.basedir}/../../chat/python</directory>
|
||||||
<outputDirectory>llmparser</outputDirectory>
|
<outputDirectory>llmparser</outputDirectory>
|
||||||
<fileMode>0777</fileMode>
|
<fileMode>0777</fileMode>
|
||||||
<directoryMode>0755</directoryMode>
|
<directoryMode>0755</directoryMode>
|
||||||
|
|||||||
@@ -6,7 +6,7 @@ import lombok.Data;
|
|||||||
@Data
|
@Data
|
||||||
public class AuthGroup {
|
public class AuthGroup {
|
||||||
|
|
||||||
private String modelId;
|
private Long modelId;
|
||||||
private String name;
|
private String name;
|
||||||
private Integer groupId;
|
private Integer groupId;
|
||||||
private List<AuthRule> authRules;
|
private List<AuthRule> authRules;
|
||||||
|
|||||||
@@ -7,13 +7,13 @@ import lombok.ToString;
|
|||||||
@ToString
|
@ToString
|
||||||
public class AuthRes {
|
public class AuthRes {
|
||||||
|
|
||||||
private String modelId;
|
private Long modelId;
|
||||||
private String name;
|
private String name;
|
||||||
|
|
||||||
public AuthRes() {
|
public AuthRes() {
|
||||||
}
|
}
|
||||||
|
|
||||||
public AuthRes(String modelId, String name) {
|
public AuthRes(Long modelId, String name) {
|
||||||
this.modelId = modelId;
|
this.modelId = modelId;
|
||||||
this.name = name;
|
this.name = name;
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,11 +1,13 @@
|
|||||||
package com.tencent.supersonic.auth.api.authorization.request;
|
package com.tencent.supersonic.auth.api.authorization.request;
|
||||||
|
|
||||||
|
import com.google.common.collect.Lists;
|
||||||
import com.tencent.supersonic.auth.api.authorization.pojo.AuthRes;
|
import com.tencent.supersonic.auth.api.authorization.pojo.AuthRes;
|
||||||
import java.util.ArrayList;
|
|
||||||
import java.util.List;
|
|
||||||
|
|
||||||
import lombok.Data;
|
import lombok.Data;
|
||||||
import lombok.ToString;
|
import lombok.ToString;
|
||||||
|
import org.springframework.util.CollectionUtils;
|
||||||
|
|
||||||
|
import java.util.ArrayList;
|
||||||
|
import java.util.List;
|
||||||
|
|
||||||
@Data
|
@Data
|
||||||
@ToString
|
@ToString
|
||||||
@@ -15,5 +17,17 @@ public class QueryAuthResReq {
|
|||||||
|
|
||||||
private List<AuthRes> resources;
|
private List<AuthRes> resources;
|
||||||
|
|
||||||
private String modelId;
|
private Long modelId;
|
||||||
|
|
||||||
|
private List<Long> modelIds;
|
||||||
|
|
||||||
|
public List<Long> getModelIds() {
|
||||||
|
if (!CollectionUtils.isEmpty(modelIds)) {
|
||||||
|
return modelIds;
|
||||||
|
}
|
||||||
|
if (modelId != null) {
|
||||||
|
return Lists.newArrayList(modelId);
|
||||||
|
}
|
||||||
|
return Lists.newArrayList();
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -33,12 +33,6 @@
|
|||||||
<artifactId>spring-boot-starter-jdbc</artifactId>
|
<artifactId>spring-boot-starter-jdbc</artifactId>
|
||||||
</dependency>
|
</dependency>
|
||||||
|
|
||||||
<dependency>
|
|
||||||
<groupId>org.mybatis</groupId>
|
|
||||||
<artifactId>mybatis</artifactId>
|
|
||||||
</dependency>
|
|
||||||
|
|
||||||
|
|
||||||
<dependency>
|
<dependency>
|
||||||
<groupId>com.alibaba</groupId>
|
<groupId>com.alibaba</groupId>
|
||||||
<artifactId>druid</artifactId>
|
<artifactId>druid</artifactId>
|
||||||
@@ -52,12 +46,7 @@
|
|||||||
<groupId>org.springframework.boot</groupId>
|
<groupId>org.springframework.boot</groupId>
|
||||||
<artifactId>spring-boot-starter-web</artifactId>
|
<artifactId>spring-boot-starter-web</artifactId>
|
||||||
</dependency>
|
</dependency>
|
||||||
<dependency>
|
|
||||||
<groupId>org.mybatis</groupId>
|
|
||||||
<artifactId>mybatis-spring</artifactId>
|
|
||||||
<version>${mybatis-spring.version}</version>
|
|
||||||
<scope>compile</scope>
|
|
||||||
</dependency>
|
|
||||||
<dependency>
|
<dependency>
|
||||||
<groupId>com.github.pagehelper</groupId>
|
<groupId>com.github.pagehelper</groupId>
|
||||||
<artifactId>pagehelper</artifactId>
|
<artifactId>pagehelper</artifactId>
|
||||||
|
|||||||
@@ -13,7 +13,6 @@ import com.tencent.supersonic.auth.api.authorization.service.AuthService;
|
|||||||
import com.tencent.supersonic.auth.api.authorization.pojo.AuthGroup;
|
import com.tencent.supersonic.auth.api.authorization.pojo.AuthGroup;
|
||||||
import com.tencent.supersonic.auth.api.authorization.pojo.AuthRule;
|
import com.tencent.supersonic.auth.api.authorization.pojo.AuthRule;
|
||||||
import lombok.extern.slf4j.Slf4j;
|
import lombok.extern.slf4j.Slf4j;
|
||||||
import org.apache.commons.lang3.StringUtils;
|
|
||||||
import org.springframework.jdbc.core.JdbcTemplate;
|
import org.springframework.jdbc.core.JdbcTemplate;
|
||||||
import org.springframework.stereotype.Service;
|
import org.springframework.stereotype.Service;
|
||||||
import org.springframework.util.CollectionUtils;
|
import org.springframework.util.CollectionUtils;
|
||||||
@@ -48,7 +47,7 @@ public class AuthServiceImpl implements AuthService {
|
|||||||
public List<AuthGroup> queryAuthGroups(String modelId, Integer groupId) {
|
public List<AuthGroup> queryAuthGroups(String modelId, Integer groupId) {
|
||||||
return load().stream()
|
return load().stream()
|
||||||
.filter(group -> (Objects.isNull(groupId) || groupId.equals(group.getGroupId()))
|
.filter(group -> (Objects.isNull(groupId) || groupId.equals(group.getGroupId()))
|
||||||
&& modelId.equals(group.getModelId()))
|
&& modelId.equals(group.getModelId().toString()))
|
||||||
.collect(Collectors.toList());
|
.collect(Collectors.toList());
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -80,17 +79,14 @@ public class AuthServiceImpl implements AuthService {
|
|||||||
@Override
|
@Override
|
||||||
public AuthorizedResourceResp queryAuthorizedResources(QueryAuthResReq req, User user) {
|
public AuthorizedResourceResp queryAuthorizedResources(QueryAuthResReq req, User user) {
|
||||||
Set<String> userOrgIds = userService.getUserAllOrgId(user.getName());
|
Set<String> userOrgIds = userService.getUserAllOrgId(user.getName());
|
||||||
if (!CollectionUtils.isEmpty(userOrgIds)) {
|
List<AuthGroup> groups = getAuthGroups(req.getModelIds(), user.getName(), new ArrayList<>(userOrgIds));
|
||||||
req.setDepartmentIds(new ArrayList<>(userOrgIds));
|
|
||||||
}
|
|
||||||
List<AuthGroup> groups = getAuthGroups(req, user.getName());
|
|
||||||
AuthorizedResourceResp resource = new AuthorizedResourceResp();
|
AuthorizedResourceResp resource = new AuthorizedResourceResp();
|
||||||
Map<String, List<AuthGroup>> authGroupsByModelId = groups.stream()
|
Map<Long, List<AuthGroup>> authGroupsByModelId = groups.stream()
|
||||||
.collect(Collectors.groupingBy(AuthGroup::getModelId));
|
.collect(Collectors.groupingBy(AuthGroup::getModelId));
|
||||||
Map<String, List<AuthRes>> reqAuthRes = req.getResources().stream()
|
Map<Long, List<AuthRes>> reqAuthRes = req.getResources().stream()
|
||||||
.collect(Collectors.groupingBy(AuthRes::getModelId));
|
.collect(Collectors.groupingBy(AuthRes::getModelId));
|
||||||
|
|
||||||
for (String modelId : reqAuthRes.keySet()) {
|
for (Long modelId : reqAuthRes.keySet()) {
|
||||||
List<AuthRes> reqResourcesList = reqAuthRes.get(modelId);
|
List<AuthRes> reqResourcesList = reqAuthRes.get(modelId);
|
||||||
AuthResGrp rg = new AuthResGrp();
|
AuthResGrp rg = new AuthResGrp();
|
||||||
if (authGroupsByModelId.containsKey(modelId)) {
|
if (authGroupsByModelId.containsKey(modelId)) {
|
||||||
@@ -113,7 +109,7 @@ public class AuthServiceImpl implements AuthService {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if (StringUtils.isNotEmpty(req.getModelId())) {
|
if (req.getModelId() != null) {
|
||||||
List<AuthGroup> authGroups = authGroupsByModelId.get(req.getModelId());
|
List<AuthGroup> authGroups = authGroupsByModelId.get(req.getModelId());
|
||||||
if (!CollectionUtils.isEmpty(authGroups)) {
|
if (!CollectionUtils.isEmpty(authGroups)) {
|
||||||
for (AuthGroup group : authGroups) {
|
for (AuthGroup group : authGroups) {
|
||||||
@@ -130,17 +126,17 @@ public class AuthServiceImpl implements AuthService {
|
|||||||
return resource;
|
return resource;
|
||||||
}
|
}
|
||||||
|
|
||||||
private List<AuthGroup> getAuthGroups(QueryAuthResReq req, String userName) {
|
private List<AuthGroup> getAuthGroups(List<Long> modelIds, String userName, List<String> departmentIds) {
|
||||||
List<AuthGroup> groups = load().stream()
|
List<AuthGroup> groups = load().stream()
|
||||||
.filter(group -> {
|
.filter(group -> {
|
||||||
if (!Objects.equals(group.getModelId(), req.getModelId())) {
|
if (CollectionUtils.isEmpty(modelIds) || !modelIds.contains(group.getModelId())) {
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
if (!CollectionUtils.isEmpty(group.getAuthorizedUsers()) && group.getAuthorizedUsers()
|
if (!CollectionUtils.isEmpty(group.getAuthorizedUsers()) && group.getAuthorizedUsers()
|
||||||
.contains(userName)) {
|
.contains(userName)) {
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
for (String departmentId : req.getDepartmentIds()) {
|
for (String departmentId : departmentIds) {
|
||||||
if (!CollectionUtils.isEmpty(group.getAuthorizedDepartmentIds())
|
if (!CollectionUtils.isEmpty(group.getAuthorizedDepartmentIds())
|
||||||
&& group.getAuthorizedDepartmentIds().contains(departmentId)) {
|
&& group.getAuthorizedDepartmentIds().contains(departmentId)) {
|
||||||
return true;
|
return true;
|
||||||
@@ -148,7 +144,7 @@ public class AuthServiceImpl implements AuthService {
|
|||||||
}
|
}
|
||||||
return false;
|
return false;
|
||||||
}).collect(Collectors.toList());
|
}).collect(Collectors.toList());
|
||||||
log.info("user:{} department:{} authGroups:{}", userName, req.getDepartmentIds(), groups);
|
log.info("user:{} department:{} authGroups:{}", userName, departmentIds, groups);
|
||||||
return groups;
|
return groups;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -1,7 +1,7 @@
|
|||||||
package com.tencent.supersonic.chat.api.component;
|
package com.tencent.supersonic.chat.api.component;
|
||||||
|
|
||||||
import com.tencent.supersonic.chat.api.pojo.SemanticCorrectInfo;
|
import com.tencent.supersonic.chat.api.pojo.SemanticParseInfo;
|
||||||
import net.sf.jsqlparser.JSQLParserException;
|
import com.tencent.supersonic.chat.api.pojo.request.QueryReq;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* A semantic corrector checks validity of extracted semantic information and
|
* A semantic corrector checks validity of extracted semantic information and
|
||||||
@@ -9,5 +9,5 @@ import net.sf.jsqlparser.JSQLParserException;
|
|||||||
*/
|
*/
|
||||||
public interface SemanticCorrector {
|
public interface SemanticCorrector {
|
||||||
|
|
||||||
void correct(SemanticCorrectInfo semanticCorrectInfo) throws JSQLParserException;
|
void correct(QueryReq queryReq, SemanticParseInfo semanticParseInfo);
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -9,12 +9,13 @@ import com.tencent.supersonic.semantic.api.model.request.PageMetricReq;
|
|||||||
import com.tencent.supersonic.semantic.api.model.response.DomainResp;
|
import com.tencent.supersonic.semantic.api.model.response.DomainResp;
|
||||||
import com.tencent.supersonic.semantic.api.model.response.DimensionResp;
|
import com.tencent.supersonic.semantic.api.model.response.DimensionResp;
|
||||||
import com.tencent.supersonic.semantic.api.model.response.ExplainResp;
|
import com.tencent.supersonic.semantic.api.model.response.ExplainResp;
|
||||||
import com.tencent.supersonic.semantic.api.model.response.ModelResp;
|
|
||||||
import com.tencent.supersonic.semantic.api.model.response.MetricResp;
|
import com.tencent.supersonic.semantic.api.model.response.MetricResp;
|
||||||
|
import com.tencent.supersonic.semantic.api.model.response.ModelResp;
|
||||||
|
import com.tencent.supersonic.semantic.api.model.response.ModelSchemaResp;
|
||||||
import com.tencent.supersonic.semantic.api.model.response.QueryResultWithSchemaResp;
|
import com.tencent.supersonic.semantic.api.model.response.QueryResultWithSchemaResp;
|
||||||
import com.tencent.supersonic.semantic.api.query.request.ExplainSqlReq;
|
import com.tencent.supersonic.semantic.api.query.request.ExplainSqlReq;
|
||||||
import com.tencent.supersonic.semantic.api.query.request.QueryDimValueReq;
|
import com.tencent.supersonic.semantic.api.query.request.QueryDimValueReq;
|
||||||
import com.tencent.supersonic.semantic.api.query.request.QueryDslReq;
|
import com.tencent.supersonic.semantic.api.query.request.QueryS2SQLReq;
|
||||||
import com.tencent.supersonic.semantic.api.query.request.QueryMultiStructReq;
|
import com.tencent.supersonic.semantic.api.query.request.QueryMultiStructReq;
|
||||||
import com.tencent.supersonic.semantic.api.query.request.QueryStructReq;
|
import com.tencent.supersonic.semantic.api.query.request.QueryStructReq;
|
||||||
|
|
||||||
@@ -37,7 +38,7 @@ public interface SemanticInterpreter {
|
|||||||
|
|
||||||
QueryResultWithSchemaResp queryByMultiStruct(QueryMultiStructReq queryMultiStructReq, User user);
|
QueryResultWithSchemaResp queryByMultiStruct(QueryMultiStructReq queryMultiStructReq, User user);
|
||||||
|
|
||||||
QueryResultWithSchemaResp queryByDsl(QueryDslReq queryDslReq, User user);
|
QueryResultWithSchemaResp queryByS2SQL(QueryS2SQLReq queryS2SQLReq, User user);
|
||||||
|
|
||||||
QueryResultWithSchemaResp queryDimValue(QueryDimValueReq queryDimValueReq, User user);
|
QueryResultWithSchemaResp queryDimValue(QueryDimValueReq queryDimValueReq, User user);
|
||||||
|
|
||||||
@@ -47,9 +48,9 @@ public interface SemanticInterpreter {
|
|||||||
|
|
||||||
ModelSchema getModelSchema(Long model, Boolean cacheEnable);
|
ModelSchema getModelSchema(Long model, Boolean cacheEnable);
|
||||||
|
|
||||||
PageInfo<DimensionResp> getDimensionPage(PageDimensionReq pageDimensionCmd);
|
PageInfo<DimensionResp> getDimensionPage(PageDimensionReq pageDimensionReq);
|
||||||
|
|
||||||
PageInfo<MetricResp> getMetricPage(PageMetricReq pageMetricCmd, User user);
|
PageInfo<MetricResp> getMetricPage(PageMetricReq pageDimensionReq, User user);
|
||||||
|
|
||||||
List<DomainResp> getDomainList(User user);
|
List<DomainResp> getDomainList(User user);
|
||||||
|
|
||||||
@@ -57,4 +58,6 @@ public interface SemanticInterpreter {
|
|||||||
|
|
||||||
<T> ExplainResp explain(ExplainSqlReq<T> explainSqlReq, User user) throws Exception;
|
<T> ExplainResp explain(ExplainSqlReq<T> explainSqlReq, User user) throws Exception;
|
||||||
|
|
||||||
|
List<ModelSchemaResp> fetchModelSchema(List<Long> ids, Boolean cacheEnable);
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -3,7 +3,6 @@ package com.tencent.supersonic.chat.api.component;
|
|||||||
import com.tencent.supersonic.auth.api.authentication.pojo.User;
|
import com.tencent.supersonic.auth.api.authentication.pojo.User;
|
||||||
import com.tencent.supersonic.chat.api.pojo.SemanticParseInfo;
|
import com.tencent.supersonic.chat.api.pojo.SemanticParseInfo;
|
||||||
import com.tencent.supersonic.chat.api.pojo.response.QueryResult;
|
import com.tencent.supersonic.chat.api.pojo.response.QueryResult;
|
||||||
import com.tencent.supersonic.semantic.api.model.response.ExplainResp;
|
|
||||||
import org.apache.calcite.sql.parser.SqlParseException;
|
import org.apache.calcite.sql.parser.SqlParseException;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
@@ -15,7 +14,9 @@ public interface SemanticQuery {
|
|||||||
|
|
||||||
QueryResult execute(User user) throws SqlParseException;
|
QueryResult execute(User user) throws SqlParseException;
|
||||||
|
|
||||||
ExplainResp explain(User user);
|
void initS2Sql(User user);
|
||||||
|
|
||||||
|
String explain(User user);
|
||||||
|
|
||||||
SemanticParseInfo getParseInfo();
|
SemanticParseInfo getParseInfo();
|
||||||
|
|
||||||
|
|||||||
@@ -1,8 +1,13 @@
|
|||||||
package com.tencent.supersonic.chat.api.pojo;
|
package com.tencent.supersonic.chat.api.pojo;
|
||||||
|
|
||||||
|
import com.google.common.collect.Sets;
|
||||||
|
import com.tencent.supersonic.common.pojo.ModelRela;
|
||||||
import lombok.Data;
|
import lombok.Data;
|
||||||
|
import org.springframework.util.CollectionUtils;
|
||||||
|
|
||||||
|
import java.util.ArrayList;
|
||||||
import java.util.HashSet;
|
import java.util.HashSet;
|
||||||
|
import java.util.List;
|
||||||
import java.util.Optional;
|
import java.util.Optional;
|
||||||
import java.util.Set;
|
import java.util.Set;
|
||||||
|
|
||||||
@@ -13,7 +18,9 @@ public class ModelSchema {
|
|||||||
private Set<SchemaElement> metrics = new HashSet<>();
|
private Set<SchemaElement> metrics = new HashSet<>();
|
||||||
private Set<SchemaElement> dimensions = new HashSet<>();
|
private Set<SchemaElement> dimensions = new HashSet<>();
|
||||||
private Set<SchemaElement> dimensionValues = new HashSet<>();
|
private Set<SchemaElement> dimensionValues = new HashSet<>();
|
||||||
|
private Set<SchemaElement> tags = new HashSet<>();
|
||||||
private SchemaElement entity = new SchemaElement();
|
private SchemaElement entity = new SchemaElement();
|
||||||
|
private List<ModelRela> modelRelas = new ArrayList<>();
|
||||||
|
|
||||||
public SchemaElement getElement(SchemaElementType elementType, long elementID) {
|
public SchemaElement getElement(SchemaElementType elementType, long elementID) {
|
||||||
Optional<SchemaElement> element = Optional.empty();
|
Optional<SchemaElement> element = Optional.empty();
|
||||||
@@ -34,6 +41,9 @@ public class ModelSchema {
|
|||||||
case VALUE:
|
case VALUE:
|
||||||
element = dimensionValues.stream().filter(e -> e.getId() == elementID).findFirst();
|
element = dimensionValues.stream().filter(e -> e.getId() == elementID).findFirst();
|
||||||
break;
|
break;
|
||||||
|
case TAG:
|
||||||
|
element = tags.stream().filter(e -> e.getId() == elementID).findFirst();
|
||||||
|
break;
|
||||||
default:
|
default:
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -44,4 +54,45 @@ public class ModelSchema {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
public SchemaElement getElement(SchemaElementType elementType, String name) {
|
||||||
|
Optional<SchemaElement> element = Optional.empty();
|
||||||
|
|
||||||
|
switch (elementType) {
|
||||||
|
case ENTITY:
|
||||||
|
element = Optional.ofNullable(entity);
|
||||||
|
break;
|
||||||
|
case MODEL:
|
||||||
|
element = Optional.of(model);
|
||||||
|
break;
|
||||||
|
case METRIC:
|
||||||
|
element = metrics.stream().filter(e -> name.equals(e.getName())).findFirst();
|
||||||
|
break;
|
||||||
|
case DIMENSION:
|
||||||
|
element = dimensions.stream().filter(e -> name.equals(e.getName())).findFirst();
|
||||||
|
break;
|
||||||
|
case VALUE:
|
||||||
|
element = dimensionValues.stream().filter(e -> name.equals(e.getName())).findFirst();
|
||||||
|
break;
|
||||||
|
default:
|
||||||
|
}
|
||||||
|
|
||||||
|
if (element.isPresent()) {
|
||||||
|
return element.get();
|
||||||
|
} else {
|
||||||
|
return null;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
public Set<Long> getModelClusterSet() {
|
||||||
|
if (CollectionUtils.isEmpty(modelRelas)) {
|
||||||
|
return Sets.newHashSet();
|
||||||
|
}
|
||||||
|
Set<Long> modelClusterSet = new HashSet<>();
|
||||||
|
modelRelas.forEach(modelRela -> {
|
||||||
|
modelClusterSet.add(modelRela.getToModelId());
|
||||||
|
modelClusterSet.add(modelRela.getFromModelId());
|
||||||
|
});
|
||||||
|
return modelClusterSet;
|
||||||
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -15,6 +15,7 @@ public class QueryContext {
|
|||||||
private QueryReq request;
|
private QueryReq request;
|
||||||
private List<SemanticQuery> candidateQueries = new ArrayList<>();
|
private List<SemanticQuery> candidateQueries = new ArrayList<>();
|
||||||
private SchemaMapInfo mapInfo = new SchemaMapInfo();
|
private SchemaMapInfo mapInfo = new SchemaMapInfo();
|
||||||
|
private SchemaModelClusterMapInfo modelClusterMapInfo = new SchemaModelClusterMapInfo();
|
||||||
|
|
||||||
public QueryContext(QueryReq request) {
|
public QueryContext(QueryReq request) {
|
||||||
this.request = request;
|
this.request = request;
|
||||||
|
|||||||
@@ -0,0 +1,19 @@
|
|||||||
|
package com.tencent.supersonic.chat.api.pojo;
|
||||||
|
|
||||||
|
|
||||||
|
import lombok.AllArgsConstructor;
|
||||||
|
import lombok.Builder;
|
||||||
|
import lombok.Data;
|
||||||
|
import lombok.NoArgsConstructor;
|
||||||
|
|
||||||
|
@Data
|
||||||
|
@Builder
|
||||||
|
@NoArgsConstructor
|
||||||
|
@AllArgsConstructor
|
||||||
|
public class RelatedSchemaElement {
|
||||||
|
|
||||||
|
private Long dimensionId;
|
||||||
|
|
||||||
|
private boolean isNecessary;
|
||||||
|
|
||||||
|
}
|
||||||
@@ -1,14 +1,15 @@
|
|||||||
package com.tencent.supersonic.chat.api.pojo;
|
package com.tencent.supersonic.chat.api.pojo;
|
||||||
|
|
||||||
import com.google.common.base.Objects;
|
import com.google.common.base.Objects;
|
||||||
import java.io.Serializable;
|
|
||||||
import java.util.List;
|
|
||||||
import lombok.AllArgsConstructor;
|
import lombok.AllArgsConstructor;
|
||||||
import lombok.Builder;
|
import lombok.Builder;
|
||||||
import lombok.Data;
|
import lombok.Data;
|
||||||
import lombok.Getter;
|
import lombok.Getter;
|
||||||
import lombok.NoArgsConstructor;
|
import lombok.NoArgsConstructor;
|
||||||
|
|
||||||
|
import java.io.Serializable;
|
||||||
|
import java.util.List;
|
||||||
|
|
||||||
@Data
|
@Data
|
||||||
@Getter
|
@Getter
|
||||||
@Builder
|
@Builder
|
||||||
@@ -22,13 +23,14 @@ public class SchemaElement implements Serializable {
|
|||||||
private String bizName;
|
private String bizName;
|
||||||
private Long useCnt;
|
private Long useCnt;
|
||||||
private SchemaElementType type;
|
private SchemaElementType type;
|
||||||
|
|
||||||
private List<String> alias;
|
private List<String> alias;
|
||||||
|
|
||||||
private List<SchemaValueMap> schemaValueMaps;
|
private List<SchemaValueMap> schemaValueMaps;
|
||||||
|
private List<RelatedSchemaElement> relatedSchemaElements;
|
||||||
|
|
||||||
private String defaultAgg;
|
private String defaultAgg;
|
||||||
|
|
||||||
|
private double order;
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public boolean equals(Object o) {
|
public boolean equals(Object o) {
|
||||||
if (this == o) {
|
if (this == o) {
|
||||||
@@ -40,13 +42,13 @@ public class SchemaElement implements Serializable {
|
|||||||
SchemaElement schemaElement = (SchemaElement) o;
|
SchemaElement schemaElement = (SchemaElement) o;
|
||||||
return Objects.equal(model, schemaElement.model) && Objects.equal(id,
|
return Objects.equal(model, schemaElement.model) && Objects.equal(id,
|
||||||
schemaElement.id) && Objects.equal(name, schemaElement.name)
|
schemaElement.id) && Objects.equal(name, schemaElement.name)
|
||||||
&& Objects.equal(bizName, schemaElement.bizName) && Objects.equal(
|
&& Objects.equal(bizName, schemaElement.bizName)
|
||||||
useCnt, schemaElement.useCnt) && Objects.equal(type, schemaElement.type);
|
&& Objects.equal(type, schemaElement.type);
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public int hashCode() {
|
public int hashCode() {
|
||||||
return Objects.hashCode(model, id, name, bizName, useCnt, type);
|
return Objects.hashCode(model, id, name, bizName, type);
|
||||||
}
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -6,6 +6,7 @@ public enum SchemaElementType {
|
|||||||
DIMENSION,
|
DIMENSION,
|
||||||
VALUE,
|
VALUE,
|
||||||
ENTITY,
|
ENTITY,
|
||||||
|
TAG,
|
||||||
ID,
|
ID,
|
||||||
DATE
|
DATE
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -0,0 +1,61 @@
|
|||||||
|
package com.tencent.supersonic.chat.api.pojo;
|
||||||
|
|
||||||
|
import com.clickhouse.client.internal.apache.commons.compress.utils.Lists;
|
||||||
|
import com.tencent.supersonic.common.pojo.ModelCluster;
|
||||||
|
import lombok.Data;
|
||||||
|
import org.springframework.util.CollectionUtils;
|
||||||
|
|
||||||
|
import java.util.HashMap;
|
||||||
|
import java.util.List;
|
||||||
|
import java.util.Map;
|
||||||
|
import java.util.Set;
|
||||||
|
|
||||||
|
@Data
|
||||||
|
public class SchemaModelClusterMapInfo {
|
||||||
|
|
||||||
|
private Map<String, List<SchemaElementMatch>> modelElementMatches = new HashMap<>();
|
||||||
|
|
||||||
|
public Set<String> getMatchedModelClusters() {
|
||||||
|
return modelElementMatches.keySet();
|
||||||
|
}
|
||||||
|
|
||||||
|
public List<SchemaElementMatch> getMatchedElements(Long modelId) {
|
||||||
|
for (String key : modelElementMatches.keySet()) {
|
||||||
|
if (ModelCluster.getModelIdFromKey(key).contains(modelId)) {
|
||||||
|
return modelElementMatches.get(key);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return Lists.newArrayList();
|
||||||
|
}
|
||||||
|
|
||||||
|
public List<SchemaElementMatch> getMatchedElements(String modelCluster) {
|
||||||
|
return modelElementMatches.get(modelCluster);
|
||||||
|
}
|
||||||
|
|
||||||
|
public Map<String, List<SchemaElementMatch>> getModelElementMatches() {
|
||||||
|
return modelElementMatches;
|
||||||
|
}
|
||||||
|
|
||||||
|
public Map<String, List<SchemaElementMatch>> getElementMatchesByModelIds(Set<Long> modelIds) {
|
||||||
|
if (CollectionUtils.isEmpty(modelIds)) {
|
||||||
|
return modelElementMatches;
|
||||||
|
}
|
||||||
|
Map<String, List<SchemaElementMatch>> modelElementMatchesFiltered = new HashMap<>();
|
||||||
|
for (String key : modelElementMatches.keySet()) {
|
||||||
|
for (Long modelId : modelIds) {
|
||||||
|
if (ModelCluster.getModelIdFromKey(key).contains(modelId)) {
|
||||||
|
modelElementMatchesFiltered.put(key, modelElementMatches.get(key));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return modelElementMatchesFiltered;
|
||||||
|
}
|
||||||
|
|
||||||
|
public void setModelElementMatches(Map<String, List<SchemaElementMatch>> modelElementMatches) {
|
||||||
|
this.modelElementMatches = modelElementMatches;
|
||||||
|
}
|
||||||
|
|
||||||
|
public void setMatchedElements(String modelCluster, List<SchemaElementMatch> elementMatches) {
|
||||||
|
modelElementMatches.put(modelCluster, elementMatches);
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -5,8 +5,13 @@ import com.tencent.supersonic.chat.api.pojo.request.QueryFilter;
|
|||||||
import com.tencent.supersonic.chat.api.pojo.response.EntityInfo;
|
import com.tencent.supersonic.chat.api.pojo.response.EntityInfo;
|
||||||
import com.tencent.supersonic.chat.api.pojo.response.SqlInfo;
|
import com.tencent.supersonic.chat.api.pojo.response.SqlInfo;
|
||||||
import com.tencent.supersonic.common.pojo.DateConf;
|
import com.tencent.supersonic.common.pojo.DateConf;
|
||||||
|
import com.tencent.supersonic.common.pojo.ModelCluster;
|
||||||
import com.tencent.supersonic.common.pojo.Order;
|
import com.tencent.supersonic.common.pojo.Order;
|
||||||
|
import com.tencent.supersonic.common.pojo.QueryType;
|
||||||
import com.tencent.supersonic.common.pojo.enums.AggregateTypeEnum;
|
import com.tencent.supersonic.common.pojo.enums.AggregateTypeEnum;
|
||||||
|
import com.tencent.supersonic.common.pojo.enums.FilterType;
|
||||||
|
import lombok.Data;
|
||||||
|
|
||||||
import java.util.ArrayList;
|
import java.util.ArrayList;
|
||||||
import java.util.Comparator;
|
import java.util.Comparator;
|
||||||
import java.util.HashMap;
|
import java.util.HashMap;
|
||||||
@@ -15,15 +20,13 @@ import java.util.List;
|
|||||||
import java.util.Map;
|
import java.util.Map;
|
||||||
import java.util.Set;
|
import java.util.Set;
|
||||||
import java.util.TreeSet;
|
import java.util.TreeSet;
|
||||||
import com.tencent.supersonic.common.pojo.enums.FilterType;
|
|
||||||
import lombok.Data;
|
|
||||||
|
|
||||||
@Data
|
@Data
|
||||||
public class SemanticParseInfo {
|
public class SemanticParseInfo {
|
||||||
|
|
||||||
private Integer id;
|
private Integer id;
|
||||||
private String queryMode;
|
private String queryMode;
|
||||||
private SchemaElement model;
|
private ModelCluster model = new ModelCluster();
|
||||||
private Set<SchemaElement> metrics = new TreeSet<>(new SchemaNameLengthComparator());
|
private Set<SchemaElement> metrics = new TreeSet<>(new SchemaNameLengthComparator());
|
||||||
private Set<SchemaElement> dimensions = new LinkedHashSet();
|
private Set<SchemaElement> dimensions = new LinkedHashSet();
|
||||||
private SchemaElement entity;
|
private SchemaElement entity;
|
||||||
@@ -34,25 +37,38 @@ public class SemanticParseInfo {
|
|||||||
private Set<Order> orders = new LinkedHashSet();
|
private Set<Order> orders = new LinkedHashSet();
|
||||||
private DateConf dateInfo;
|
private DateConf dateInfo;
|
||||||
private Long limit;
|
private Long limit;
|
||||||
private Boolean nativeQuery = false;
|
|
||||||
private double score;
|
private double score;
|
||||||
private List<SchemaElementMatch> elementMatches = new ArrayList<>();
|
private List<SchemaElementMatch> elementMatches = new ArrayList<>();
|
||||||
private Map<String, Object> properties = new HashMap<>();
|
private Map<String, Object> properties = new HashMap<>();
|
||||||
private EntityInfo entityInfo;
|
private EntityInfo entityInfo;
|
||||||
private SqlInfo sqlInfo = new SqlInfo();
|
private SqlInfo sqlInfo = new SqlInfo();
|
||||||
|
private QueryType queryType = QueryType.OTHER;
|
||||||
|
|
||||||
public Long getModelId() {
|
public String getModelClusterKey() {
|
||||||
return model != null ? model.getId() : 0L;
|
if (model == null) {
|
||||||
|
return "";
|
||||||
|
}
|
||||||
|
return model.getKey();
|
||||||
}
|
}
|
||||||
|
|
||||||
public String getModelName() {
|
public String getModelName() {
|
||||||
return model != null ? model.getName() : "null";
|
if (model == null) {
|
||||||
|
return "";
|
||||||
|
}
|
||||||
|
return model.getName();
|
||||||
}
|
}
|
||||||
|
|
||||||
private static class SchemaNameLengthComparator implements Comparator<SchemaElement> {
|
private static class SchemaNameLengthComparator implements Comparator<SchemaElement> {
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public int compare(SchemaElement o1, SchemaElement o2) {
|
public int compare(SchemaElement o1, SchemaElement o2) {
|
||||||
|
if (o1.getOrder() != o2.getOrder()) {
|
||||||
|
if (o1.getOrder() < o2.getOrder()) {
|
||||||
|
return -1;
|
||||||
|
} else {
|
||||||
|
return 1;
|
||||||
|
}
|
||||||
|
}
|
||||||
int len1 = o1.getName().length();
|
int len1 = o1.getName().length();
|
||||||
int len2 = o2.getName().length();
|
int len2 = o2.getName().length();
|
||||||
if (len1 != len2) {
|
if (len1 != len2) {
|
||||||
@@ -70,4 +86,26 @@ public class SemanticParseInfo {
|
|||||||
return metrics;
|
return metrics;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
private Map<Long, Integer> getModelElementCountMap() {
|
||||||
|
Map<Long, Integer> elementCountMap = new HashMap<>();
|
||||||
|
elementMatches.forEach(element -> {
|
||||||
|
int count = elementCountMap.getOrDefault(element.getElement().getModel(), 0);
|
||||||
|
elementCountMap.put(element.getElement().getModel(), count + 1);
|
||||||
|
});
|
||||||
|
return elementCountMap;
|
||||||
|
}
|
||||||
|
|
||||||
|
public Long getModelId() {
|
||||||
|
Map<Long, Integer> elementCountMap = getModelElementCountMap();
|
||||||
|
Long modelId = -1L;
|
||||||
|
int maxCnt = 0;
|
||||||
|
for (Long model : elementCountMap.keySet()) {
|
||||||
|
if (elementCountMap.get(model) > maxCnt) {
|
||||||
|
maxCnt = elementCountMap.get(model);
|
||||||
|
modelId = model;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return modelId;
|
||||||
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,9 +1,14 @@
|
|||||||
package com.tencent.supersonic.chat.api.pojo;
|
package com.tencent.supersonic.chat.api.pojo;
|
||||||
|
|
||||||
|
import org.springframework.util.CollectionUtils;
|
||||||
|
|
||||||
import java.io.Serializable;
|
import java.io.Serializable;
|
||||||
import java.util.ArrayList;
|
import java.util.ArrayList;
|
||||||
|
import java.util.HashMap;
|
||||||
import java.util.List;
|
import java.util.List;
|
||||||
import java.util.Map;
|
import java.util.Map;
|
||||||
|
import java.util.Optional;
|
||||||
|
import java.util.Set;
|
||||||
import java.util.stream.Collectors;
|
import java.util.stream.Collectors;
|
||||||
|
|
||||||
public class SemanticSchema implements Serializable {
|
public class SemanticSchema implements Serializable {
|
||||||
@@ -18,6 +23,64 @@ public class SemanticSchema implements Serializable {
|
|||||||
modelSchemaList.add(schema);
|
modelSchemaList.add(schema);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
public SchemaElement getElement(SchemaElementType elementType, long elementID) {
|
||||||
|
Optional<SchemaElement> element = Optional.empty();
|
||||||
|
|
||||||
|
switch (elementType) {
|
||||||
|
case ENTITY:
|
||||||
|
element = getElementsById(elementID, getEntities());
|
||||||
|
break;
|
||||||
|
case MODEL:
|
||||||
|
element = getElementsById(elementID, getModels());
|
||||||
|
break;
|
||||||
|
case METRIC:
|
||||||
|
element = getElementsById(elementID, getMetrics());
|
||||||
|
break;
|
||||||
|
case DIMENSION:
|
||||||
|
element = getElementsById(elementID, getDimensions());
|
||||||
|
break;
|
||||||
|
case VALUE:
|
||||||
|
element = getElementsById(elementID, getDimensionValues());
|
||||||
|
break;
|
||||||
|
default:
|
||||||
|
}
|
||||||
|
|
||||||
|
if (element.isPresent()) {
|
||||||
|
return element.get();
|
||||||
|
} else {
|
||||||
|
return null;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
public SchemaElement getElementByName(SchemaElementType elementType, String name) {
|
||||||
|
Optional<SchemaElement> element = Optional.empty();
|
||||||
|
|
||||||
|
switch (elementType) {
|
||||||
|
case ENTITY:
|
||||||
|
element = getElementsByName(name, getEntities());
|
||||||
|
break;
|
||||||
|
case MODEL:
|
||||||
|
element = getElementsByName(name, getModels());
|
||||||
|
break;
|
||||||
|
case METRIC:
|
||||||
|
element = getElementsByName(name, getMetrics());
|
||||||
|
break;
|
||||||
|
case DIMENSION:
|
||||||
|
element = getElementsByName(name, getDimensions());
|
||||||
|
break;
|
||||||
|
case VALUE:
|
||||||
|
element = getElementsByName(name, getDimensionValues());
|
||||||
|
break;
|
||||||
|
default:
|
||||||
|
}
|
||||||
|
|
||||||
|
if (element.isPresent()) {
|
||||||
|
return element.get();
|
||||||
|
} else {
|
||||||
|
return null;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
public Map<Long, String> getModelIdToName() {
|
public Map<Long, String> getModelIdToName() {
|
||||||
return modelSchemaList.stream()
|
return modelSchemaList.stream()
|
||||||
.collect(Collectors.toMap(a -> a.getModel().getId(), a -> a.getModel().getName(), (k1, k2) -> k1));
|
.collect(Collectors.toMap(a -> a.getModel().getId(), a -> a.getModel().getName(), (k1, k2) -> k1));
|
||||||
@@ -35,9 +98,28 @@ public class SemanticSchema implements Serializable {
|
|||||||
return dimensions;
|
return dimensions;
|
||||||
}
|
}
|
||||||
|
|
||||||
public List<SchemaElement> getDimensions(Long modelId) {
|
public List<SchemaElement> getDimensions(Set<Long> modelIds) {
|
||||||
List<SchemaElement> dimensions = getDimensions();
|
List<SchemaElement> dimensions = getDimensions();
|
||||||
return getElementsByModelId(modelId, dimensions);
|
return getElementsByModelId(modelIds, dimensions);
|
||||||
|
}
|
||||||
|
|
||||||
|
public SchemaElement getDimensions(Long id) {
|
||||||
|
List<SchemaElement> dimensions = getDimensions();
|
||||||
|
Optional<SchemaElement> dimension = getElementsById(id, dimensions);
|
||||||
|
return dimension.orElse(null);
|
||||||
|
}
|
||||||
|
|
||||||
|
public List<SchemaElement> getTags() {
|
||||||
|
List<SchemaElement> tags = new ArrayList<>();
|
||||||
|
modelSchemaList.stream().forEach(d -> tags.addAll(d.getTags()));
|
||||||
|
return tags;
|
||||||
|
}
|
||||||
|
|
||||||
|
public List<SchemaElement> getTags(Set<Long> modelIds) {
|
||||||
|
List<SchemaElement> tags = new ArrayList<>();
|
||||||
|
modelSchemaList.stream().filter(schemaElement -> modelIds.contains(schemaElement.getModel()))
|
||||||
|
.forEach(d -> tags.addAll(d.getTags()));
|
||||||
|
return tags;
|
||||||
}
|
}
|
||||||
|
|
||||||
public List<SchemaElement> getMetrics() {
|
public List<SchemaElement> getMetrics() {
|
||||||
@@ -46,26 +128,54 @@ public class SemanticSchema implements Serializable {
|
|||||||
return metrics;
|
return metrics;
|
||||||
}
|
}
|
||||||
|
|
||||||
public List<SchemaElement> getMetrics(Long modelId) {
|
public List<SchemaElement> getMetrics(Set<Long> modelIds) {
|
||||||
List<SchemaElement> metrics = getMetrics();
|
List<SchemaElement> metrics = getMetrics();
|
||||||
return getElementsByModelId(modelId, metrics);
|
return getElementsByModelId(modelIds, metrics);
|
||||||
}
|
}
|
||||||
|
|
||||||
private List<SchemaElement> getElementsByModelId(Long modelId, List<SchemaElement> elements) {
|
public List<SchemaElement> getEntities() {
|
||||||
|
List<SchemaElement> entities = new ArrayList<>();
|
||||||
|
modelSchemaList.stream().forEach(d -> entities.add(d.getEntity()));
|
||||||
|
return entities;
|
||||||
|
}
|
||||||
|
|
||||||
|
private List<SchemaElement> getElementsByModelId(Set<Long> modelIds, List<SchemaElement> elements) {
|
||||||
return elements.stream()
|
return elements.stream()
|
||||||
.filter(schemaElement -> modelId.equals(schemaElement.getModel()))
|
.filter(schemaElement -> modelIds.contains(schemaElement.getModel()))
|
||||||
.collect(Collectors.toList());
|
.collect(Collectors.toList());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
private Optional<SchemaElement> getElementsById(Long id, List<SchemaElement> elements) {
|
||||||
|
return elements.stream()
|
||||||
|
.filter(schemaElement -> id.equals(schemaElement.getId()))
|
||||||
|
.findFirst();
|
||||||
|
}
|
||||||
|
|
||||||
|
private Optional<SchemaElement> getElementsByName(String name, List<SchemaElement> elements) {
|
||||||
|
return elements.stream()
|
||||||
|
.filter(schemaElement -> name.equals(schemaElement.getName()))
|
||||||
|
.findFirst();
|
||||||
|
}
|
||||||
|
|
||||||
public List<SchemaElement> getModels() {
|
public List<SchemaElement> getModels() {
|
||||||
List<SchemaElement> models = new ArrayList<>();
|
List<SchemaElement> models = new ArrayList<>();
|
||||||
modelSchemaList.stream().forEach(d -> models.add(d.getModel()));
|
modelSchemaList.stream().forEach(d -> models.add(d.getModel()));
|
||||||
return models;
|
return models;
|
||||||
}
|
}
|
||||||
|
|
||||||
public List<SchemaElement> getEntities() {
|
public Map<String, String> getBizNameToName(Set<Long> modelIds) {
|
||||||
List<SchemaElement> entities = new ArrayList<>();
|
List<SchemaElement> allElements = new ArrayList<>();
|
||||||
modelSchemaList.stream().forEach(d -> entities.add(d.getEntity()));
|
allElements.addAll(getDimensions(modelIds));
|
||||||
return entities;
|
allElements.addAll(getMetrics(modelIds));
|
||||||
|
return allElements.stream()
|
||||||
|
.collect(Collectors.toMap(SchemaElement::getBizName, SchemaElement::getName, (k1, k2) -> k1));
|
||||||
|
}
|
||||||
|
|
||||||
|
public Map<Long, ModelSchema> getModelSchemaMap() {
|
||||||
|
if (CollectionUtils.isEmpty(modelSchemaList)) {
|
||||||
|
return new HashMap<>();
|
||||||
|
}
|
||||||
|
return modelSchemaList.stream().collect(Collectors.toMap(modelSchema
|
||||||
|
-> modelSchema.getModel().getModel(), modelSchema -> modelSchema));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -32,6 +32,11 @@ public class ChatConfigBaseReq {
|
|||||||
*/
|
*/
|
||||||
private List<RecommendedQuestionReq> recommendedQuestions;
|
private List<RecommendedQuestionReq> recommendedQuestions;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* the llm examples about the model
|
||||||
|
*/
|
||||||
|
private String llmExamples;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* available status
|
* available status
|
||||||
*/
|
*/
|
||||||
|
|||||||
@@ -1,16 +1,21 @@
|
|||||||
package com.tencent.supersonic.chat.api.pojo.request;
|
package com.tencent.supersonic.chat.api.pojo.request;
|
||||||
|
|
||||||
|
import javax.validation.constraints.NotNull;
|
||||||
import lombok.Data;
|
import lombok.Data;
|
||||||
|
|
||||||
@Data
|
@Data
|
||||||
public class DimensionValueReq {
|
public class DimensionValueReq {
|
||||||
|
|
||||||
private Integer agentId;
|
private Integer agentId;
|
||||||
|
|
||||||
|
@NotNull
|
||||||
private Long elementID;
|
private Long elementID;
|
||||||
|
|
||||||
|
@NotNull
|
||||||
private Long modelId;
|
private Long modelId;
|
||||||
|
|
||||||
private String bizName;
|
private String bizName;
|
||||||
|
|
||||||
private Object value;
|
@NotNull
|
||||||
|
private String value;
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -3,16 +3,18 @@ package com.tencent.supersonic.chat.api.pojo.request;
|
|||||||
|
|
||||||
import com.tencent.supersonic.auth.api.authentication.pojo.User;
|
import com.tencent.supersonic.auth.api.authentication.pojo.User;
|
||||||
import com.tencent.supersonic.chat.api.pojo.SemanticParseInfo;
|
import com.tencent.supersonic.chat.api.pojo.SemanticParseInfo;
|
||||||
|
import lombok.Builder;
|
||||||
import lombok.Data;
|
import lombok.Data;
|
||||||
|
|
||||||
|
@Builder
|
||||||
@Data
|
@Data
|
||||||
public class ExecuteQueryReq {
|
public class ExecuteQueryReq {
|
||||||
private User user;
|
private User user;
|
||||||
private Integer agentId;
|
private Integer agentId;
|
||||||
private Integer chatId;
|
private Integer chatId;
|
||||||
private String queryText;
|
private String queryText;
|
||||||
private Long queryId = 7L;
|
private Long queryId;
|
||||||
private Integer parseId = 2;
|
private Integer parseId;
|
||||||
private SemanticParseInfo parseInfo;
|
private SemanticParseInfo parseInfo;
|
||||||
private boolean saveAnswer = true;
|
private boolean saveAnswer;
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -13,4 +13,8 @@ public class PageQueryInfoReq {
|
|||||||
private String userName;
|
private String userName;
|
||||||
|
|
||||||
private List<Long> ids;
|
private List<Long> ids;
|
||||||
|
|
||||||
|
public Integer getLimitStart() {
|
||||||
|
return this.pageSize * (this.current - 1);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -16,6 +16,6 @@ public class QueryDataReq {
|
|||||||
private Set<QueryFilter> dimensionFilters = new HashSet<>();
|
private Set<QueryFilter> dimensionFilters = new HashSet<>();
|
||||||
private Set<QueryFilter> metricFilters = new HashSet<>();
|
private Set<QueryFilter> metricFilters = new HashSet<>();
|
||||||
private DateConf dateInfo;
|
private DateConf dateInfo;
|
||||||
private Long queryId = 7L;
|
private Long queryId;
|
||||||
private Integer parseId = 2;
|
private Integer parseId;
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,7 +1,7 @@
|
|||||||
package com.tencent.supersonic.chat.api.pojo.request;
|
package com.tencent.supersonic.chat.api.pojo.request;
|
||||||
|
|
||||||
import com.google.common.base.Objects;
|
import com.google.common.base.Objects;
|
||||||
import com.tencent.supersonic.semantic.api.query.enums.FilterOperatorEnum;
|
import com.tencent.supersonic.common.pojo.enums.FilterOperatorEnum;
|
||||||
import lombok.Data;
|
import lombok.Data;
|
||||||
import lombok.ToString;
|
import lombok.ToString;
|
||||||
|
|
||||||
|
|||||||
@@ -7,7 +7,7 @@ import lombok.Data;
|
|||||||
public class QueryReq {
|
public class QueryReq {
|
||||||
private String queryText;
|
private String queryText;
|
||||||
private Integer chatId;
|
private Integer chatId;
|
||||||
private Long modelId = 0L;
|
private Long modelId;
|
||||||
private User user;
|
private User user;
|
||||||
private QueryFilters queryFilters;
|
private QueryFilters queryFilters;
|
||||||
private boolean saveAnswer = true;
|
private boolean saveAnswer = true;
|
||||||
|
|||||||
@@ -0,0 +1,13 @@
|
|||||||
|
package com.tencent.supersonic.chat.api.pojo.request;
|
||||||
|
|
||||||
|
|
||||||
|
import lombok.Data;
|
||||||
|
|
||||||
|
@Data
|
||||||
|
public class RecommendReq {
|
||||||
|
|
||||||
|
private Long modelId;
|
||||||
|
|
||||||
|
private Long metricId;
|
||||||
|
|
||||||
|
}
|
||||||
@@ -18,7 +18,7 @@ public class SolvedQueryReq {
|
|||||||
|
|
||||||
private String queryText;
|
private String queryText;
|
||||||
|
|
||||||
private Long modelId;
|
private String modelId;
|
||||||
|
|
||||||
private Integer agentId;
|
private Integer agentId;
|
||||||
|
|
||||||
|
|||||||
@@ -23,6 +23,8 @@ public class ChatConfigResp {
|
|||||||
|
|
||||||
private List<RecommendedQuestionReq> recommendedQuestions;
|
private List<RecommendedQuestionReq> recommendedQuestions;
|
||||||
|
|
||||||
|
private String llmExamples;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* available status
|
* available status
|
||||||
*/
|
*/
|
||||||
|
|||||||
@@ -1,13 +1,13 @@
|
|||||||
package com.tencent.supersonic.chat.api.pojo.response;
|
package com.tencent.supersonic.chat.api.pojo.response;
|
||||||
|
|
||||||
|
import lombok.Data;
|
||||||
|
|
||||||
import java.io.Serializable;
|
import java.io.Serializable;
|
||||||
import java.util.List;
|
import java.util.List;
|
||||||
import lombok.Data;
|
|
||||||
|
|
||||||
@Data
|
@Data
|
||||||
public class ModelInfo extends DataInfo implements Serializable {
|
public class ModelInfo extends DataInfo implements Serializable {
|
||||||
|
|
||||||
private List<String> words;
|
private List<String> words;
|
||||||
private String primaryEntityName;
|
private String primaryKey;
|
||||||
private String primaryEntityBizName;
|
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,31 +1,24 @@
|
|||||||
package com.tencent.supersonic.chat.api.pojo.response;
|
package com.tencent.supersonic.chat.api.pojo.response;
|
||||||
|
|
||||||
|
import com.google.common.collect.Lists;
|
||||||
import com.tencent.supersonic.chat.api.pojo.SemanticParseInfo;
|
import com.tencent.supersonic.chat.api.pojo.SemanticParseInfo;
|
||||||
import lombok.Data;
|
import lombok.Data;
|
||||||
import lombok.Getter;
|
|
||||||
import lombok.Builder;
|
|
||||||
import lombok.NoArgsConstructor;
|
|
||||||
import lombok.AllArgsConstructor;
|
|
||||||
|
|
||||||
import java.util.List;
|
import java.util.List;
|
||||||
|
|
||||||
@Data
|
@Data
|
||||||
@Getter
|
|
||||||
@Builder
|
|
||||||
@NoArgsConstructor
|
|
||||||
@AllArgsConstructor
|
|
||||||
public class ParseResp {
|
public class ParseResp {
|
||||||
private Integer chatId;
|
private Integer chatId;
|
||||||
private String queryText;
|
private String queryText;
|
||||||
private Long queryId;
|
private Long queryId;
|
||||||
private ParseState state;
|
private ParseState state;
|
||||||
private List<SemanticParseInfo> selectedParses;
|
private List<SemanticParseInfo> selectedParses = Lists.newArrayList();
|
||||||
private List<SemanticParseInfo> candidateParses;
|
private List<SemanticParseInfo> candidateParses = Lists.newArrayList();
|
||||||
private List<SolvedQueryRecallResp> similarSolvedQuery;
|
private ParseTimeCostDO parseTimeCost = new ParseTimeCostDO();
|
||||||
|
|
||||||
public enum ParseState {
|
public enum ParseState {
|
||||||
COMPLETED,
|
COMPLETED,
|
||||||
PENDING,
|
PENDING,
|
||||||
FAILED
|
FAILED
|
||||||
}
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -0,0 +1,15 @@
|
|||||||
|
package com.tencent.supersonic.chat.api.pojo.response;
|
||||||
|
|
||||||
|
import lombok.Data;
|
||||||
|
|
||||||
|
@Data
|
||||||
|
public class ParseTimeCostDO {
|
||||||
|
|
||||||
|
private long parseStartTime;
|
||||||
|
private long parseTime;
|
||||||
|
private long sqlTime;
|
||||||
|
|
||||||
|
public ParseTimeCostDO() {
|
||||||
|
this.parseStartTime = System.currentTimeMillis();
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -0,0 +1,11 @@
|
|||||||
|
package com.tencent.supersonic.chat.api.pojo.response;
|
||||||
|
|
||||||
|
import lombok.Data;
|
||||||
|
|
||||||
|
import java.util.List;
|
||||||
|
|
||||||
|
@Data
|
||||||
|
public class QueryRecallResp {
|
||||||
|
private List<SolvedQueryRecallResp> solvedQueryRecallRespList;
|
||||||
|
private Long queryTimeCost;
|
||||||
|
}
|
||||||
@@ -1,6 +1,8 @@
|
|||||||
package com.tencent.supersonic.chat.api.pojo.response;
|
package com.tencent.supersonic.chat.api.pojo.response;
|
||||||
|
|
||||||
import java.util.Date;
|
import java.util.Date;
|
||||||
|
import java.util.List;
|
||||||
|
import com.tencent.supersonic.chat.api.pojo.SemanticParseInfo;
|
||||||
import lombok.Data;
|
import lombok.Data;
|
||||||
|
|
||||||
@Data
|
@Data
|
||||||
@@ -13,4 +15,5 @@ public class QueryResp {
|
|||||||
private String feedback;
|
private String feedback;
|
||||||
private String queryText;
|
private String queryText;
|
||||||
private QueryResult queryResult;
|
private QueryResult queryResult;
|
||||||
|
private List<SemanticParseInfo> parseInfos;
|
||||||
}
|
}
|
||||||
@@ -21,4 +21,5 @@ public class QueryResult {
|
|||||||
private SemanticParseInfo chatContext;
|
private SemanticParseInfo chatContext;
|
||||||
private Object response;
|
private Object response;
|
||||||
private List<Map<String, Object>> queryResults;
|
private List<Map<String, Object>> queryResults;
|
||||||
|
private Long queryTimeCost;
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -5,7 +5,7 @@ import lombok.Data;
|
|||||||
@Data
|
@Data
|
||||||
public class SqlInfo {
|
public class SqlInfo {
|
||||||
|
|
||||||
private String llmParseSql;
|
private String s2SQL;
|
||||||
private String logicSql;
|
private String correctS2SQL;
|
||||||
private String querySql;
|
private String querySQL;
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -59,16 +59,7 @@
|
|||||||
<groupId>org.springframework.boot</groupId>
|
<groupId>org.springframework.boot</groupId>
|
||||||
<artifactId>spring-boot-starter-web</artifactId>
|
<artifactId>spring-boot-starter-web</artifactId>
|
||||||
</dependency>
|
</dependency>
|
||||||
<dependency>
|
|
||||||
<groupId>org.mybatis</groupId>
|
|
||||||
<artifactId>mybatis</artifactId>
|
|
||||||
</dependency>
|
|
||||||
|
|
||||||
<dependency>
|
|
||||||
<groupId>org.mybatis.spring.boot</groupId>
|
|
||||||
<artifactId>mybatis-spring-boot-starter</artifactId>
|
|
||||||
<version>${mybatis-spring.version}</version>
|
|
||||||
</dependency>
|
|
||||||
<dependency>
|
<dependency>
|
||||||
<groupId>com.alibaba</groupId>
|
<groupId>com.alibaba</groupId>
|
||||||
<artifactId>druid</artifactId>
|
<artifactId>druid</artifactId>
|
||||||
@@ -78,24 +69,6 @@
|
|||||||
<groupId>mysql</groupId>
|
<groupId>mysql</groupId>
|
||||||
<artifactId>mysql-connector-java</artifactId>
|
<artifactId>mysql-connector-java</artifactId>
|
||||||
</dependency>
|
</dependency>
|
||||||
<dependency>
|
|
||||||
<groupId>org.mybatis</groupId>
|
|
||||||
<artifactId>mybatis-spring</artifactId>
|
|
||||||
<scope>test</scope>
|
|
||||||
</dependency>
|
|
||||||
|
|
||||||
<dependency>
|
|
||||||
<groupId>org.mybatis.spring.boot</groupId>
|
|
||||||
<artifactId>mybatis-spring-boot-starter-test</artifactId>
|
|
||||||
<version>${mybatis.test.version}</version>
|
|
||||||
<scope>test</scope>
|
|
||||||
</dependency>
|
|
||||||
|
|
||||||
<dependency>
|
|
||||||
<groupId>com.github.pagehelper</groupId>
|
|
||||||
<artifactId>pagehelper-spring-boot-starter</artifactId>
|
|
||||||
<version>${pagehelper.spring.version}</version>
|
|
||||||
</dependency>
|
|
||||||
|
|
||||||
<dependency>
|
<dependency>
|
||||||
<groupId>com.h2database</groupId>
|
<groupId>com.h2database</groupId>
|
||||||
|
|||||||
@@ -2,7 +2,7 @@ package com.tencent.supersonic.chat.agent.tool;
|
|||||||
|
|
||||||
public enum AgentToolType {
|
public enum AgentToolType {
|
||||||
RULE,
|
RULE,
|
||||||
DSL,
|
LLM_S2SQL,
|
||||||
PLUGIN,
|
PLUGIN,
|
||||||
INTERPRET
|
INTERPRET
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -0,0 +1,16 @@
|
|||||||
|
package com.tencent.supersonic.chat.agent.tool;
|
||||||
|
|
||||||
|
|
||||||
|
import java.util.List;
|
||||||
|
import lombok.AllArgsConstructor;
|
||||||
|
import lombok.Data;
|
||||||
|
import lombok.NoArgsConstructor;
|
||||||
|
|
||||||
|
@Data
|
||||||
|
@NoArgsConstructor
|
||||||
|
@AllArgsConstructor
|
||||||
|
public class CommonAgentTool extends AgentTool {
|
||||||
|
|
||||||
|
protected List<Long> modelIds;
|
||||||
|
|
||||||
|
}
|
||||||
@@ -5,9 +5,7 @@ import lombok.Data;
|
|||||||
import java.util.List;
|
import java.util.List;
|
||||||
|
|
||||||
@Data
|
@Data
|
||||||
public class DslTool extends AgentTool {
|
public class LLMParserTool extends CommonAgentTool {
|
||||||
|
|
||||||
private List<Long> modelIds;
|
|
||||||
|
|
||||||
private List<String> exampleQuestions;
|
private List<String> exampleQuestions;
|
||||||
|
|
||||||
@@ -7,12 +7,13 @@ import org.apache.commons.collections.CollectionUtils;
|
|||||||
import java.util.List;
|
import java.util.List;
|
||||||
|
|
||||||
@Data
|
@Data
|
||||||
public class RuleQueryTool extends AgentTool {
|
public class RuleQueryTool extends CommonAgentTool {
|
||||||
|
|
||||||
private List<Long> modelIds;
|
|
||||||
|
|
||||||
private List<String> queryModes;
|
private List<String> queryModes;
|
||||||
|
|
||||||
|
private List<String> queryTypes;
|
||||||
|
|
||||||
public boolean isContainsAllModel() {
|
public boolean isContainsAllModel() {
|
||||||
return CollectionUtils.isNotEmpty(modelIds) && modelIds.contains(-1L);
|
return CollectionUtils.isNotEmpty(modelIds) && modelIds.contains(-1L);
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,43 +1,161 @@
|
|||||||
package com.tencent.supersonic.chat.config;
|
package com.tencent.supersonic.chat.config;
|
||||||
|
|
||||||
|
import com.tencent.supersonic.common.service.SysParameterService;
|
||||||
import lombok.Data;
|
import lombok.Data;
|
||||||
|
import lombok.extern.slf4j.Slf4j;
|
||||||
|
import org.apache.commons.lang3.StringUtils;
|
||||||
|
import org.springframework.beans.factory.annotation.Autowired;
|
||||||
import org.springframework.beans.factory.annotation.Value;
|
import org.springframework.beans.factory.annotation.Value;
|
||||||
import org.springframework.context.annotation.Configuration;
|
import org.springframework.context.annotation.Configuration;
|
||||||
import org.springframework.context.annotation.PropertySource;
|
|
||||||
|
|
||||||
@Configuration
|
@Configuration
|
||||||
@Data
|
@Data
|
||||||
@PropertySource("classpath:optimization.properties")
|
@Slf4j
|
||||||
//@ComponentScan(basePackages = "com.tencent.supersonic.chat")
|
|
||||||
public class OptimizationConfig {
|
public class OptimizationConfig {
|
||||||
|
|
||||||
@Value("${one.detection.size}")
|
@Value("${one.detection.size:8}")
|
||||||
private Integer oneDetectionSize;
|
private Integer oneDetectionSize;
|
||||||
@Value("${one.detection.max.size}")
|
|
||||||
|
@Value("${one.detection.max.size:20}")
|
||||||
private Integer oneDetectionMaxSize;
|
private Integer oneDetectionMaxSize;
|
||||||
|
|
||||||
@Value("${metric.dimension.min.threshold}")
|
@Value("${metric.dimension.min.threshold:0.3}")
|
||||||
private Double metricDimensionMinThresholdConfig;
|
private Double metricDimensionMinThresholdConfig;
|
||||||
|
|
||||||
@Value("${metric.dimension.threshold}")
|
@Value("${metric.dimension.threshold:0.3}")
|
||||||
private Double metricDimensionThresholdConfig;
|
private Double metricDimensionThresholdConfig;
|
||||||
|
|
||||||
@Value("${dimension.value.threshold}")
|
@Value("${dimension.value.threshold:0.5}")
|
||||||
private Double dimensionValueThresholdConfig;
|
private Double dimensionValueThresholdConfig;
|
||||||
|
|
||||||
@Value("${function.bonus.threshold}")
|
@Value("${long.text.threshold:0.8}")
|
||||||
private Double functionBonusThreshold;
|
|
||||||
|
|
||||||
@Value("${long.text.threshold}")
|
|
||||||
private Double longTextThreshold;
|
private Double longTextThreshold;
|
||||||
|
|
||||||
@Value("${short.text.threshold}")
|
@Value("${short.text.threshold:0.5}")
|
||||||
private Double shortTextThreshold;
|
private Double shortTextThreshold;
|
||||||
|
|
||||||
@Value("${query.text.length.threshold}")
|
@Value("${query.text.length.threshold:10}")
|
||||||
private Integer queryTextLengthThreshold;
|
private Integer queryTextLengthThreshold;
|
||||||
|
@Value("${embedding.mapper.word.min:4}")
|
||||||
|
private int embeddingMapperWordMin;
|
||||||
|
|
||||||
@Value("${candidate.threshold}")
|
@Value("${embedding.mapper.word.max:5}")
|
||||||
private Double candidateThreshold;
|
private int embeddingMapperWordMax;
|
||||||
|
|
||||||
|
@Value("${embedding.mapper.batch:50}")
|
||||||
|
private int embeddingMapperBatch;
|
||||||
|
|
||||||
|
@Value("${embedding.mapper.number:5}")
|
||||||
|
private int embeddingMapperNumber;
|
||||||
|
|
||||||
|
@Value("${embedding.mapper.round.number:10}")
|
||||||
|
private int embeddingMapperRoundNumber;
|
||||||
|
|
||||||
|
@Value("${embedding.mapper.distance.threshold:0.58}")
|
||||||
|
private Double embeddingMapperDistanceThreshold;
|
||||||
|
|
||||||
|
@Value("${s2SQL.linking.value.switch:true}")
|
||||||
|
private boolean useLinkingValueSwitch;
|
||||||
|
|
||||||
|
@Value("${s2SQL.use.switch:true}")
|
||||||
|
private boolean useS2SqlSwitch;
|
||||||
|
|
||||||
|
@Value("${text2sql.example.num:10}")
|
||||||
|
private int text2sqlExampleNum;
|
||||||
|
|
||||||
|
@Value("${text2sql.fewShots.num:10}")
|
||||||
|
private int text2sqlFewShotsNum;
|
||||||
|
|
||||||
|
@Value("${text2sql.self.consistency.num:5}")
|
||||||
|
private int text2sqlSelfConsistencyNum;
|
||||||
|
|
||||||
|
@Value("${text2sql.collection.name:text2dsl_agent_collection}")
|
||||||
|
private String text2sqlCollectionName;
|
||||||
|
|
||||||
|
@Autowired
|
||||||
|
private SysParameterService sysParameterService;
|
||||||
|
|
||||||
|
public Integer getOneDetectionSize() {
|
||||||
|
return convertValue("one.detection.size", Integer.class, oneDetectionSize);
|
||||||
|
}
|
||||||
|
|
||||||
|
public Integer getOneDetectionMaxSize() {
|
||||||
|
return convertValue("one.detection.max.size", Integer.class, oneDetectionMaxSize);
|
||||||
|
}
|
||||||
|
|
||||||
|
public Double getMetricDimensionMinThresholdConfig() {
|
||||||
|
return convertValue("metric.dimension.min.threshold", Double.class, metricDimensionMinThresholdConfig);
|
||||||
|
}
|
||||||
|
|
||||||
|
public Double getMetricDimensionThresholdConfig() {
|
||||||
|
return convertValue("metric.dimension.threshold", Double.class, metricDimensionThresholdConfig);
|
||||||
|
}
|
||||||
|
|
||||||
|
public Double getDimensionValueThresholdConfig() {
|
||||||
|
return convertValue("dimension.value.threshold", Double.class, dimensionValueThresholdConfig);
|
||||||
|
}
|
||||||
|
|
||||||
|
public Double getLongTextThreshold() {
|
||||||
|
return convertValue("long.text.threshold", Double.class, longTextThreshold);
|
||||||
|
}
|
||||||
|
|
||||||
|
public Double getShortTextThreshold() {
|
||||||
|
return convertValue("short.text.threshold", Double.class, shortTextThreshold);
|
||||||
|
}
|
||||||
|
|
||||||
|
public Integer getQueryTextLengthThreshold() {
|
||||||
|
return convertValue("query.text.length.threshold", Integer.class, queryTextLengthThreshold);
|
||||||
|
}
|
||||||
|
|
||||||
|
public boolean isUseS2SqlSwitch() {
|
||||||
|
return convertValue("use.s2SQL.switch", Boolean.class, useS2SqlSwitch);
|
||||||
|
}
|
||||||
|
|
||||||
|
public Integer getEmbeddingMapperWordMin() {
|
||||||
|
return convertValue("embedding.mapper.word.min", Integer.class, embeddingMapperWordMin);
|
||||||
|
}
|
||||||
|
|
||||||
|
public Integer getEmbeddingMapperWordMax() {
|
||||||
|
return convertValue("embedding.mapper.word.max", Integer.class, embeddingMapperWordMax);
|
||||||
|
}
|
||||||
|
|
||||||
|
public Integer getEmbeddingMapperBatch() {
|
||||||
|
return convertValue("embedding.mapper.batch", Integer.class, embeddingMapperBatch);
|
||||||
|
}
|
||||||
|
|
||||||
|
public Integer getEmbeddingMapperNumber() {
|
||||||
|
return convertValue("embedding.mapper.number", Integer.class, embeddingMapperNumber);
|
||||||
|
}
|
||||||
|
|
||||||
|
public Integer getEmbeddingMapperRoundNumber() {
|
||||||
|
return convertValue("embedding.mapper.round.number", Integer.class, embeddingMapperRoundNumber);
|
||||||
|
}
|
||||||
|
|
||||||
|
public Double getEmbeddingMapperDistanceThreshold() {
|
||||||
|
return convertValue("embedding.mapper.distance.threshold", Double.class, embeddingMapperDistanceThreshold);
|
||||||
|
}
|
||||||
|
|
||||||
|
public boolean isUseLinkingValueSwitch() {
|
||||||
|
return convertValue("s2SQL.linking.value.switch", Boolean.class, useLinkingValueSwitch);
|
||||||
|
}
|
||||||
|
|
||||||
|
public <T> T convertValue(String paramName, Class<T> targetType, T defaultValue) {
|
||||||
|
try {
|
||||||
|
String value = sysParameterService.getSysParameter().getParameterByName(paramName);
|
||||||
|
if (StringUtils.isBlank(value)) {
|
||||||
|
return defaultValue;
|
||||||
|
}
|
||||||
|
if (targetType == Double.class) {
|
||||||
|
return targetType.cast(Double.parseDouble(value));
|
||||||
|
} else if (targetType == Integer.class) {
|
||||||
|
return targetType.cast(Integer.parseInt(value));
|
||||||
|
} else if (targetType == Boolean.class) {
|
||||||
|
return targetType.cast(Boolean.parseBoolean(value));
|
||||||
|
}
|
||||||
|
} catch (Exception e) {
|
||||||
|
log.error("convertValue", e);
|
||||||
|
}
|
||||||
|
return defaultValue;
|
||||||
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -2,14 +2,20 @@ package com.tencent.supersonic.chat.corrector;
|
|||||||
|
|
||||||
import com.tencent.supersonic.chat.api.component.SemanticCorrector;
|
import com.tencent.supersonic.chat.api.component.SemanticCorrector;
|
||||||
import com.tencent.supersonic.chat.api.pojo.SchemaElement;
|
import com.tencent.supersonic.chat.api.pojo.SchemaElement;
|
||||||
import com.tencent.supersonic.chat.api.pojo.SemanticCorrectInfo;
|
import com.tencent.supersonic.chat.api.pojo.SemanticParseInfo;
|
||||||
import com.tencent.supersonic.chat.api.pojo.SemanticSchema;
|
import com.tencent.supersonic.chat.api.pojo.SemanticSchema;
|
||||||
|
import com.tencent.supersonic.chat.api.pojo.request.QueryReq;
|
||||||
import com.tencent.supersonic.common.pojo.enums.AggregateTypeEnum;
|
import com.tencent.supersonic.common.pojo.enums.AggregateTypeEnum;
|
||||||
|
import com.tencent.supersonic.common.pojo.enums.TimeDimensionEnum;
|
||||||
import com.tencent.supersonic.common.util.ContextUtils;
|
import com.tencent.supersonic.common.util.ContextUtils;
|
||||||
import com.tencent.supersonic.common.util.DateUtils;
|
|
||||||
import com.tencent.supersonic.common.util.jsqlparser.SqlParserAddHelper;
|
import com.tencent.supersonic.common.util.jsqlparser.SqlParserAddHelper;
|
||||||
|
import com.tencent.supersonic.common.util.jsqlparser.SqlParserSelectFunctionHelper;
|
||||||
import com.tencent.supersonic.common.util.jsqlparser.SqlParserSelectHelper;
|
import com.tencent.supersonic.common.util.jsqlparser.SqlParserSelectHelper;
|
||||||
import com.tencent.supersonic.knowledge.service.SchemaService;
|
import com.tencent.supersonic.knowledge.service.SchemaService;
|
||||||
|
import lombok.extern.slf4j.Slf4j;
|
||||||
|
import org.apache.commons.lang3.StringUtils;
|
||||||
|
import org.springframework.util.CollectionUtils;
|
||||||
|
|
||||||
import java.util.ArrayList;
|
import java.util.ArrayList;
|
||||||
import java.util.HashSet;
|
import java.util.HashSet;
|
||||||
import java.util.List;
|
import java.util.List;
|
||||||
@@ -17,17 +23,30 @@ import java.util.Map;
|
|||||||
import java.util.Objects;
|
import java.util.Objects;
|
||||||
import java.util.Set;
|
import java.util.Set;
|
||||||
import java.util.stream.Collectors;
|
import java.util.stream.Collectors;
|
||||||
import lombok.extern.slf4j.Slf4j;
|
|
||||||
import org.springframework.util.CollectionUtils;
|
|
||||||
|
|
||||||
|
/**
|
||||||
|
* basic semantic correction functionality, offering common methods and an
|
||||||
|
* abstract method called doCorrect
|
||||||
|
*/
|
||||||
@Slf4j
|
@Slf4j
|
||||||
public abstract class BaseSemanticCorrector implements SemanticCorrector {
|
public abstract class BaseSemanticCorrector implements SemanticCorrector {
|
||||||
|
|
||||||
public void correct(SemanticCorrectInfo semanticCorrectInfo) {
|
public void correct(QueryReq queryReq, SemanticParseInfo semanticParseInfo) {
|
||||||
semanticCorrectInfo.setPreSql(semanticCorrectInfo.getSql());
|
try {
|
||||||
|
if (StringUtils.isBlank(semanticParseInfo.getSqlInfo().getCorrectS2SQL())) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
doCorrect(queryReq, semanticParseInfo);
|
||||||
|
log.info("sqlCorrection:{} sql:{}", this.getClass().getSimpleName(), semanticParseInfo.getSqlInfo());
|
||||||
|
} catch (Exception e) {
|
||||||
|
log.error(String.format("correct error,sqlInfo:%s", semanticParseInfo.getSqlInfo()), e);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
protected Map<String, String> getFieldNameMap(Long modelId) {
|
|
||||||
|
public abstract void doCorrect(QueryReq queryReq, SemanticParseInfo semanticParseInfo);
|
||||||
|
|
||||||
|
protected Map<String, String> getFieldNameMap(Set<Long> modelIds) {
|
||||||
|
|
||||||
SemanticSchema semanticSchema = ContextUtils.getBean(SchemaService.class).getSemanticSchema();
|
SemanticSchema semanticSchema = ContextUtils.getBean(SchemaService.class).getSemanticSchema();
|
||||||
|
|
||||||
@@ -35,34 +54,59 @@ public abstract class BaseSemanticCorrector implements SemanticCorrector {
|
|||||||
dbAllFields.addAll(semanticSchema.getMetrics());
|
dbAllFields.addAll(semanticSchema.getMetrics());
|
||||||
dbAllFields.addAll(semanticSchema.getDimensions());
|
dbAllFields.addAll(semanticSchema.getDimensions());
|
||||||
|
|
||||||
|
// support fieldName and field alias
|
||||||
Map<String, String> result = dbAllFields.stream()
|
Map<String, String> result = dbAllFields.stream()
|
||||||
.filter(entry -> entry.getModel().equals(modelId))
|
.filter(entry -> modelIds.contains(entry.getModel()))
|
||||||
.collect(Collectors.toMap(SchemaElement::getName, a -> a.getName(), (k1, k2) -> k1));
|
.flatMap(schemaElement -> {
|
||||||
result.put(DateUtils.DATE_FIELD, DateUtils.DATE_FIELD);
|
Set<String> elements = new HashSet<>();
|
||||||
|
elements.add(schemaElement.getName());
|
||||||
|
if (!CollectionUtils.isEmpty(schemaElement.getAlias())) {
|
||||||
|
elements.addAll(schemaElement.getAlias());
|
||||||
|
}
|
||||||
|
return elements.stream();
|
||||||
|
})
|
||||||
|
.collect(Collectors.toMap(a -> a, a -> a, (k1, k2) -> k1));
|
||||||
|
result.put(TimeDimensionEnum.DAY.getChName(), TimeDimensionEnum.DAY.getChName());
|
||||||
|
result.put(TimeDimensionEnum.MONTH.getChName(), TimeDimensionEnum.MONTH.getChName());
|
||||||
|
result.put(TimeDimensionEnum.WEEK.getChName(), TimeDimensionEnum.WEEK.getChName());
|
||||||
|
|
||||||
|
result.put(TimeDimensionEnum.DAY.getName(), TimeDimensionEnum.DAY.getChName());
|
||||||
|
result.put(TimeDimensionEnum.MONTH.getName(), TimeDimensionEnum.MONTH.getChName());
|
||||||
|
result.put(TimeDimensionEnum.WEEK.getName(), TimeDimensionEnum.WEEK.getChName());
|
||||||
|
|
||||||
return result;
|
return result;
|
||||||
}
|
}
|
||||||
|
|
||||||
protected void addFieldsToSelect(SemanticCorrectInfo semanticCorrectInfo, String sql) {
|
protected void addFieldsToSelect(SemanticParseInfo semanticParseInfo, String correctS2SQL) {
|
||||||
Set<String> selectFields = new HashSet<>(SqlParserSelectHelper.getSelectFields(sql));
|
Set<String> selectFields = new HashSet<>(SqlParserSelectHelper.getSelectFields(correctS2SQL));
|
||||||
Set<String> needAddFields = new HashSet<>(SqlParserSelectHelper.getGroupByFields(sql));
|
Set<String> needAddFields = new HashSet<>(SqlParserSelectHelper.getGroupByFields(correctS2SQL));
|
||||||
needAddFields.addAll(SqlParserSelectHelper.getOrderByFields(sql));
|
needAddFields.addAll(SqlParserSelectHelper.getOrderByFields(correctS2SQL));
|
||||||
|
|
||||||
|
// If there is no aggregate function in the S2SQL statement and
|
||||||
|
// there is a data field in 'WHERE' statement, add the field to the 'SELECT' statement.
|
||||||
|
if (!SqlParserSelectFunctionHelper.hasAggregateFunction(correctS2SQL)) {
|
||||||
|
List<String> whereFields = SqlParserSelectHelper.getWhereFields(correctS2SQL);
|
||||||
|
List<String> timeChNameList = TimeDimensionEnum.getChNameList();
|
||||||
|
Set<String> timeFields = whereFields.stream().filter(field -> timeChNameList.contains(field))
|
||||||
|
.collect(Collectors.toSet());
|
||||||
|
needAddFields.addAll(timeFields);
|
||||||
|
}
|
||||||
|
|
||||||
if (CollectionUtils.isEmpty(selectFields) || CollectionUtils.isEmpty(needAddFields)) {
|
if (CollectionUtils.isEmpty(selectFields) || CollectionUtils.isEmpty(needAddFields)) {
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
needAddFields.removeAll(selectFields);
|
needAddFields.removeAll(selectFields);
|
||||||
needAddFields.remove(DateUtils.DATE_FIELD);
|
String replaceFields = SqlParserAddHelper.addFieldsToSelect(correctS2SQL, new ArrayList<>(needAddFields));
|
||||||
String replaceFields = SqlParserAddHelper.addFieldsToSelect(sql, new ArrayList<>(needAddFields));
|
semanticParseInfo.getSqlInfo().setCorrectS2SQL(replaceFields);
|
||||||
semanticCorrectInfo.setSql(replaceFields);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
protected void addAggregateToMetric(SemanticCorrectInfo semanticCorrectInfo) {
|
protected void addAggregateToMetric(SemanticParseInfo semanticParseInfo) {
|
||||||
//add aggregate to all metric
|
//add aggregate to all metric
|
||||||
String sql = semanticCorrectInfo.getSql();
|
String correctS2SQL = semanticParseInfo.getSqlInfo().getCorrectS2SQL();
|
||||||
Long modelId = semanticCorrectInfo.getParseInfo().getModel().getModel();
|
Set<Long> modelIds = semanticParseInfo.getModel().getModelIds();
|
||||||
|
|
||||||
List<SchemaElement> metrics = getMetricElements(modelId);
|
List<SchemaElement> metrics = getMetricElements(modelIds);
|
||||||
|
|
||||||
Map<String, String> metricToAggregate = metrics.stream()
|
Map<String, String> metricToAggregate = metrics.stream()
|
||||||
.map(schemaElement -> {
|
.map(schemaElement -> {
|
||||||
@@ -75,18 +119,13 @@ public abstract class BaseSemanticCorrector implements SemanticCorrector {
|
|||||||
if (CollectionUtils.isEmpty(metricToAggregate)) {
|
if (CollectionUtils.isEmpty(metricToAggregate)) {
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
String aggregateSql = SqlParserAddHelper.addAggregateToField(correctS2SQL, metricToAggregate);
|
||||||
String aggregateSql = SqlParserAddHelper.addAggregateToField(sql, metricToAggregate);
|
semanticParseInfo.getSqlInfo().setCorrectS2SQL(aggregateSql);
|
||||||
semanticCorrectInfo.setSql(aggregateSql);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
protected List<SchemaElement> getMetricElements(Long modelId) {
|
protected List<SchemaElement> getMetricElements(Set<Long> modelIds) {
|
||||||
SemanticSchema semanticSchema = ContextUtils.getBean(SchemaService.class).getSemanticSchema();
|
SemanticSchema semanticSchema = ContextUtils.getBean(SchemaService.class).getSemanticSchema();
|
||||||
return semanticSchema.getMetrics(modelId);
|
return semanticSchema.getMetrics(modelIds);
|
||||||
}
|
}
|
||||||
|
|
||||||
protected List<SchemaElement> getDimensionElements(Long modelId) {
|
|
||||||
SemanticSchema semanticSchema = ContextUtils.getBean(SchemaService.class).getSemanticSchema();
|
|
||||||
return semanticSchema.getDimensions(modelId);
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -0,0 +1,17 @@
|
|||||||
|
package com.tencent.supersonic.chat.corrector;
|
||||||
|
|
||||||
|
import com.tencent.supersonic.chat.api.pojo.SemanticParseInfo;
|
||||||
|
import com.tencent.supersonic.chat.api.pojo.request.QueryReq;
|
||||||
|
import com.tencent.supersonic.common.util.jsqlparser.SqlParserReplaceHelper;
|
||||||
|
import lombok.extern.slf4j.Slf4j;
|
||||||
|
|
||||||
|
@Slf4j
|
||||||
|
public class FromCorrector extends BaseSemanticCorrector {
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public void doCorrect(QueryReq queryReq, SemanticParseInfo semanticParseInfo) {
|
||||||
|
String modelName = semanticParseInfo.getModel().getName();
|
||||||
|
SqlParserReplaceHelper.replaceTable(semanticParseInfo.getSqlInfo().getCorrectS2SQL(), modelName);
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
||||||
@@ -1,30 +0,0 @@
|
|||||||
package com.tencent.supersonic.chat.corrector;
|
|
||||||
|
|
||||||
import com.tencent.supersonic.chat.api.pojo.SemanticCorrectInfo;
|
|
||||||
import com.tencent.supersonic.common.util.jsqlparser.SqlParserAddHelper;
|
|
||||||
import com.tencent.supersonic.common.util.jsqlparser.SqlParserSelectFunctionHelper;
|
|
||||||
import com.tencent.supersonic.common.util.jsqlparser.SqlParserSelectHelper;
|
|
||||||
import java.util.Objects;
|
|
||||||
import lombok.extern.slf4j.Slf4j;
|
|
||||||
import net.sf.jsqlparser.expression.Expression;
|
|
||||||
|
|
||||||
@Slf4j
|
|
||||||
public class GlobalAfterCorrector extends BaseSemanticCorrector {
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public void correct(SemanticCorrectInfo semanticCorrectInfo) {
|
|
||||||
|
|
||||||
super.correct(semanticCorrectInfo);
|
|
||||||
String sql = semanticCorrectInfo.getSql();
|
|
||||||
if (!SqlParserSelectFunctionHelper.hasAggregateFunction(sql)) {
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
Expression havingExpression = SqlParserSelectHelper.getHavingExpression(sql);
|
|
||||||
if (Objects.nonNull(havingExpression)) {
|
|
||||||
String replaceSql = SqlParserAddHelper.addFunctionToSelect(sql, havingExpression);
|
|
||||||
semanticCorrectInfo.setSql(replaceSql);
|
|
||||||
}
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
|
|
||||||
}
|
|
||||||
@@ -1,96 +0,0 @@
|
|||||||
package com.tencent.supersonic.chat.corrector;
|
|
||||||
|
|
||||||
import com.tencent.supersonic.chat.api.pojo.SemanticCorrectInfo;
|
|
||||||
import com.tencent.supersonic.chat.parser.llm.dsl.DSLParseResult;
|
|
||||||
import com.tencent.supersonic.chat.query.llm.dsl.LLMReq;
|
|
||||||
import com.tencent.supersonic.chat.query.llm.dsl.LLMReq.ElementValue;
|
|
||||||
import com.tencent.supersonic.common.pojo.Constants;
|
|
||||||
import com.tencent.supersonic.common.util.JsonUtil;
|
|
||||||
import com.tencent.supersonic.common.util.jsqlparser.SqlParserReplaceHelper;
|
|
||||||
import java.util.List;
|
|
||||||
import java.util.Map;
|
|
||||||
import java.util.Objects;
|
|
||||||
import java.util.Set;
|
|
||||||
import java.util.stream.Collectors;
|
|
||||||
import lombok.extern.slf4j.Slf4j;
|
|
||||||
import org.springframework.util.CollectionUtils;
|
|
||||||
|
|
||||||
@Slf4j
|
|
||||||
public class GlobalBeforeCorrector extends BaseSemanticCorrector {
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public void correct(SemanticCorrectInfo semanticCorrectInfo) {
|
|
||||||
|
|
||||||
super.correct(semanticCorrectInfo);
|
|
||||||
|
|
||||||
replaceAlias(semanticCorrectInfo);
|
|
||||||
|
|
||||||
updateFieldNameByLinkingValue(semanticCorrectInfo);
|
|
||||||
|
|
||||||
updateFieldValueByLinkingValue(semanticCorrectInfo);
|
|
||||||
|
|
||||||
correctFieldName(semanticCorrectInfo);
|
|
||||||
}
|
|
||||||
|
|
||||||
private void replaceAlias(SemanticCorrectInfo semanticCorrectInfo) {
|
|
||||||
String replaceAlias = SqlParserReplaceHelper.replaceAlias(semanticCorrectInfo.getSql());
|
|
||||||
semanticCorrectInfo.setSql(replaceAlias);
|
|
||||||
}
|
|
||||||
|
|
||||||
private void correctFieldName(SemanticCorrectInfo semanticCorrectInfo) {
|
|
||||||
|
|
||||||
Map<String, String> fieldNameMap = getFieldNameMap(semanticCorrectInfo.getParseInfo().getModelId());
|
|
||||||
|
|
||||||
String sql = SqlParserReplaceHelper.replaceFields(semanticCorrectInfo.getSql(), fieldNameMap);
|
|
||||||
|
|
||||||
semanticCorrectInfo.setSql(sql);
|
|
||||||
}
|
|
||||||
|
|
||||||
private void updateFieldNameByLinkingValue(SemanticCorrectInfo semanticCorrectInfo) {
|
|
||||||
List<ElementValue> linking = getLinkingValues(semanticCorrectInfo);
|
|
||||||
if (CollectionUtils.isEmpty(linking)) {
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
|
|
||||||
Map<String, Set<String>> fieldValueToFieldNames = linking.stream().collect(
|
|
||||||
Collectors.groupingBy(ElementValue::getFieldValue,
|
|
||||||
Collectors.mapping(ElementValue::getFieldName, Collectors.toSet())));
|
|
||||||
|
|
||||||
String sql = SqlParserReplaceHelper.replaceFieldNameByValue(semanticCorrectInfo.getSql(),
|
|
||||||
fieldValueToFieldNames);
|
|
||||||
semanticCorrectInfo.setSql(sql);
|
|
||||||
}
|
|
||||||
|
|
||||||
private List<ElementValue> getLinkingValues(SemanticCorrectInfo semanticCorrectInfo) {
|
|
||||||
Object context = semanticCorrectInfo.getParseInfo().getProperties().get(Constants.CONTEXT);
|
|
||||||
if (Objects.isNull(context)) {
|
|
||||||
return null;
|
|
||||||
}
|
|
||||||
|
|
||||||
DSLParseResult dslParseResult = JsonUtil.toObject(JsonUtil.toString(context), DSLParseResult.class);
|
|
||||||
if (Objects.isNull(dslParseResult) || Objects.isNull(dslParseResult.getLlmReq())) {
|
|
||||||
return null;
|
|
||||||
}
|
|
||||||
LLMReq llmReq = dslParseResult.getLlmReq();
|
|
||||||
return llmReq.getLinking();
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
private void updateFieldValueByLinkingValue(SemanticCorrectInfo semanticCorrectInfo) {
|
|
||||||
List<ElementValue> linking = getLinkingValues(semanticCorrectInfo);
|
|
||||||
if (CollectionUtils.isEmpty(linking)) {
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
|
|
||||||
Map<String, Map<String, String>> filedNameToValueMap = linking.stream().collect(
|
|
||||||
Collectors.groupingBy(ElementValue::getFieldName,
|
|
||||||
Collectors.mapping(ElementValue::getFieldValue, Collectors.toMap(
|
|
||||||
oldValue -> oldValue,
|
|
||||||
newValue -> newValue,
|
|
||||||
(existingValue, newValue) -> newValue)
|
|
||||||
)));
|
|
||||||
|
|
||||||
String sql = SqlParserReplaceHelper.replaceValue(semanticCorrectInfo.getSql(), filedNameToValueMap, false);
|
|
||||||
semanticCorrectInfo.setSql(sql);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
@@ -1,46 +1,71 @@
|
|||||||
package com.tencent.supersonic.chat.corrector;
|
package com.tencent.supersonic.chat.corrector;
|
||||||
|
|
||||||
import com.tencent.supersonic.chat.api.pojo.SemanticCorrectInfo;
|
import com.tencent.supersonic.chat.api.pojo.SemanticParseInfo;
|
||||||
import com.tencent.supersonic.chat.api.pojo.SemanticSchema;
|
import com.tencent.supersonic.chat.api.pojo.SemanticSchema;
|
||||||
|
import com.tencent.supersonic.chat.api.pojo.request.QueryReq;
|
||||||
|
import com.tencent.supersonic.chat.api.pojo.response.SqlInfo;
|
||||||
|
import com.tencent.supersonic.common.pojo.enums.TimeDimensionEnum;
|
||||||
import com.tencent.supersonic.common.util.ContextUtils;
|
import com.tencent.supersonic.common.util.ContextUtils;
|
||||||
import com.tencent.supersonic.common.util.DateUtils;
|
|
||||||
import com.tencent.supersonic.common.util.jsqlparser.SqlParserAddHelper;
|
import com.tencent.supersonic.common.util.jsqlparser.SqlParserAddHelper;
|
||||||
import com.tencent.supersonic.common.util.jsqlparser.SqlParserSelectHelper;
|
import com.tencent.supersonic.common.util.jsqlparser.SqlParserSelectHelper;
|
||||||
import com.tencent.supersonic.knowledge.service.SchemaService;
|
import com.tencent.supersonic.knowledge.service.SchemaService;
|
||||||
import java.util.List;
|
|
||||||
import java.util.Set;
|
|
||||||
import java.util.stream.Collectors;
|
|
||||||
import lombok.extern.slf4j.Slf4j;
|
import lombok.extern.slf4j.Slf4j;
|
||||||
import org.springframework.util.CollectionUtils;
|
import org.springframework.util.CollectionUtils;
|
||||||
|
|
||||||
|
import java.util.HashSet;
|
||||||
|
import java.util.List;
|
||||||
|
import java.util.Set;
|
||||||
|
import java.util.stream.Collectors;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Perform SQL corrections on the "group by" section in S2SQL.
|
||||||
|
*/
|
||||||
@Slf4j
|
@Slf4j
|
||||||
public class GroupByCorrector extends BaseSemanticCorrector {
|
public class GroupByCorrector extends BaseSemanticCorrector {
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public void correct(SemanticCorrectInfo semanticCorrectInfo) {
|
public void doCorrect(QueryReq queryReq, SemanticParseInfo semanticParseInfo) {
|
||||||
|
|
||||||
super.correct(semanticCorrectInfo);
|
addGroupByFields(semanticParseInfo);
|
||||||
|
|
||||||
addGroupByFields(semanticCorrectInfo);
|
|
||||||
|
|
||||||
addAggregate(semanticCorrectInfo);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
private void addGroupByFields(SemanticCorrectInfo semanticCorrectInfo) {
|
private void addGroupByFields(SemanticParseInfo semanticParseInfo) {
|
||||||
Long modelId = semanticCorrectInfo.getParseInfo().getModel().getModel();
|
Set<Long> modelIds = semanticParseInfo.getModel().getModelIds();
|
||||||
|
|
||||||
//add dimension group by
|
//add dimension group by
|
||||||
String sql = semanticCorrectInfo.getSql();
|
SqlInfo sqlInfo = semanticParseInfo.getSqlInfo();
|
||||||
|
String correctS2SQL = sqlInfo.getCorrectS2SQL();
|
||||||
SemanticSchema semanticSchema = ContextUtils.getBean(SchemaService.class).getSemanticSchema();
|
SemanticSchema semanticSchema = ContextUtils.getBean(SchemaService.class).getSemanticSchema();
|
||||||
Set<String> dimensions = semanticSchema.getDimensions(modelId).stream()
|
//add alias field name
|
||||||
.map(schemaElement -> schemaElement.getName()).collect(Collectors.toSet());
|
Set<String> dimensions = semanticSchema.getDimensions(modelIds).stream()
|
||||||
dimensions.add(DateUtils.DATE_FIELD);
|
.flatMap(
|
||||||
List<String> selectFields = SqlParserSelectHelper.getSelectFields(sql);
|
schemaElement -> {
|
||||||
|
Set<String> elements = new HashSet<>();
|
||||||
|
elements.add(schemaElement.getName());
|
||||||
|
if (!CollectionUtils.isEmpty(schemaElement.getAlias())) {
|
||||||
|
elements.addAll(schemaElement.getAlias());
|
||||||
|
}
|
||||||
|
return elements.stream();
|
||||||
|
}
|
||||||
|
).collect(Collectors.toSet());
|
||||||
|
dimensions.add(TimeDimensionEnum.DAY.getChName());
|
||||||
|
|
||||||
|
List<String> selectFields = SqlParserSelectHelper.getSelectFields(correctS2SQL);
|
||||||
|
|
||||||
if (CollectionUtils.isEmpty(selectFields) || CollectionUtils.isEmpty(dimensions)) {
|
if (CollectionUtils.isEmpty(selectFields) || CollectionUtils.isEmpty(dimensions)) {
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
List<String> aggregateFields = SqlParserSelectHelper.getAggregateFields(sql);
|
// if only date in select not add group by.
|
||||||
|
if (selectFields.size() == 1 && selectFields.contains(TimeDimensionEnum.DAY.getChName())) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
if (SqlParserSelectHelper.hasGroupBy(correctS2SQL)) {
|
||||||
|
log.info("not add group by ,exist group by in correctS2SQL:{}", correctS2SQL);
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
List<String> aggregateFields = SqlParserSelectHelper.getAggregateFields(correctS2SQL);
|
||||||
Set<String> groupByFields = selectFields.stream()
|
Set<String> groupByFields = selectFields.stream()
|
||||||
.filter(field -> dimensions.contains(field))
|
.filter(field -> dimensions.contains(field))
|
||||||
.filter(field -> {
|
.filter(field -> {
|
||||||
@@ -50,14 +75,17 @@ public class GroupByCorrector extends BaseSemanticCorrector {
|
|||||||
return true;
|
return true;
|
||||||
})
|
})
|
||||||
.collect(Collectors.toSet());
|
.collect(Collectors.toSet());
|
||||||
semanticCorrectInfo.setSql(SqlParserAddHelper.addGroupBy(sql, groupByFields));
|
semanticParseInfo.getSqlInfo().setCorrectS2SQL(SqlParserAddHelper.addGroupBy(correctS2SQL, groupByFields));
|
||||||
|
|
||||||
|
addAggregate(semanticParseInfo);
|
||||||
}
|
}
|
||||||
|
|
||||||
private void addAggregate(SemanticCorrectInfo semanticCorrectInfo) {
|
private void addAggregate(SemanticParseInfo semanticParseInfo) {
|
||||||
List<String> sqlGroupByFields = SqlParserSelectHelper.getGroupByFields(semanticCorrectInfo.getSql());
|
List<String> sqlGroupByFields = SqlParserSelectHelper.getGroupByFields(
|
||||||
|
semanticParseInfo.getSqlInfo().getCorrectS2SQL());
|
||||||
if (CollectionUtils.isEmpty(sqlGroupByFields)) {
|
if (CollectionUtils.isEmpty(sqlGroupByFields)) {
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
addAggregateToMetric(semanticCorrectInfo);
|
addAggregateToMetric(semanticParseInfo);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,37 +1,66 @@
|
|||||||
package com.tencent.supersonic.chat.corrector;
|
package com.tencent.supersonic.chat.corrector;
|
||||||
|
|
||||||
import com.tencent.supersonic.chat.api.pojo.SemanticCorrectInfo;
|
import com.tencent.supersonic.chat.api.pojo.SemanticParseInfo;
|
||||||
import com.tencent.supersonic.chat.api.pojo.SemanticSchema;
|
import com.tencent.supersonic.chat.api.pojo.SemanticSchema;
|
||||||
|
import com.tencent.supersonic.chat.api.pojo.request.QueryReq;
|
||||||
import com.tencent.supersonic.common.util.ContextUtils;
|
import com.tencent.supersonic.common.util.ContextUtils;
|
||||||
import com.tencent.supersonic.common.util.jsqlparser.SqlParserAddHelper;
|
import com.tencent.supersonic.common.util.jsqlparser.SqlParserAddHelper;
|
||||||
|
import com.tencent.supersonic.common.util.jsqlparser.SqlParserSelectFunctionHelper;
|
||||||
|
import com.tencent.supersonic.common.util.jsqlparser.SqlParserSelectHelper;
|
||||||
import com.tencent.supersonic.knowledge.service.SchemaService;
|
import com.tencent.supersonic.knowledge.service.SchemaService;
|
||||||
|
|
||||||
import java.util.Set;
|
import java.util.Set;
|
||||||
|
import java.util.List;
|
||||||
import java.util.stream.Collectors;
|
import java.util.stream.Collectors;
|
||||||
|
|
||||||
import lombok.extern.slf4j.Slf4j;
|
import lombok.extern.slf4j.Slf4j;
|
||||||
|
import net.sf.jsqlparser.expression.Expression;
|
||||||
import org.springframework.util.CollectionUtils;
|
import org.springframework.util.CollectionUtils;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Perform SQL corrections on the "Having" section in S2SQL.
|
||||||
|
*/
|
||||||
@Slf4j
|
@Slf4j
|
||||||
public class HavingCorrector extends BaseSemanticCorrector {
|
public class HavingCorrector extends BaseSemanticCorrector {
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public void correct(SemanticCorrectInfo semanticCorrectInfo) {
|
public void doCorrect(QueryReq queryReq, SemanticParseInfo semanticParseInfo) {
|
||||||
|
|
||||||
super.correct(semanticCorrectInfo);
|
|
||||||
|
|
||||||
//add aggregate to all metric
|
//add aggregate to all metric
|
||||||
semanticCorrectInfo.setPreSql(semanticCorrectInfo.getSql());
|
addHaving(semanticParseInfo);
|
||||||
Long modelId = semanticCorrectInfo.getParseInfo().getModel().getModel();
|
|
||||||
|
//add having expression filed to select
|
||||||
|
addHavingToSelect(semanticParseInfo);
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
private void addHaving(SemanticParseInfo semanticParseInfo) {
|
||||||
|
Set<Long> modelIds = semanticParseInfo.getModel().getModelIds();
|
||||||
|
|
||||||
SemanticSchema semanticSchema = ContextUtils.getBean(SchemaService.class).getSemanticSchema();
|
SemanticSchema semanticSchema = ContextUtils.getBean(SchemaService.class).getSemanticSchema();
|
||||||
|
|
||||||
Set<String> metrics = semanticSchema.getMetrics(modelId).stream()
|
Set<String> metrics = semanticSchema.getMetrics(modelIds).stream()
|
||||||
.map(schemaElement -> schemaElement.getName()).collect(Collectors.toSet());
|
.map(schemaElement -> schemaElement.getName()).collect(Collectors.toSet());
|
||||||
|
|
||||||
if (CollectionUtils.isEmpty(metrics)) {
|
if (CollectionUtils.isEmpty(metrics)) {
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
String havingSql = SqlParserAddHelper.addHaving(semanticCorrectInfo.getSql(), metrics);
|
String havingSql = SqlParserAddHelper.addHaving(semanticParseInfo.getSqlInfo().getCorrectS2SQL(), metrics);
|
||||||
semanticCorrectInfo.setSql(havingSql);
|
semanticParseInfo.getSqlInfo().setCorrectS2SQL(havingSql);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
private void addHavingToSelect(SemanticParseInfo semanticParseInfo) {
|
||||||
|
String correctS2SQL = semanticParseInfo.getSqlInfo().getCorrectS2SQL();
|
||||||
|
if (!SqlParserSelectFunctionHelper.hasAggregateFunction(correctS2SQL)) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
List<Expression> havingExpressionList = SqlParserSelectHelper.getHavingExpression(correctS2SQL);
|
||||||
|
if (!CollectionUtils.isEmpty(havingExpressionList)) {
|
||||||
|
String replaceSql = SqlParserAddHelper.addFunctionToSelect(correctS2SQL, havingExpressionList);
|
||||||
|
semanticParseInfo.getSqlInfo().setCorrectS2SQL(replaceSql);
|
||||||
|
}
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -0,0 +1,108 @@
|
|||||||
|
package com.tencent.supersonic.chat.corrector;
|
||||||
|
|
||||||
|
import com.tencent.supersonic.chat.api.pojo.SemanticParseInfo;
|
||||||
|
import com.tencent.supersonic.chat.api.pojo.request.QueryReq;
|
||||||
|
import com.tencent.supersonic.chat.api.pojo.response.SqlInfo;
|
||||||
|
import com.tencent.supersonic.chat.parser.llm.s2sql.ParseResult;
|
||||||
|
import com.tencent.supersonic.chat.query.llm.s2sql.LLMReq.ElementValue;
|
||||||
|
import com.tencent.supersonic.common.pojo.Constants;
|
||||||
|
import com.tencent.supersonic.common.util.JsonUtil;
|
||||||
|
import com.tencent.supersonic.common.util.jsqlparser.AggregateEnum;
|
||||||
|
import com.tencent.supersonic.common.util.jsqlparser.SqlParserReplaceHelper;
|
||||||
|
import java.util.List;
|
||||||
|
import java.util.Map;
|
||||||
|
import java.util.Objects;
|
||||||
|
import java.util.Set;
|
||||||
|
import java.util.stream.Collectors;
|
||||||
|
import lombok.extern.slf4j.Slf4j;
|
||||||
|
import org.springframework.util.CollectionUtils;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Perform schema corrections on the Schema information in S2QL.
|
||||||
|
*/
|
||||||
|
@Slf4j
|
||||||
|
public class SchemaCorrector extends BaseSemanticCorrector {
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public void doCorrect(QueryReq queryReq, SemanticParseInfo semanticParseInfo) {
|
||||||
|
|
||||||
|
correctAggFunction(semanticParseInfo);
|
||||||
|
|
||||||
|
replaceAlias(semanticParseInfo);
|
||||||
|
|
||||||
|
updateFieldNameByLinkingValue(semanticParseInfo);
|
||||||
|
|
||||||
|
updateFieldValueByLinkingValue(semanticParseInfo);
|
||||||
|
|
||||||
|
correctFieldName(semanticParseInfo);
|
||||||
|
}
|
||||||
|
|
||||||
|
private void correctAggFunction(SemanticParseInfo semanticParseInfo) {
|
||||||
|
Map<String, String> aggregateEnum = AggregateEnum.getAggregateEnum();
|
||||||
|
SqlInfo sqlInfo = semanticParseInfo.getSqlInfo();
|
||||||
|
String sql = SqlParserReplaceHelper.replaceFunction(sqlInfo.getCorrectS2SQL(), aggregateEnum);
|
||||||
|
sqlInfo.setCorrectS2SQL(sql);
|
||||||
|
}
|
||||||
|
|
||||||
|
private void replaceAlias(SemanticParseInfo semanticParseInfo) {
|
||||||
|
SqlInfo sqlInfo = semanticParseInfo.getSqlInfo();
|
||||||
|
String replaceAlias = SqlParserReplaceHelper.replaceAlias(sqlInfo.getCorrectS2SQL());
|
||||||
|
sqlInfo.setCorrectS2SQL(replaceAlias);
|
||||||
|
}
|
||||||
|
|
||||||
|
private void correctFieldName(SemanticParseInfo semanticParseInfo) {
|
||||||
|
Map<String, String> fieldNameMap = getFieldNameMap(semanticParseInfo.getModel().getModelIds());
|
||||||
|
SqlInfo sqlInfo = semanticParseInfo.getSqlInfo();
|
||||||
|
String sql = SqlParserReplaceHelper.replaceFields(sqlInfo.getCorrectS2SQL(), fieldNameMap);
|
||||||
|
sqlInfo.setCorrectS2SQL(sql);
|
||||||
|
}
|
||||||
|
|
||||||
|
private void updateFieldNameByLinkingValue(SemanticParseInfo semanticParseInfo) {
|
||||||
|
List<ElementValue> linking = getLinkingValues(semanticParseInfo);
|
||||||
|
if (CollectionUtils.isEmpty(linking)) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
Map<String, Set<String>> fieldValueToFieldNames = linking.stream().collect(
|
||||||
|
Collectors.groupingBy(ElementValue::getFieldValue,
|
||||||
|
Collectors.mapping(ElementValue::getFieldName, Collectors.toSet())));
|
||||||
|
|
||||||
|
SqlInfo sqlInfo = semanticParseInfo.getSqlInfo();
|
||||||
|
|
||||||
|
String sql = SqlParserReplaceHelper.replaceFieldNameByValue(sqlInfo.getCorrectS2SQL(), fieldValueToFieldNames);
|
||||||
|
sqlInfo.setCorrectS2SQL(sql);
|
||||||
|
}
|
||||||
|
|
||||||
|
private List<ElementValue> getLinkingValues(SemanticParseInfo semanticParseInfo) {
|
||||||
|
Object context = semanticParseInfo.getProperties().get(Constants.CONTEXT);
|
||||||
|
if (Objects.isNull(context)) {
|
||||||
|
return null;
|
||||||
|
}
|
||||||
|
|
||||||
|
ParseResult parseResult = JsonUtil.toObject(JsonUtil.toString(context), ParseResult.class);
|
||||||
|
if (Objects.isNull(parseResult) || Objects.isNull(parseResult.getLlmReq())) {
|
||||||
|
return null;
|
||||||
|
}
|
||||||
|
return parseResult.getLinkingValues();
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
private void updateFieldValueByLinkingValue(SemanticParseInfo semanticParseInfo) {
|
||||||
|
List<ElementValue> linking = getLinkingValues(semanticParseInfo);
|
||||||
|
if (CollectionUtils.isEmpty(linking)) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
Map<String, Map<String, String>> filedNameToValueMap = linking.stream().collect(
|
||||||
|
Collectors.groupingBy(ElementValue::getFieldName,
|
||||||
|
Collectors.mapping(ElementValue::getFieldValue, Collectors.toMap(
|
||||||
|
oldValue -> oldValue,
|
||||||
|
newValue -> newValue,
|
||||||
|
(existingValue, newValue) -> newValue)
|
||||||
|
)));
|
||||||
|
|
||||||
|
SqlInfo sqlInfo = semanticParseInfo.getSqlInfo();
|
||||||
|
String sql = SqlParserReplaceHelper.replaceValue(sqlInfo.getCorrectS2SQL(), filedNameToValueMap, false);
|
||||||
|
sqlInfo.setCorrectS2SQL(sql);
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -1,26 +1,29 @@
|
|||||||
package com.tencent.supersonic.chat.corrector;
|
package com.tencent.supersonic.chat.corrector;
|
||||||
|
|
||||||
import com.tencent.supersonic.chat.api.pojo.SemanticCorrectInfo;
|
import com.tencent.supersonic.chat.api.pojo.SemanticParseInfo;
|
||||||
|
import com.tencent.supersonic.chat.api.pojo.request.QueryReq;
|
||||||
import com.tencent.supersonic.common.util.jsqlparser.SqlParserSelectHelper;
|
import com.tencent.supersonic.common.util.jsqlparser.SqlParserSelectHelper;
|
||||||
import java.util.List;
|
import java.util.List;
|
||||||
import lombok.extern.slf4j.Slf4j;
|
import lombok.extern.slf4j.Slf4j;
|
||||||
import org.springframework.util.CollectionUtils;
|
import org.springframework.util.CollectionUtils;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Perform SQL corrections on the "Select" section in S2SQL.
|
||||||
|
*/
|
||||||
@Slf4j
|
@Slf4j
|
||||||
public class SelectCorrector extends BaseSemanticCorrector {
|
public class SelectCorrector extends BaseSemanticCorrector {
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public void correct(SemanticCorrectInfo semanticCorrectInfo) {
|
public void doCorrect(QueryReq queryReq, SemanticParseInfo semanticParseInfo) {
|
||||||
super.correct(semanticCorrectInfo);
|
String correctS2SQL = semanticParseInfo.getSqlInfo().getCorrectS2SQL();
|
||||||
String sql = semanticCorrectInfo.getSql();
|
List<String> aggregateFields = SqlParserSelectHelper.getAggregateFields(correctS2SQL);
|
||||||
List<String> aggregateFields = SqlParserSelectHelper.getAggregateFields(sql);
|
List<String> selectFields = SqlParserSelectHelper.getSelectFields(correctS2SQL);
|
||||||
List<String> selectFields = SqlParserSelectHelper.getSelectFields(sql);
|
|
||||||
// If the number of aggregated fields is equal to the number of queried fields, do not add fields to select.
|
// If the number of aggregated fields is equal to the number of queried fields, do not add fields to select.
|
||||||
if (!CollectionUtils.isEmpty(aggregateFields)
|
if (!CollectionUtils.isEmpty(aggregateFields)
|
||||||
&& !CollectionUtils.isEmpty(selectFields)
|
&& !CollectionUtils.isEmpty(selectFields)
|
||||||
&& aggregateFields.size() == selectFields.size()) {
|
&& aggregateFields.size() == selectFields.size()) {
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
addFieldsToSelect(semanticCorrectInfo, sql);
|
addFieldsToSelect(semanticParseInfo, correctS2SQL);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -2,23 +2,19 @@ package com.tencent.supersonic.chat.corrector;
|
|||||||
|
|
||||||
import com.tencent.supersonic.chat.api.pojo.SchemaElement;
|
import com.tencent.supersonic.chat.api.pojo.SchemaElement;
|
||||||
import com.tencent.supersonic.chat.api.pojo.SchemaValueMap;
|
import com.tencent.supersonic.chat.api.pojo.SchemaValueMap;
|
||||||
import com.tencent.supersonic.chat.api.pojo.SemanticCorrectInfo;
|
import com.tencent.supersonic.chat.api.pojo.SemanticParseInfo;
|
||||||
import com.tencent.supersonic.chat.api.pojo.SemanticSchema;
|
import com.tencent.supersonic.chat.api.pojo.SemanticSchema;
|
||||||
import com.tencent.supersonic.chat.api.pojo.request.QueryFilters;
|
import com.tencent.supersonic.chat.api.pojo.request.QueryFilters;
|
||||||
import com.tencent.supersonic.chat.parser.llm.dsl.DSLDateHelper;
|
import com.tencent.supersonic.chat.api.pojo.request.QueryReq;
|
||||||
|
import com.tencent.supersonic.chat.parser.llm.s2sql.S2SQLDateHelper;
|
||||||
import com.tencent.supersonic.common.pojo.Constants;
|
import com.tencent.supersonic.common.pojo.Constants;
|
||||||
|
import com.tencent.supersonic.common.pojo.enums.TimeDimensionEnum;
|
||||||
import com.tencent.supersonic.common.util.ContextUtils;
|
import com.tencent.supersonic.common.util.ContextUtils;
|
||||||
import com.tencent.supersonic.common.util.DateUtils;
|
|
||||||
import com.tencent.supersonic.common.util.StringUtil;
|
import com.tencent.supersonic.common.util.StringUtil;
|
||||||
import com.tencent.supersonic.common.util.jsqlparser.SqlParserAddHelper;
|
import com.tencent.supersonic.common.util.jsqlparser.SqlParserAddHelper;
|
||||||
import com.tencent.supersonic.common.util.jsqlparser.SqlParserReplaceHelper;
|
import com.tencent.supersonic.common.util.jsqlparser.SqlParserReplaceHelper;
|
||||||
import com.tencent.supersonic.common.util.jsqlparser.SqlParserSelectHelper;
|
import com.tencent.supersonic.common.util.jsqlparser.SqlParserSelectHelper;
|
||||||
import com.tencent.supersonic.knowledge.service.SchemaService;
|
import com.tencent.supersonic.knowledge.service.SchemaService;
|
||||||
import java.util.HashMap;
|
|
||||||
import java.util.List;
|
|
||||||
import java.util.Map;
|
|
||||||
import java.util.Objects;
|
|
||||||
import java.util.stream.Collectors;
|
|
||||||
import lombok.extern.slf4j.Slf4j;
|
import lombok.extern.slf4j.Slf4j;
|
||||||
import net.sf.jsqlparser.JSQLParserException;
|
import net.sf.jsqlparser.JSQLParserException;
|
||||||
import net.sf.jsqlparser.expression.Expression;
|
import net.sf.jsqlparser.expression.Expression;
|
||||||
@@ -27,56 +23,67 @@ import org.apache.commons.lang3.StringUtils;
|
|||||||
import org.apache.logging.log4j.util.Strings;
|
import org.apache.logging.log4j.util.Strings;
|
||||||
import org.springframework.util.CollectionUtils;
|
import org.springframework.util.CollectionUtils;
|
||||||
|
|
||||||
|
import java.util.HashMap;
|
||||||
|
import java.util.List;
|
||||||
|
import java.util.Map;
|
||||||
|
import java.util.Objects;
|
||||||
|
import java.util.Set;
|
||||||
|
import java.util.stream.Collectors;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Perform SQL corrections on the "Where" section in S2SQL.
|
||||||
|
*/
|
||||||
@Slf4j
|
@Slf4j
|
||||||
public class WhereCorrector extends BaseSemanticCorrector {
|
public class WhereCorrector extends BaseSemanticCorrector {
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public void correct(SemanticCorrectInfo semanticCorrectInfo) {
|
public void doCorrect(QueryReq queryReq, SemanticParseInfo semanticParseInfo) {
|
||||||
|
|
||||||
super.correct(semanticCorrectInfo);
|
addDateIfNotExist(semanticParseInfo);
|
||||||
|
|
||||||
addDateIfNotExist(semanticCorrectInfo);
|
parserDateDiffFunction(semanticParseInfo);
|
||||||
|
|
||||||
parserDateDiffFunction(semanticCorrectInfo);
|
addQueryFilter(queryReq, semanticParseInfo);
|
||||||
|
|
||||||
addQueryFilter(semanticCorrectInfo);
|
updateFieldValueByTechName(semanticParseInfo);
|
||||||
|
|
||||||
updateFieldValueByTechName(semanticCorrectInfo);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
private void addQueryFilter(SemanticCorrectInfo semanticCorrectInfo) {
|
private void addQueryFilter(QueryReq queryReq, SemanticParseInfo semanticParseInfo) {
|
||||||
String queryFilter = getQueryFilter(semanticCorrectInfo.getQueryFilters());
|
String queryFilter = getQueryFilter(queryReq.getQueryFilters());
|
||||||
|
|
||||||
String preSql = semanticCorrectInfo.getSql();
|
String correctS2SQL = semanticParseInfo.getSqlInfo().getCorrectS2SQL();
|
||||||
|
|
||||||
if (StringUtils.isNotEmpty(queryFilter)) {
|
if (StringUtils.isNotEmpty(queryFilter)) {
|
||||||
log.info("add queryFilter to preSql :{}", queryFilter);
|
log.info("add queryFilter to correctS2SQL :{}", queryFilter);
|
||||||
Expression expression = null;
|
Expression expression = null;
|
||||||
try {
|
try {
|
||||||
expression = CCJSqlParserUtil.parseCondExpression(queryFilter);
|
expression = CCJSqlParserUtil.parseCondExpression(queryFilter);
|
||||||
} catch (JSQLParserException e) {
|
} catch (JSQLParserException e) {
|
||||||
log.error("parseCondExpression", e);
|
log.error("parseCondExpression", e);
|
||||||
}
|
}
|
||||||
String sql = SqlParserAddHelper.addWhere(preSql, expression);
|
correctS2SQL = SqlParserAddHelper.addWhere(correctS2SQL, expression);
|
||||||
semanticCorrectInfo.setSql(sql);
|
semanticParseInfo.getSqlInfo().setCorrectS2SQL(correctS2SQL);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
private void parserDateDiffFunction(SemanticCorrectInfo semanticCorrectInfo) {
|
private void parserDateDiffFunction(SemanticParseInfo semanticParseInfo) {
|
||||||
String sql = semanticCorrectInfo.getSql();
|
String correctS2SQL = semanticParseInfo.getSqlInfo().getCorrectS2SQL();
|
||||||
sql = SqlParserReplaceHelper.replaceFunction(sql);
|
correctS2SQL = SqlParserReplaceHelper.replaceFunction(correctS2SQL);
|
||||||
semanticCorrectInfo.setSql(sql);
|
semanticParseInfo.getSqlInfo().setCorrectS2SQL(correctS2SQL);
|
||||||
}
|
}
|
||||||
|
|
||||||
private void addDateIfNotExist(SemanticCorrectInfo semanticCorrectInfo) {
|
private void addDateIfNotExist(SemanticParseInfo semanticParseInfo) {
|
||||||
String sql = semanticCorrectInfo.getSql();
|
String correctS2SQL = semanticParseInfo.getSqlInfo().getCorrectS2SQL();
|
||||||
List<String> whereFields = SqlParserSelectHelper.getWhereFields(sql);
|
List<String> whereFields = SqlParserSelectHelper.getWhereFields(correctS2SQL);
|
||||||
if (CollectionUtils.isEmpty(whereFields) || !whereFields.contains(DateUtils.DATE_FIELD)) {
|
if (CollectionUtils.isEmpty(whereFields) || !TimeDimensionEnum.containsZhTimeDimension(whereFields)) {
|
||||||
String currentDate = DSLDateHelper.getReferenceDate(semanticCorrectInfo.getParseInfo().getModelId());
|
String currentDate = S2SQLDateHelper.getReferenceDate(semanticParseInfo.getModelId());
|
||||||
sql = SqlParserAddHelper.addParenthesisToWhere(sql);
|
if (StringUtils.isNotBlank(currentDate)) {
|
||||||
sql = SqlParserAddHelper.addWhere(sql, DateUtils.DATE_FIELD, currentDate);
|
correctS2SQL = SqlParserAddHelper.addParenthesisToWhere(correctS2SQL);
|
||||||
|
correctS2SQL = SqlParserAddHelper.addWhere(
|
||||||
|
correctS2SQL, TimeDimensionEnum.DAY.getChName(), currentDate);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
semanticCorrectInfo.setSql(sql);
|
semanticParseInfo.getSqlInfo().setCorrectS2SQL(correctS2SQL);
|
||||||
}
|
}
|
||||||
|
|
||||||
private String getQueryFilter(QueryFilters queryFilters) {
|
private String getQueryFilter(QueryFilters queryFilters) {
|
||||||
@@ -93,21 +100,19 @@ public class WhereCorrector extends BaseSemanticCorrector {
|
|||||||
.collect(Collectors.joining(Constants.AND_UPPER));
|
.collect(Collectors.joining(Constants.AND_UPPER));
|
||||||
}
|
}
|
||||||
|
|
||||||
private void updateFieldValueByTechName(SemanticCorrectInfo semanticCorrectInfo) {
|
private void updateFieldValueByTechName(SemanticParseInfo semanticParseInfo) {
|
||||||
SemanticSchema semanticSchema = ContextUtils.getBean(SchemaService.class).getSemanticSchema();
|
SemanticSchema semanticSchema = ContextUtils.getBean(SchemaService.class).getSemanticSchema();
|
||||||
Long modelId = semanticCorrectInfo.getParseInfo().getModel().getId();
|
Set<Long> modelIds = semanticParseInfo.getModel().getModelIds();
|
||||||
List<SchemaElement> dimensions = semanticSchema.getDimensions().stream()
|
List<SchemaElement> dimensions = semanticSchema.getDimensions(modelIds);
|
||||||
.filter(schemaElement -> modelId.equals(schemaElement.getModel()))
|
|
||||||
.collect(Collectors.toList());
|
|
||||||
|
|
||||||
if (CollectionUtils.isEmpty(dimensions)) {
|
if (CollectionUtils.isEmpty(dimensions)) {
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
Map<String, Map<String, String>> aliasAndBizNameToTechName = getAliasAndBizNameToTechName(dimensions);
|
Map<String, Map<String, String>> aliasAndBizNameToTechName = getAliasAndBizNameToTechName(dimensions);
|
||||||
String sql = SqlParserReplaceHelper.replaceValue(semanticCorrectInfo.getSql(), aliasAndBizNameToTechName);
|
String correctS2SQL = SqlParserReplaceHelper.replaceValue(semanticParseInfo.getSqlInfo().getCorrectS2SQL(),
|
||||||
semanticCorrectInfo.setSql(sql);
|
aliasAndBizNameToTechName);
|
||||||
return;
|
semanticParseInfo.getSqlInfo().setCorrectS2SQL(correctS2SQL);
|
||||||
}
|
}
|
||||||
|
|
||||||
private Map<String, Map<String, String>> getAliasAndBizNameToTechName(List<SchemaElement> dimensions) {
|
private Map<String, Map<String, String>> getAliasAndBizNameToTechName(List<SchemaElement> dimensions) {
|
||||||
|
|||||||
@@ -0,0 +1,104 @@
|
|||||||
|
package com.tencent.supersonic.chat.mapper;
|
||||||
|
|
||||||
|
import com.tencent.supersonic.chat.api.component.SchemaMapper;
|
||||||
|
import com.tencent.supersonic.chat.api.pojo.ModelSchema;
|
||||||
|
import com.tencent.supersonic.chat.api.pojo.QueryContext;
|
||||||
|
import com.tencent.supersonic.chat.api.pojo.SchemaElement;
|
||||||
|
import com.tencent.supersonic.chat.api.pojo.SchemaElementMatch;
|
||||||
|
import com.tencent.supersonic.chat.api.pojo.SchemaElementType;
|
||||||
|
import com.tencent.supersonic.chat.api.pojo.SchemaMapInfo;
|
||||||
|
import com.tencent.supersonic.chat.service.SemanticService;
|
||||||
|
import com.tencent.supersonic.common.util.ContextUtils;
|
||||||
|
import java.util.ArrayList;
|
||||||
|
import java.util.List;
|
||||||
|
import java.util.Map;
|
||||||
|
import java.util.Objects;
|
||||||
|
import java.util.concurrent.atomic.AtomicBoolean;
|
||||||
|
import java.util.stream.Collectors;
|
||||||
|
import lombok.extern.slf4j.Slf4j;
|
||||||
|
import org.apache.commons.lang3.StringUtils;
|
||||||
|
import org.springframework.beans.BeanUtils;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* base Mapper
|
||||||
|
*/
|
||||||
|
@Slf4j
|
||||||
|
public abstract class BaseMapper implements SchemaMapper {
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public void map(QueryContext queryContext) {
|
||||||
|
|
||||||
|
String simpleName = this.getClass().getSimpleName();
|
||||||
|
long startTime = System.currentTimeMillis();
|
||||||
|
log.info("before {},mapInfo:{}", simpleName, queryContext.getMapInfo().getModelElementMatches());
|
||||||
|
|
||||||
|
try {
|
||||||
|
doMap(queryContext);
|
||||||
|
} catch (Exception e) {
|
||||||
|
log.error("work error", e);
|
||||||
|
}
|
||||||
|
|
||||||
|
long cost = System.currentTimeMillis() - startTime;
|
||||||
|
log.info("after {},cost:{},mapInfo:{}", simpleName, cost, queryContext.getMapInfo().getModelElementMatches());
|
||||||
|
}
|
||||||
|
|
||||||
|
public abstract void doMap(QueryContext queryContext);
|
||||||
|
|
||||||
|
|
||||||
|
public void addToSchemaMap(SchemaMapInfo schemaMap, Long modelId, SchemaElementMatch newElementMatch) {
|
||||||
|
Map<Long, List<SchemaElementMatch>> modelElementMatches = schemaMap.getModelElementMatches();
|
||||||
|
List<SchemaElementMatch> schemaElementMatches = modelElementMatches.putIfAbsent(modelId, new ArrayList<>());
|
||||||
|
if (schemaElementMatches == null) {
|
||||||
|
schemaElementMatches = modelElementMatches.get(modelId);
|
||||||
|
}
|
||||||
|
//remove duplication
|
||||||
|
AtomicBoolean needAddNew = new AtomicBoolean(true);
|
||||||
|
schemaElementMatches.removeIf(
|
||||||
|
existElementMatch -> {
|
||||||
|
SchemaElement existElement = existElementMatch.getElement();
|
||||||
|
SchemaElement newElement = newElementMatch.getElement();
|
||||||
|
if (existElement.equals(newElement)) {
|
||||||
|
if (newElementMatch.getSimilarity() > existElementMatch.getSimilarity()) {
|
||||||
|
return true;
|
||||||
|
} else {
|
||||||
|
needAddNew.set(false);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
);
|
||||||
|
if (needAddNew.get()) {
|
||||||
|
schemaElementMatches.add(newElementMatch);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
public SchemaElement getSchemaElement(Long modelId, SchemaElementType elementType, Long elementID) {
|
||||||
|
SchemaElement element = new SchemaElement();
|
||||||
|
SemanticService schemaService = ContextUtils.getBean(SemanticService.class);
|
||||||
|
ModelSchema modelSchema = schemaService.getModelSchema(modelId);
|
||||||
|
if (Objects.isNull(modelSchema)) {
|
||||||
|
return null;
|
||||||
|
}
|
||||||
|
SchemaElement elementDb = modelSchema.getElement(elementType, elementID);
|
||||||
|
if (Objects.isNull(elementDb)) {
|
||||||
|
log.info("element is null, elementType:{},elementID:{}", elementType, elementID);
|
||||||
|
return null;
|
||||||
|
}
|
||||||
|
BeanUtils.copyProperties(elementDb, element);
|
||||||
|
element.setAlias(getAlias(elementDb));
|
||||||
|
return element;
|
||||||
|
}
|
||||||
|
|
||||||
|
public List<String> getAlias(SchemaElement element) {
|
||||||
|
if (!SchemaElementType.VALUE.equals(element.getType())) {
|
||||||
|
return element.getAlias();
|
||||||
|
}
|
||||||
|
if (org.apache.commons.collections.CollectionUtils.isNotEmpty(element.getAlias()) && StringUtils.isNotEmpty(
|
||||||
|
element.getName())) {
|
||||||
|
return element.getAlias().stream()
|
||||||
|
.filter(aliasItem -> aliasItem.contains(element.getName()))
|
||||||
|
.collect(Collectors.toList());
|
||||||
|
}
|
||||||
|
return element.getAlias();
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -0,0 +1,158 @@
|
|||||||
|
package com.tencent.supersonic.chat.mapper;
|
||||||
|
|
||||||
|
import com.hankcs.hanlp.seg.common.Term;
|
||||||
|
import com.tencent.supersonic.chat.api.pojo.QueryContext;
|
||||||
|
import com.tencent.supersonic.knowledge.utils.NatureHelper;
|
||||||
|
import java.util.ArrayList;
|
||||||
|
import java.util.Comparator;
|
||||||
|
import java.util.HashMap;
|
||||||
|
import java.util.HashSet;
|
||||||
|
import java.util.List;
|
||||||
|
import java.util.Map;
|
||||||
|
import java.util.Objects;
|
||||||
|
import java.util.Optional;
|
||||||
|
import java.util.Set;
|
||||||
|
import java.util.stream.Collectors;
|
||||||
|
import lombok.extern.slf4j.Slf4j;
|
||||||
|
import org.apache.commons.collections4.CollectionUtils;
|
||||||
|
import org.apache.commons.lang3.StringUtils;
|
||||||
|
import org.springframework.beans.factory.annotation.Autowired;
|
||||||
|
import org.springframework.stereotype.Service;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Base Match Strategy
|
||||||
|
*/
|
||||||
|
@Service
|
||||||
|
@Slf4j
|
||||||
|
public abstract class BaseMatchStrategy<T> implements MatchStrategy<T> {
|
||||||
|
|
||||||
|
@Autowired
|
||||||
|
private MapperHelper mapperHelper;
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public Map<MatchText, List<T>> match(QueryContext queryContext, List<Term> terms, Set<Long> detectModelIds) {
|
||||||
|
String text = queryContext.getRequest().getQueryText();
|
||||||
|
if (Objects.isNull(terms) || StringUtils.isEmpty(text)) {
|
||||||
|
return null;
|
||||||
|
}
|
||||||
|
|
||||||
|
log.debug("terms:{},,detectModelIds:{}", terms, detectModelIds);
|
||||||
|
|
||||||
|
List<T> detects = detect(queryContext, terms, detectModelIds);
|
||||||
|
Map<MatchText, List<T>> result = new HashMap<>();
|
||||||
|
|
||||||
|
result.put(MatchText.builder().regText(text).detectSegment(text).build(), detects);
|
||||||
|
return result;
|
||||||
|
}
|
||||||
|
|
||||||
|
public List<T> detect(QueryContext queryContext, List<Term> terms, Set<Long> detectModelIds) {
|
||||||
|
Map<Integer, Integer> regOffsetToLength = getRegOffsetToLength(terms);
|
||||||
|
String text = queryContext.getRequest().getQueryText();
|
||||||
|
Set<T> results = new HashSet<>();
|
||||||
|
|
||||||
|
Set<String> detectSegments = new HashSet<>();
|
||||||
|
|
||||||
|
for (Integer startIndex = 0; startIndex <= text.length() - 1; ) {
|
||||||
|
|
||||||
|
for (Integer index = startIndex; index <= text.length(); ) {
|
||||||
|
int offset = mapperHelper.getStepOffset(terms, startIndex);
|
||||||
|
index = mapperHelper.getStepIndex(regOffsetToLength, index);
|
||||||
|
if (index <= text.length()) {
|
||||||
|
String detectSegment = text.substring(startIndex, index);
|
||||||
|
detectSegments.add(detectSegment);
|
||||||
|
detectByStep(queryContext, results, detectModelIds, startIndex, index, offset);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
startIndex = mapperHelper.getStepIndex(regOffsetToLength, startIndex);
|
||||||
|
}
|
||||||
|
detectByBatch(queryContext, results, detectModelIds, detectSegments);
|
||||||
|
return new ArrayList<>(results);
|
||||||
|
}
|
||||||
|
|
||||||
|
protected void detectByBatch(QueryContext queryContext, Set<T> results, Set<Long> detectModelIds,
|
||||||
|
Set<String> detectSegments) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
public Map<Integer, Integer> getRegOffsetToLength(List<Term> terms) {
|
||||||
|
return terms.stream().sorted(Comparator.comparing(Term::length))
|
||||||
|
.collect(Collectors.toMap(Term::getOffset, term -> term.word.length(), (value1, value2) -> value2));
|
||||||
|
}
|
||||||
|
|
||||||
|
public void selectResultInOneRound(Set<T> existResults, List<T> oneRoundResults) {
|
||||||
|
if (CollectionUtils.isEmpty(oneRoundResults)) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
for (T oneRoundResult : oneRoundResults) {
|
||||||
|
if (existResults.contains(oneRoundResult)) {
|
||||||
|
boolean isDeleted = existResults.removeIf(
|
||||||
|
existResult -> {
|
||||||
|
boolean delete = needDelete(oneRoundResult, existResult);
|
||||||
|
if (delete) {
|
||||||
|
log.info("deleted existResult:{}", existResult);
|
||||||
|
}
|
||||||
|
return delete;
|
||||||
|
}
|
||||||
|
);
|
||||||
|
if (isDeleted) {
|
||||||
|
log.info("deleted, add oneRoundResult:{}", oneRoundResult);
|
||||||
|
existResults.add(oneRoundResult);
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
existResults.add(oneRoundResult);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
public List<T> getMatches(QueryContext queryContext, List<Term> terms) {
|
||||||
|
Set<Long> detectModelIds = mapperHelper.getModelIds(queryContext.getRequest());
|
||||||
|
terms = filterByModelIds(terms, detectModelIds);
|
||||||
|
Map<MatchText, List<T>> matchResult = match(queryContext, terms, detectModelIds);
|
||||||
|
List<T> matches = new ArrayList<>();
|
||||||
|
if (Objects.isNull(matchResult)) {
|
||||||
|
return matches;
|
||||||
|
}
|
||||||
|
Optional<List<T>> first = matchResult.entrySet().stream()
|
||||||
|
.filter(entry -> CollectionUtils.isNotEmpty(entry.getValue()))
|
||||||
|
.map(entry -> entry.getValue()).findFirst();
|
||||||
|
|
||||||
|
if (first.isPresent()) {
|
||||||
|
matches = first.get();
|
||||||
|
}
|
||||||
|
return matches;
|
||||||
|
}
|
||||||
|
|
||||||
|
public List<Term> filterByModelIds(List<Term> terms, Set<Long> detectModelIds) {
|
||||||
|
logTerms(terms);
|
||||||
|
if (CollectionUtils.isNotEmpty(detectModelIds)) {
|
||||||
|
terms = terms.stream().filter(term -> {
|
||||||
|
Long modelId = NatureHelper.getModelId(term.getNature().toString());
|
||||||
|
if (Objects.nonNull(modelId)) {
|
||||||
|
return detectModelIds.contains(modelId);
|
||||||
|
}
|
||||||
|
return false;
|
||||||
|
}).collect(Collectors.toList());
|
||||||
|
log.info("terms filter by modelIds:{}", detectModelIds);
|
||||||
|
logTerms(terms);
|
||||||
|
}
|
||||||
|
return terms;
|
||||||
|
}
|
||||||
|
|
||||||
|
public void logTerms(List<Term> terms) {
|
||||||
|
if (CollectionUtils.isEmpty(terms)) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
for (Term term : terms) {
|
||||||
|
log.debug("word:{},nature:{},frequency:{}", term.word, term.nature.toString(), term.getFrequency());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
public abstract boolean needDelete(T oneRoundResult, T existResult);
|
||||||
|
|
||||||
|
public abstract String getMapKey(T a);
|
||||||
|
|
||||||
|
public abstract void detectByStep(QueryContext queryContext, Set<T> results,
|
||||||
|
Set<Long> detectModelIds, Integer startIndex, Integer index, int offset);
|
||||||
|
|
||||||
|
|
||||||
|
}
|
||||||
@@ -0,0 +1,62 @@
|
|||||||
|
package com.tencent.supersonic.chat.mapper;
|
||||||
|
|
||||||
|
import com.alibaba.fastjson.JSONObject;
|
||||||
|
import com.hankcs.hanlp.seg.common.Term;
|
||||||
|
import com.tencent.supersonic.chat.api.pojo.QueryContext;
|
||||||
|
import com.tencent.supersonic.chat.api.pojo.SchemaElement;
|
||||||
|
import com.tencent.supersonic.chat.api.pojo.SchemaElementMatch;
|
||||||
|
import com.tencent.supersonic.common.util.ContextUtils;
|
||||||
|
import com.tencent.supersonic.common.util.embedding.Retrieval;
|
||||||
|
import com.tencent.supersonic.knowledge.dictionary.EmbeddingResult;
|
||||||
|
import com.tencent.supersonic.knowledge.dictionary.builder.BaseWordBuilder;
|
||||||
|
import com.tencent.supersonic.knowledge.utils.HanlpHelper;
|
||||||
|
import java.util.List;
|
||||||
|
import lombok.extern.slf4j.Slf4j;
|
||||||
|
import org.apache.commons.lang3.StringUtils;
|
||||||
|
|
||||||
|
/***
|
||||||
|
* A mapper that is capable of semantic understanding of text.
|
||||||
|
*/
|
||||||
|
@Slf4j
|
||||||
|
public class EmbeddingMapper extends BaseMapper {
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public void doMap(QueryContext queryContext) {
|
||||||
|
//1. query from embedding by queryText
|
||||||
|
|
||||||
|
String queryText = queryContext.getRequest().getQueryText();
|
||||||
|
List<Term> terms = HanlpHelper.getTerms(queryText);
|
||||||
|
|
||||||
|
EmbeddingMatchStrategy matchStrategy = ContextUtils.getBean(EmbeddingMatchStrategy.class);
|
||||||
|
List<EmbeddingResult> matchResults = matchStrategy.getMatches(queryContext, terms);
|
||||||
|
|
||||||
|
HanlpHelper.transLetterOriginal(matchResults);
|
||||||
|
|
||||||
|
//2. build SchemaElementMatch by info
|
||||||
|
for (EmbeddingResult matchResult : matchResults) {
|
||||||
|
Long elementId = Retrieval.getLongId(matchResult.getId());
|
||||||
|
|
||||||
|
SchemaElement schemaElement = JSONObject.parseObject(JSONObject.toJSONString(matchResult.getMetadata()),
|
||||||
|
SchemaElement.class);
|
||||||
|
|
||||||
|
if (StringUtils.isBlank(matchResult.getMetadata().get("modelId"))) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
long modelId = Long.parseLong(matchResult.getMetadata().get("modelId"));
|
||||||
|
|
||||||
|
schemaElement = getSchemaElement(modelId, schemaElement.getType(), elementId);
|
||||||
|
if (schemaElement == null) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
SchemaElementMatch schemaElementMatch = SchemaElementMatch.builder()
|
||||||
|
.element(schemaElement)
|
||||||
|
.frequency(BaseWordBuilder.DEFAULT_FREQUENCY)
|
||||||
|
.word(matchResult.getName())
|
||||||
|
.similarity(1 - matchResult.getDistance())
|
||||||
|
.detectWord(matchResult.getDetectWord())
|
||||||
|
.build();
|
||||||
|
//3. add to mapInfo
|
||||||
|
addToSchemaMap(queryContext.getMapInfo(), modelId, schemaElementMatch);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -0,0 +1,136 @@
|
|||||||
|
package com.tencent.supersonic.chat.mapper;
|
||||||
|
|
||||||
|
import com.google.common.collect.Lists;
|
||||||
|
import com.tencent.supersonic.chat.api.pojo.QueryContext;
|
||||||
|
import com.tencent.supersonic.chat.config.OptimizationConfig;
|
||||||
|
import com.tencent.supersonic.common.pojo.Constants;
|
||||||
|
import com.tencent.supersonic.common.util.embedding.EmbeddingUtils;
|
||||||
|
import com.tencent.supersonic.common.util.embedding.Retrieval;
|
||||||
|
import com.tencent.supersonic.common.util.embedding.RetrieveQuery;
|
||||||
|
import com.tencent.supersonic.common.util.embedding.RetrieveQueryResult;
|
||||||
|
import com.tencent.supersonic.knowledge.dictionary.EmbeddingResult;
|
||||||
|
import com.tencent.supersonic.semantic.model.domain.listener.MetaEmbeddingListener;
|
||||||
|
import java.util.Comparator;
|
||||||
|
import java.util.HashMap;
|
||||||
|
import java.util.List;
|
||||||
|
import java.util.Map;
|
||||||
|
import java.util.Set;
|
||||||
|
import java.util.stream.Collectors;
|
||||||
|
import lombok.extern.slf4j.Slf4j;
|
||||||
|
import org.apache.commons.collections.CollectionUtils;
|
||||||
|
import org.apache.commons.lang3.StringUtils;
|
||||||
|
import org.springframework.beans.BeanUtils;
|
||||||
|
import org.springframework.beans.factory.annotation.Autowired;
|
||||||
|
import org.springframework.stereotype.Service;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* match strategy implement
|
||||||
|
*/
|
||||||
|
@Service
|
||||||
|
@Slf4j
|
||||||
|
public class EmbeddingMatchStrategy extends BaseMatchStrategy<EmbeddingResult> {
|
||||||
|
|
||||||
|
@Autowired
|
||||||
|
private OptimizationConfig optimizationConfig;
|
||||||
|
@Autowired
|
||||||
|
private EmbeddingUtils embeddingUtils;
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public boolean needDelete(EmbeddingResult oneRoundResult, EmbeddingResult existResult) {
|
||||||
|
return getMapKey(oneRoundResult).equals(getMapKey(existResult))
|
||||||
|
&& existResult.getDistance() > oneRoundResult.getDistance();
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public String getMapKey(EmbeddingResult a) {
|
||||||
|
return a.getName() + Constants.UNDERLINE + a.getId();
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@Override
|
||||||
|
protected void detectByBatch(QueryContext queryContext, Set<EmbeddingResult> results, Set<Long> detectModelIds,
|
||||||
|
Set<String> detectSegments) {
|
||||||
|
|
||||||
|
List<String> queryTextsList = detectSegments.stream()
|
||||||
|
.map(detectSegment -> detectSegment.trim())
|
||||||
|
.filter(detectSegment -> StringUtils.isNotBlank(detectSegment)
|
||||||
|
&& detectSegment.length() >= optimizationConfig.getEmbeddingMapperWordMin()
|
||||||
|
&& detectSegment.length() <= optimizationConfig.getEmbeddingMapperWordMax())
|
||||||
|
.collect(Collectors.toList());
|
||||||
|
|
||||||
|
List<List<String>> queryTextsSubList = Lists.partition(queryTextsList,
|
||||||
|
optimizationConfig.getEmbeddingMapperBatch());
|
||||||
|
|
||||||
|
for (List<String> queryTextsSub : queryTextsSubList) {
|
||||||
|
detectByQueryTextsSub(results, detectModelIds, queryTextsSub);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
private void detectByQueryTextsSub(Set<EmbeddingResult> results, Set<Long> detectModelIds,
|
||||||
|
List<String> queryTextsSub) {
|
||||||
|
int embeddingNumber = optimizationConfig.getEmbeddingMapperNumber();
|
||||||
|
Double distance = optimizationConfig.getEmbeddingMapperDistanceThreshold();
|
||||||
|
Map<String, String> filterCondition = null;
|
||||||
|
// step1. build query params
|
||||||
|
// if only one modelId, add to filterCondition
|
||||||
|
if (CollectionUtils.isNotEmpty(detectModelIds) && detectModelIds.size() == 1) {
|
||||||
|
filterCondition = new HashMap<>();
|
||||||
|
filterCondition.put("modelId", detectModelIds.stream().findFirst().get().toString());
|
||||||
|
}
|
||||||
|
|
||||||
|
RetrieveQuery retrieveQuery = RetrieveQuery.builder()
|
||||||
|
.queryTextsList(queryTextsSub)
|
||||||
|
.filterCondition(filterCondition)
|
||||||
|
.queryEmbeddings(null)
|
||||||
|
.build();
|
||||||
|
// step2. retrieveQuery by detectSegment
|
||||||
|
List<RetrieveQueryResult> retrieveQueryResults = embeddingUtils.retrieveQuery(
|
||||||
|
MetaEmbeddingListener.COLLECTION_NAME, retrieveQuery, embeddingNumber);
|
||||||
|
|
||||||
|
if (CollectionUtils.isEmpty(retrieveQueryResults)) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
// step3. build EmbeddingResults. filter by modelId
|
||||||
|
List<EmbeddingResult> collect = retrieveQueryResults.stream()
|
||||||
|
.map(retrieveQueryResult -> {
|
||||||
|
List<Retrieval> retrievals = retrieveQueryResult.getRetrieval();
|
||||||
|
if (CollectionUtils.isNotEmpty(retrievals)) {
|
||||||
|
retrievals.removeIf(retrieval -> retrieval.getDistance() > distance.doubleValue());
|
||||||
|
if (CollectionUtils.isNotEmpty(detectModelIds)) {
|
||||||
|
retrievals.removeIf(retrieval -> {
|
||||||
|
String modelIdStr = retrieval.getMetadata().get("modelId");
|
||||||
|
if (StringUtils.isBlank(modelIdStr)) {
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
return detectModelIds.contains(Long.parseLong(modelIdStr));
|
||||||
|
});
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return retrieveQueryResult;
|
||||||
|
})
|
||||||
|
.filter(retrieveQueryResult -> CollectionUtils.isNotEmpty(retrieveQueryResult.getRetrieval()))
|
||||||
|
.flatMap(retrieveQueryResult -> retrieveQueryResult.getRetrieval().stream()
|
||||||
|
.map(retrieval -> {
|
||||||
|
EmbeddingResult embeddingResult = new EmbeddingResult();
|
||||||
|
BeanUtils.copyProperties(retrieval, embeddingResult);
|
||||||
|
embeddingResult.setDetectWord(retrieveQueryResult.getQuery());
|
||||||
|
embeddingResult.setName(retrieval.getQuery());
|
||||||
|
return embeddingResult;
|
||||||
|
}))
|
||||||
|
.collect(Collectors.toList());
|
||||||
|
|
||||||
|
// step4. select mapResul in one round
|
||||||
|
int roundNumber = optimizationConfig.getEmbeddingMapperRoundNumber() * queryTextsSub.size();
|
||||||
|
List<EmbeddingResult> oneRoundResults = collect.stream()
|
||||||
|
.sorted(Comparator.comparingDouble(EmbeddingResult::getDistance))
|
||||||
|
.limit(roundNumber)
|
||||||
|
.collect(Collectors.toList());
|
||||||
|
selectResultInOneRound(results, oneRoundResults);
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public void detectByStep(QueryContext queryContext, Set<EmbeddingResult> existResults, Set<Long> detectModelIds,
|
||||||
|
Integer startIndex, Integer index, int offset) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -1,27 +1,29 @@
|
|||||||
package com.tencent.supersonic.chat.mapper;
|
package com.tencent.supersonic.chat.mapper;
|
||||||
|
|
||||||
|
|
||||||
import com.tencent.supersonic.chat.api.component.SchemaMapper;
|
|
||||||
import com.tencent.supersonic.chat.api.pojo.QueryContext;
|
|
||||||
import com.tencent.supersonic.chat.api.pojo.SchemaMapInfo;
|
|
||||||
import com.tencent.supersonic.chat.api.pojo.SchemaElementMatch;
|
|
||||||
import com.tencent.supersonic.chat.api.pojo.SchemaElement;
|
|
||||||
import com.tencent.supersonic.chat.api.pojo.SchemaElementType;
|
|
||||||
import com.tencent.supersonic.chat.api.pojo.ModelSchema;
|
import com.tencent.supersonic.chat.api.pojo.ModelSchema;
|
||||||
|
import com.tencent.supersonic.chat.api.pojo.QueryContext;
|
||||||
|
import com.tencent.supersonic.chat.api.pojo.SchemaElement;
|
||||||
|
import com.tencent.supersonic.chat.api.pojo.SchemaElementMatch;
|
||||||
|
import com.tencent.supersonic.chat.api.pojo.SchemaElementType;
|
||||||
|
import com.tencent.supersonic.chat.api.pojo.SchemaMapInfo;
|
||||||
import com.tencent.supersonic.chat.service.SemanticService;
|
import com.tencent.supersonic.chat.service.SemanticService;
|
||||||
import com.tencent.supersonic.common.util.ContextUtils;
|
import com.tencent.supersonic.common.util.ContextUtils;
|
||||||
import lombok.extern.slf4j.Slf4j;
|
import lombok.extern.slf4j.Slf4j;
|
||||||
import org.springframework.beans.BeanUtils;
|
import org.springframework.beans.BeanUtils;
|
||||||
import org.springframework.util.CollectionUtils;
|
import org.springframework.util.CollectionUtils;
|
||||||
|
|
||||||
import java.util.List;
|
import java.util.List;
|
||||||
import java.util.stream.Collectors;
|
import java.util.stream.Collectors;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* A mapper capable of converting the VALUE of entity dimension values into ID types.
|
||||||
|
*/
|
||||||
@Slf4j
|
@Slf4j
|
||||||
public class EntityMapper implements SchemaMapper {
|
public class EntityMapper extends BaseMapper {
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public void map(QueryContext queryContext) {
|
public void doMap(QueryContext queryContext) {
|
||||||
SchemaMapInfo schemaMapInfo = queryContext.getMapInfo();
|
SchemaMapInfo schemaMapInfo = queryContext.getMapInfo();
|
||||||
for (Long modelId : schemaMapInfo.getMatchedModels()) {
|
for (Long modelId : schemaMapInfo.getMatchedModels()) {
|
||||||
List<SchemaElementMatch> schemaElementMatchList = schemaMapInfo.getMatchedElements(modelId);
|
List<SchemaElementMatch> schemaElementMatchList = schemaMapInfo.getMatchedElements(modelId);
|
||||||
@@ -32,8 +34,9 @@ public class EntityMapper implements SchemaMapper {
|
|||||||
if (entity == null || entity.getId() == null) {
|
if (entity == null || entity.getId() == null) {
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
List<SchemaElementMatch> valueSchemaElements = schemaElementMatchList.stream().filter(schemaElementMatch ->
|
List<SchemaElementMatch> valueSchemaElements = schemaElementMatchList.stream()
|
||||||
SchemaElementType.VALUE.equals(schemaElementMatch.getElement().getType()))
|
.filter(schemaElementMatch ->
|
||||||
|
SchemaElementType.VALUE.equals(schemaElementMatch.getElement().getType()))
|
||||||
.collect(Collectors.toList());
|
.collect(Collectors.toList());
|
||||||
for (SchemaElementMatch schemaElementMatch : valueSchemaElements) {
|
for (SchemaElementMatch schemaElementMatch : valueSchemaElements) {
|
||||||
if (!entity.getId().equals(schemaElementMatch.getElement().getId())) {
|
if (!entity.getId().equals(schemaElementMatch.getElement().getId())) {
|
||||||
@@ -51,7 +54,7 @@ public class EntityMapper implements SchemaMapper {
|
|||||||
}
|
}
|
||||||
|
|
||||||
private boolean checkExistSameEntitySchemaElements(SchemaElementMatch valueSchemaElementMatch,
|
private boolean checkExistSameEntitySchemaElements(SchemaElementMatch valueSchemaElementMatch,
|
||||||
List<SchemaElementMatch> schemaElementMatchList) {
|
List<SchemaElementMatch> schemaElementMatchList) {
|
||||||
List<SchemaElementMatch> entitySchemaElements = schemaElementMatchList.stream().filter(schemaElementMatch ->
|
List<SchemaElementMatch> entitySchemaElements = schemaElementMatchList.stream().filter(schemaElementMatch ->
|
||||||
SchemaElementType.ENTITY.equals(schemaElementMatch.getElement().getType()))
|
SchemaElementType.ENTITY.equals(schemaElementMatch.getElement().getType()))
|
||||||
.collect(Collectors.toList());
|
.collect(Collectors.toList());
|
||||||
|
|||||||
@@ -1,179 +1,67 @@
|
|||||||
package com.tencent.supersonic.chat.mapper;
|
package com.tencent.supersonic.chat.mapper;
|
||||||
|
|
||||||
import com.hankcs.hanlp.seg.common.Term;
|
import com.hankcs.hanlp.seg.common.Term;
|
||||||
import com.tencent.supersonic.chat.api.component.SchemaMapper;
|
|
||||||
import com.tencent.supersonic.chat.api.pojo.SchemaMapInfo;
|
|
||||||
import com.tencent.supersonic.chat.api.pojo.SchemaElement;
|
|
||||||
import com.tencent.supersonic.chat.api.pojo.SchemaElementType;
|
|
||||||
import com.tencent.supersonic.chat.api.pojo.SchemaElementMatch;
|
|
||||||
import com.tencent.supersonic.chat.api.pojo.SemanticSchema;
|
|
||||||
import com.tencent.supersonic.chat.api.pojo.QueryContext;
|
import com.tencent.supersonic.chat.api.pojo.QueryContext;
|
||||||
import com.tencent.supersonic.chat.config.OptimizationConfig;
|
import com.tencent.supersonic.chat.api.pojo.SchemaElement;
|
||||||
|
import com.tencent.supersonic.chat.api.pojo.SchemaElementMatch;
|
||||||
|
import com.tencent.supersonic.chat.api.pojo.SchemaElementType;
|
||||||
|
import com.tencent.supersonic.chat.api.pojo.SchemaMapInfo;
|
||||||
import com.tencent.supersonic.common.util.ContextUtils;
|
import com.tencent.supersonic.common.util.ContextUtils;
|
||||||
import com.tencent.supersonic.knowledge.service.SchemaService;
|
import com.tencent.supersonic.knowledge.dictionary.FuzzyResult;
|
||||||
import com.tencent.supersonic.knowledge.utils.HanlpHelper;
|
import com.tencent.supersonic.knowledge.utils.HanlpHelper;
|
||||||
import java.util.ArrayList;
|
|
||||||
import java.util.Comparator;
|
|
||||||
import java.util.HashMap;
|
|
||||||
import java.util.HashSet;
|
import java.util.HashSet;
|
||||||
import java.util.List;
|
import java.util.List;
|
||||||
import java.util.Map;
|
|
||||||
import java.util.Map.Entry;
|
|
||||||
import java.util.Objects;
|
|
||||||
import java.util.Set;
|
import java.util.Set;
|
||||||
import java.util.stream.Collectors;
|
import java.util.stream.Collectors;
|
||||||
import lombok.extern.slf4j.Slf4j;
|
import lombok.extern.slf4j.Slf4j;
|
||||||
import org.springframework.util.CollectionUtils;
|
import org.springframework.util.CollectionUtils;
|
||||||
|
|
||||||
|
/***
|
||||||
|
* A mapper capable of fuzzy parsing of metric names and dimension names.
|
||||||
|
*/
|
||||||
@Slf4j
|
@Slf4j
|
||||||
public class FuzzyNameMapper implements SchemaMapper {
|
public class FuzzyNameMapper extends BaseMapper {
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public void map(QueryContext queryContext) {
|
public void doMap(QueryContext queryContext) {
|
||||||
|
|
||||||
log.debug("before db mapper,mapInfo:{}", queryContext.getMapInfo());
|
|
||||||
|
|
||||||
List<Term> terms = HanlpHelper.getTerms(queryContext.getRequest().getQueryText());
|
List<Term> terms = HanlpHelper.getTerms(queryContext.getRequest().getQueryText());
|
||||||
|
|
||||||
SemanticSchema semanticSchema = ContextUtils.getBean(SchemaService.class).getSemanticSchema();
|
FuzzyNameMatchStrategy fuzzyNameMatchStrategy = ContextUtils.getBean(FuzzyNameMatchStrategy.class);
|
||||||
|
|
||||||
detectAndAddToSchema(queryContext, terms, semanticSchema.getDimensions(), SchemaElementType.DIMENSION);
|
|
||||||
|
|
||||||
detectAndAddToSchema(queryContext, terms, semanticSchema.getMetrics(), SchemaElementType.METRIC);
|
|
||||||
|
|
||||||
log.debug("after db mapper,mapInfo:{}", queryContext.getMapInfo());
|
|
||||||
}
|
|
||||||
|
|
||||||
private void detectAndAddToSchema(QueryContext queryContext, List<Term> terms, List<SchemaElement> models,
|
|
||||||
SchemaElementType schemaElementType) {
|
|
||||||
try {
|
|
||||||
|
|
||||||
Map<String, Set<SchemaElement>> modelResultSet = getResultSet(queryContext, terms, models);
|
|
||||||
|
|
||||||
addToSchemaMapInfo(modelResultSet, queryContext.getMapInfo(), schemaElementType);
|
|
||||||
|
|
||||||
} catch (Exception e) {
|
|
||||||
log.error("detectAndAddToSchema error", e);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
private Map<String, Set<SchemaElement>> getResultSet(QueryContext queryContext, List<Term> terms,
|
|
||||||
List<SchemaElement> models) {
|
|
||||||
|
|
||||||
String queryText = queryContext.getRequest().getQueryText();
|
|
||||||
|
|
||||||
MapperHelper mapperHelper = ContextUtils.getBean(MapperHelper.class);
|
MapperHelper mapperHelper = ContextUtils.getBean(MapperHelper.class);
|
||||||
Set<Long> modelIds = mapperHelper.getModelIds(queryContext.getRequest());
|
|
||||||
|
|
||||||
Double metricDimensionThresholdConfig = getThreshold(queryContext, mapperHelper);
|
List<FuzzyResult> matches = fuzzyNameMatchStrategy.getMatches(queryContext, terms);
|
||||||
|
|
||||||
Map<String, Set<SchemaElement>> nameToItems = getNameToItems(models);
|
for (FuzzyResult match : matches) {
|
||||||
|
SchemaElement schemaElement = match.getSchemaElement();
|
||||||
Map<Integer, Integer> regOffsetToLength = terms.stream().sorted(Comparator.comparing(Term::length))
|
Set<Long> regElementSet = getRegElementSet(queryContext.getMapInfo(), schemaElement);
|
||||||
.collect(Collectors.toMap(Term::getOffset, term -> term.word.length(), (value1, value2) -> value2));
|
if (regElementSet.contains(schemaElement.getId())) {
|
||||||
|
continue;
|
||||||
Map<String, Set<SchemaElement>> modelResultSet = new HashMap<>();
|
|
||||||
for (Integer startIndex = 0; startIndex <= queryText.length() - 1; ) {
|
|
||||||
for (Integer endIndex = startIndex; endIndex <= queryText.length(); ) {
|
|
||||||
endIndex = mapperHelper.getStepIndex(regOffsetToLength, endIndex);
|
|
||||||
if (endIndex > queryText.length()) {
|
|
||||||
continue;
|
|
||||||
}
|
|
||||||
String detectSegment = queryText.substring(startIndex, endIndex);
|
|
||||||
|
|
||||||
for (Entry<String, Set<SchemaElement>> entry : nameToItems.entrySet()) {
|
|
||||||
String name = entry.getKey();
|
|
||||||
Set<SchemaElement> schemaElements = entry.getValue();
|
|
||||||
if (!name.contains(detectSegment)
|
|
||||||
|| mapperHelper.getSimilarity(detectSegment, name) < metricDimensionThresholdConfig) {
|
|
||||||
continue;
|
|
||||||
}
|
|
||||||
if (!CollectionUtils.isEmpty(modelIds)) {
|
|
||||||
schemaElements = schemaElements.stream()
|
|
||||||
.filter(schemaElement -> modelIds.contains(schemaElement.getModel()))
|
|
||||||
.collect(Collectors.toSet());
|
|
||||||
}
|
|
||||||
Set<SchemaElement> preSchemaElements = modelResultSet.putIfAbsent(detectSegment, schemaElements);
|
|
||||||
if (Objects.nonNull(preSchemaElements)) {
|
|
||||||
preSchemaElements.addAll(schemaElements);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
startIndex = mapperHelper.getStepIndex(regOffsetToLength, startIndex);
|
SchemaElementMatch schemaElementMatch = SchemaElementMatch.builder()
|
||||||
|
.element(schemaElement)
|
||||||
|
.word(schemaElement.getName())
|
||||||
|
.detectWord(match.getDetectWord())
|
||||||
|
.frequency(10000L)
|
||||||
|
.similarity(mapperHelper.getSimilarity(match.getDetectWord(), schemaElement.getName()))
|
||||||
|
.build();
|
||||||
|
log.info("add to schema, elementMatch {}", schemaElementMatch);
|
||||||
|
addToSchemaMap(queryContext.getMapInfo(), schemaElement.getModel(), schemaElementMatch);
|
||||||
}
|
}
|
||||||
return modelResultSet;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
private Double getThreshold(QueryContext queryContext, MapperHelper mapperHelper) {
|
private Set<Long> getRegElementSet(SchemaMapInfo schemaMap, SchemaElement schemaElement) {
|
||||||
|
List<SchemaElementMatch> elements = schemaMap.getMatchedElements(schemaElement.getModel());
|
||||||
OptimizationConfig optimizationConfig = ContextUtils.getBean(OptimizationConfig.class);
|
if (CollectionUtils.isEmpty(elements)) {
|
||||||
Double metricDimensionThresholdConfig = optimizationConfig.getMetricDimensionThresholdConfig();
|
return new HashSet<>();
|
||||||
Double metricDimensionMinThresholdConfig = optimizationConfig.getMetricDimensionMinThresholdConfig();
|
|
||||||
|
|
||||||
Map<Long, List<SchemaElementMatch>> modelElementMatches = queryContext.getMapInfo()
|
|
||||||
.getModelElementMatches();
|
|
||||||
boolean existElement = modelElementMatches.entrySet().stream()
|
|
||||||
.anyMatch(entry -> entry.getValue().size() >= 1);
|
|
||||||
|
|
||||||
if (!existElement) {
|
|
||||||
double halfThreshold = metricDimensionThresholdConfig / 2;
|
|
||||||
|
|
||||||
metricDimensionThresholdConfig = halfThreshold >= metricDimensionMinThresholdConfig ? halfThreshold
|
|
||||||
: metricDimensionMinThresholdConfig;
|
|
||||||
log.info("ModelElementMatches:{} , not exist Element metricDimensionThresholdConfig reduce by half:{}",
|
|
||||||
modelElementMatches, metricDimensionThresholdConfig);
|
|
||||||
}
|
|
||||||
return metricDimensionThresholdConfig;
|
|
||||||
}
|
|
||||||
|
|
||||||
private Map<String, Set<SchemaElement>> getNameToItems(List<SchemaElement> models) {
|
|
||||||
return models.stream().collect(
|
|
||||||
Collectors.toMap(SchemaElement::getName, a -> {
|
|
||||||
Set<SchemaElement> result = new HashSet<>();
|
|
||||||
result.add(a);
|
|
||||||
return result;
|
|
||||||
}, (k1, k2) -> {
|
|
||||||
k1.addAll(k2);
|
|
||||||
return k1;
|
|
||||||
}));
|
|
||||||
}
|
|
||||||
|
|
||||||
private void addToSchemaMapInfo(Map<String, Set<SchemaElement>> mapResultRowSet, SchemaMapInfo schemaMap,
|
|
||||||
SchemaElementType schemaElementType) {
|
|
||||||
if (Objects.isNull(mapResultRowSet) || mapResultRowSet.size() <= 0) {
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
MapperHelper mapperHelper = ContextUtils.getBean(MapperHelper.class);
|
|
||||||
|
|
||||||
for (Map.Entry<String, Set<SchemaElement>> entry : mapResultRowSet.entrySet()) {
|
|
||||||
String detectWord = entry.getKey();
|
|
||||||
Set<SchemaElement> schemaElements = entry.getValue();
|
|
||||||
for (SchemaElement schemaElement : schemaElements) {
|
|
||||||
|
|
||||||
List<SchemaElementMatch> elements = schemaMap.getMatchedElements(schemaElement.getModel());
|
|
||||||
if (CollectionUtils.isEmpty(elements)) {
|
|
||||||
elements = new ArrayList<>();
|
|
||||||
schemaMap.setMatchedElements(schemaElement.getModel(), elements);
|
|
||||||
}
|
|
||||||
Set<Long> regElementSet = elements.stream()
|
|
||||||
.filter(elementMatch -> schemaElementType.equals(elementMatch.getElement().getType()))
|
|
||||||
.map(elementMatch -> elementMatch.getElement().getId())
|
|
||||||
.collect(Collectors.toSet());
|
|
||||||
|
|
||||||
if (regElementSet.contains(schemaElement.getId())) {
|
|
||||||
continue;
|
|
||||||
}
|
|
||||||
SchemaElementMatch schemaElementMatch = SchemaElementMatch.builder()
|
|
||||||
.element(schemaElement)
|
|
||||||
.word(schemaElement.getName())
|
|
||||||
.detectWord(detectWord)
|
|
||||||
.frequency(10000L)
|
|
||||||
.similarity(mapperHelper.getSimilarity(detectWord, schemaElement.getName()))
|
|
||||||
.build();
|
|
||||||
log.info("schemaElementType:{},add to schema, elementMatch {}", schemaElementType, schemaElementMatch);
|
|
||||||
elements.add(schemaElementMatch);
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
return elements.stream()
|
||||||
|
.filter(elementMatch ->
|
||||||
|
SchemaElementType.METRIC.equals(elementMatch.getElement().getType())
|
||||||
|
|| SchemaElementType.DIMENSION.equals(elementMatch.getElement().getType()))
|
||||||
|
.map(elementMatch -> elementMatch.getElement().getId())
|
||||||
|
.collect(Collectors.toSet());
|
||||||
}
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -0,0 +1,131 @@
|
|||||||
|
package com.tencent.supersonic.chat.mapper;
|
||||||
|
|
||||||
|
import com.hankcs.hanlp.seg.common.Term;
|
||||||
|
import com.tencent.supersonic.chat.api.pojo.QueryContext;
|
||||||
|
import com.tencent.supersonic.chat.api.pojo.SchemaElement;
|
||||||
|
import com.tencent.supersonic.chat.api.pojo.SchemaElementMatch;
|
||||||
|
import com.tencent.supersonic.chat.config.OptimizationConfig;
|
||||||
|
import com.tencent.supersonic.common.pojo.Constants;
|
||||||
|
import com.tencent.supersonic.knowledge.dictionary.FuzzyResult;
|
||||||
|
import com.tencent.supersonic.knowledge.service.SchemaService;
|
||||||
|
import java.util.ArrayList;
|
||||||
|
import java.util.HashSet;
|
||||||
|
import java.util.List;
|
||||||
|
import java.util.Map;
|
||||||
|
import java.util.Map.Entry;
|
||||||
|
import java.util.Set;
|
||||||
|
import java.util.stream.Collectors;
|
||||||
|
import lombok.extern.slf4j.Slf4j;
|
||||||
|
import org.apache.commons.lang3.StringUtils;
|
||||||
|
import org.springframework.beans.factory.annotation.Autowired;
|
||||||
|
import org.springframework.stereotype.Service;
|
||||||
|
import org.springframework.util.CollectionUtils;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Fuzzy Name Match Strategy
|
||||||
|
*/
|
||||||
|
@Service
|
||||||
|
@Slf4j
|
||||||
|
public class FuzzyNameMatchStrategy extends BaseMatchStrategy<FuzzyResult> {
|
||||||
|
|
||||||
|
@Autowired
|
||||||
|
private OptimizationConfig optimizationConfig;
|
||||||
|
@Autowired
|
||||||
|
private MapperHelper mapperHelper;
|
||||||
|
@Autowired
|
||||||
|
private SchemaService schemaService;
|
||||||
|
private List<SchemaElement> allElements;
|
||||||
|
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public Map<MatchText, List<FuzzyResult>> match(QueryContext queryContext, List<Term> terms,
|
||||||
|
Set<Long> detectModelIds) {
|
||||||
|
this.allElements = getSchemaElements();
|
||||||
|
return super.match(queryContext, terms, detectModelIds);
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public boolean needDelete(FuzzyResult oneRoundResult, FuzzyResult existResult) {
|
||||||
|
return getMapKey(oneRoundResult).equals(getMapKey(existResult))
|
||||||
|
&& existResult.getDetectWord().length() < oneRoundResult.getDetectWord().length();
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public String getMapKey(FuzzyResult a) {
|
||||||
|
return a.getName() + Constants.UNDERLINE + a.getSchemaElement().getId()
|
||||||
|
+ Constants.UNDERLINE + a.getSchemaElement().getName();
|
||||||
|
}
|
||||||
|
|
||||||
|
public void detectByStep(QueryContext queryContext, Set<FuzzyResult> existResults, Set<Long> detectModelIds,
|
||||||
|
Integer startIndex, Integer index, int offset) {
|
||||||
|
String detectSegment = queryContext.getRequest().getQueryText().substring(startIndex, index);
|
||||||
|
if (StringUtils.isBlank(detectSegment)) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
Set<Long> modelIds = mapperHelper.getModelIds(queryContext.getRequest());
|
||||||
|
|
||||||
|
Double metricDimensionThresholdConfig = getThreshold(queryContext);
|
||||||
|
|
||||||
|
Map<String, Set<SchemaElement>> nameToItems = getNameToItems(allElements);
|
||||||
|
|
||||||
|
for (Entry<String, Set<SchemaElement>> entry : nameToItems.entrySet()) {
|
||||||
|
String name = entry.getKey();
|
||||||
|
if (!name.contains(detectSegment)
|
||||||
|
|| mapperHelper.getSimilarity(detectSegment, name) < metricDimensionThresholdConfig) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
Set<SchemaElement> schemaElements = entry.getValue();
|
||||||
|
if (!CollectionUtils.isEmpty(modelIds)) {
|
||||||
|
schemaElements = schemaElements.stream()
|
||||||
|
.filter(schemaElement -> modelIds.contains(schemaElement.getModel()))
|
||||||
|
.collect(Collectors.toSet());
|
||||||
|
}
|
||||||
|
for (SchemaElement schemaElement : schemaElements) {
|
||||||
|
FuzzyResult fuzzyResult = new FuzzyResult();
|
||||||
|
fuzzyResult.setDetectWord(detectSegment);
|
||||||
|
fuzzyResult.setName(schemaElement.getName());
|
||||||
|
fuzzyResult.setSchemaElement(schemaElement);
|
||||||
|
existResults.add(fuzzyResult);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
private List<SchemaElement> getSchemaElements() {
|
||||||
|
List<SchemaElement> allElements = new ArrayList<>();
|
||||||
|
allElements.addAll(schemaService.getSemanticSchema().getDimensions());
|
||||||
|
allElements.addAll(schemaService.getSemanticSchema().getMetrics());
|
||||||
|
return allElements;
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
private Double getThreshold(QueryContext queryContext) {
|
||||||
|
Double metricDimensionThresholdConfig = optimizationConfig.getMetricDimensionThresholdConfig();
|
||||||
|
Double metricDimensionMinThresholdConfig = optimizationConfig.getMetricDimensionMinThresholdConfig();
|
||||||
|
|
||||||
|
Map<Long, List<SchemaElementMatch>> modelElementMatches = queryContext.getMapInfo().getModelElementMatches();
|
||||||
|
|
||||||
|
boolean existElement = modelElementMatches.entrySet().stream().anyMatch(entry -> entry.getValue().size() >= 1);
|
||||||
|
|
||||||
|
if (!existElement) {
|
||||||
|
double halfThreshold = metricDimensionThresholdConfig / 2;
|
||||||
|
|
||||||
|
metricDimensionThresholdConfig = halfThreshold >= metricDimensionMinThresholdConfig ? halfThreshold
|
||||||
|
: metricDimensionMinThresholdConfig;
|
||||||
|
log.info("ModelElementMatches:{} , not exist Element metricDimensionThresholdConfig reduce by half:{}",
|
||||||
|
modelElementMatches, metricDimensionThresholdConfig);
|
||||||
|
}
|
||||||
|
return metricDimensionThresholdConfig;
|
||||||
|
}
|
||||||
|
|
||||||
|
private Map<String, Set<SchemaElement>> getNameToItems(List<SchemaElement> models) {
|
||||||
|
return models.stream().collect(
|
||||||
|
Collectors.toMap(SchemaElement::getName, a -> {
|
||||||
|
Set<SchemaElement> result = new HashSet<>();
|
||||||
|
result.add(a);
|
||||||
|
return result;
|
||||||
|
}, (k1, k2) -> {
|
||||||
|
k1.addAll(k2);
|
||||||
|
return k1;
|
||||||
|
}));
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -1,83 +1,48 @@
|
|||||||
package com.tencent.supersonic.chat.mapper;
|
package com.tencent.supersonic.chat.mapper;
|
||||||
|
|
||||||
import com.hankcs.hanlp.seg.common.Term;
|
import com.hankcs.hanlp.seg.common.Term;
|
||||||
import com.tencent.supersonic.chat.api.component.SchemaMapper;
|
|
||||||
import com.tencent.supersonic.chat.api.pojo.ModelSchema;
|
|
||||||
import com.tencent.supersonic.chat.api.pojo.QueryContext;
|
import com.tencent.supersonic.chat.api.pojo.QueryContext;
|
||||||
import com.tencent.supersonic.chat.api.pojo.SchemaElement;
|
import com.tencent.supersonic.chat.api.pojo.SchemaElement;
|
||||||
import com.tencent.supersonic.chat.api.pojo.SchemaElementMatch;
|
import com.tencent.supersonic.chat.api.pojo.SchemaElementMatch;
|
||||||
import com.tencent.supersonic.chat.api.pojo.SchemaElementType;
|
import com.tencent.supersonic.chat.api.pojo.SchemaElementType;
|
||||||
import com.tencent.supersonic.chat.api.pojo.SchemaMapInfo;
|
import com.tencent.supersonic.chat.api.pojo.SchemaMapInfo;
|
||||||
import com.tencent.supersonic.chat.service.SemanticService;
|
|
||||||
import com.tencent.supersonic.common.util.ContextUtils;
|
import com.tencent.supersonic.common.util.ContextUtils;
|
||||||
import com.tencent.supersonic.common.pojo.enums.DictWordType;
|
import com.tencent.supersonic.knowledge.dictionary.HanlpMapResult;
|
||||||
import com.tencent.supersonic.knowledge.dictionary.MapResult;
|
|
||||||
import com.tencent.supersonic.knowledge.dictionary.builder.BaseWordBuilder;
|
|
||||||
import com.tencent.supersonic.knowledge.dictionary.builder.WordBuilderFactory;
|
|
||||||
import com.tencent.supersonic.knowledge.utils.HanlpHelper;
|
import com.tencent.supersonic.knowledge.utils.HanlpHelper;
|
||||||
import com.tencent.supersonic.knowledge.utils.NatureHelper;
|
import com.tencent.supersonic.knowledge.utils.NatureHelper;
|
||||||
import java.util.ArrayList;
|
|
||||||
import java.util.List;
|
import java.util.List;
|
||||||
import java.util.Map;
|
import java.util.Map;
|
||||||
import java.util.Objects;
|
import java.util.Objects;
|
||||||
import java.util.Optional;
|
|
||||||
import java.util.Set;
|
|
||||||
import java.util.stream.Collectors;
|
import java.util.stream.Collectors;
|
||||||
import lombok.extern.slf4j.Slf4j;
|
import lombok.extern.slf4j.Slf4j;
|
||||||
import org.apache.commons.collections.CollectionUtils;
|
import org.apache.commons.collections.CollectionUtils;
|
||||||
import org.apache.commons.lang3.StringUtils;
|
|
||||||
import org.springframework.beans.BeanUtils;
|
|
||||||
|
|
||||||
|
/***
|
||||||
|
* A mapper capable of prefix and suffix similarity parsing for
|
||||||
|
* domain names, dimension values, metric names, and dimension names.
|
||||||
|
*/
|
||||||
@Slf4j
|
@Slf4j
|
||||||
public class HanlpDictMapper implements SchemaMapper {
|
public class HanlpDictMapper extends BaseMapper {
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public void map(QueryContext queryContext) {
|
public void doMap(QueryContext queryContext) {
|
||||||
|
|
||||||
String queryText = queryContext.getRequest().getQueryText();
|
String queryText = queryContext.getRequest().getQueryText();
|
||||||
List<Term> terms = HanlpHelper.getTerms(queryText);
|
List<Term> terms = HanlpHelper.getTerms(queryText);
|
||||||
|
|
||||||
QueryMatchStrategy matchStrategy = ContextUtils.getBean(QueryMatchStrategy.class);
|
HanlpDictMatchStrategy matchStrategy = ContextUtils.getBean(HanlpDictMatchStrategy.class);
|
||||||
MapperHelper mapperHelper = ContextUtils.getBean(MapperHelper.class);
|
|
||||||
Set<Long> detectModelIds = mapperHelper.getModelIds(queryContext.getRequest());
|
|
||||||
|
|
||||||
terms = filterByModelIds(terms, detectModelIds);
|
List<HanlpMapResult> matches = matchStrategy.getMatches(queryContext, terms);
|
||||||
|
|
||||||
Map<MatchText, List<MapResult>> matchResult = matchStrategy.match(queryContext.getRequest(), terms,
|
|
||||||
detectModelIds);
|
|
||||||
|
|
||||||
List<MapResult> matches = getMatches(matchResult);
|
|
||||||
|
|
||||||
HanlpHelper.transLetterOriginal(matches);
|
HanlpHelper.transLetterOriginal(matches);
|
||||||
|
|
||||||
log.info("queryContext:{},matches:{}", queryContext, matches);
|
|
||||||
|
|
||||||
convertTermsToSchemaMapInfo(matches, queryContext.getMapInfo(), terms);
|
convertTermsToSchemaMapInfo(matches, queryContext.getMapInfo(), terms);
|
||||||
}
|
}
|
||||||
|
|
||||||
private List<Term> filterByModelIds(List<Term> terms, Set<Long> detectModelIds) {
|
|
||||||
for (Term term : terms) {
|
|
||||||
log.info("before word:{},nature:{},frequency:{}", term.word, term.nature.toString(), term.getFrequency());
|
|
||||||
}
|
|
||||||
if (CollectionUtils.isNotEmpty(detectModelIds)) {
|
|
||||||
terms = terms.stream().filter(term -> {
|
|
||||||
Long modelId = NatureHelper.getModelId(term.getNature().toString());
|
|
||||||
if (Objects.nonNull(modelId)) {
|
|
||||||
return detectModelIds.contains(modelId);
|
|
||||||
}
|
|
||||||
return false;
|
|
||||||
}).collect(Collectors.toList());
|
|
||||||
}
|
|
||||||
for (Term term : terms) {
|
|
||||||
log.info("after filter word:{},nature:{},frequency:{}", term.word, term.nature.toString(),
|
|
||||||
term.getFrequency());
|
|
||||||
}
|
|
||||||
return terms;
|
|
||||||
}
|
|
||||||
|
|
||||||
|
private void convertTermsToSchemaMapInfo(List<HanlpMapResult> hanlpMapResults, SchemaMapInfo schemaMap,
|
||||||
private void convertTermsToSchemaMapInfo(List<MapResult> mapResults, SchemaMapInfo schemaMap, List<Term> terms) {
|
List<Term> terms) {
|
||||||
if (CollectionUtils.isEmpty(mapResults)) {
|
if (CollectionUtils.isEmpty(hanlpMapResults)) {
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -85,8 +50,8 @@ public class HanlpDictMapper implements SchemaMapper {
|
|||||||
Collectors.toMap(entry -> entry.getWord() + entry.getNature(),
|
Collectors.toMap(entry -> entry.getWord() + entry.getNature(),
|
||||||
term -> Long.valueOf(term.getFrequency()), (value1, value2) -> value2));
|
term -> Long.valueOf(term.getFrequency()), (value1, value2) -> value2));
|
||||||
|
|
||||||
for (MapResult mapResult : mapResults) {
|
for (HanlpMapResult hanlpMapResult : hanlpMapResults) {
|
||||||
for (String nature : mapResult.getNatures()) {
|
for (String nature : hanlpMapResult.getNatures()) {
|
||||||
Long modelId = NatureHelper.getModelId(nature);
|
Long modelId = NatureHelper.getModelId(nature);
|
||||||
if (Objects.isNull(modelId)) {
|
if (Objects.isNull(modelId)) {
|
||||||
continue;
|
continue;
|
||||||
@@ -95,68 +60,27 @@ public class HanlpDictMapper implements SchemaMapper {
|
|||||||
if (Objects.isNull(elementType)) {
|
if (Objects.isNull(elementType)) {
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
|
Long elementID = NatureHelper.getElementID(nature);
|
||||||
SemanticService schemaService = ContextUtils.getBean(SemanticService.class);
|
SchemaElement element = getSchemaElement(modelId, elementType, elementID);
|
||||||
ModelSchema modelSchema = schemaService.getModelSchema(modelId);
|
if (element == null) {
|
||||||
|
|
||||||
BaseWordBuilder baseWordBuilder = WordBuilderFactory.get(DictWordType.getNatureType(nature));
|
|
||||||
Long elementID = baseWordBuilder.getElementID(nature);
|
|
||||||
Long frequency = wordNatureToFrequency.get(mapResult.getName() + nature);
|
|
||||||
|
|
||||||
SchemaElement elementDb = modelSchema.getElement(elementType, elementID);
|
|
||||||
if (Objects.isNull(elementDb)) {
|
|
||||||
log.info("element is null, elementType:{},elementID:{}", elementType, elementID);
|
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
SchemaElement element = new SchemaElement();
|
|
||||||
BeanUtils.copyProperties(elementDb, element);
|
|
||||||
element.setAlias(getAlias(elementDb));
|
|
||||||
if (element.getType().equals(SchemaElementType.VALUE)) {
|
if (element.getType().equals(SchemaElementType.VALUE)) {
|
||||||
element.setName(mapResult.getName());
|
element.setName(hanlpMapResult.getName());
|
||||||
}
|
}
|
||||||
|
Long frequency = wordNatureToFrequency.get(hanlpMapResult.getName() + nature);
|
||||||
SchemaElementMatch schemaElementMatch = SchemaElementMatch.builder()
|
SchemaElementMatch schemaElementMatch = SchemaElementMatch.builder()
|
||||||
.element(element)
|
.element(element)
|
||||||
.frequency(frequency)
|
.frequency(frequency)
|
||||||
.word(mapResult.getName())
|
.word(hanlpMapResult.getName())
|
||||||
.similarity(mapResult.getSimilarity())
|
.similarity(hanlpMapResult.getSimilarity())
|
||||||
.detectWord(mapResult.getDetectWord())
|
.detectWord(hanlpMapResult.getDetectWord())
|
||||||
.build();
|
.build();
|
||||||
|
|
||||||
Map<Long, List<SchemaElementMatch>> modelElementMatches = schemaMap.getModelElementMatches();
|
addToSchemaMap(schemaMap, modelId, schemaElementMatch);
|
||||||
List<SchemaElementMatch> schemaElementMatches = modelElementMatches.putIfAbsent(modelId,
|
|
||||||
new ArrayList<>());
|
|
||||||
if (schemaElementMatches == null) {
|
|
||||||
schemaElementMatches = modelElementMatches.get(modelId);
|
|
||||||
}
|
|
||||||
schemaElementMatches.add(schemaElementMatch);
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
private List<MapResult> getMatches(Map<MatchText, List<MapResult>> matchResult) {
|
|
||||||
List<MapResult> matches = new ArrayList<>();
|
|
||||||
if (Objects.isNull(matchResult)) {
|
|
||||||
return matches;
|
|
||||||
}
|
|
||||||
Optional<List<MapResult>> first = matchResult.entrySet().stream()
|
|
||||||
.filter(entry -> CollectionUtils.isNotEmpty(entry.getValue()))
|
|
||||||
.map(entry -> entry.getValue()).findFirst();
|
|
||||||
|
|
||||||
if (first.isPresent()) {
|
|
||||||
matches = first.get();
|
|
||||||
}
|
|
||||||
return matches;
|
|
||||||
}
|
|
||||||
|
|
||||||
public List<String> getAlias(SchemaElement element) {
|
|
||||||
if (!SchemaElementType.VALUE.equals(element.getType())) {
|
|
||||||
return element.getAlias();
|
|
||||||
}
|
|
||||||
if (CollectionUtils.isNotEmpty(element.getAlias()) && StringUtils.isNotEmpty(element.getName())) {
|
|
||||||
return element.getAlias().stream()
|
|
||||||
.filter(aliasItem -> aliasItem.contains(element.getName()))
|
|
||||||
.collect(Collectors.toList());
|
|
||||||
}
|
|
||||||
return element.getAlias();
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -0,0 +1,123 @@
|
|||||||
|
package com.tencent.supersonic.chat.mapper;
|
||||||
|
|
||||||
|
import com.hankcs.hanlp.seg.common.Term;
|
||||||
|
import com.tencent.supersonic.chat.api.pojo.QueryContext;
|
||||||
|
import com.tencent.supersonic.chat.api.pojo.request.QueryReq;
|
||||||
|
import com.tencent.supersonic.chat.config.OptimizationConfig;
|
||||||
|
import com.tencent.supersonic.common.pojo.Constants;
|
||||||
|
import com.tencent.supersonic.knowledge.dictionary.HanlpMapResult;
|
||||||
|
import com.tencent.supersonic.knowledge.service.SearchService;
|
||||||
|
import java.util.HashMap;
|
||||||
|
import java.util.LinkedHashSet;
|
||||||
|
import java.util.List;
|
||||||
|
import java.util.Map;
|
||||||
|
import java.util.Objects;
|
||||||
|
import java.util.Set;
|
||||||
|
import java.util.stream.Collectors;
|
||||||
|
import lombok.extern.slf4j.Slf4j;
|
||||||
|
import org.apache.commons.collections.CollectionUtils;
|
||||||
|
import org.apache.commons.lang3.StringUtils;
|
||||||
|
import org.springframework.beans.factory.annotation.Autowired;
|
||||||
|
import org.springframework.stereotype.Service;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* match strategy implement
|
||||||
|
*/
|
||||||
|
@Service
|
||||||
|
@Slf4j
|
||||||
|
public class HanlpDictMatchStrategy extends BaseMatchStrategy<HanlpMapResult> {
|
||||||
|
|
||||||
|
@Autowired
|
||||||
|
private MapperHelper mapperHelper;
|
||||||
|
|
||||||
|
@Autowired
|
||||||
|
private OptimizationConfig optimizationConfig;
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public Map<MatchText, List<HanlpMapResult>> match(QueryContext queryContext, List<Term> terms,
|
||||||
|
Set<Long> detectModelIds) {
|
||||||
|
QueryReq queryReq = queryContext.getRequest();
|
||||||
|
String text = queryReq.getQueryText();
|
||||||
|
if (Objects.isNull(terms) || StringUtils.isEmpty(text)) {
|
||||||
|
return null;
|
||||||
|
}
|
||||||
|
|
||||||
|
log.debug("retryCount:{},terms:{},,detectModelIds:{}", terms, detectModelIds);
|
||||||
|
|
||||||
|
List<HanlpMapResult> detects = detect(queryContext, terms, detectModelIds);
|
||||||
|
Map<MatchText, List<HanlpMapResult>> result = new HashMap<>();
|
||||||
|
|
||||||
|
result.put(MatchText.builder().regText(text).detectSegment(text).build(), detects);
|
||||||
|
return result;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public boolean needDelete(HanlpMapResult oneRoundResult, HanlpMapResult existResult) {
|
||||||
|
return getMapKey(oneRoundResult).equals(getMapKey(existResult))
|
||||||
|
&& existResult.getDetectWord().length() < oneRoundResult.getDetectWord().length();
|
||||||
|
}
|
||||||
|
|
||||||
|
public void detectByStep(QueryContext queryContext, Set<HanlpMapResult> existResults, Set<Long> detectModelIds,
|
||||||
|
Integer startIndex, Integer index, int offset) {
|
||||||
|
QueryReq queryReq = queryContext.getRequest();
|
||||||
|
String text = queryReq.getQueryText();
|
||||||
|
Integer agentId = queryReq.getAgentId();
|
||||||
|
String detectSegment = text.substring(startIndex, index);
|
||||||
|
|
||||||
|
// step1. pre search
|
||||||
|
Integer oneDetectionMaxSize = optimizationConfig.getOneDetectionMaxSize();
|
||||||
|
LinkedHashSet<HanlpMapResult> hanlpMapResults = SearchService.prefixSearch(detectSegment, oneDetectionMaxSize,
|
||||||
|
agentId,
|
||||||
|
detectModelIds).stream().collect(Collectors.toCollection(LinkedHashSet::new));
|
||||||
|
// step2. suffix search
|
||||||
|
LinkedHashSet<HanlpMapResult> suffixHanlpMapResults = SearchService.suffixSearch(detectSegment,
|
||||||
|
oneDetectionMaxSize, agentId, detectModelIds).stream()
|
||||||
|
.collect(Collectors.toCollection(LinkedHashSet::new));
|
||||||
|
|
||||||
|
hanlpMapResults.addAll(suffixHanlpMapResults);
|
||||||
|
|
||||||
|
if (CollectionUtils.isEmpty(hanlpMapResults)) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
// step3. merge pre/suffix result
|
||||||
|
hanlpMapResults = hanlpMapResults.stream().sorted((a, b) -> -(b.getName().length() - a.getName().length()))
|
||||||
|
.collect(Collectors.toCollection(LinkedHashSet::new));
|
||||||
|
|
||||||
|
// step4. filter by similarity
|
||||||
|
hanlpMapResults = hanlpMapResults.stream()
|
||||||
|
.filter(term -> mapperHelper.getSimilarity(detectSegment, term.getName())
|
||||||
|
>= mapperHelper.getThresholdMatch(term.getNatures()))
|
||||||
|
.filter(term -> CollectionUtils.isNotEmpty(term.getNatures()))
|
||||||
|
.collect(Collectors.toCollection(LinkedHashSet::new));
|
||||||
|
|
||||||
|
log.info("after isSimilarity parseResults:{}", hanlpMapResults);
|
||||||
|
|
||||||
|
hanlpMapResults = hanlpMapResults.stream().map(parseResult -> {
|
||||||
|
parseResult.setOffset(offset);
|
||||||
|
parseResult.setSimilarity(mapperHelper.getSimilarity(detectSegment, parseResult.getName()));
|
||||||
|
return parseResult;
|
||||||
|
}).collect(Collectors.toCollection(LinkedHashSet::new));
|
||||||
|
|
||||||
|
// step5. take only one dimension or 10 metric/dimension value per rond.
|
||||||
|
List<HanlpMapResult> dimensionMetrics = hanlpMapResults.stream()
|
||||||
|
.filter(entry -> mapperHelper.existDimensionValues(entry.getNatures()))
|
||||||
|
.collect(Collectors.toList())
|
||||||
|
.stream()
|
||||||
|
.limit(1)
|
||||||
|
.collect(Collectors.toList());
|
||||||
|
|
||||||
|
Integer oneDetectionSize = optimizationConfig.getOneDetectionSize();
|
||||||
|
List<HanlpMapResult> oneRoundResults = hanlpMapResults.stream().limit(oneDetectionSize)
|
||||||
|
.collect(Collectors.toList());
|
||||||
|
|
||||||
|
if (CollectionUtils.isNotEmpty(dimensionMetrics)) {
|
||||||
|
oneRoundResults = dimensionMetrics;
|
||||||
|
}
|
||||||
|
// step6. select mapResul in one round
|
||||||
|
selectResultInOneRound(existResults, oneRoundResults);
|
||||||
|
}
|
||||||
|
|
||||||
|
public String getMapKey(HanlpMapResult a) {
|
||||||
|
return a.getName() + Constants.UNDERLINE + String.join(Constants.UNDERLINE, a.getNatures());
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -1,11 +1,13 @@
|
|||||||
package com.tencent.supersonic.chat.mapper;
|
package com.tencent.supersonic.chat.mapper;
|
||||||
|
|
||||||
import com.hankcs.hanlp.algorithm.EditDistance;
|
import com.hankcs.hanlp.algorithm.EditDistance;
|
||||||
|
import com.hankcs.hanlp.seg.common.Term;
|
||||||
import com.tencent.supersonic.chat.api.pojo.request.QueryReq;
|
import com.tencent.supersonic.chat.api.pojo.request.QueryReq;
|
||||||
import com.tencent.supersonic.chat.config.OptimizationConfig;
|
import com.tencent.supersonic.chat.config.OptimizationConfig;
|
||||||
import com.tencent.supersonic.chat.service.AgentService;
|
import com.tencent.supersonic.chat.service.AgentService;
|
||||||
import com.tencent.supersonic.common.util.ContextUtils;
|
import com.tencent.supersonic.common.util.ContextUtils;
|
||||||
import com.tencent.supersonic.knowledge.utils.NatureHelper;
|
import com.tencent.supersonic.knowledge.utils.NatureHelper;
|
||||||
|
import java.util.Comparator;
|
||||||
import java.util.HashSet;
|
import java.util.HashSet;
|
||||||
import java.util.List;
|
import java.util.List;
|
||||||
import java.util.Map;
|
import java.util.Map;
|
||||||
@@ -39,10 +41,14 @@ public class MapperHelper {
|
|||||||
return index;
|
return index;
|
||||||
}
|
}
|
||||||
|
|
||||||
public Integer getStepOffset(List<Integer> termList, Integer index) {
|
|
||||||
|
public Integer getStepOffset(List<Term> termList, Integer index) {
|
||||||
|
List<Integer> offsetList = termList.stream().sorted(Comparator.comparing(Term::getOffset))
|
||||||
|
.map(term -> term.getOffset()).collect(Collectors.toList());
|
||||||
|
|
||||||
for (int j = 0; j < termList.size() - 1; j++) {
|
for (int j = 0; j < termList.size() - 1; j++) {
|
||||||
if (termList.get(j) <= index && termList.get(j + 1) > index) {
|
if (offsetList.get(j) <= index && offsetList.get(j + 1) > index) {
|
||||||
return termList.get(j);
|
return offsetList.get(j);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return index;
|
return index;
|
||||||
@@ -88,7 +94,7 @@ public class MapperHelper {
|
|||||||
|
|
||||||
AgentService agentService = ContextUtils.getBean(AgentService.class);
|
AgentService agentService = ContextUtils.getBean(AgentService.class);
|
||||||
|
|
||||||
Set<Long> detectModelIds = agentService.getDslToolsModelIds(request.getAgentId(), null);
|
Set<Long> detectModelIds = agentService.getModelIds(request.getAgentId(), null);
|
||||||
//contains all
|
//contains all
|
||||||
if (agentService.containsAllModel(detectModelIds)) {
|
if (agentService.containsAllModel(detectModelIds)) {
|
||||||
if (Objects.nonNull(modelId) && modelId > 0) {
|
if (Objects.nonNull(modelId) && modelId > 0) {
|
||||||
|
|||||||
@@ -1,8 +1,7 @@
|
|||||||
package com.tencent.supersonic.chat.mapper;
|
package com.tencent.supersonic.chat.mapper;
|
||||||
|
|
||||||
import com.hankcs.hanlp.seg.common.Term;
|
import com.hankcs.hanlp.seg.common.Term;
|
||||||
import com.tencent.supersonic.chat.api.pojo.request.QueryReq;
|
import com.tencent.supersonic.chat.api.pojo.QueryContext;
|
||||||
import com.tencent.supersonic.knowledge.dictionary.MapResult;
|
|
||||||
import java.util.List;
|
import java.util.List;
|
||||||
import java.util.Map;
|
import java.util.Map;
|
||||||
import java.util.Set;
|
import java.util.Set;
|
||||||
@@ -10,8 +9,8 @@ import java.util.Set;
|
|||||||
/**
|
/**
|
||||||
* match strategy
|
* match strategy
|
||||||
*/
|
*/
|
||||||
public interface MatchStrategy {
|
public interface MatchStrategy<T> {
|
||||||
|
|
||||||
Map<MatchText, List<MapResult>> match(QueryReq queryReq, List<Term> terms, Set<Long> detectModelId);
|
Map<MatchText, List<T>> match(QueryContext queryContext, List<Term> terms, Set<Long> detectModelId);
|
||||||
|
|
||||||
}
|
}
|
||||||
@@ -0,0 +1,52 @@
|
|||||||
|
package com.tencent.supersonic.chat.mapper;
|
||||||
|
|
||||||
|
import com.tencent.supersonic.chat.api.component.SchemaMapper;
|
||||||
|
import com.tencent.supersonic.chat.api.pojo.QueryContext;
|
||||||
|
import com.tencent.supersonic.chat.api.pojo.SchemaElementMatch;
|
||||||
|
import com.tencent.supersonic.chat.api.pojo.SchemaMapInfo;
|
||||||
|
import com.tencent.supersonic.chat.api.pojo.SchemaModelClusterMapInfo;
|
||||||
|
import com.tencent.supersonic.chat.api.pojo.SemanticSchema;
|
||||||
|
import com.tencent.supersonic.chat.utils.ModelClusterBuilder;
|
||||||
|
import com.tencent.supersonic.common.pojo.ModelCluster;
|
||||||
|
import com.tencent.supersonic.common.util.ContextUtils;
|
||||||
|
import com.tencent.supersonic.knowledge.service.SchemaService;
|
||||||
|
|
||||||
|
import java.util.ArrayList;
|
||||||
|
import java.util.HashMap;
|
||||||
|
import java.util.List;
|
||||||
|
import java.util.Map;
|
||||||
|
import java.util.Set;
|
||||||
|
import java.util.stream.Collectors;
|
||||||
|
|
||||||
|
public class ModelClusterMapper implements SchemaMapper {
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public void map(QueryContext queryContext) {
|
||||||
|
SchemaService schemaService = ContextUtils.getBean(SchemaService.class);
|
||||||
|
SemanticSchema semanticSchema = schemaService.getSemanticSchema();
|
||||||
|
SchemaMapInfo schemaMapInfo = queryContext.getMapInfo();
|
||||||
|
List<ModelCluster> modelClusters = buildModelClusterMatched(schemaMapInfo, semanticSchema);
|
||||||
|
Map<String, List<SchemaElementMatch>> modelClusterElementMatches = new HashMap<>();
|
||||||
|
for (ModelCluster modelCluster : modelClusters) {
|
||||||
|
for (Long modelId : schemaMapInfo.getMatchedModels()) {
|
||||||
|
if (modelCluster.getModelIds().contains(modelId)) {
|
||||||
|
modelClusterElementMatches.computeIfAbsent(modelCluster.getKey(), k -> new ArrayList<>())
|
||||||
|
.addAll(schemaMapInfo.getMatchedElements(modelId));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
SchemaModelClusterMapInfo modelClusterMapInfo = new SchemaModelClusterMapInfo();
|
||||||
|
modelClusterMapInfo.setModelElementMatches(modelClusterElementMatches);
|
||||||
|
queryContext.setModelClusterMapInfo(modelClusterMapInfo);
|
||||||
|
}
|
||||||
|
|
||||||
|
private List<ModelCluster> buildModelClusterMatched(SchemaMapInfo schemaMapInfo,
|
||||||
|
SemanticSchema semanticSchema) {
|
||||||
|
Set<Long> matchedModels = schemaMapInfo.getMatchedModels();
|
||||||
|
List<ModelCluster> modelClusters = ModelClusterBuilder.buildModelClusters(semanticSchema);
|
||||||
|
return modelClusters.stream().map(ModelCluster::getModelIds).peek(modelCluster -> {
|
||||||
|
modelCluster.removeIf(model -> !matchedModels.contains(model));
|
||||||
|
}).filter(modelCluster -> modelCluster.size() > 0).map(ModelCluster::build).collect(Collectors.toList());
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
||||||
@@ -1,163 +0,0 @@
|
|||||||
package com.tencent.supersonic.chat.mapper;
|
|
||||||
|
|
||||||
import com.hankcs.hanlp.seg.common.Term;
|
|
||||||
import com.tencent.supersonic.chat.api.pojo.request.QueryReq;
|
|
||||||
import com.tencent.supersonic.chat.config.OptimizationConfig;
|
|
||||||
import com.tencent.supersonic.common.pojo.Constants;
|
|
||||||
import com.tencent.supersonic.knowledge.dictionary.MapResult;
|
|
||||||
import com.tencent.supersonic.knowledge.service.SearchService;
|
|
||||||
import java.util.ArrayList;
|
|
||||||
import java.util.Comparator;
|
|
||||||
import java.util.HashMap;
|
|
||||||
import java.util.LinkedHashSet;
|
|
||||||
import java.util.List;
|
|
||||||
import java.util.Map;
|
|
||||||
import java.util.Objects;
|
|
||||||
import java.util.Set;
|
|
||||||
import java.util.stream.Collectors;
|
|
||||||
import lombok.extern.slf4j.Slf4j;
|
|
||||||
import org.apache.commons.collections.CollectionUtils;
|
|
||||||
import org.apache.commons.compress.utils.Lists;
|
|
||||||
import org.apache.commons.lang3.StringUtils;
|
|
||||||
import org.springframework.beans.factory.annotation.Autowired;
|
|
||||||
import org.springframework.stereotype.Service;
|
|
||||||
|
|
||||||
/**
|
|
||||||
* match strategy implement
|
|
||||||
*/
|
|
||||||
@Service
|
|
||||||
@Slf4j
|
|
||||||
public class QueryMatchStrategy implements MatchStrategy {
|
|
||||||
|
|
||||||
@Autowired
|
|
||||||
private MapperHelper mapperHelper;
|
|
||||||
|
|
||||||
@Autowired
|
|
||||||
private OptimizationConfig optimizationConfig;
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public Map<MatchText, List<MapResult>> match(QueryReq queryReq, List<Term> terms, Set<Long> detectModelIds) {
|
|
||||||
String text = queryReq.getQueryText();
|
|
||||||
if (Objects.isNull(terms) || StringUtils.isEmpty(text)) {
|
|
||||||
return null;
|
|
||||||
}
|
|
||||||
|
|
||||||
Map<Integer, Integer> regOffsetToLength = terms.stream().sorted(Comparator.comparing(Term::length))
|
|
||||||
.collect(Collectors.toMap(Term::getOffset, term -> term.word.length(), (value1, value2) -> value2));
|
|
||||||
|
|
||||||
List<Integer> offsetList = terms.stream().sorted(Comparator.comparing(Term::getOffset))
|
|
||||||
.map(term -> term.getOffset()).collect(Collectors.toList());
|
|
||||||
|
|
||||||
log.debug("retryCount:{},terms:{},regOffsetToLength:{},offsetList:{},detectModelIds:{}", terms,
|
|
||||||
regOffsetToLength, offsetList, detectModelIds);
|
|
||||||
|
|
||||||
List<MapResult> detects = detect(queryReq, regOffsetToLength, offsetList, detectModelIds);
|
|
||||||
Map<MatchText, List<MapResult>> result = new HashMap<>();
|
|
||||||
|
|
||||||
result.put(MatchText.builder().regText(text).detectSegment(text).build(), detects);
|
|
||||||
return result;
|
|
||||||
}
|
|
||||||
|
|
||||||
private List<MapResult> detect(QueryReq queryReq, Map<Integer, Integer> regOffsetToLength, List<Integer> offsetList,
|
|
||||||
Set<Long> detectModelIds) {
|
|
||||||
String text = queryReq.getQueryText();
|
|
||||||
List<MapResult> results = Lists.newArrayList();
|
|
||||||
|
|
||||||
for (Integer index = 0; index <= text.length() - 1; ) {
|
|
||||||
|
|
||||||
Set<MapResult> mapResultRowSet = new LinkedHashSet();
|
|
||||||
|
|
||||||
for (Integer i = index; i <= text.length(); ) {
|
|
||||||
int offset = mapperHelper.getStepOffset(offsetList, index);
|
|
||||||
i = mapperHelper.getStepIndex(regOffsetToLength, i);
|
|
||||||
if (i <= text.length()) {
|
|
||||||
List<MapResult> mapResults = detectByStep(queryReq, detectModelIds, index, i, offset);
|
|
||||||
selectMapResultInOneRound(mapResultRowSet, mapResults);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
index = mapperHelper.getStepIndex(regOffsetToLength, index);
|
|
||||||
results.addAll(mapResultRowSet);
|
|
||||||
}
|
|
||||||
return results;
|
|
||||||
}
|
|
||||||
|
|
||||||
private void selectMapResultInOneRound(Set<MapResult> mapResultRowSet, List<MapResult> mapResults) {
|
|
||||||
for (MapResult mapResult : mapResults) {
|
|
||||||
if (mapResultRowSet.contains(mapResult)) {
|
|
||||||
boolean isDeleted = mapResultRowSet.removeIf(
|
|
||||||
entry -> {
|
|
||||||
boolean deleted = getMapKey(mapResult).equals(getMapKey(entry))
|
|
||||||
&& entry.getDetectWord().length() < mapResult.getDetectWord().length();
|
|
||||||
if (deleted) {
|
|
||||||
log.info("deleted entry:{}", entry);
|
|
||||||
}
|
|
||||||
return deleted;
|
|
||||||
}
|
|
||||||
);
|
|
||||||
if (isDeleted) {
|
|
||||||
log.info("deleted, add mapResult:{}", mapResult);
|
|
||||||
mapResultRowSet.add(mapResult);
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
mapResultRowSet.add(mapResult);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
private String getMapKey(MapResult a) {
|
|
||||||
return a.getName() + Constants.UNDERLINE + String.join(Constants.UNDERLINE, a.getNatures());
|
|
||||||
}
|
|
||||||
|
|
||||||
private List<MapResult> detectByStep(QueryReq queryReq, Set<Long> detectModelIds, Integer index, Integer i,
|
|
||||||
int offset) {
|
|
||||||
String text = queryReq.getQueryText();
|
|
||||||
Integer agentId = queryReq.getAgentId();
|
|
||||||
String detectSegment = text.substring(index, i);
|
|
||||||
|
|
||||||
// step1. pre search
|
|
||||||
Integer oneDetectionMaxSize = optimizationConfig.getOneDetectionMaxSize();
|
|
||||||
LinkedHashSet<MapResult> mapResults = SearchService.prefixSearch(detectSegment, oneDetectionMaxSize, agentId,
|
|
||||||
detectModelIds).stream().collect(Collectors.toCollection(LinkedHashSet::new));
|
|
||||||
// step2. suffix search
|
|
||||||
LinkedHashSet<MapResult> suffixMapResults = SearchService.suffixSearch(detectSegment, oneDetectionMaxSize,
|
|
||||||
agentId, detectModelIds).stream().collect(Collectors.toCollection(LinkedHashSet::new));
|
|
||||||
|
|
||||||
mapResults.addAll(suffixMapResults);
|
|
||||||
|
|
||||||
if (CollectionUtils.isEmpty(mapResults)) {
|
|
||||||
return new ArrayList<>();
|
|
||||||
}
|
|
||||||
// step3. merge pre/suffix result
|
|
||||||
mapResults = mapResults.stream().sorted((a, b) -> -(b.getName().length() - a.getName().length()))
|
|
||||||
.collect(Collectors.toCollection(LinkedHashSet::new));
|
|
||||||
|
|
||||||
// step4. filter by similarity
|
|
||||||
mapResults = mapResults.stream()
|
|
||||||
.filter(term -> mapperHelper.getSimilarity(detectSegment, term.getName())
|
|
||||||
>= mapperHelper.getThresholdMatch(term.getNatures()))
|
|
||||||
.filter(term -> CollectionUtils.isNotEmpty(term.getNatures()))
|
|
||||||
.collect(Collectors.toCollection(LinkedHashSet::new));
|
|
||||||
|
|
||||||
log.info("after isSimilarity parseResults:{}", mapResults);
|
|
||||||
|
|
||||||
mapResults = mapResults.stream().map(parseResult -> {
|
|
||||||
parseResult.setOffset(offset);
|
|
||||||
parseResult.setSimilarity(mapperHelper.getSimilarity(detectSegment, parseResult.getName()));
|
|
||||||
return parseResult;
|
|
||||||
}).collect(Collectors.toCollection(LinkedHashSet::new));
|
|
||||||
|
|
||||||
// step5. take only one dimension or 10 metric/dimension value per rond.
|
|
||||||
List<MapResult> dimensionMetrics = mapResults.stream()
|
|
||||||
.filter(entry -> mapperHelper.existDimensionValues(entry.getNatures()))
|
|
||||||
.collect(Collectors.toList())
|
|
||||||
.stream()
|
|
||||||
.limit(1)
|
|
||||||
.collect(Collectors.toList());
|
|
||||||
|
|
||||||
if (CollectionUtils.isNotEmpty(dimensionMetrics)) {
|
|
||||||
return dimensionMetrics;
|
|
||||||
} else {
|
|
||||||
return mapResults.stream().limit(optimizationConfig.getOneDetectionSize()).collect(Collectors.toList());
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
@@ -2,9 +2,10 @@ package com.tencent.supersonic.chat.mapper;
|
|||||||
|
|
||||||
import com.google.common.collect.Lists;
|
import com.google.common.collect.Lists;
|
||||||
import com.hankcs.hanlp.seg.common.Term;
|
import com.hankcs.hanlp.seg.common.Term;
|
||||||
|
import com.tencent.supersonic.chat.api.pojo.QueryContext;
|
||||||
import com.tencent.supersonic.chat.api.pojo.request.QueryReq;
|
import com.tencent.supersonic.chat.api.pojo.request.QueryReq;
|
||||||
import com.tencent.supersonic.common.pojo.enums.DictWordType;
|
import com.tencent.supersonic.common.pojo.enums.DictWordType;
|
||||||
import com.tencent.supersonic.knowledge.dictionary.MapResult;
|
import com.tencent.supersonic.knowledge.dictionary.HanlpMapResult;
|
||||||
import com.tencent.supersonic.knowledge.service.SearchService;
|
import com.tencent.supersonic.knowledge.service.SearchService;
|
||||||
import java.util.List;
|
import java.util.List;
|
||||||
import java.util.Map;
|
import java.util.Map;
|
||||||
@@ -20,17 +21,16 @@ import org.springframework.stereotype.Service;
|
|||||||
* match strategy implement
|
* match strategy implement
|
||||||
*/
|
*/
|
||||||
@Service
|
@Service
|
||||||
public class SearchMatchStrategy implements MatchStrategy {
|
public class SearchMatchStrategy extends BaseMatchStrategy<HanlpMapResult> {
|
||||||
|
|
||||||
private static final int SEARCH_SIZE = 3;
|
private static final int SEARCH_SIZE = 3;
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public Map<MatchText, List<MapResult>> match(QueryReq queryReq, List<Term> originals, Set<Long> detectModelIds) {
|
public Map<MatchText, List<HanlpMapResult>> match(QueryContext queryContext, List<Term> originals,
|
||||||
|
Set<Long> detectModelIds) {
|
||||||
|
QueryReq queryReq = queryContext.getRequest();
|
||||||
String text = queryReq.getQueryText();
|
String text = queryReq.getQueryText();
|
||||||
Map<Integer, Integer> regOffsetToLength = originals.stream()
|
Map<Integer, Integer> regOffsetToLength = getRegOffsetToLength(originals);
|
||||||
.filter(entry -> !entry.nature.toString().startsWith(DictWordType.NATURE_SPILT))
|
|
||||||
.collect(Collectors.toMap(Term::getOffset, value -> value.word.length(),
|
|
||||||
(value1, value2) -> value2));
|
|
||||||
|
|
||||||
List<Integer> detectIndexList = Lists.newArrayList();
|
List<Integer> detectIndexList = Lists.newArrayList();
|
||||||
|
|
||||||
@@ -46,19 +46,19 @@ public class SearchMatchStrategy implements MatchStrategy {
|
|||||||
index++;
|
index++;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
Map<MatchText, List<MapResult>> regTextMap = new ConcurrentHashMap<>();
|
Map<MatchText, List<HanlpMapResult>> regTextMap = new ConcurrentHashMap<>();
|
||||||
detectIndexList.stream().parallel().forEach(detectIndex -> {
|
detectIndexList.stream().parallel().forEach(detectIndex -> {
|
||||||
String regText = text.substring(0, detectIndex);
|
String regText = text.substring(0, detectIndex);
|
||||||
String detectSegment = text.substring(detectIndex);
|
String detectSegment = text.substring(detectIndex);
|
||||||
|
|
||||||
if (StringUtils.isNotEmpty(detectSegment)) {
|
if (StringUtils.isNotEmpty(detectSegment)) {
|
||||||
List<MapResult> mapResults = SearchService.prefixSearch(detectSegment,
|
List<HanlpMapResult> hanlpMapResults = SearchService.prefixSearch(detectSegment,
|
||||||
SearchService.SEARCH_SIZE, queryReq.getAgentId(), detectModelIds);
|
SearchService.SEARCH_SIZE, queryReq.getAgentId(), detectModelIds);
|
||||||
List<MapResult> suffixMapResults = SearchService.suffixSearch(detectSegment, SEARCH_SIZE,
|
List<HanlpMapResult> suffixHanlpMapResults = SearchService.suffixSearch(
|
||||||
queryReq.getAgentId(), detectModelIds);
|
detectSegment, SEARCH_SIZE, queryReq.getAgentId(), detectModelIds);
|
||||||
mapResults.addAll(suffixMapResults);
|
hanlpMapResults.addAll(suffixHanlpMapResults);
|
||||||
// remove entity name where search
|
// remove entity name where search
|
||||||
mapResults = mapResults.stream().filter(entry -> {
|
hanlpMapResults = hanlpMapResults.stream().filter(entry -> {
|
||||||
List<String> natures = entry.getNatures().stream()
|
List<String> natures = entry.getNatures().stream()
|
||||||
.filter(nature -> !nature.endsWith(DictWordType.ENTITY.getType()))
|
.filter(nature -> !nature.endsWith(DictWordType.ENTITY.getType()))
|
||||||
.collect(Collectors.toList());
|
.collect(Collectors.toList());
|
||||||
@@ -71,10 +71,27 @@ public class SearchMatchStrategy implements MatchStrategy {
|
|||||||
.regText(regText)
|
.regText(regText)
|
||||||
.detectSegment(detectSegment)
|
.detectSegment(detectSegment)
|
||||||
.build();
|
.build();
|
||||||
regTextMap.put(matchText, mapResults);
|
regTextMap.put(matchText, hanlpMapResults);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
);
|
);
|
||||||
return regTextMap;
|
return regTextMap;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public boolean needDelete(HanlpMapResult oneRoundResult, HanlpMapResult existResult) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public String getMapKey(HanlpMapResult a) {
|
||||||
|
return null;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public void detectByStep(QueryContext queryContext, Set<HanlpMapResult> results, Set<Long> detectModelIds,
|
||||||
|
Integer startIndex,
|
||||||
|
Integer i, int offset) {
|
||||||
|
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -0,0 +1,71 @@
|
|||||||
|
package com.tencent.supersonic.chat.parser;
|
||||||
|
|
||||||
|
import com.alibaba.fastjson.JSON;
|
||||||
|
import com.tencent.supersonic.chat.config.LLMParserConfig;
|
||||||
|
import com.tencent.supersonic.chat.parser.plugin.function.FunctionCallConfig;
|
||||||
|
import com.tencent.supersonic.chat.parser.plugin.function.FunctionReq;
|
||||||
|
import com.tencent.supersonic.chat.parser.plugin.function.FunctionResp;
|
||||||
|
import com.tencent.supersonic.chat.query.llm.s2sql.LLMReq;
|
||||||
|
import com.tencent.supersonic.chat.query.llm.s2sql.LLMResp;
|
||||||
|
import com.tencent.supersonic.common.util.ContextUtils;
|
||||||
|
import com.tencent.supersonic.common.util.JsonUtil;
|
||||||
|
import java.net.URI;
|
||||||
|
import java.net.URL;
|
||||||
|
import lombok.extern.slf4j.Slf4j;
|
||||||
|
import org.springframework.http.HttpEntity;
|
||||||
|
import org.springframework.http.HttpHeaders;
|
||||||
|
import org.springframework.http.HttpMethod;
|
||||||
|
import org.springframework.http.MediaType;
|
||||||
|
import org.springframework.http.ResponseEntity;
|
||||||
|
import org.springframework.web.client.RestTemplate;
|
||||||
|
import org.springframework.web.util.UriComponentsBuilder;
|
||||||
|
|
||||||
|
@Slf4j
|
||||||
|
public class HttpLLMInterpreter implements LLMInterpreter {
|
||||||
|
|
||||||
|
public LLMResp query2sql(LLMReq llmReq, String modelClusterKey) {
|
||||||
|
|
||||||
|
long startTime = System.currentTimeMillis();
|
||||||
|
log.info("requestLLM request, modelId:{},llmReq:{}", modelClusterKey, llmReq);
|
||||||
|
try {
|
||||||
|
LLMParserConfig llmParserConfig = ContextUtils.getBean(LLMParserConfig.class);
|
||||||
|
|
||||||
|
URL url = new URL(new URL(llmParserConfig.getUrl()), llmParserConfig.getQueryToSqlPath());
|
||||||
|
HttpHeaders headers = new HttpHeaders();
|
||||||
|
headers.setContentType(MediaType.APPLICATION_JSON);
|
||||||
|
HttpEntity<String> entity = new HttpEntity<>(JsonUtil.toString(llmReq), headers);
|
||||||
|
RestTemplate restTemplate = ContextUtils.getBean(RestTemplate.class);
|
||||||
|
ResponseEntity<LLMResp> responseEntity = restTemplate.exchange(url.toString(), HttpMethod.POST, entity,
|
||||||
|
LLMResp.class);
|
||||||
|
|
||||||
|
log.info("requestLLM response,cost:{}, questUrl:{} \n entity:{} \n body:{}",
|
||||||
|
System.currentTimeMillis() - startTime, url, entity, responseEntity.getBody());
|
||||||
|
return responseEntity.getBody();
|
||||||
|
} catch (Exception e) {
|
||||||
|
log.error("requestLLM error", e);
|
||||||
|
}
|
||||||
|
return null;
|
||||||
|
}
|
||||||
|
|
||||||
|
public FunctionResp requestFunction(FunctionReq functionReq) {
|
||||||
|
FunctionCallConfig functionCallInfoConfig = ContextUtils.getBean(FunctionCallConfig.class);
|
||||||
|
String url = functionCallInfoConfig.getUrl() + functionCallInfoConfig.getPluginSelectPath();
|
||||||
|
HttpHeaders headers = new HttpHeaders();
|
||||||
|
long startTime = System.currentTimeMillis();
|
||||||
|
headers.setContentType(MediaType.APPLICATION_JSON);
|
||||||
|
HttpEntity<String> entity = new HttpEntity<>(JSON.toJSONString(functionReq), headers);
|
||||||
|
URI requestUrl = UriComponentsBuilder.fromHttpUrl(url).build().encode().toUri();
|
||||||
|
RestTemplate restTemplate = ContextUtils.getBean(RestTemplate.class);
|
||||||
|
try {
|
||||||
|
log.info("requestFunction functionReq:{}", JsonUtil.toString(functionReq));
|
||||||
|
ResponseEntity<FunctionResp> responseEntity = restTemplate.exchange(requestUrl, HttpMethod.POST, entity,
|
||||||
|
FunctionResp.class);
|
||||||
|
log.info("requestFunction responseEntity:{},cost:{}", responseEntity,
|
||||||
|
System.currentTimeMillis() - startTime);
|
||||||
|
return responseEntity.getBody();
|
||||||
|
} catch (Exception e) {
|
||||||
|
log.error("requestFunction error", e);
|
||||||
|
}
|
||||||
|
return null;
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -0,0 +1,18 @@
|
|||||||
|
package com.tencent.supersonic.chat.parser;
|
||||||
|
|
||||||
|
import com.tencent.supersonic.chat.parser.plugin.function.FunctionReq;
|
||||||
|
import com.tencent.supersonic.chat.parser.plugin.function.FunctionResp;
|
||||||
|
import com.tencent.supersonic.chat.query.llm.s2sql.LLMReq;
|
||||||
|
import com.tencent.supersonic.chat.query.llm.s2sql.LLMResp;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Unified interpreter for invoking the llm layer.
|
||||||
|
*/
|
||||||
|
public interface LLMInterpreter {
|
||||||
|
|
||||||
|
|
||||||
|
LLMResp query2sql(LLMReq llmReq, String modelClusterKey);
|
||||||
|
|
||||||
|
FunctionResp requestFunction(FunctionReq functionReq);
|
||||||
|
|
||||||
|
}
|
||||||
@@ -0,0 +1,85 @@
|
|||||||
|
package com.tencent.supersonic.chat.parser;
|
||||||
|
|
||||||
|
import com.tencent.supersonic.auth.api.authentication.pojo.User;
|
||||||
|
import com.tencent.supersonic.chat.api.component.SemanticParser;
|
||||||
|
import com.tencent.supersonic.chat.api.component.SemanticQuery;
|
||||||
|
import com.tencent.supersonic.chat.api.pojo.ChatContext;
|
||||||
|
import com.tencent.supersonic.chat.api.pojo.QueryContext;
|
||||||
|
import com.tencent.supersonic.chat.api.pojo.SchemaElement;
|
||||||
|
import com.tencent.supersonic.chat.api.pojo.SemanticParseInfo;
|
||||||
|
import com.tencent.supersonic.chat.api.pojo.SemanticSchema;
|
||||||
|
import com.tencent.supersonic.chat.api.pojo.response.SqlInfo;
|
||||||
|
import com.tencent.supersonic.chat.query.llm.s2sql.S2SQLQuery;
|
||||||
|
import com.tencent.supersonic.chat.query.rule.RuleSemanticQuery;
|
||||||
|
import com.tencent.supersonic.chat.service.SemanticService;
|
||||||
|
import com.tencent.supersonic.common.pojo.QueryType;
|
||||||
|
import com.tencent.supersonic.common.util.ContextUtils;
|
||||||
|
import com.tencent.supersonic.common.util.jsqlparser.SqlParserSelectHelper;
|
||||||
|
import com.tencent.supersonic.knowledge.service.SchemaService;
|
||||||
|
import lombok.extern.slf4j.Slf4j;
|
||||||
|
import org.apache.commons.collections4.CollectionUtils;
|
||||||
|
import org.apache.commons.lang3.StringUtils;
|
||||||
|
|
||||||
|
import java.util.List;
|
||||||
|
import java.util.Objects;
|
||||||
|
import java.util.Set;
|
||||||
|
import java.util.stream.Collectors;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Query type parser, determine if the query is a metric query, an entity query,
|
||||||
|
* or another type of query.
|
||||||
|
*/
|
||||||
|
@Slf4j
|
||||||
|
public class QueryTypeParser implements SemanticParser {
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public void parse(QueryContext queryContext, ChatContext chatContext) {
|
||||||
|
|
||||||
|
List<SemanticQuery> candidateQueries = queryContext.getCandidateQueries();
|
||||||
|
User user = queryContext.getRequest().getUser();
|
||||||
|
|
||||||
|
for (SemanticQuery semanticQuery : candidateQueries) {
|
||||||
|
// 1.init S2SQL
|
||||||
|
semanticQuery.initS2Sql(user);
|
||||||
|
// 2.set queryType
|
||||||
|
QueryType queryType = getQueryType(semanticQuery);
|
||||||
|
semanticQuery.getParseInfo().setQueryType(queryType);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
private QueryType getQueryType(SemanticQuery semanticQuery) {
|
||||||
|
SemanticParseInfo parseInfo = semanticQuery.getParseInfo();
|
||||||
|
SqlInfo sqlInfo = parseInfo.getSqlInfo();
|
||||||
|
if (Objects.isNull(sqlInfo) || StringUtils.isBlank(sqlInfo.getS2SQL())) {
|
||||||
|
return QueryType.OTHER;
|
||||||
|
}
|
||||||
|
//1. entity queryType
|
||||||
|
Set<Long> modelIds = parseInfo.getModel().getModelIds();
|
||||||
|
if (semanticQuery instanceof RuleSemanticQuery || semanticQuery instanceof S2SQLQuery) {
|
||||||
|
//If all the fields in the SELECT statement are of tag type.
|
||||||
|
List<String> selectFields = SqlParserSelectHelper.getSelectFields(sqlInfo.getS2SQL());
|
||||||
|
SemanticService semanticService = ContextUtils.getBean(SemanticService.class);
|
||||||
|
SemanticSchema semanticSchema = semanticService.getSemanticSchema();
|
||||||
|
if (CollectionUtils.isNotEmpty(selectFields)) {
|
||||||
|
Set<String> tags = semanticSchema.getTags(modelIds).stream().map(SchemaElement::getName)
|
||||||
|
.collect(Collectors.toSet());
|
||||||
|
if (CollectionUtils.isNotEmpty(tags) && tags.containsAll(selectFields)) {
|
||||||
|
return QueryType.TAG;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
//2. metric queryType
|
||||||
|
List<String> selectFields = SqlParserSelectHelper.getSelectFields(sqlInfo.getS2SQL());
|
||||||
|
SemanticSchema semanticSchema = ContextUtils.getBean(SchemaService.class).getSemanticSchema();
|
||||||
|
List<SchemaElement> metrics = semanticSchema.getMetrics(modelIds);
|
||||||
|
if (CollectionUtils.isNotEmpty(metrics)) {
|
||||||
|
Set<String> metricNameSet = metrics.stream().map(SchemaElement::getName).collect(Collectors.toSet());
|
||||||
|
boolean containMetric = selectFields.stream().anyMatch(metricNameSet::contains);
|
||||||
|
if (containMetric) {
|
||||||
|
return QueryType.METRIC;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return QueryType.OTHER;
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
||||||
@@ -5,7 +5,7 @@ import com.tencent.supersonic.chat.api.component.SemanticQuery;
|
|||||||
import com.tencent.supersonic.chat.api.pojo.QueryContext;
|
import com.tencent.supersonic.chat.api.pojo.QueryContext;
|
||||||
import com.tencent.supersonic.chat.api.pojo.SemanticParseInfo;
|
import com.tencent.supersonic.chat.api.pojo.SemanticParseInfo;
|
||||||
import com.tencent.supersonic.chat.config.OptimizationConfig;
|
import com.tencent.supersonic.chat.config.OptimizationConfig;
|
||||||
import com.tencent.supersonic.chat.query.llm.dsl.DslQuery;
|
import com.tencent.supersonic.chat.query.llm.s2sql.S2SQLQuery;
|
||||||
import com.tencent.supersonic.common.util.ContextUtils;
|
import com.tencent.supersonic.common.util.ContextUtils;
|
||||||
import lombok.extern.slf4j.Slf4j;
|
import lombok.extern.slf4j.Slf4j;
|
||||||
|
|
||||||
@@ -20,7 +20,7 @@ public class SatisfactionChecker {
|
|||||||
// check all the parse info in candidate
|
// check all the parse info in candidate
|
||||||
public static boolean check(QueryContext queryContext) {
|
public static boolean check(QueryContext queryContext) {
|
||||||
for (SemanticQuery query : queryContext.getCandidateQueries()) {
|
for (SemanticQuery query : queryContext.getCandidateQueries()) {
|
||||||
if (query.getQueryMode().equals(DslQuery.QUERY_MODE)) {
|
if (query.getQueryMode().equals(S2SQLQuery.QUERY_MODE)) {
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
if (checkThreshold(queryContext.getRequest().getQueryText(), query.getParseInfo())) {
|
if (checkThreshold(queryContext.getRequest().getQueryText(), query.getParseInfo())) {
|
||||||
|
|||||||
@@ -1,25 +0,0 @@
|
|||||||
package com.tencent.supersonic.chat.parser.llm.dsl;
|
|
||||||
|
|
||||||
import com.tencent.supersonic.chat.agent.tool.DslTool;
|
|
||||||
import com.tencent.supersonic.chat.api.pojo.request.QueryReq;
|
|
||||||
import com.tencent.supersonic.chat.query.llm.dsl.LLMReq;
|
|
||||||
import com.tencent.supersonic.chat.query.llm.dsl.LLMResp;
|
|
||||||
import lombok.AllArgsConstructor;
|
|
||||||
import lombok.Builder;
|
|
||||||
import lombok.Data;
|
|
||||||
import lombok.NoArgsConstructor;
|
|
||||||
|
|
||||||
@Data
|
|
||||||
@Builder
|
|
||||||
@AllArgsConstructor
|
|
||||||
@NoArgsConstructor
|
|
||||||
public class DSLParseResult {
|
|
||||||
|
|
||||||
private LLMReq llmReq;
|
|
||||||
|
|
||||||
private LLMResp llmResp;
|
|
||||||
|
|
||||||
private QueryReq request;
|
|
||||||
|
|
||||||
private DslTool dslTool;
|
|
||||||
}
|
|
||||||
@@ -1,231 +0,0 @@
|
|||||||
package com.tencent.supersonic.chat.parser.llm.dsl;
|
|
||||||
|
|
||||||
import com.tencent.supersonic.chat.api.component.SemanticQuery;
|
|
||||||
import com.tencent.supersonic.chat.api.pojo.ChatContext;
|
|
||||||
import com.tencent.supersonic.chat.api.pojo.QueryContext;
|
|
||||||
import com.tencent.supersonic.chat.api.pojo.SchemaElementMatch;
|
|
||||||
import com.tencent.supersonic.chat.api.pojo.SchemaElementType;
|
|
||||||
import com.tencent.supersonic.chat.api.pojo.SchemaMapInfo;
|
|
||||||
import com.tencent.supersonic.chat.api.pojo.SemanticParseInfo;
|
|
||||||
import com.tencent.supersonic.chat.api.pojo.request.QueryReq;
|
|
||||||
import java.util.Comparator;
|
|
||||||
import java.util.HashMap;
|
|
||||||
import java.util.HashSet;
|
|
||||||
import java.util.List;
|
|
||||||
import java.util.Map;
|
|
||||||
import java.util.Map.Entry;
|
|
||||||
import java.util.Objects;
|
|
||||||
import java.util.Set;
|
|
||||||
import java.util.stream.Collectors;
|
|
||||||
import lombok.extern.slf4j.Slf4j;
|
|
||||||
import org.apache.commons.collections.CollectionUtils;
|
|
||||||
|
|
||||||
@Slf4j
|
|
||||||
public class HeuristicModelResolver implements ModelResolver {
|
|
||||||
|
|
||||||
protected static Long selectModelBySchemaElementCount(Map<Long, SemanticQuery> modelQueryModes,
|
|
||||||
SchemaMapInfo schemaMap) {
|
|
||||||
//model count priority
|
|
||||||
Long modelIdByModelCount = getModelIdByModelCount(schemaMap);
|
|
||||||
if (Objects.nonNull(modelIdByModelCount)) {
|
|
||||||
log.info("selectModel by model count:{}", modelIdByModelCount);
|
|
||||||
return modelIdByModelCount;
|
|
||||||
}
|
|
||||||
|
|
||||||
Map<Long, ModelMatchResult> modelTypeMap = getModelTypeMap(schemaMap);
|
|
||||||
if (modelTypeMap.size() == 1) {
|
|
||||||
Long modelSelect = modelTypeMap.entrySet().stream().collect(Collectors.toList()).get(0).getKey();
|
|
||||||
if (modelQueryModes.containsKey(modelSelect)) {
|
|
||||||
log.info("selectModel with only one Model [{}]", modelSelect);
|
|
||||||
return modelSelect;
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
|
|
||||||
Map.Entry<Long, ModelMatchResult> maxModel = modelTypeMap.entrySet().stream()
|
|
||||||
.filter(entry -> modelQueryModes.containsKey(entry.getKey()))
|
|
||||||
.sorted((o1, o2) -> {
|
|
||||||
int difference = o2.getValue().getCount() - o1.getValue().getCount();
|
|
||||||
if (difference == 0) {
|
|
||||||
return (int) ((o2.getValue().getMaxSimilarity()
|
|
||||||
- o1.getValue().getMaxSimilarity()) * 100);
|
|
||||||
}
|
|
||||||
return difference;
|
|
||||||
}).findFirst().orElse(null);
|
|
||||||
if (maxModel != null) {
|
|
||||||
log.info("selectModel with multiple Models [{}]", maxModel.getKey());
|
|
||||||
return maxModel.getKey();
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return 0L;
|
|
||||||
}
|
|
||||||
|
|
||||||
private static Long getModelIdByModelCount(SchemaMapInfo schemaMap) {
|
|
||||||
Map<Long, List<SchemaElementMatch>> modelElementMatches = schemaMap.getModelElementMatches();
|
|
||||||
Map<Long, Integer> modelIdToModelCount = new HashMap<>();
|
|
||||||
if (Objects.nonNull(modelElementMatches)) {
|
|
||||||
for (Entry<Long, List<SchemaElementMatch>> modelElementMatch : modelElementMatches.entrySet()) {
|
|
||||||
Long modelId = modelElementMatch.getKey();
|
|
||||||
List<SchemaElementMatch> modelMatches = modelElementMatch.getValue().stream().filter(
|
|
||||||
elementMatch -> SchemaElementType.MODEL.equals(elementMatch.getElement().getType())
|
|
||||||
).collect(Collectors.toList());
|
|
||||||
|
|
||||||
if (!CollectionUtils.isEmpty(modelMatches)) {
|
|
||||||
Integer count = modelMatches.size();
|
|
||||||
modelIdToModelCount.put(modelId, count);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
Entry<Long, Integer> maxModelCount = modelIdToModelCount.entrySet().stream()
|
|
||||||
.max(Comparator.comparingInt(o -> o.getValue())).orElse(null);
|
|
||||||
log.info("maxModelCount:{},modelIdToModelCount:{}", maxModelCount, modelIdToModelCount);
|
|
||||||
if (Objects.nonNull(maxModelCount)) {
|
|
||||||
return maxModelCount.getKey();
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return null;
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* to check can switch Model if context exit Model
|
|
||||||
*
|
|
||||||
* @return false will use context Model, true will use other Model , maybe include context Model
|
|
||||||
*/
|
|
||||||
protected static boolean isAllowSwitch(Map<Long, SemanticQuery> modelQueryModes, SchemaMapInfo schemaMap,
|
|
||||||
ChatContext chatCtx, QueryReq searchCtx,
|
|
||||||
Long modelId, Set<Long> restrictiveModels) {
|
|
||||||
if (!Objects.nonNull(modelId) || modelId <= 0) {
|
|
||||||
return true;
|
|
||||||
}
|
|
||||||
// except content Model, calculate the number of types for each Model, if numbers<=1 will not switch
|
|
||||||
Map<Long, ModelMatchResult> modelTypeMap = getModelTypeMap(schemaMap);
|
|
||||||
log.info("isAllowSwitch ModelTypeMap [{}]", modelTypeMap);
|
|
||||||
long otherModelTypeNumBigOneCount = modelTypeMap.entrySet().stream()
|
|
||||||
.filter(entry -> modelQueryModes.containsKey(entry.getKey()) && !entry.getKey().equals(modelId))
|
|
||||||
.filter(entry -> entry.getValue().getCount() > 1).count();
|
|
||||||
if (otherModelTypeNumBigOneCount >= 1) {
|
|
||||||
return true;
|
|
||||||
}
|
|
||||||
// if query text only contain time , will not switch
|
|
||||||
if (!CollectionUtils.isEmpty(modelQueryModes.values())) {
|
|
||||||
for (SemanticQuery semanticQuery : modelQueryModes.values()) {
|
|
||||||
if (semanticQuery == null) {
|
|
||||||
continue;
|
|
||||||
}
|
|
||||||
SemanticParseInfo semanticParseInfo = semanticQuery.getParseInfo();
|
|
||||||
if (semanticParseInfo == null) {
|
|
||||||
continue;
|
|
||||||
}
|
|
||||||
if (searchCtx.getQueryText() != null && semanticParseInfo.getDateInfo() != null) {
|
|
||||||
if (semanticParseInfo.getDateInfo().getDetectWord() != null) {
|
|
||||||
if (semanticParseInfo.getDateInfo().getDetectWord()
|
|
||||||
.equalsIgnoreCase(searchCtx.getQueryText())) {
|
|
||||||
log.info("timeParseResults is not null , can not switch context , timeParseResults:{},",
|
|
||||||
semanticParseInfo.getDateInfo());
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if (CollectionUtils.isNotEmpty(restrictiveModels) && !restrictiveModels.contains(modelId)) {
|
|
||||||
return true;
|
|
||||||
}
|
|
||||||
// if context Model not in schemaMap , will switch
|
|
||||||
if (schemaMap.getMatchedElements(modelId) == null || schemaMap.getMatchedElements(modelId).size() <= 0) {
|
|
||||||
log.info("modelId not in schemaMap ");
|
|
||||||
return true;
|
|
||||||
}
|
|
||||||
// other will not switch
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
|
|
||||||
public static Map<Long, ModelMatchResult> getModelTypeMap(SchemaMapInfo schemaMap) {
|
|
||||||
Map<Long, ModelMatchResult> modelCount = new HashMap<>();
|
|
||||||
for (Map.Entry<Long, List<SchemaElementMatch>> entry : schemaMap.getModelElementMatches().entrySet()) {
|
|
||||||
List<SchemaElementMatch> schemaElementMatches = schemaMap.getMatchedElements(entry.getKey());
|
|
||||||
if (schemaElementMatches != null && schemaElementMatches.size() > 0) {
|
|
||||||
if (!modelCount.containsKey(entry.getKey())) {
|
|
||||||
modelCount.put(entry.getKey(), new ModelMatchResult());
|
|
||||||
}
|
|
||||||
ModelMatchResult modelMatchResult = modelCount.get(entry.getKey());
|
|
||||||
Set<SchemaElementType> schemaElementTypes = new HashSet<>();
|
|
||||||
schemaElementMatches.stream()
|
|
||||||
.forEach(schemaElementMatch -> schemaElementTypes.add(
|
|
||||||
schemaElementMatch.getElement().getType()));
|
|
||||||
SchemaElementMatch schemaElementMatchMax = schemaElementMatches.stream()
|
|
||||||
.sorted((o1, o2) ->
|
|
||||||
((int) ((o2.getSimilarity() - o1.getSimilarity()) * 100))
|
|
||||||
).findFirst().orElse(null);
|
|
||||||
if (schemaElementMatchMax != null) {
|
|
||||||
modelMatchResult.setMaxSimilarity(schemaElementMatchMax.getSimilarity());
|
|
||||||
}
|
|
||||||
modelMatchResult.setCount(schemaElementTypes.size());
|
|
||||||
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return modelCount;
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
public Long resolve(QueryContext queryContext, ChatContext chatCtx, Set<Long> restrictiveModels) {
|
|
||||||
Long modelId = queryContext.getRequest().getModelId();
|
|
||||||
if (Objects.nonNull(modelId) && modelId > 0) {
|
|
||||||
if (CollectionUtils.isEmpty(restrictiveModels)) {
|
|
||||||
return modelId;
|
|
||||||
}
|
|
||||||
if (restrictiveModels.contains(modelId)) {
|
|
||||||
return modelId;
|
|
||||||
} else {
|
|
||||||
return null;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
SchemaMapInfo mapInfo = queryContext.getMapInfo();
|
|
||||||
Set<Long> matchedModels = mapInfo.getMatchedModels();
|
|
||||||
if (CollectionUtils.isNotEmpty(restrictiveModels)) {
|
|
||||||
matchedModels = matchedModels.stream()
|
|
||||||
.filter(restrictiveModels::contains)
|
|
||||||
.collect(Collectors.toSet());
|
|
||||||
}
|
|
||||||
Map<Long, SemanticQuery> modelQueryModes = new HashMap<>();
|
|
||||||
for (Long matchedModel : matchedModels) {
|
|
||||||
modelQueryModes.put(matchedModel, null);
|
|
||||||
}
|
|
||||||
if (modelQueryModes.size() == 1) {
|
|
||||||
return modelQueryModes.keySet().stream().findFirst().get();
|
|
||||||
}
|
|
||||||
return resolve(modelQueryModes, queryContext, chatCtx,
|
|
||||||
queryContext.getMapInfo(), restrictiveModels);
|
|
||||||
}
|
|
||||||
|
|
||||||
public Long resolve(Map<Long, SemanticQuery> modelQueryModes, QueryContext queryContext,
|
|
||||||
ChatContext chatCtx, SchemaMapInfo schemaMap, Set<Long> restrictiveModels) {
|
|
||||||
Long selectModel = selectModel(modelQueryModes, queryContext.getRequest(),
|
|
||||||
chatCtx, schemaMap, restrictiveModels);
|
|
||||||
if (selectModel > 0) {
|
|
||||||
log.info("selectModel {} ", selectModel);
|
|
||||||
return selectModel;
|
|
||||||
}
|
|
||||||
// get the max SchemaElementType number
|
|
||||||
return selectModelBySchemaElementCount(modelQueryModes, schemaMap);
|
|
||||||
}
|
|
||||||
|
|
||||||
public Long selectModel(Map<Long, SemanticQuery> modelQueryModes, QueryReq queryContext,
|
|
||||||
ChatContext chatCtx,
|
|
||||||
SchemaMapInfo schemaMap, Set<Long> restrictiveModels) {
|
|
||||||
// if QueryContext has modelId and in ModelQueryModes
|
|
||||||
if (modelQueryModes.containsKey(queryContext.getModelId())) {
|
|
||||||
log.info("selectModel from QueryContext [{}]", queryContext.getModelId());
|
|
||||||
return queryContext.getModelId();
|
|
||||||
}
|
|
||||||
// if ChatContext has modelId and in ModelQueryModes
|
|
||||||
if (chatCtx.getParseInfo().getModelId() > 0) {
|
|
||||||
Long modelId = chatCtx.getParseInfo().getModelId();
|
|
||||||
if (!isAllowSwitch(modelQueryModes, schemaMap, chatCtx, queryContext, modelId, restrictiveModels)) {
|
|
||||||
log.info("selectModel from ChatContext [{}]", modelId);
|
|
||||||
return modelId;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
// default 0
|
|
||||||
return 0L;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
@@ -1,460 +0,0 @@
|
|||||||
package com.tencent.supersonic.chat.parser.llm.dsl;
|
|
||||||
|
|
||||||
import com.google.common.collect.Lists;
|
|
||||||
import com.tencent.supersonic.chat.agent.tool.AgentToolType;
|
|
||||||
import com.tencent.supersonic.chat.agent.tool.DslTool;
|
|
||||||
import com.tencent.supersonic.chat.api.component.SemanticCorrector;
|
|
||||||
import com.tencent.supersonic.chat.api.component.SemanticParser;
|
|
||||||
import com.tencent.supersonic.chat.api.pojo.ChatContext;
|
|
||||||
import com.tencent.supersonic.chat.api.pojo.QueryContext;
|
|
||||||
import com.tencent.supersonic.chat.api.pojo.SchemaElement;
|
|
||||||
import com.tencent.supersonic.chat.api.pojo.SchemaElementMatch;
|
|
||||||
import com.tencent.supersonic.chat.api.pojo.SchemaElementType;
|
|
||||||
import com.tencent.supersonic.chat.api.pojo.SemanticCorrectInfo;
|
|
||||||
import com.tencent.supersonic.chat.api.pojo.SemanticParseInfo;
|
|
||||||
import com.tencent.supersonic.chat.api.pojo.SemanticSchema;
|
|
||||||
import com.tencent.supersonic.chat.api.pojo.request.QueryFilter;
|
|
||||||
import com.tencent.supersonic.chat.api.pojo.request.QueryReq;
|
|
||||||
import com.tencent.supersonic.chat.config.LLMParserConfig;
|
|
||||||
import com.tencent.supersonic.chat.parser.SatisfactionChecker;
|
|
||||||
import com.tencent.supersonic.chat.query.QueryManager;
|
|
||||||
import com.tencent.supersonic.chat.query.llm.dsl.DslQuery;
|
|
||||||
import com.tencent.supersonic.chat.query.llm.dsl.LLMReq;
|
|
||||||
import com.tencent.supersonic.chat.query.llm.dsl.LLMReq.ElementValue;
|
|
||||||
import com.tencent.supersonic.chat.query.llm.dsl.LLMResp;
|
|
||||||
import com.tencent.supersonic.chat.query.plugin.PluginSemanticQuery;
|
|
||||||
import com.tencent.supersonic.chat.service.AgentService;
|
|
||||||
import com.tencent.supersonic.chat.utils.ComponentFactory;
|
|
||||||
import com.tencent.supersonic.common.pojo.Constants;
|
|
||||||
import com.tencent.supersonic.common.pojo.DateConf;
|
|
||||||
import com.tencent.supersonic.common.pojo.DateConf.DateMode;
|
|
||||||
import com.tencent.supersonic.common.util.ContextUtils;
|
|
||||||
import com.tencent.supersonic.common.util.DateUtils;
|
|
||||||
import com.tencent.supersonic.common.util.JsonUtil;
|
|
||||||
import com.tencent.supersonic.common.util.jsqlparser.FilterExpression;
|
|
||||||
import com.tencent.supersonic.common.util.jsqlparser.SqlParserSelectFunctionHelper;
|
|
||||||
import com.tencent.supersonic.common.util.jsqlparser.SqlParserSelectHelper;
|
|
||||||
import com.tencent.supersonic.knowledge.service.SchemaService;
|
|
||||||
import com.tencent.supersonic.semantic.api.query.enums.FilterOperatorEnum;
|
|
||||||
import java.util.ArrayList;
|
|
||||||
import java.util.Arrays;
|
|
||||||
import java.util.Comparator;
|
|
||||||
import java.util.HashMap;
|
|
||||||
import java.util.HashSet;
|
|
||||||
import java.util.List;
|
|
||||||
import java.util.Map;
|
|
||||||
import java.util.Objects;
|
|
||||||
import java.util.Optional;
|
|
||||||
import java.util.Set;
|
|
||||||
import java.util.function.Function;
|
|
||||||
import java.util.stream.Collectors;
|
|
||||||
import lombok.extern.slf4j.Slf4j;
|
|
||||||
import org.apache.commons.lang3.StringUtils;
|
|
||||||
import org.springframework.http.HttpEntity;
|
|
||||||
import org.springframework.http.HttpHeaders;
|
|
||||||
import org.springframework.http.HttpMethod;
|
|
||||||
import org.springframework.http.MediaType;
|
|
||||||
import org.springframework.http.ResponseEntity;
|
|
||||||
import org.springframework.util.CollectionUtils;
|
|
||||||
import org.springframework.web.client.RestTemplate;
|
|
||||||
|
|
||||||
@Slf4j
|
|
||||||
public class LLMDslParser implements SemanticParser {
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public void parse(QueryContext queryCtx, ChatContext chatCtx) {
|
|
||||||
QueryReq request = queryCtx.getRequest();
|
|
||||||
LLMParserConfig llmParserConfig = ContextUtils.getBean(LLMParserConfig.class);
|
|
||||||
if (StringUtils.isEmpty(llmParserConfig.getUrl())) {
|
|
||||||
log.info("llm parser url is empty, skip dsl parser, llmParserConfig:{}", llmParserConfig);
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
if (SatisfactionChecker.check(queryCtx)) {
|
|
||||||
log.info("skip dsl parser, queryText:{}", request.getQueryText());
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
try {
|
|
||||||
Long modelId = getModelId(queryCtx, chatCtx, request.getAgentId());
|
|
||||||
if (Objects.isNull(modelId) || modelId <= 0) {
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
|
|
||||||
DslTool dslTool = getDslTool(request, modelId);
|
|
||||||
if (Objects.isNull(dslTool)) {
|
|
||||||
log.info("no dsl tool in this agent, skip dsl parser");
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
|
|
||||||
LLMReq llmReq = getLlmReq(queryCtx, modelId, llmParserConfig);
|
|
||||||
LLMResp llmResp = requestLLM(llmReq, modelId, llmParserConfig);
|
|
||||||
|
|
||||||
if (Objects.isNull(llmResp)) {
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
DSLParseResult dslParseResult = DSLParseResult.builder().request(request)
|
|
||||||
.dslTool(dslTool).llmReq(llmReq).llmResp(llmResp).build();
|
|
||||||
|
|
||||||
SemanticParseInfo parseInfo = getParseInfo(queryCtx, modelId, dslTool, dslParseResult);
|
|
||||||
|
|
||||||
SemanticCorrectInfo semanticCorrectInfo = getCorrectorSql(queryCtx, parseInfo, llmResp.getSqlOutput());
|
|
||||||
|
|
||||||
llmResp.setCorrectorSql(semanticCorrectInfo.getSql());
|
|
||||||
|
|
||||||
updateParseInfo(semanticCorrectInfo, modelId, parseInfo);
|
|
||||||
|
|
||||||
} catch (Exception e) {
|
|
||||||
log.error("LLMDSLParser error", e);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
private Set<SchemaElement> getElements(Long modelId, List<String> allFields, List<SchemaElement> elements) {
|
|
||||||
return elements.stream()
|
|
||||||
.filter(schemaElement -> modelId.equals(schemaElement.getModel())
|
|
||||||
&& allFields.contains(schemaElement.getName())
|
|
||||||
).collect(Collectors.toSet());
|
|
||||||
}
|
|
||||||
|
|
||||||
private List<String> getFieldsExceptDate(List<String> allFields) {
|
|
||||||
if (CollectionUtils.isEmpty(allFields)) {
|
|
||||||
return new ArrayList<>();
|
|
||||||
}
|
|
||||||
return allFields.stream()
|
|
||||||
.filter(entry -> !DateUtils.DATE_FIELD.equalsIgnoreCase(entry))
|
|
||||||
.collect(Collectors.toList());
|
|
||||||
}
|
|
||||||
|
|
||||||
public void updateParseInfo(SemanticCorrectInfo semanticCorrectInfo, Long modelId, SemanticParseInfo parseInfo) {
|
|
||||||
|
|
||||||
String correctorSql = semanticCorrectInfo.getSql();
|
|
||||||
parseInfo.getSqlInfo().setLogicSql(correctorSql);
|
|
||||||
|
|
||||||
List<FilterExpression> expressions = SqlParserSelectHelper.getFilterExpression(correctorSql);
|
|
||||||
//set dataInfo
|
|
||||||
try {
|
|
||||||
if (!CollectionUtils.isEmpty(expressions)) {
|
|
||||||
DateConf dateInfo = getDateInfo(expressions);
|
|
||||||
parseInfo.setDateInfo(dateInfo);
|
|
||||||
}
|
|
||||||
} catch (Exception e) {
|
|
||||||
log.error("set dateInfo error :", e);
|
|
||||||
}
|
|
||||||
|
|
||||||
//set filter
|
|
||||||
try {
|
|
||||||
Map<String, SchemaElement> fieldNameToElement = getNameToElement(modelId);
|
|
||||||
List<QueryFilter> result = getDimensionFilter(fieldNameToElement, expressions);
|
|
||||||
parseInfo.getDimensionFilters().addAll(result);
|
|
||||||
} catch (Exception e) {
|
|
||||||
log.error("set dimensionFilter error :", e);
|
|
||||||
}
|
|
||||||
|
|
||||||
SemanticSchema semanticSchema = ContextUtils.getBean(SchemaService.class).getSemanticSchema();
|
|
||||||
|
|
||||||
if (Objects.isNull(semanticSchema)) {
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
List<String> allFields = getFieldsExceptDate(SqlParserSelectHelper.getAllFields(semanticCorrectInfo.getSql()));
|
|
||||||
|
|
||||||
Set<SchemaElement> metrics = getElements(modelId, allFields, semanticSchema.getMetrics());
|
|
||||||
parseInfo.setMetrics(metrics);
|
|
||||||
|
|
||||||
if (SqlParserSelectFunctionHelper.hasAggregateFunction(semanticCorrectInfo.getSql())) {
|
|
||||||
parseInfo.setNativeQuery(false);
|
|
||||||
List<String> groupByFields = SqlParserSelectHelper.getGroupByFields(semanticCorrectInfo.getSql());
|
|
||||||
List<String> groupByDimensions = getFieldsExceptDate(groupByFields);
|
|
||||||
parseInfo.setDimensions(getElements(modelId, groupByDimensions, semanticSchema.getDimensions()));
|
|
||||||
} else {
|
|
||||||
parseInfo.setNativeQuery(true);
|
|
||||||
List<String> selectFields = SqlParserSelectHelper.getSelectFields(semanticCorrectInfo.getSql());
|
|
||||||
List<String> selectDimensions = getFieldsExceptDate(selectFields);
|
|
||||||
parseInfo.setDimensions(getElements(modelId, selectDimensions, semanticSchema.getDimensions()));
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
private List<QueryFilter> getDimensionFilter(Map<String, SchemaElement> fieldNameToElement,
|
|
||||||
List<FilterExpression> filterExpressions) {
|
|
||||||
List<QueryFilter> result = Lists.newArrayList();
|
|
||||||
for (FilterExpression expression : filterExpressions) {
|
|
||||||
QueryFilter dimensionFilter = new QueryFilter();
|
|
||||||
dimensionFilter.setValue(expression.getFieldValue());
|
|
||||||
SchemaElement schemaElement = fieldNameToElement.get(expression.getFieldName());
|
|
||||||
if (Objects.isNull(schemaElement)) {
|
|
||||||
continue;
|
|
||||||
}
|
|
||||||
dimensionFilter.setName(schemaElement.getName());
|
|
||||||
dimensionFilter.setBizName(schemaElement.getBizName());
|
|
||||||
dimensionFilter.setElementID(schemaElement.getId());
|
|
||||||
|
|
||||||
FilterOperatorEnum operatorEnum = FilterOperatorEnum.getSqlOperator(expression.getOperator());
|
|
||||||
dimensionFilter.setOperator(operatorEnum);
|
|
||||||
dimensionFilter.setFunction(expression.getFunction());
|
|
||||||
result.add(dimensionFilter);
|
|
||||||
}
|
|
||||||
return result;
|
|
||||||
}
|
|
||||||
|
|
||||||
private DateConf getDateInfo(List<FilterExpression> filterExpressions) {
|
|
||||||
List<FilterExpression> dateExpressions = filterExpressions.stream()
|
|
||||||
.filter(expression -> DateUtils.DATE_FIELD.equalsIgnoreCase(expression.getFieldName()))
|
|
||||||
.collect(Collectors.toList());
|
|
||||||
if (CollectionUtils.isEmpty(dateExpressions)) {
|
|
||||||
return new DateConf();
|
|
||||||
}
|
|
||||||
DateConf dateInfo = new DateConf();
|
|
||||||
dateInfo.setDateMode(DateMode.BETWEEN);
|
|
||||||
FilterExpression firstExpression = dateExpressions.get(0);
|
|
||||||
|
|
||||||
FilterOperatorEnum firstOperator = FilterOperatorEnum.getSqlOperator(firstExpression.getOperator());
|
|
||||||
if (FilterOperatorEnum.EQUALS.equals(firstOperator) && Objects.nonNull(firstExpression.getFieldValue())) {
|
|
||||||
dateInfo.setStartDate(firstExpression.getFieldValue().toString());
|
|
||||||
dateInfo.setEndDate(firstExpression.getFieldValue().toString());
|
|
||||||
dateInfo.setDateMode(DateMode.BETWEEN);
|
|
||||||
return dateInfo;
|
|
||||||
}
|
|
||||||
if (containOperators(firstExpression, firstOperator, FilterOperatorEnum.GREATER_THAN,
|
|
||||||
FilterOperatorEnum.GREATER_THAN_EQUALS)) {
|
|
||||||
dateInfo.setStartDate(firstExpression.getFieldValue().toString());
|
|
||||||
if (hasSecondDate(dateExpressions)) {
|
|
||||||
dateInfo.setEndDate(dateExpressions.get(1).getFieldValue().toString());
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if (containOperators(firstExpression, firstOperator, FilterOperatorEnum.MINOR_THAN,
|
|
||||||
FilterOperatorEnum.MINOR_THAN_EQUALS)) {
|
|
||||||
dateInfo.setEndDate(firstExpression.getFieldValue().toString());
|
|
||||||
if (hasSecondDate(dateExpressions)) {
|
|
||||||
dateInfo.setStartDate(dateExpressions.get(1).getFieldValue().toString());
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return dateInfo;
|
|
||||||
}
|
|
||||||
|
|
||||||
private boolean containOperators(FilterExpression expression, FilterOperatorEnum firstOperator,
|
|
||||||
FilterOperatorEnum... operatorEnums) {
|
|
||||||
return (Arrays.asList(operatorEnums).contains(firstOperator) && Objects.nonNull(expression.getFieldValue()));
|
|
||||||
}
|
|
||||||
|
|
||||||
private boolean hasSecondDate(List<FilterExpression> dateExpressions) {
|
|
||||||
return dateExpressions.size() > 1 && Objects.nonNull(dateExpressions.get(1).getFieldValue());
|
|
||||||
}
|
|
||||||
|
|
||||||
private SemanticCorrectInfo getCorrectorSql(QueryContext queryCtx, SemanticParseInfo parseInfo, String sql) {
|
|
||||||
|
|
||||||
SemanticCorrectInfo correctInfo = SemanticCorrectInfo.builder()
|
|
||||||
.queryFilters(queryCtx.getRequest().getQueryFilters()).sql(sql)
|
|
||||||
.parseInfo(parseInfo).build();
|
|
||||||
|
|
||||||
List<SemanticCorrector> dslCorrections = ComponentFactory.getSqlCorrections();
|
|
||||||
|
|
||||||
dslCorrections.forEach(dslCorrection -> {
|
|
||||||
try {
|
|
||||||
dslCorrection.correct(correctInfo);
|
|
||||||
log.info("sqlCorrection:{} sql:{}", dslCorrection.getClass().getSimpleName(), correctInfo.getSql());
|
|
||||||
} catch (Exception e) {
|
|
||||||
log.error(String.format("correct error,correctInfo:%s", correctInfo), e);
|
|
||||||
}
|
|
||||||
});
|
|
||||||
return correctInfo;
|
|
||||||
}
|
|
||||||
|
|
||||||
private SemanticParseInfo getParseInfo(QueryContext queryCtx, Long modelId, DslTool dslTool,
|
|
||||||
DSLParseResult dslParseResult) {
|
|
||||||
PluginSemanticQuery semanticQuery = QueryManager.createPluginQuery(DslQuery.QUERY_MODE);
|
|
||||||
SemanticParseInfo parseInfo = semanticQuery.getParseInfo();
|
|
||||||
parseInfo.getElementMatches().addAll(queryCtx.getMapInfo().getMatchedElements(modelId));
|
|
||||||
|
|
||||||
Map<String, Object> properties = new HashMap<>();
|
|
||||||
properties.put(Constants.CONTEXT, dslParseResult);
|
|
||||||
properties.put("type", "internal");
|
|
||||||
properties.put("name", dslTool.getName());
|
|
||||||
|
|
||||||
parseInfo.setProperties(properties);
|
|
||||||
parseInfo.setScore(queryCtx.getRequest().getQueryText().length());
|
|
||||||
parseInfo.setQueryMode(semanticQuery.getQueryMode());
|
|
||||||
parseInfo.getSqlInfo().setLlmParseSql(dslParseResult.getLlmResp().getSqlOutput());
|
|
||||||
|
|
||||||
SemanticSchema semanticSchema = ContextUtils.getBean(SchemaService.class).getSemanticSchema();
|
|
||||||
Map<Long, String> modelIdToName = semanticSchema.getModelIdToName();
|
|
||||||
|
|
||||||
SchemaElement model = new SchemaElement();
|
|
||||||
model.setModel(modelId);
|
|
||||||
model.setId(modelId);
|
|
||||||
model.setName(modelIdToName.get(modelId));
|
|
||||||
parseInfo.setModel(model);
|
|
||||||
queryCtx.getCandidateQueries().add(semanticQuery);
|
|
||||||
return parseInfo;
|
|
||||||
}
|
|
||||||
|
|
||||||
private DslTool getDslTool(QueryReq request, Long modelId) {
|
|
||||||
AgentService agentService = ContextUtils.getBean(AgentService.class);
|
|
||||||
List<DslTool> dslTools = agentService.getDslTools(request.getAgentId(), AgentToolType.DSL);
|
|
||||||
Optional<DslTool> dslToolOptional = dslTools.stream()
|
|
||||||
.filter(tool -> {
|
|
||||||
List<Long> modelIds = tool.getModelIds();
|
|
||||||
if (agentService.containsAllModel(new HashSet<>(modelIds))) {
|
|
||||||
return true;
|
|
||||||
}
|
|
||||||
return modelIds.contains(modelId);
|
|
||||||
})
|
|
||||||
.findFirst();
|
|
||||||
return dslToolOptional.orElse(null);
|
|
||||||
}
|
|
||||||
|
|
||||||
private Long getModelId(QueryContext queryCtx, ChatContext chatCtx, Integer agentId) {
|
|
||||||
AgentService agentService = ContextUtils.getBean(AgentService.class);
|
|
||||||
Set<Long> distinctModelIds = agentService.getDslToolsModelIds(agentId, AgentToolType.DSL);
|
|
||||||
if (agentService.containsAllModel(distinctModelIds)) {
|
|
||||||
distinctModelIds = new HashSet<>();
|
|
||||||
}
|
|
||||||
ModelResolver modelResolver = ComponentFactory.getModelResolver();
|
|
||||||
Long modelId = modelResolver.resolve(queryCtx, chatCtx, distinctModelIds);
|
|
||||||
log.info("resolve modelId:{},dslModels:{}", modelId, distinctModelIds);
|
|
||||||
return modelId;
|
|
||||||
}
|
|
||||||
|
|
||||||
private LLMResp requestLLM(LLMReq llmReq, Long modelId, LLMParserConfig llmParserConfig) {
|
|
||||||
String questUrl = llmParserConfig.getUrl() + llmParserConfig.getQueryToSqlPath();
|
|
||||||
long startTime = System.currentTimeMillis();
|
|
||||||
log.info("requestLLM request, modelId:{},llmReq:{}", modelId, llmReq);
|
|
||||||
RestTemplate restTemplate = ContextUtils.getBean(RestTemplate.class);
|
|
||||||
try {
|
|
||||||
HttpHeaders headers = new HttpHeaders();
|
|
||||||
headers.setContentType(MediaType.APPLICATION_JSON);
|
|
||||||
HttpEntity<String> entity = new HttpEntity<>(JsonUtil.toString(llmReq), headers);
|
|
||||||
ResponseEntity<LLMResp> responseEntity = restTemplate.exchange(questUrl, HttpMethod.POST, entity,
|
|
||||||
LLMResp.class);
|
|
||||||
|
|
||||||
log.info("requestLLM response,cost:{}, questUrl:{} \n entity:{} \n body:{}",
|
|
||||||
System.currentTimeMillis() - startTime, questUrl, entity, responseEntity.getBody());
|
|
||||||
return responseEntity.getBody();
|
|
||||||
} catch (Exception e) {
|
|
||||||
log.error("requestLLM error", e);
|
|
||||||
}
|
|
||||||
return null;
|
|
||||||
}
|
|
||||||
|
|
||||||
private LLMReq getLlmReq(QueryContext queryCtx, Long modelId, LLMParserConfig llmParserConfig) {
|
|
||||||
SemanticSchema semanticSchema = ContextUtils.getBean(SchemaService.class).getSemanticSchema();
|
|
||||||
Map<Long, String> modelIdToName = semanticSchema.getModelIdToName();
|
|
||||||
String queryText = queryCtx.getRequest().getQueryText();
|
|
||||||
|
|
||||||
LLMReq llmReq = new LLMReq();
|
|
||||||
llmReq.setQueryText(queryText);
|
|
||||||
|
|
||||||
LLMReq.LLMSchema llmSchema = new LLMReq.LLMSchema();
|
|
||||||
llmSchema.setModelName(modelIdToName.get(modelId));
|
|
||||||
llmSchema.setDomainName(modelIdToName.get(modelId));
|
|
||||||
|
|
||||||
List<String> fieldNameList = getFieldNameList(queryCtx, modelId, semanticSchema, llmParserConfig);
|
|
||||||
|
|
||||||
fieldNameList.add(DateUtils.DATE_FIELD);
|
|
||||||
llmSchema.setFieldNameList(fieldNameList);
|
|
||||||
llmReq.setSchema(llmSchema);
|
|
||||||
|
|
||||||
List<ElementValue> linking = new ArrayList<>();
|
|
||||||
linking.addAll(getValueList(queryCtx, modelId, semanticSchema));
|
|
||||||
llmReq.setLinking(linking);
|
|
||||||
|
|
||||||
String currentDate = DSLDateHelper.getReferenceDate(modelId);
|
|
||||||
llmReq.setCurrentDate(currentDate);
|
|
||||||
return llmReq;
|
|
||||||
}
|
|
||||||
|
|
||||||
protected List<ElementValue> getValueList(QueryContext queryCtx, Long modelId, SemanticSchema semanticSchema) {
|
|
||||||
Map<Long, String> itemIdToName = getItemIdToName(modelId, semanticSchema);
|
|
||||||
|
|
||||||
List<SchemaElementMatch> matchedElements = queryCtx.getMapInfo().getMatchedElements(modelId);
|
|
||||||
if (CollectionUtils.isEmpty(matchedElements)) {
|
|
||||||
return new ArrayList<>();
|
|
||||||
}
|
|
||||||
Set<ElementValue> valueMatches = matchedElements
|
|
||||||
.stream()
|
|
||||||
.filter(elementMatch -> !elementMatch.isInherited())
|
|
||||||
.filter(schemaElementMatch -> {
|
|
||||||
SchemaElementType type = schemaElementMatch.getElement().getType();
|
|
||||||
return SchemaElementType.VALUE.equals(type) || SchemaElementType.ID.equals(type);
|
|
||||||
})
|
|
||||||
.map(elementMatch -> {
|
|
||||||
ElementValue elementValue = new ElementValue();
|
|
||||||
elementValue.setFieldName(itemIdToName.get(elementMatch.getElement().getId()));
|
|
||||||
elementValue.setFieldValue(elementMatch.getWord());
|
|
||||||
return elementValue;
|
|
||||||
}).collect(Collectors.toSet());
|
|
||||||
return new ArrayList<>(valueMatches);
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
protected Map<String, SchemaElement> getNameToElement(Long modelId) {
|
|
||||||
SemanticSchema semanticSchema = ContextUtils.getBean(SchemaService.class).getSemanticSchema();
|
|
||||||
List<SchemaElement> dimensions = semanticSchema.getDimensions();
|
|
||||||
List<SchemaElement> metrics = semanticSchema.getMetrics();
|
|
||||||
|
|
||||||
List<SchemaElement> allElements = Lists.newArrayList();
|
|
||||||
allElements.addAll(dimensions);
|
|
||||||
allElements.addAll(metrics);
|
|
||||||
return allElements.stream()
|
|
||||||
.filter(schemaElement -> schemaElement.getModel().equals(modelId))
|
|
||||||
.collect(Collectors.toMap(SchemaElement::getName, Function.identity(), (value1, value2) -> value2));
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
protected List<String> getFieldNameList(QueryContext queryCtx, Long modelId, SemanticSchema semanticSchema,
|
|
||||||
LLMParserConfig llmParserConfig) {
|
|
||||||
|
|
||||||
Set<String> results = getTopNFieldNames(modelId, semanticSchema, llmParserConfig);
|
|
||||||
|
|
||||||
Set<String> fieldNameList = getMatchedFieldNames(queryCtx, modelId, semanticSchema);
|
|
||||||
|
|
||||||
results.addAll(fieldNameList);
|
|
||||||
return new ArrayList<>(results);
|
|
||||||
}
|
|
||||||
|
|
||||||
protected Set<String> getMatchedFieldNames(QueryContext queryCtx, Long modelId, SemanticSchema semanticSchema) {
|
|
||||||
Map<Long, String> itemIdToName = getItemIdToName(modelId, semanticSchema);
|
|
||||||
List<SchemaElementMatch> matchedElements = queryCtx.getMapInfo().getMatchedElements(modelId);
|
|
||||||
if (CollectionUtils.isEmpty(matchedElements)) {
|
|
||||||
return new HashSet<>();
|
|
||||||
}
|
|
||||||
Set<String> fieldNameList = matchedElements.stream()
|
|
||||||
.filter(schemaElementMatch -> {
|
|
||||||
SchemaElementType elementType = schemaElementMatch.getElement().getType();
|
|
||||||
return SchemaElementType.METRIC.equals(elementType)
|
|
||||||
|| SchemaElementType.DIMENSION.equals(elementType)
|
|
||||||
|| SchemaElementType.VALUE.equals(elementType);
|
|
||||||
})
|
|
||||||
.map(schemaElementMatch -> {
|
|
||||||
SchemaElementType elementType = schemaElementMatch.getElement().getType();
|
|
||||||
|
|
||||||
if (!SchemaElementType.VALUE.equals(elementType)) {
|
|
||||||
return schemaElementMatch.getWord();
|
|
||||||
}
|
|
||||||
return itemIdToName.get(schemaElementMatch.getElement().getId());
|
|
||||||
})
|
|
||||||
.filter(name -> StringUtils.isNotEmpty(name) && !name.contains("%"))
|
|
||||||
.collect(Collectors.toSet());
|
|
||||||
return fieldNameList;
|
|
||||||
}
|
|
||||||
|
|
||||||
private Set<String> getTopNFieldNames(Long modelId, SemanticSchema semanticSchema,
|
|
||||||
LLMParserConfig llmParserConfig) {
|
|
||||||
Set<String> results = semanticSchema.getDimensions(modelId).stream()
|
|
||||||
.sorted(Comparator.comparing(SchemaElement::getUseCnt).reversed())
|
|
||||||
.limit(llmParserConfig.getDimensionTopN())
|
|
||||||
.map(entry -> entry.getName())
|
|
||||||
.collect(Collectors.toSet());
|
|
||||||
|
|
||||||
Set<String> metrics = semanticSchema.getMetrics(modelId).stream()
|
|
||||||
.sorted(Comparator.comparing(SchemaElement::getUseCnt).reversed())
|
|
||||||
.limit(llmParserConfig.getMetricTopN())
|
|
||||||
.map(entry -> entry.getName())
|
|
||||||
.collect(Collectors.toSet());
|
|
||||||
|
|
||||||
results.addAll(metrics);
|
|
||||||
return results;
|
|
||||||
}
|
|
||||||
|
|
||||||
protected Map<Long, String> getItemIdToName(Long modelId, SemanticSchema semanticSchema) {
|
|
||||||
return semanticSchema.getDimensions(modelId).stream()
|
|
||||||
.collect(Collectors.toMap(SchemaElement::getId, SchemaElement::getName, (value1, value2) -> value2));
|
|
||||||
}
|
|
||||||
|
|
||||||
}
|
|
||||||
@@ -5,34 +5,32 @@ import com.google.common.collect.Sets;
|
|||||||
import com.tencent.supersonic.chat.agent.Agent;
|
import com.tencent.supersonic.chat.agent.Agent;
|
||||||
import com.tencent.supersonic.chat.agent.tool.AgentToolType;
|
import com.tencent.supersonic.chat.agent.tool.AgentToolType;
|
||||||
import com.tencent.supersonic.chat.agent.tool.MetricInterpretTool;
|
import com.tencent.supersonic.chat.agent.tool.MetricInterpretTool;
|
||||||
import com.tencent.supersonic.chat.api.component.SemanticInterpreter;
|
|
||||||
import com.tencent.supersonic.chat.api.component.SemanticParser;
|
import com.tencent.supersonic.chat.api.component.SemanticParser;
|
||||||
import com.tencent.supersonic.chat.api.pojo.QueryContext;
|
|
||||||
import com.tencent.supersonic.chat.api.pojo.ChatContext;
|
import com.tencent.supersonic.chat.api.pojo.ChatContext;
|
||||||
import com.tencent.supersonic.chat.api.pojo.SchemaElementMatch;
|
import com.tencent.supersonic.chat.api.pojo.QueryContext;
|
||||||
import com.tencent.supersonic.chat.api.pojo.SchemaElement;
|
import com.tencent.supersonic.chat.api.pojo.SchemaElement;
|
||||||
|
import com.tencent.supersonic.chat.api.pojo.SchemaElementMatch;
|
||||||
import com.tencent.supersonic.chat.api.pojo.SchemaElementType;
|
import com.tencent.supersonic.chat.api.pojo.SchemaElementType;
|
||||||
import com.tencent.supersonic.chat.api.pojo.SemanticParseInfo;
|
import com.tencent.supersonic.chat.api.pojo.SemanticParseInfo;
|
||||||
import com.tencent.supersonic.chat.api.pojo.ModelSchema;
|
|
||||||
import com.tencent.supersonic.chat.api.pojo.request.QueryFilter;
|
import com.tencent.supersonic.chat.api.pojo.request.QueryFilter;
|
||||||
import com.tencent.supersonic.chat.api.pojo.request.QueryReq;
|
import com.tencent.supersonic.chat.api.pojo.request.QueryReq;
|
||||||
import com.tencent.supersonic.chat.parser.SatisfactionChecker;
|
import com.tencent.supersonic.chat.parser.SatisfactionChecker;
|
||||||
import com.tencent.supersonic.chat.query.llm.interpret.MetricInterpretQuery;
|
|
||||||
import com.tencent.supersonic.chat.query.QueryManager;
|
import com.tencent.supersonic.chat.query.QueryManager;
|
||||||
import com.tencent.supersonic.chat.query.plugin.PluginSemanticQuery;
|
import com.tencent.supersonic.chat.query.llm.LLMSemanticQuery;
|
||||||
|
import com.tencent.supersonic.chat.query.llm.interpret.MetricInterpretQuery;
|
||||||
import com.tencent.supersonic.chat.service.AgentService;
|
import com.tencent.supersonic.chat.service.AgentService;
|
||||||
import com.tencent.supersonic.chat.utils.ComponentFactory;
|
import com.tencent.supersonic.chat.service.SemanticService;
|
||||||
import com.tencent.supersonic.common.pojo.DateConf;
|
import com.tencent.supersonic.common.pojo.DateConf;
|
||||||
|
import com.tencent.supersonic.common.pojo.ModelCluster;
|
||||||
|
import com.tencent.supersonic.common.pojo.enums.FilterOperatorEnum;
|
||||||
|
import com.tencent.supersonic.common.pojo.enums.TimeDimensionEnum;
|
||||||
import com.tencent.supersonic.common.util.ContextUtils;
|
import com.tencent.supersonic.common.util.ContextUtils;
|
||||||
import com.tencent.supersonic.semantic.api.model.enums.TimeDimensionEnum;
|
|
||||||
import com.tencent.supersonic.semantic.api.query.enums.FilterOperatorEnum;
|
|
||||||
import lombok.extern.slf4j.Slf4j;
|
import lombok.extern.slf4j.Slf4j;
|
||||||
import org.springframework.util.CollectionUtils;
|
import org.springframework.util.CollectionUtils;
|
||||||
|
import java.util.HashMap;
|
||||||
import java.util.List;
|
import java.util.List;
|
||||||
import java.util.Map;
|
import java.util.Map;
|
||||||
import java.util.Set;
|
import java.util.Set;
|
||||||
import java.util.HashMap;
|
|
||||||
import java.util.stream.Collectors;
|
import java.util.stream.Collectors;
|
||||||
|
|
||||||
@Slf4j
|
@Slf4j
|
||||||
@@ -71,7 +69,7 @@ public class MetricInterpretParser implements SemanticParser {
|
|||||||
|
|
||||||
private void buildQuery(Long modelId, QueryContext queryContext,
|
private void buildQuery(Long modelId, QueryContext queryContext,
|
||||||
List<Long> metricIds, List<SchemaElementMatch> schemaElementMatches, String toolName) {
|
List<Long> metricIds, List<SchemaElementMatch> schemaElementMatches, String toolName) {
|
||||||
PluginSemanticQuery metricInterpretQuery = QueryManager.createPluginQuery(MetricInterpretQuery.QUERY_MODE);
|
LLMSemanticQuery metricInterpretQuery = QueryManager.createLLMQuery(MetricInterpretQuery.QUERY_MODE);
|
||||||
Set<SchemaElement> metrics = getMetrics(metricIds, modelId);
|
Set<SchemaElement> metrics = getMetrics(metricIds, modelId);
|
||||||
SemanticParseInfo semanticParseInfo = buildSemanticParseInfo(modelId, queryContext.getRequest(),
|
SemanticParseInfo semanticParseInfo = buildSemanticParseInfo(modelId, queryContext.getRequest(),
|
||||||
metrics, schemaElementMatches, toolName);
|
metrics, schemaElementMatches, toolName);
|
||||||
@@ -82,9 +80,8 @@ public class MetricInterpretParser implements SemanticParser {
|
|||||||
}
|
}
|
||||||
|
|
||||||
public Set<SchemaElement> getMetrics(List<Long> metricIds, Long modelId) {
|
public Set<SchemaElement> getMetrics(List<Long> metricIds, Long modelId) {
|
||||||
SemanticInterpreter semanticInterpreter = ComponentFactory.getSemanticLayer();
|
SemanticService semanticService = ContextUtils.getBean(SemanticService.class);
|
||||||
ModelSchema modelSchema = semanticInterpreter.getModelSchema(modelId, true);
|
List<SchemaElement> metrics = semanticService.getSemanticSchema().getMetrics();
|
||||||
Set<SchemaElement> metrics = modelSchema.getMetrics();
|
|
||||||
return metrics.stream().filter(schemaElement -> metricIds.contains(schemaElement.getId()))
|
return metrics.stream().filter(schemaElement -> metricIds.contains(schemaElement.getId()))
|
||||||
.collect(Collectors.toSet());
|
.collect(Collectors.toSet());
|
||||||
}
|
}
|
||||||
@@ -113,16 +110,13 @@ public class MetricInterpretParser implements SemanticParser {
|
|||||||
|
|
||||||
private SemanticParseInfo buildSemanticParseInfo(Long modelId, QueryReq queryReq, Set<SchemaElement> metrics,
|
private SemanticParseInfo buildSemanticParseInfo(Long modelId, QueryReq queryReq, Set<SchemaElement> metrics,
|
||||||
List<SchemaElementMatch> schemaElementMatches, String toolName) {
|
List<SchemaElementMatch> schemaElementMatches, String toolName) {
|
||||||
SchemaElement model = new SchemaElement();
|
|
||||||
model.setModel(modelId);
|
|
||||||
model.setId(modelId);
|
|
||||||
SemanticParseInfo semanticParseInfo = new SemanticParseInfo();
|
SemanticParseInfo semanticParseInfo = new SemanticParseInfo();
|
||||||
semanticParseInfo.setMetrics(metrics);
|
semanticParseInfo.setMetrics(metrics);
|
||||||
SchemaElement dimension = new SchemaElement();
|
SchemaElement dimension = new SchemaElement();
|
||||||
dimension.setBizName(TimeDimensionEnum.DAY.getName());
|
dimension.setBizName(TimeDimensionEnum.DAY.getName());
|
||||||
semanticParseInfo.setDimensions(Sets.newHashSet(dimension));
|
semanticParseInfo.setDimensions(Sets.newHashSet(dimension));
|
||||||
semanticParseInfo.setElementMatches(schemaElementMatches);
|
semanticParseInfo.setElementMatches(schemaElementMatches);
|
||||||
semanticParseInfo.setModel(model);
|
semanticParseInfo.setModel(ModelCluster.build(Sets.newHashSet(modelId)));
|
||||||
semanticParseInfo.setScore(queryReq.getQueryText().length());
|
semanticParseInfo.setScore(queryReq.getQueryText().length());
|
||||||
DateConf dateConf = new DateConf();
|
DateConf dateConf = new DateConf();
|
||||||
dateConf.setDateMode(DateConf.DateMode.RECENT);
|
dateConf.setDateMode(DateConf.DateMode.RECENT);
|
||||||
|
|||||||
@@ -0,0 +1,148 @@
|
|||||||
|
package com.tencent.supersonic.chat.parser.llm.s2sql;
|
||||||
|
|
||||||
|
import com.tencent.supersonic.chat.api.component.SemanticQuery;
|
||||||
|
import com.tencent.supersonic.chat.api.pojo.ChatContext;
|
||||||
|
import com.tencent.supersonic.chat.api.pojo.QueryContext;
|
||||||
|
import com.tencent.supersonic.chat.api.pojo.SchemaElementMatch;
|
||||||
|
import com.tencent.supersonic.chat.api.pojo.SchemaElementType;
|
||||||
|
import com.tencent.supersonic.chat.api.pojo.SchemaModelClusterMapInfo;
|
||||||
|
import com.tencent.supersonic.common.pojo.ModelCluster;
|
||||||
|
import lombok.extern.slf4j.Slf4j;
|
||||||
|
import org.apache.commons.collections.CollectionUtils;
|
||||||
|
|
||||||
|
import java.util.Comparator;
|
||||||
|
import java.util.HashMap;
|
||||||
|
import java.util.HashSet;
|
||||||
|
import java.util.List;
|
||||||
|
import java.util.Map;
|
||||||
|
import java.util.Map.Entry;
|
||||||
|
import java.util.Objects;
|
||||||
|
import java.util.Set;
|
||||||
|
import java.util.stream.Collectors;
|
||||||
|
|
||||||
|
@Slf4j
|
||||||
|
public class HeuristicModelResolver implements ModelResolver {
|
||||||
|
|
||||||
|
protected static String selectModelBySchemaElementMatchScore(Map<String, SemanticQuery> modelQueryModes,
|
||||||
|
SchemaModelClusterMapInfo schemaMap) {
|
||||||
|
//model count priority
|
||||||
|
String modelIdByModelCount = getModelIdByMatchModelScore(schemaMap);
|
||||||
|
if (Objects.nonNull(modelIdByModelCount)) {
|
||||||
|
log.info("selectModel by model count:{}", modelIdByModelCount);
|
||||||
|
return modelIdByModelCount;
|
||||||
|
}
|
||||||
|
|
||||||
|
Map<String, ModelMatchResult> modelTypeMap = getModelTypeMap(schemaMap);
|
||||||
|
if (modelTypeMap.size() == 1) {
|
||||||
|
String modelSelect = modelTypeMap.entrySet().stream().collect(Collectors.toList()).get(0).getKey();
|
||||||
|
if (modelQueryModes.containsKey(modelSelect)) {
|
||||||
|
log.info("selectModel with only one Model [{}]", modelSelect);
|
||||||
|
return modelSelect;
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
Map.Entry<String, ModelMatchResult> maxModel = modelTypeMap.entrySet().stream()
|
||||||
|
.filter(entry -> modelQueryModes.containsKey(entry.getKey()))
|
||||||
|
.sorted((o1, o2) -> {
|
||||||
|
int difference = o2.getValue().getCount() - o1.getValue().getCount();
|
||||||
|
if (difference == 0) {
|
||||||
|
return (int) ((o2.getValue().getMaxSimilarity()
|
||||||
|
- o1.getValue().getMaxSimilarity()) * 100);
|
||||||
|
}
|
||||||
|
return difference;
|
||||||
|
}).findFirst().orElse(null);
|
||||||
|
if (maxModel != null) {
|
||||||
|
log.info("selectModel with multiple Models [{}]", maxModel.getKey());
|
||||||
|
return maxModel.getKey();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return null;
|
||||||
|
}
|
||||||
|
|
||||||
|
private static String getModelIdByMatchModelScore(SchemaModelClusterMapInfo schemaMap) {
|
||||||
|
Map<String, List<SchemaElementMatch>> modelElementMatches = schemaMap.getModelElementMatches();
|
||||||
|
// calculate model match score, matched element gets 1.0 point, and inherit element gets 0.5 point
|
||||||
|
Map<String, Double> modelIdToModelScore = new HashMap<>();
|
||||||
|
if (Objects.nonNull(modelElementMatches)) {
|
||||||
|
for (Entry<String, List<SchemaElementMatch>> modelElementMatch : modelElementMatches.entrySet()) {
|
||||||
|
String modelId = modelElementMatch.getKey();
|
||||||
|
List<Double> modelMatchesScore = modelElementMatch.getValue().stream()
|
||||||
|
.filter(elementMatch -> elementMatch.getSimilarity() >= 1)
|
||||||
|
.filter(elementMatch -> SchemaElementType.MODEL.equals(elementMatch.getElement().getType()))
|
||||||
|
.map(elementMatch -> elementMatch.isInherited() ? 0.5 : 1.0).collect(Collectors.toList());
|
||||||
|
|
||||||
|
if (!CollectionUtils.isEmpty(modelMatchesScore)) {
|
||||||
|
// get sum of model match score
|
||||||
|
double score = modelMatchesScore.stream().mapToDouble(Double::doubleValue).sum();
|
||||||
|
modelIdToModelScore.put(modelId, score);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
Entry<String, Double> maxModelScore = modelIdToModelScore.entrySet().stream()
|
||||||
|
.max(Comparator.comparingDouble(o -> o.getValue())).orElse(null);
|
||||||
|
log.info("maxModelCount:{},modelIdToModelCount:{}", maxModelScore, modelIdToModelScore);
|
||||||
|
if (Objects.nonNull(maxModelScore)) {
|
||||||
|
return maxModelScore.getKey();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return null;
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
public static Map<String, ModelMatchResult> getModelTypeMap(SchemaModelClusterMapInfo schemaMap) {
|
||||||
|
Map<String, ModelMatchResult> modelCount = new HashMap<>();
|
||||||
|
for (Map.Entry<String, List<SchemaElementMatch>> entry : schemaMap.getModelElementMatches().entrySet()) {
|
||||||
|
List<SchemaElementMatch> schemaElementMatches = schemaMap.getMatchedElements(entry.getKey());
|
||||||
|
if (schemaElementMatches != null && schemaElementMatches.size() > 0) {
|
||||||
|
if (!modelCount.containsKey(entry.getKey())) {
|
||||||
|
modelCount.put(entry.getKey(), new ModelMatchResult());
|
||||||
|
}
|
||||||
|
ModelMatchResult modelMatchResult = modelCount.get(entry.getKey());
|
||||||
|
Set<SchemaElementType> schemaElementTypes = new HashSet<>();
|
||||||
|
schemaElementMatches.stream()
|
||||||
|
.forEach(schemaElementMatch -> schemaElementTypes.add(
|
||||||
|
schemaElementMatch.getElement().getType()));
|
||||||
|
SchemaElementMatch schemaElementMatchMax = schemaElementMatches.stream()
|
||||||
|
.sorted((o1, o2) ->
|
||||||
|
((int) ((o2.getSimilarity() - o1.getSimilarity()) * 100))
|
||||||
|
).findFirst().orElse(null);
|
||||||
|
if (schemaElementMatchMax != null) {
|
||||||
|
modelMatchResult.setMaxSimilarity(schemaElementMatchMax.getSimilarity());
|
||||||
|
}
|
||||||
|
modelMatchResult.setCount(schemaElementTypes.size());
|
||||||
|
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return modelCount;
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
public String resolve(QueryContext queryContext, ChatContext chatCtx, Set<Long> restrictiveModels) {
|
||||||
|
SchemaModelClusterMapInfo mapInfo = queryContext.getModelClusterMapInfo();
|
||||||
|
Set<String> matchedModelClusters = mapInfo.getElementMatchesByModelIds(restrictiveModels).keySet();
|
||||||
|
Long modelId = queryContext.getRequest().getModelId();
|
||||||
|
if (Objects.nonNull(modelId) && modelId > 0) {
|
||||||
|
if (CollectionUtils.isEmpty(restrictiveModels) || restrictiveModels.contains(modelId)) {
|
||||||
|
return getModelClusterByModelId(modelId, matchedModelClusters);
|
||||||
|
}
|
||||||
|
return null;
|
||||||
|
}
|
||||||
|
|
||||||
|
Map<String, SemanticQuery> modelQueryModes = new HashMap<>();
|
||||||
|
for (String matchedModel : matchedModelClusters) {
|
||||||
|
modelQueryModes.put(matchedModel, null);
|
||||||
|
}
|
||||||
|
if (modelQueryModes.size() == 1) {
|
||||||
|
return modelQueryModes.keySet().stream().findFirst().get();
|
||||||
|
}
|
||||||
|
return selectModelBySchemaElementMatchScore(modelQueryModes, mapInfo);
|
||||||
|
}
|
||||||
|
|
||||||
|
private String getModelClusterByModelId(Long modelId, Set<String> modelClusterKeySet) {
|
||||||
|
for (String modelClusterKey : modelClusterKeySet) {
|
||||||
|
if (ModelCluster.build(modelClusterKey).getModelIds().contains(modelId)) {
|
||||||
|
return modelClusterKey;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return null;
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
||||||
@@ -0,0 +1,271 @@
|
|||||||
|
package com.tencent.supersonic.chat.parser.llm.s2sql;
|
||||||
|
|
||||||
|
import com.tencent.supersonic.chat.agent.tool.AgentToolType;
|
||||||
|
import com.tencent.supersonic.chat.agent.tool.CommonAgentTool;
|
||||||
|
import com.tencent.supersonic.chat.api.component.SemanticInterpreter;
|
||||||
|
import com.tencent.supersonic.chat.api.pojo.ChatContext;
|
||||||
|
import com.tencent.supersonic.chat.api.pojo.QueryContext;
|
||||||
|
import com.tencent.supersonic.chat.api.pojo.SchemaElement;
|
||||||
|
import com.tencent.supersonic.chat.api.pojo.SchemaElementMatch;
|
||||||
|
import com.tencent.supersonic.chat.api.pojo.SchemaElementType;
|
||||||
|
import com.tencent.supersonic.chat.api.pojo.SemanticSchema;
|
||||||
|
import com.tencent.supersonic.chat.api.pojo.request.QueryReq;
|
||||||
|
import com.tencent.supersonic.chat.config.LLMParserConfig;
|
||||||
|
import com.tencent.supersonic.chat.config.OptimizationConfig;
|
||||||
|
import com.tencent.supersonic.chat.parser.LLMInterpreter;
|
||||||
|
import com.tencent.supersonic.chat.parser.SatisfactionChecker;
|
||||||
|
import com.tencent.supersonic.chat.query.llm.s2sql.LLMReq;
|
||||||
|
import com.tencent.supersonic.chat.query.llm.s2sql.LLMReq.ElementValue;
|
||||||
|
import com.tencent.supersonic.chat.query.llm.s2sql.LLMResp;
|
||||||
|
import com.tencent.supersonic.chat.service.AgentService;
|
||||||
|
import com.tencent.supersonic.chat.utils.ComponentFactory;
|
||||||
|
import com.tencent.supersonic.common.pojo.ModelCluster;
|
||||||
|
import com.tencent.supersonic.common.pojo.enums.DataFormatTypeEnum;
|
||||||
|
import com.tencent.supersonic.common.pojo.enums.TimeDimensionEnum;
|
||||||
|
import com.tencent.supersonic.common.util.DateUtils;
|
||||||
|
import com.tencent.supersonic.knowledge.service.SchemaService;
|
||||||
|
import com.tencent.supersonic.semantic.api.model.pojo.SchemaItem;
|
||||||
|
import com.tencent.supersonic.semantic.api.model.response.ModelSchemaResp;
|
||||||
|
import lombok.extern.slf4j.Slf4j;
|
||||||
|
import org.apache.commons.lang3.StringUtils;
|
||||||
|
import org.apache.commons.lang3.tuple.Pair;
|
||||||
|
import org.springframework.beans.factory.annotation.Autowired;
|
||||||
|
import org.springframework.stereotype.Service;
|
||||||
|
import org.springframework.util.CollectionUtils;
|
||||||
|
|
||||||
|
import java.util.ArrayList;
|
||||||
|
import java.util.Comparator;
|
||||||
|
import java.util.HashSet;
|
||||||
|
import java.util.List;
|
||||||
|
import java.util.Map;
|
||||||
|
import java.util.Objects;
|
||||||
|
import java.util.Optional;
|
||||||
|
import java.util.Set;
|
||||||
|
import java.util.stream.Collectors;
|
||||||
|
|
||||||
|
@Slf4j
|
||||||
|
@Service
|
||||||
|
public class LLMRequestService {
|
||||||
|
|
||||||
|
protected LLMInterpreter llmInterpreter = ComponentFactory.getLLMInterpreter();
|
||||||
|
|
||||||
|
protected SemanticInterpreter semanticInterpreter = ComponentFactory.getSemanticLayer();
|
||||||
|
@Autowired
|
||||||
|
private LLMParserConfig llmParserConfig;
|
||||||
|
@Autowired
|
||||||
|
private AgentService agentService;
|
||||||
|
@Autowired
|
||||||
|
private SchemaService schemaService;
|
||||||
|
@Autowired
|
||||||
|
private OptimizationConfig optimizationConfig;
|
||||||
|
|
||||||
|
|
||||||
|
public boolean check(QueryContext queryCtx) {
|
||||||
|
QueryReq request = queryCtx.getRequest();
|
||||||
|
if (StringUtils.isEmpty(llmParserConfig.getUrl())) {
|
||||||
|
log.info("llm parser url is empty, skip {} , llmParserConfig:{}", LLMS2SQLParser.class, llmParserConfig);
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
if (SatisfactionChecker.check(queryCtx)) {
|
||||||
|
log.info("skip {}, queryText:{}", LLMS2SQLParser.class, request.getQueryText());
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
public ModelCluster getModelCluster(QueryContext queryCtx, ChatContext chatCtx, Integer agentId) {
|
||||||
|
Set<Long> distinctModelIds = agentService.getModelIds(agentId, AgentToolType.LLM_S2SQL);
|
||||||
|
if (agentService.containsAllModel(distinctModelIds)) {
|
||||||
|
distinctModelIds = new HashSet<>();
|
||||||
|
}
|
||||||
|
ModelResolver modelResolver = ComponentFactory.getModelResolver();
|
||||||
|
String modelCluster = modelResolver.resolve(queryCtx, chatCtx, distinctModelIds);
|
||||||
|
log.info("resolve modelId:{},llmParser Models:{}", modelCluster, distinctModelIds);
|
||||||
|
return ModelCluster.build(modelCluster);
|
||||||
|
}
|
||||||
|
|
||||||
|
public CommonAgentTool getParserTool(QueryReq request, Set<Long> modelIdSet) {
|
||||||
|
List<CommonAgentTool> commonAgentTools = agentService.getParserTools(request.getAgentId(),
|
||||||
|
AgentToolType.LLM_S2SQL);
|
||||||
|
Optional<CommonAgentTool> llmParserTool = commonAgentTools.stream()
|
||||||
|
.filter(tool -> {
|
||||||
|
List<Long> modelIds = tool.getModelIds();
|
||||||
|
if (agentService.containsAllModel(new HashSet<>(modelIds))) {
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
for (Long modelId : modelIdSet) {
|
||||||
|
if (modelIds.contains(modelId)) {
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return false;
|
||||||
|
})
|
||||||
|
.findFirst();
|
||||||
|
return llmParserTool.orElse(null);
|
||||||
|
}
|
||||||
|
|
||||||
|
public LLMReq getLlmReq(QueryContext queryCtx, SemanticSchema semanticSchema,
|
||||||
|
ModelCluster modelCluster, List<ElementValue> linkingValues) {
|
||||||
|
Map<Long, String> modelIdToName = semanticSchema.getModelIdToName();
|
||||||
|
String queryText = queryCtx.getRequest().getQueryText();
|
||||||
|
|
||||||
|
LLMReq llmReq = new LLMReq();
|
||||||
|
llmReq.setQueryText(queryText);
|
||||||
|
Long firstModelId = modelCluster.getFirstModel();
|
||||||
|
LLMReq.FilterCondition filterCondition = new LLMReq.FilterCondition();
|
||||||
|
llmReq.setFilterCondition(filterCondition);
|
||||||
|
|
||||||
|
LLMReq.LLMSchema llmSchema = new LLMReq.LLMSchema();
|
||||||
|
llmSchema.setModelName(modelIdToName.get(firstModelId));
|
||||||
|
llmSchema.setDomainName(modelIdToName.get(firstModelId));
|
||||||
|
|
||||||
|
List<String> fieldNameList = getFieldNameList(queryCtx, modelCluster, llmParserConfig);
|
||||||
|
|
||||||
|
String priorExts = getPriorExts(modelCluster.getModelIds(), fieldNameList);
|
||||||
|
llmReq.setPriorExts(priorExts);
|
||||||
|
|
||||||
|
fieldNameList.add(TimeDimensionEnum.DAY.getChName());
|
||||||
|
llmSchema.setFieldNameList(fieldNameList);
|
||||||
|
llmReq.setSchema(llmSchema);
|
||||||
|
|
||||||
|
List<ElementValue> linking = new ArrayList<>();
|
||||||
|
if (optimizationConfig.isUseLinkingValueSwitch()) {
|
||||||
|
linking.addAll(linkingValues);
|
||||||
|
}
|
||||||
|
llmReq.setLinking(linking);
|
||||||
|
|
||||||
|
String currentDate = S2SQLDateHelper.getReferenceDate(firstModelId);
|
||||||
|
if (StringUtils.isEmpty(currentDate)) {
|
||||||
|
currentDate = DateUtils.getBeforeDate(0);
|
||||||
|
}
|
||||||
|
llmReq.setCurrentDate(currentDate);
|
||||||
|
return llmReq;
|
||||||
|
}
|
||||||
|
|
||||||
|
public LLMResp requestLLM(LLMReq llmReq, String modelClusterKey) {
|
||||||
|
return llmInterpreter.query2sql(llmReq, modelClusterKey);
|
||||||
|
}
|
||||||
|
|
||||||
|
protected List<String> getFieldNameList(QueryContext queryCtx, ModelCluster modelCluster,
|
||||||
|
LLMParserConfig llmParserConfig) {
|
||||||
|
|
||||||
|
Set<String> results = getTopNFieldNames(modelCluster, llmParserConfig);
|
||||||
|
|
||||||
|
Set<String> fieldNameList = getMatchedFieldNames(queryCtx, modelCluster);
|
||||||
|
|
||||||
|
results.addAll(fieldNameList);
|
||||||
|
return new ArrayList<>(results);
|
||||||
|
}
|
||||||
|
|
||||||
|
private String getPriorExts(Set<Long> modelIds, List<String> fieldNameList) {
|
||||||
|
StringBuilder extraInfoSb = new StringBuilder();
|
||||||
|
List<ModelSchemaResp> modelSchemaResps = semanticInterpreter.fetchModelSchema(
|
||||||
|
new ArrayList<>(modelIds), true);
|
||||||
|
if (!CollectionUtils.isEmpty(modelSchemaResps)) {
|
||||||
|
|
||||||
|
ModelSchemaResp modelSchemaResp = modelSchemaResps.get(0);
|
||||||
|
Map<String, String> fieldNameToDataFormatType = modelSchemaResp.getMetrics()
|
||||||
|
.stream().filter(metricSchemaResp -> Objects.nonNull(metricSchemaResp.getDataFormatType()))
|
||||||
|
.flatMap(metricSchemaResp -> {
|
||||||
|
Set<Pair<String, String>> result = new HashSet<>();
|
||||||
|
String dataFormatType = metricSchemaResp.getDataFormatType();
|
||||||
|
result.add(Pair.of(metricSchemaResp.getName(), dataFormatType));
|
||||||
|
List<String> aliasList = SchemaItem.getAliasList(metricSchemaResp.getAlias());
|
||||||
|
if (!CollectionUtils.isEmpty(aliasList)) {
|
||||||
|
for (String alias : aliasList) {
|
||||||
|
result.add(Pair.of(alias, dataFormatType));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return result.stream();
|
||||||
|
})
|
||||||
|
.collect(Collectors.toMap(a -> a.getLeft(), a -> a.getRight(), (k1, k2) -> k1));
|
||||||
|
|
||||||
|
for (String fieldName : fieldNameList) {
|
||||||
|
String dataFormatType = fieldNameToDataFormatType.get(fieldName);
|
||||||
|
if (DataFormatTypeEnum.DECIMAL.getName().equalsIgnoreCase(dataFormatType)
|
||||||
|
|| DataFormatTypeEnum.PERCENT.getName().equalsIgnoreCase(dataFormatType)) {
|
||||||
|
String format = String.format("%s的计量单位是%s", fieldName, "小数; ");
|
||||||
|
extraInfoSb.append(format);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return extraInfoSb.toString();
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
protected List<ElementValue> getValueList(QueryContext queryCtx, ModelCluster modelCluster) {
|
||||||
|
Map<Long, String> itemIdToName = getItemIdToName(modelCluster);
|
||||||
|
|
||||||
|
List<SchemaElementMatch> matchedElements = queryCtx.getModelClusterMapInfo()
|
||||||
|
.getMatchedElements(modelCluster.getKey());
|
||||||
|
if (CollectionUtils.isEmpty(matchedElements)) {
|
||||||
|
return new ArrayList<>();
|
||||||
|
}
|
||||||
|
Set<ElementValue> valueMatches = matchedElements
|
||||||
|
.stream()
|
||||||
|
.filter(elementMatch -> !elementMatch.isInherited())
|
||||||
|
.filter(schemaElementMatch -> {
|
||||||
|
SchemaElementType type = schemaElementMatch.getElement().getType();
|
||||||
|
return SchemaElementType.VALUE.equals(type) || SchemaElementType.ID.equals(type);
|
||||||
|
})
|
||||||
|
.map(elementMatch -> {
|
||||||
|
ElementValue elementValue = new ElementValue();
|
||||||
|
elementValue.setFieldName(itemIdToName.get(elementMatch.getElement().getId()));
|
||||||
|
elementValue.setFieldValue(elementMatch.getWord());
|
||||||
|
return elementValue;
|
||||||
|
}).collect(Collectors.toSet());
|
||||||
|
return new ArrayList<>(valueMatches);
|
||||||
|
}
|
||||||
|
|
||||||
|
protected Map<Long, String> getItemIdToName(ModelCluster modelCluster) {
|
||||||
|
SemanticSchema semanticSchema = schemaService.getSemanticSchema();
|
||||||
|
return semanticSchema.getDimensions(modelCluster.getModelIds()).stream()
|
||||||
|
.collect(Collectors.toMap(SchemaElement::getId, SchemaElement::getName, (value1, value2) -> value2));
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
private Set<String> getTopNFieldNames(ModelCluster modelCluster, LLMParserConfig llmParserConfig) {
|
||||||
|
SemanticSchema semanticSchema = schemaService.getSemanticSchema();
|
||||||
|
Set<String> results = semanticSchema.getDimensions(modelCluster.getModelIds()).stream()
|
||||||
|
.sorted(Comparator.comparing(SchemaElement::getUseCnt).reversed())
|
||||||
|
.limit(llmParserConfig.getDimensionTopN())
|
||||||
|
.map(entry -> entry.getName())
|
||||||
|
.collect(Collectors.toSet());
|
||||||
|
|
||||||
|
Set<String> metrics = semanticSchema.getMetrics(modelCluster.getModelIds()).stream()
|
||||||
|
.sorted(Comparator.comparing(SchemaElement::getUseCnt).reversed())
|
||||||
|
.limit(llmParserConfig.getMetricTopN())
|
||||||
|
.map(entry -> entry.getName())
|
||||||
|
.collect(Collectors.toSet());
|
||||||
|
|
||||||
|
results.addAll(metrics);
|
||||||
|
return results;
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
protected Set<String> getMatchedFieldNames(QueryContext queryCtx, ModelCluster modelCluster) {
|
||||||
|
Map<Long, String> itemIdToName = getItemIdToName(modelCluster);
|
||||||
|
List<SchemaElementMatch> matchedElements = queryCtx.getModelClusterMapInfo()
|
||||||
|
.getMatchedElements(modelCluster.getKey());
|
||||||
|
if (CollectionUtils.isEmpty(matchedElements)) {
|
||||||
|
return new HashSet<>();
|
||||||
|
}
|
||||||
|
Set<String> fieldNameList = matchedElements.stream()
|
||||||
|
.filter(schemaElementMatch -> {
|
||||||
|
SchemaElementType elementType = schemaElementMatch.getElement().getType();
|
||||||
|
return SchemaElementType.METRIC.equals(elementType)
|
||||||
|
|| SchemaElementType.DIMENSION.equals(elementType)
|
||||||
|
|| SchemaElementType.VALUE.equals(elementType);
|
||||||
|
})
|
||||||
|
.map(schemaElementMatch -> {
|
||||||
|
SchemaElement element = schemaElementMatch.getElement();
|
||||||
|
SchemaElementType elementType = element.getType();
|
||||||
|
if (SchemaElementType.VALUE.equals(elementType)) {
|
||||||
|
return itemIdToName.get(element.getId());
|
||||||
|
}
|
||||||
|
return schemaElementMatch.getWord();
|
||||||
|
})
|
||||||
|
.collect(Collectors.toSet());
|
||||||
|
return fieldNameList;
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -0,0 +1,63 @@
|
|||||||
|
package com.tencent.supersonic.chat.parser.llm.s2sql;
|
||||||
|
|
||||||
|
import com.tencent.supersonic.chat.agent.tool.CommonAgentTool;
|
||||||
|
import com.tencent.supersonic.chat.api.pojo.QueryContext;
|
||||||
|
import com.tencent.supersonic.chat.api.pojo.SemanticParseInfo;
|
||||||
|
import com.tencent.supersonic.chat.query.QueryManager;
|
||||||
|
import com.tencent.supersonic.chat.query.llm.LLMSemanticQuery;
|
||||||
|
import com.tencent.supersonic.chat.query.llm.s2sql.LLMResp;
|
||||||
|
import com.tencent.supersonic.chat.query.llm.s2sql.S2SQLQuery;
|
||||||
|
import com.tencent.supersonic.common.pojo.Constants;
|
||||||
|
import com.tencent.supersonic.common.util.jsqlparser.SqlParserEqualHelper;
|
||||||
|
import lombok.extern.slf4j.Slf4j;
|
||||||
|
import org.apache.commons.collections4.MapUtils;
|
||||||
|
import org.springframework.stereotype.Service;
|
||||||
|
|
||||||
|
import java.util.HashMap;
|
||||||
|
import java.util.Map;
|
||||||
|
import java.util.Objects;
|
||||||
|
|
||||||
|
@Slf4j
|
||||||
|
@Service
|
||||||
|
public class LLMResponseService {
|
||||||
|
|
||||||
|
public SemanticParseInfo addParseInfo(QueryContext queryCtx, ParseResult parseResult, String s2SQL, Double weight) {
|
||||||
|
if (Objects.isNull(weight)) {
|
||||||
|
weight = 0D;
|
||||||
|
}
|
||||||
|
LLMSemanticQuery semanticQuery = QueryManager.createLLMQuery(S2SQLQuery.QUERY_MODE);
|
||||||
|
SemanticParseInfo parseInfo = semanticQuery.getParseInfo();
|
||||||
|
parseInfo.setModel(parseResult.getModelCluster());
|
||||||
|
CommonAgentTool commonAgentTool = parseResult.getCommonAgentTool();
|
||||||
|
parseInfo.getElementMatches().addAll(queryCtx.getModelClusterMapInfo()
|
||||||
|
.getMatchedElements(parseInfo.getModelClusterKey()));
|
||||||
|
|
||||||
|
Map<String, Object> properties = new HashMap<>();
|
||||||
|
properties.put(Constants.CONTEXT, parseResult);
|
||||||
|
properties.put("type", "internal");
|
||||||
|
properties.put("name", commonAgentTool.getName());
|
||||||
|
|
||||||
|
parseInfo.setProperties(properties);
|
||||||
|
parseInfo.setScore(queryCtx.getRequest().getQueryText().length() * (1 + weight));
|
||||||
|
parseInfo.setQueryMode(semanticQuery.getQueryMode());
|
||||||
|
parseInfo.getSqlInfo().setS2SQL(s2SQL);
|
||||||
|
parseInfo.setModel(parseResult.getModelCluster());
|
||||||
|
queryCtx.getCandidateQueries().add(semanticQuery);
|
||||||
|
return parseInfo;
|
||||||
|
}
|
||||||
|
|
||||||
|
public Map<String, Double> getDeduplicationSqlWeight(LLMResp llmResp) {
|
||||||
|
if (MapUtils.isEmpty(llmResp.getSqlWeight())) {
|
||||||
|
return llmResp.getSqlWeight();
|
||||||
|
}
|
||||||
|
Map<String, Double> result = new HashMap<>();
|
||||||
|
for (Map.Entry<String, Double> entry : llmResp.getSqlWeight().entrySet()) {
|
||||||
|
String key = entry.getKey();
|
||||||
|
if (result.keySet().stream().anyMatch(existKey -> SqlParserEqualHelper.equals(existKey, key))) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
result.put(key, entry.getValue());
|
||||||
|
}
|
||||||
|
return result;
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -0,0 +1,83 @@
|
|||||||
|
package com.tencent.supersonic.chat.parser.llm.s2sql;
|
||||||
|
|
||||||
|
import com.tencent.supersonic.chat.agent.tool.CommonAgentTool;
|
||||||
|
import com.tencent.supersonic.chat.api.component.SemanticParser;
|
||||||
|
import com.tencent.supersonic.chat.api.pojo.ChatContext;
|
||||||
|
import com.tencent.supersonic.chat.api.pojo.QueryContext;
|
||||||
|
import com.tencent.supersonic.chat.api.pojo.SemanticSchema;
|
||||||
|
import com.tencent.supersonic.chat.api.pojo.request.QueryReq;
|
||||||
|
import com.tencent.supersonic.chat.query.llm.s2sql.LLMReq;
|
||||||
|
import com.tencent.supersonic.chat.query.llm.s2sql.LLMReq.ElementValue;
|
||||||
|
import com.tencent.supersonic.chat.query.llm.s2sql.LLMResp;
|
||||||
|
import com.tencent.supersonic.chat.service.SemanticService;
|
||||||
|
import com.tencent.supersonic.common.pojo.ModelCluster;
|
||||||
|
import com.tencent.supersonic.common.util.ContextUtils;
|
||||||
|
import lombok.extern.slf4j.Slf4j;
|
||||||
|
import org.apache.commons.collections4.MapUtils;
|
||||||
|
import org.apache.commons.lang3.StringUtils;
|
||||||
|
|
||||||
|
import java.util.List;
|
||||||
|
import java.util.Map;
|
||||||
|
import java.util.Objects;
|
||||||
|
|
||||||
|
@Slf4j
|
||||||
|
public class LLMS2SQLParser implements SemanticParser {
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public void parse(QueryContext queryCtx, ChatContext chatCtx) {
|
||||||
|
QueryReq request = queryCtx.getRequest();
|
||||||
|
LLMRequestService requestService = ContextUtils.getBean(LLMRequestService.class);
|
||||||
|
SemanticService semanticService = ContextUtils.getBean(SemanticService.class);
|
||||||
|
//1.determine whether to skip this parser.
|
||||||
|
if (requestService.check(queryCtx)) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
try {
|
||||||
|
//2.get modelId from queryCtx and chatCtx.
|
||||||
|
ModelCluster modelCluster = requestService.getModelCluster(queryCtx, chatCtx, request.getAgentId());
|
||||||
|
if (StringUtils.isBlank(modelCluster.getKey())) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
//3.get agent tool and determine whether to skip this parser.
|
||||||
|
CommonAgentTool commonAgentTool = requestService.getParserTool(request, modelCluster.getModelIds());
|
||||||
|
if (Objects.isNull(commonAgentTool)) {
|
||||||
|
log.info("no tool in this agent, skip {}", LLMS2SQLParser.class);
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
//4.construct a request, call the API for the large model, and retrieve the results.
|
||||||
|
List<ElementValue> linkingValues = requestService.getValueList(queryCtx, modelCluster);
|
||||||
|
SemanticSchema semanticSchema = semanticService.getSemanticSchema();
|
||||||
|
LLMReq llmReq = requestService.getLlmReq(queryCtx, semanticSchema, modelCluster, linkingValues);
|
||||||
|
LLMResp llmResp = requestService.requestLLM(llmReq, modelCluster.getKey());
|
||||||
|
|
||||||
|
if (Objects.isNull(llmResp)) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
//5. deduplicate the SQL result list and build parserInfo
|
||||||
|
modelCluster.buildName(semanticSchema.getModelIdToName());
|
||||||
|
LLMResponseService responseService = ContextUtils.getBean(LLMResponseService.class);
|
||||||
|
Map<String, Double> deduplicationSqlWeight = responseService.getDeduplicationSqlWeight(llmResp);
|
||||||
|
ParseResult parseResult = ParseResult.builder()
|
||||||
|
.request(request)
|
||||||
|
.modelCluster(modelCluster)
|
||||||
|
.commonAgentTool(commonAgentTool)
|
||||||
|
.llmReq(llmReq)
|
||||||
|
.llmResp(llmResp)
|
||||||
|
.linkingValues(linkingValues)
|
||||||
|
.build();
|
||||||
|
|
||||||
|
if (MapUtils.isEmpty(deduplicationSqlWeight)) {
|
||||||
|
responseService.addParseInfo(queryCtx, parseResult, llmResp.getSqlOutput(), 1D);
|
||||||
|
} else {
|
||||||
|
deduplicationSqlWeight.forEach((sql, weight) -> {
|
||||||
|
responseService.addParseInfo(queryCtx, parseResult, sql, weight);
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
} catch (Exception e) {
|
||||||
|
log.error("parse", e);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
}
|
||||||
@@ -1,4 +1,4 @@
|
|||||||
package com.tencent.supersonic.chat.parser.llm.dsl;
|
package com.tencent.supersonic.chat.parser.llm.s2sql;
|
||||||
|
|
||||||
import lombok.Data;
|
import lombok.Data;
|
||||||
|
|
||||||
@@ -1,4 +1,4 @@
|
|||||||
package com.tencent.supersonic.chat.parser.llm.dsl;
|
package com.tencent.supersonic.chat.parser.llm.s2sql;
|
||||||
|
|
||||||
|
|
||||||
import com.tencent.supersonic.chat.api.pojo.ChatContext;
|
import com.tencent.supersonic.chat.api.pojo.ChatContext;
|
||||||
@@ -7,6 +7,6 @@ import java.util.Set;
|
|||||||
|
|
||||||
public interface ModelResolver {
|
public interface ModelResolver {
|
||||||
|
|
||||||
Long resolve(QueryContext queryContext, ChatContext chatCtx, Set<Long> restrictiveModels);
|
String resolve(QueryContext queryContext, ChatContext chatCtx, Set<Long> restrictiveModels);
|
||||||
|
|
||||||
}
|
}
|
||||||
@@ -0,0 +1,33 @@
|
|||||||
|
package com.tencent.supersonic.chat.parser.llm.s2sql;
|
||||||
|
|
||||||
|
import com.tencent.supersonic.chat.agent.tool.CommonAgentTool;
|
||||||
|
import com.tencent.supersonic.chat.api.pojo.request.QueryReq;
|
||||||
|
import com.tencent.supersonic.chat.query.llm.s2sql.LLMReq;
|
||||||
|
import com.tencent.supersonic.chat.query.llm.s2sql.LLMReq.ElementValue;
|
||||||
|
import com.tencent.supersonic.chat.query.llm.s2sql.LLMResp;
|
||||||
|
import com.tencent.supersonic.common.pojo.ModelCluster;
|
||||||
|
import lombok.AllArgsConstructor;
|
||||||
|
import lombok.Builder;
|
||||||
|
import lombok.Data;
|
||||||
|
import lombok.NoArgsConstructor;
|
||||||
|
|
||||||
|
import java.util.List;
|
||||||
|
|
||||||
|
@Data
|
||||||
|
@Builder
|
||||||
|
@AllArgsConstructor
|
||||||
|
@NoArgsConstructor
|
||||||
|
public class ParseResult {
|
||||||
|
|
||||||
|
private ModelCluster modelCluster;
|
||||||
|
|
||||||
|
private LLMReq llmReq;
|
||||||
|
|
||||||
|
private LLMResp llmResp;
|
||||||
|
|
||||||
|
private QueryReq request;
|
||||||
|
|
||||||
|
private CommonAgentTool commonAgentTool;
|
||||||
|
|
||||||
|
private List<ElementValue> linkingValues;
|
||||||
|
}
|
||||||
@@ -1,4 +1,4 @@
|
|||||||
package com.tencent.supersonic.chat.parser.llm.dsl;
|
package com.tencent.supersonic.chat.parser.llm.s2sql;
|
||||||
|
|
||||||
import com.tencent.supersonic.chat.api.pojo.request.ChatConfigFilter;
|
import com.tencent.supersonic.chat.api.pojo.request.ChatConfigFilter;
|
||||||
import com.tencent.supersonic.chat.api.pojo.request.ChatDefaultConfigReq;
|
import com.tencent.supersonic.chat.api.pojo.request.ChatDefaultConfigReq;
|
||||||
@@ -10,39 +10,35 @@ import com.tencent.supersonic.common.util.DateUtils;
|
|||||||
import java.util.List;
|
import java.util.List;
|
||||||
import java.util.Objects;
|
import java.util.Objects;
|
||||||
import org.apache.commons.collections.CollectionUtils;
|
import org.apache.commons.collections.CollectionUtils;
|
||||||
import org.apache.commons.lang3.StringUtils;
|
|
||||||
|
|
||||||
public class DSLDateHelper {
|
public class S2SQLDateHelper {
|
||||||
|
|
||||||
public static String getReferenceDate(Long modelId) {
|
public static String getReferenceDate(Long modelId) {
|
||||||
String chatDetailDate = getChatDetailDate(modelId);
|
String defaultDate = DateUtils.getBeforeDate(0);
|
||||||
if (StringUtils.isNotBlank(chatDetailDate)) {
|
|
||||||
return chatDetailDate;
|
|
||||||
}
|
|
||||||
return DateUtils.getBeforeDate(0);
|
|
||||||
}
|
|
||||||
|
|
||||||
private static String getChatDetailDate(Long modelId) {
|
|
||||||
if (Objects.isNull(modelId)) {
|
if (Objects.isNull(modelId)) {
|
||||||
return null;
|
return defaultDate;
|
||||||
}
|
}
|
||||||
ChatConfigFilter filter = new ChatConfigFilter();
|
ChatConfigFilter filter = new ChatConfigFilter();
|
||||||
filter.setModelId(modelId);
|
filter.setModelId(modelId);
|
||||||
|
|
||||||
List<ChatConfigResp> configResps = ContextUtils.getBean(ConfigService.class).search(filter, null);
|
List<ChatConfigResp> configResps = ContextUtils.getBean(ConfigService.class).search(filter, null);
|
||||||
if (CollectionUtils.isEmpty(configResps)) {
|
if (CollectionUtils.isEmpty(configResps)) {
|
||||||
return null;
|
return defaultDate;
|
||||||
}
|
}
|
||||||
ChatConfigResp chatConfigResp = configResps.get(0);
|
ChatConfigResp chatConfigResp = configResps.get(0);
|
||||||
if (Objects.isNull(chatConfigResp.getChatDetailConfig()) || Objects.isNull(
|
if (Objects.isNull(chatConfigResp.getChatDetailConfig()) || Objects.isNull(
|
||||||
chatConfigResp.getChatDetailConfig().getChatDefaultConfig())) {
|
chatConfigResp.getChatDetailConfig().getChatDefaultConfig())) {
|
||||||
return null;
|
return defaultDate;
|
||||||
}
|
}
|
||||||
|
|
||||||
ChatDefaultConfigReq chatDefaultConfig = chatConfigResp.getChatDetailConfig().getChatDefaultConfig();
|
ChatDefaultConfigReq chatDefaultConfig = chatConfigResp.getChatDetailConfig().getChatDefaultConfig();
|
||||||
Integer unit = chatDefaultConfig.getUnit();
|
Integer unit = chatDefaultConfig.getUnit();
|
||||||
String period = chatDefaultConfig.getPeriod();
|
String period = chatDefaultConfig.getPeriod();
|
||||||
if (Objects.nonNull(unit)) {
|
if (Objects.nonNull(unit)) {
|
||||||
|
// If the unit is set to less than 0, then do not add relative date.
|
||||||
|
if (unit < 0) {
|
||||||
|
return null;
|
||||||
|
}
|
||||||
DatePeriodEnum datePeriodEnum = DatePeriodEnum.get(period);
|
DatePeriodEnum datePeriodEnum = DatePeriodEnum.get(period);
|
||||||
if (Objects.isNull(datePeriodEnum)) {
|
if (Objects.isNull(datePeriodEnum)) {
|
||||||
return DateUtils.getBeforeDate(unit);
|
return DateUtils.getBeforeDate(unit);
|
||||||
@@ -50,6 +46,7 @@ public class DSLDateHelper {
|
|||||||
return DateUtils.getBeforeDate(unit, datePeriodEnum);
|
return DateUtils.getBeforeDate(unit, datePeriodEnum);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return null;
|
return defaultDate;
|
||||||
}
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
@@ -1,51 +0,0 @@
|
|||||||
package com.tencent.supersonic.chat.parser.llm.time;
|
|
||||||
|
|
||||||
import com.alibaba.fastjson.JSON;
|
|
||||||
import com.alibaba.fastjson.JSONObject;
|
|
||||||
import com.tencent.supersonic.chat.api.component.SemanticParser;
|
|
||||||
import com.tencent.supersonic.chat.api.component.SemanticQuery;
|
|
||||||
import com.tencent.supersonic.chat.api.pojo.ChatContext;
|
|
||||||
import com.tencent.supersonic.chat.api.pojo.QueryContext;
|
|
||||||
import com.tencent.supersonic.common.pojo.DateConf;
|
|
||||||
import com.tencent.supersonic.common.util.ChatGptHelper;
|
|
||||||
import com.tencent.supersonic.common.util.ContextUtils;
|
|
||||||
import lombok.extern.slf4j.Slf4j;
|
|
||||||
|
|
||||||
@Slf4j
|
|
||||||
public class LLMTimeEnhancementParse implements SemanticParser {
|
|
||||||
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public void parse(QueryContext queryContext, ChatContext chatContext) {
|
|
||||||
log.info("before queryContext:{},chatContext:{}", queryContext, chatContext);
|
|
||||||
ChatGptHelper chatGptHelper = ContextUtils.getBean(ChatGptHelper.class);
|
|
||||||
try {
|
|
||||||
String inferredTime = chatGptHelper.inferredTime(queryContext.getRequest().getQueryText());
|
|
||||||
if (!queryContext.getCandidateQueries().isEmpty()) {
|
|
||||||
for (SemanticQuery query : queryContext.getCandidateQueries()) {
|
|
||||||
DateConf dateInfo = query.getParseInfo().getDateInfo();
|
|
||||||
JSONObject jsonObject = JSON.parseObject(inferredTime);
|
|
||||||
if (jsonObject.containsKey("date")) {
|
|
||||||
dateInfo.setDateMode(DateConf.DateMode.BETWEEN);
|
|
||||||
dateInfo.setStartDate(jsonObject.getString("date"));
|
|
||||||
dateInfo.setEndDate(jsonObject.getString("date"));
|
|
||||||
query.getParseInfo().setDateInfo(dateInfo);
|
|
||||||
} else if (jsonObject.containsKey("start")) {
|
|
||||||
dateInfo.setDateMode(DateConf.DateMode.BETWEEN);
|
|
||||||
dateInfo.setStartDate(jsonObject.getString("start"));
|
|
||||||
dateInfo.setEndDate(jsonObject.getString("end"));
|
|
||||||
query.getParseInfo().setDateInfo(dateInfo);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
} catch (Exception exception) {
|
|
||||||
log.error("{} parse error,this reason is:{}", LLMTimeEnhancementParse.class.getSimpleName(),
|
|
||||||
(Object) exception.getStackTrace());
|
|
||||||
}
|
|
||||||
|
|
||||||
log.info("{} after queryContext:{},chatContext:{}",
|
|
||||||
LLMTimeEnhancementParse.class.getSimpleName(), queryContext, chatContext);
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
}
|
|
||||||
@@ -3,14 +3,14 @@ package com.tencent.supersonic.chat.parser.plugin;
|
|||||||
import com.google.common.collect.Lists;
|
import com.google.common.collect.Lists;
|
||||||
import com.google.common.collect.Sets;
|
import com.google.common.collect.Sets;
|
||||||
import com.tencent.supersonic.chat.api.component.SemanticParser;
|
import com.tencent.supersonic.chat.api.component.SemanticParser;
|
||||||
|
import com.tencent.supersonic.chat.api.component.SemanticQuery;
|
||||||
import com.tencent.supersonic.chat.api.pojo.ChatContext;
|
import com.tencent.supersonic.chat.api.pojo.ChatContext;
|
||||||
import com.tencent.supersonic.chat.api.pojo.QueryContext;
|
import com.tencent.supersonic.chat.api.pojo.QueryContext;
|
||||||
import com.tencent.supersonic.chat.api.pojo.SchemaElementMatch;
|
import com.tencent.supersonic.chat.api.pojo.SchemaElementMatch;
|
||||||
|
import com.tencent.supersonic.chat.api.pojo.SchemaElementType;
|
||||||
import com.tencent.supersonic.chat.api.pojo.SemanticParseInfo;
|
import com.tencent.supersonic.chat.api.pojo.SemanticParseInfo;
|
||||||
import com.tencent.supersonic.chat.api.pojo.request.QueryFilter;
|
import com.tencent.supersonic.chat.api.pojo.request.QueryFilter;
|
||||||
import com.tencent.supersonic.chat.api.pojo.request.QueryReq;
|
import com.tencent.supersonic.chat.api.pojo.request.QueryReq;
|
||||||
import com.tencent.supersonic.chat.api.pojo.SchemaElement;
|
|
||||||
import com.tencent.supersonic.chat.api.pojo.SchemaElementType;
|
|
||||||
import com.tencent.supersonic.chat.plugin.Plugin;
|
import com.tencent.supersonic.chat.plugin.Plugin;
|
||||||
import com.tencent.supersonic.chat.plugin.PluginManager;
|
import com.tencent.supersonic.chat.plugin.PluginManager;
|
||||||
import com.tencent.supersonic.chat.plugin.PluginParseResult;
|
import com.tencent.supersonic.chat.plugin.PluginParseResult;
|
||||||
@@ -18,8 +18,10 @@ import com.tencent.supersonic.chat.plugin.PluginRecallResult;
|
|||||||
import com.tencent.supersonic.chat.query.QueryManager;
|
import com.tencent.supersonic.chat.query.QueryManager;
|
||||||
import com.tencent.supersonic.chat.query.plugin.PluginSemanticQuery;
|
import com.tencent.supersonic.chat.query.plugin.PluginSemanticQuery;
|
||||||
import com.tencent.supersonic.common.pojo.Constants;
|
import com.tencent.supersonic.common.pojo.Constants;
|
||||||
import com.tencent.supersonic.semantic.api.query.enums.FilterOperatorEnum;
|
import com.tencent.supersonic.common.pojo.ModelCluster;
|
||||||
|
import com.tencent.supersonic.common.pojo.enums.FilterOperatorEnum;
|
||||||
import org.springframework.util.CollectionUtils;
|
import org.springframework.util.CollectionUtils;
|
||||||
|
|
||||||
import java.util.HashMap;
|
import java.util.HashMap;
|
||||||
import java.util.List;
|
import java.util.List;
|
||||||
import java.util.Map;
|
import java.util.Map;
|
||||||
@@ -29,6 +31,12 @@ public abstract class PluginParser implements SemanticParser {
|
|||||||
|
|
||||||
@Override
|
@Override
|
||||||
public void parse(QueryContext queryContext, ChatContext chatContext) {
|
public void parse(QueryContext queryContext, ChatContext chatContext) {
|
||||||
|
for (SemanticQuery semanticQuery : queryContext.getCandidateQueries()) {
|
||||||
|
if (queryContext.getRequest().getQueryText().length() <= semanticQuery.getParseInfo().getScore()
|
||||||
|
&& (QueryManager.getPluginQueryModes().contains(semanticQuery.getQueryMode()))) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
}
|
||||||
if (!checkPreCondition(queryContext)) {
|
if (!checkPreCondition(queryContext)) {
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
@@ -51,8 +59,10 @@ public abstract class PluginParser implements SemanticParser {
|
|||||||
}
|
}
|
||||||
for (Long modelId : modelIds) {
|
for (Long modelId : modelIds) {
|
||||||
PluginSemanticQuery pluginQuery = QueryManager.createPluginQuery(plugin.getType());
|
PluginSemanticQuery pluginQuery = QueryManager.createPluginQuery(plugin.getType());
|
||||||
SemanticParseInfo semanticParseInfo = buildSemanticParseInfo(modelId, plugin, queryContext.getRequest(),
|
SemanticParseInfo semanticParseInfo = buildSemanticParseInfo(modelId, plugin,
|
||||||
queryContext.getMapInfo().getMatchedElements(modelId), pluginRecallResult.getDistance());
|
queryContext.getRequest(),
|
||||||
|
queryContext.getModelClusterMapInfo().getMatchedElements(modelId),
|
||||||
|
pluginRecallResult.getDistance());
|
||||||
semanticParseInfo.setQueryMode(pluginQuery.getQueryMode());
|
semanticParseInfo.setQueryMode(pluginQuery.getQueryMode());
|
||||||
semanticParseInfo.setScore(pluginRecallResult.getScore());
|
semanticParseInfo.setScore(pluginRecallResult.getScore());
|
||||||
pluginQuery.setParseInfo(semanticParseInfo);
|
pluginQuery.setParseInfo(semanticParseInfo);
|
||||||
@@ -72,12 +82,9 @@ public abstract class PluginParser implements SemanticParser {
|
|||||||
if (schemaElementMatches == null) {
|
if (schemaElementMatches == null) {
|
||||||
schemaElementMatches = Lists.newArrayList();
|
schemaElementMatches = Lists.newArrayList();
|
||||||
}
|
}
|
||||||
SchemaElement model = new SchemaElement();
|
|
||||||
model.setModel(modelId);
|
|
||||||
model.setId(modelId);
|
|
||||||
SemanticParseInfo semanticParseInfo = new SemanticParseInfo();
|
SemanticParseInfo semanticParseInfo = new SemanticParseInfo();
|
||||||
semanticParseInfo.setElementMatches(schemaElementMatches);
|
semanticParseInfo.setElementMatches(schemaElementMatches);
|
||||||
semanticParseInfo.setModel(model);
|
semanticParseInfo.setModel(ModelCluster.build(Sets.newHashSet(modelId)));
|
||||||
Map<String, Object> properties = new HashMap<>();
|
Map<String, Object> properties = new HashMap<>();
|
||||||
PluginParseResult pluginParseResult = new PluginParseResult();
|
PluginParseResult pluginParseResult = new PluginParseResult();
|
||||||
pluginParseResult.setPlugin(plugin);
|
pluginParseResult.setPlugin(plugin);
|
||||||
|
|||||||
@@ -1,13 +1,16 @@
|
|||||||
package com.tencent.supersonic.chat.parser.plugin.embedding;
|
package com.tencent.supersonic.chat.parser.plugin.embedding;
|
||||||
|
|
||||||
import com.google.common.collect.Lists;
|
import com.google.common.collect.Lists;
|
||||||
import com.tencent.supersonic.chat.api.component.SemanticQuery;
|
|
||||||
import com.tencent.supersonic.chat.api.pojo.QueryContext;
|
import com.tencent.supersonic.chat.api.pojo.QueryContext;
|
||||||
|
import com.tencent.supersonic.chat.parser.HttpLLMInterpreter;
|
||||||
|
import com.tencent.supersonic.chat.parser.LLMInterpreter;
|
||||||
import com.tencent.supersonic.chat.parser.ParseMode;
|
import com.tencent.supersonic.chat.parser.ParseMode;
|
||||||
import com.tencent.supersonic.chat.parser.plugin.PluginParser;
|
import com.tencent.supersonic.chat.parser.plugin.PluginParser;
|
||||||
import com.tencent.supersonic.chat.plugin.Plugin;
|
import com.tencent.supersonic.chat.plugin.Plugin;
|
||||||
import com.tencent.supersonic.chat.plugin.PluginManager;
|
import com.tencent.supersonic.chat.plugin.PluginManager;
|
||||||
import com.tencent.supersonic.chat.plugin.PluginRecallResult;
|
import com.tencent.supersonic.chat.plugin.PluginRecallResult;
|
||||||
|
import com.tencent.supersonic.chat.utils.ComponentFactory;
|
||||||
|
import com.tencent.supersonic.common.config.EmbeddingConfig;
|
||||||
import com.tencent.supersonic.common.util.ContextUtils;
|
import com.tencent.supersonic.common.util.ContextUtils;
|
||||||
import java.util.Comparator;
|
import java.util.Comparator;
|
||||||
import java.util.List;
|
import java.util.List;
|
||||||
@@ -22,23 +25,16 @@ import org.springframework.util.CollectionUtils;
|
|||||||
@Slf4j
|
@Slf4j
|
||||||
public class EmbeddingBasedParser extends PluginParser {
|
public class EmbeddingBasedParser extends PluginParser {
|
||||||
|
|
||||||
|
protected LLMInterpreter llmInterpreter = ComponentFactory.getLLMInterpreter();
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public boolean checkPreCondition(QueryContext queryContext) {
|
public boolean checkPreCondition(QueryContext queryContext) {
|
||||||
EmbeddingConfig embeddingConfig = ContextUtils.getBean(EmbeddingConfig.class);
|
EmbeddingConfig embeddingConfig = ContextUtils.getBean(EmbeddingConfig.class);
|
||||||
if (StringUtils.isBlank(embeddingConfig.getUrl())) {
|
if (StringUtils.isBlank(embeddingConfig.getUrl()) && llmInterpreter instanceof HttpLLMInterpreter) {
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
List<Plugin> plugins = getPluginList(queryContext);
|
List<Plugin> plugins = getPluginList(queryContext);
|
||||||
if (CollectionUtils.isEmpty(plugins)) {
|
return !CollectionUtils.isEmpty(plugins);
|
||||||
return false;
|
|
||||||
}
|
|
||||||
List<SemanticQuery> semanticQueries = queryContext.getCandidateQueries();
|
|
||||||
for (SemanticQuery semanticQuery : semanticQueries) {
|
|
||||||
if (queryContext.getRequest().getQueryText().length() <= semanticQuery.getParseInfo().getScore()) {
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return true;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
|
|||||||
@@ -1,44 +1,39 @@
|
|||||||
package com.tencent.supersonic.chat.parser.plugin.function;
|
package com.tencent.supersonic.chat.parser.plugin.function;
|
||||||
|
|
||||||
import com.alibaba.fastjson.JSON;
|
|
||||||
import com.tencent.supersonic.chat.api.pojo.QueryContext;
|
import com.tencent.supersonic.chat.api.pojo.QueryContext;
|
||||||
|
import com.tencent.supersonic.chat.parser.HttpLLMInterpreter;
|
||||||
|
import com.tencent.supersonic.chat.parser.LLMInterpreter;
|
||||||
import com.tencent.supersonic.chat.parser.ParseMode;
|
import com.tencent.supersonic.chat.parser.ParseMode;
|
||||||
import com.tencent.supersonic.chat.parser.SatisfactionChecker;
|
|
||||||
import com.tencent.supersonic.chat.parser.plugin.PluginParser;
|
import com.tencent.supersonic.chat.parser.plugin.PluginParser;
|
||||||
import com.tencent.supersonic.chat.plugin.Plugin;
|
import com.tencent.supersonic.chat.plugin.Plugin;
|
||||||
import com.tencent.supersonic.chat.plugin.PluginManager;
|
import com.tencent.supersonic.chat.plugin.PluginManager;
|
||||||
import com.tencent.supersonic.chat.plugin.PluginParseConfig;
|
import com.tencent.supersonic.chat.plugin.PluginParseConfig;
|
||||||
import com.tencent.supersonic.chat.plugin.PluginRecallResult;
|
import com.tencent.supersonic.chat.plugin.PluginRecallResult;
|
||||||
import com.tencent.supersonic.chat.query.llm.dsl.DslQuery;
|
import com.tencent.supersonic.chat.query.llm.s2sql.S2SQLQuery;
|
||||||
import com.tencent.supersonic.chat.service.PluginService;
|
import com.tencent.supersonic.chat.service.PluginService;
|
||||||
|
import com.tencent.supersonic.chat.utils.ComponentFactory;
|
||||||
import com.tencent.supersonic.common.util.ContextUtils;
|
import com.tencent.supersonic.common.util.ContextUtils;
|
||||||
import java.net.URI;
|
import com.tencent.supersonic.common.util.JsonUtil;
|
||||||
import java.util.List;
|
import java.util.List;
|
||||||
|
import java.util.Objects;
|
||||||
import java.util.Optional;
|
import java.util.Optional;
|
||||||
import java.util.Set;
|
import java.util.Set;
|
||||||
import java.util.Objects;
|
|
||||||
import java.util.stream.Collectors;
|
import java.util.stream.Collectors;
|
||||||
import com.tencent.supersonic.common.util.JsonUtil;
|
|
||||||
import lombok.extern.slf4j.Slf4j;
|
import lombok.extern.slf4j.Slf4j;
|
||||||
import org.apache.commons.lang3.StringUtils;
|
import org.apache.commons.lang3.StringUtils;
|
||||||
import org.apache.commons.lang3.tuple.Pair;
|
import org.apache.commons.lang3.tuple.Pair;
|
||||||
import org.springframework.http.HttpEntity;
|
|
||||||
import org.springframework.http.HttpHeaders;
|
|
||||||
import org.springframework.http.HttpMethod;
|
|
||||||
import org.springframework.http.MediaType;
|
|
||||||
import org.springframework.http.ResponseEntity;
|
|
||||||
import org.springframework.util.CollectionUtils;
|
import org.springframework.util.CollectionUtils;
|
||||||
import org.springframework.web.client.RestTemplate;
|
|
||||||
import org.springframework.web.util.UriComponentsBuilder;
|
|
||||||
|
|
||||||
@Slf4j
|
@Slf4j
|
||||||
public class FunctionBasedParser extends PluginParser {
|
public class FunctionBasedParser extends PluginParser {
|
||||||
|
|
||||||
|
protected LLMInterpreter llmInterpreter = ComponentFactory.getLLMInterpreter();
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public boolean checkPreCondition(QueryContext queryContext) {
|
public boolean checkPreCondition(QueryContext queryContext) {
|
||||||
FunctionCallConfig functionCallConfig = ContextUtils.getBean(FunctionCallConfig.class);
|
FunctionCallConfig functionCallConfig = ContextUtils.getBean(FunctionCallConfig.class);
|
||||||
String functionUrl = functionCallConfig.getUrl();
|
String functionUrl = functionCallConfig.getUrl();
|
||||||
if (StringUtils.isBlank(functionUrl) || SatisfactionChecker.check(queryContext)) {
|
if (StringUtils.isBlank(functionUrl) && llmInterpreter instanceof HttpLLMInterpreter) {
|
||||||
log.info("functionUrl:{}, skip function parser, queryText:{}", functionUrl,
|
log.info("functionUrl:{}, skip function parser, queryText:{}", functionUrl,
|
||||||
queryContext.getRequest().getQueryText());
|
queryContext.getRequest().getQueryText());
|
||||||
return false;
|
return false;
|
||||||
@@ -89,7 +84,7 @@ public class FunctionBasedParser extends PluginParser {
|
|||||||
FunctionReq functionReq = FunctionReq.builder()
|
FunctionReq functionReq = FunctionReq.builder()
|
||||||
.queryText(queryContext.getRequest().getQueryText())
|
.queryText(queryContext.getRequest().getQueryText())
|
||||||
.pluginConfigs(pluginToFunctionCall).build();
|
.pluginConfigs(pluginToFunctionCall).build();
|
||||||
functionResp = requestFunction(functionReq);
|
functionResp = llmInterpreter.requestFunction(functionReq);
|
||||||
}
|
}
|
||||||
return functionResp;
|
return functionResp;
|
||||||
}
|
}
|
||||||
@@ -102,7 +97,7 @@ public class FunctionBasedParser extends PluginParser {
|
|||||||
log.info("user decide Model:{}", modelId);
|
log.info("user decide Model:{}", modelId);
|
||||||
List<Plugin> plugins = getPluginList(queryContext);
|
List<Plugin> plugins = getPluginList(queryContext);
|
||||||
List<PluginParseConfig> functionDOList = plugins.stream().filter(plugin -> {
|
List<PluginParseConfig> functionDOList = plugins.stream().filter(plugin -> {
|
||||||
if (DslQuery.QUERY_MODE.equalsIgnoreCase(plugin.getType())) {
|
if (S2SQLQuery.QUERY_MODE.equalsIgnoreCase(plugin.getType())) {
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
if (plugin.getParseModeConfig() == null) {
|
if (plugin.getParseModeConfig() == null) {
|
||||||
@@ -132,25 +127,4 @@ public class FunctionBasedParser extends PluginParser {
|
|||||||
return functionDOList;
|
return functionDOList;
|
||||||
}
|
}
|
||||||
|
|
||||||
public FunctionResp requestFunction(FunctionReq functionReq) {
|
|
||||||
FunctionCallConfig functionCallInfoConfig = ContextUtils.getBean(FunctionCallConfig.class);
|
|
||||||
String url = functionCallInfoConfig.getUrl() + functionCallInfoConfig.getPluginSelectPath();
|
|
||||||
HttpHeaders headers = new HttpHeaders();
|
|
||||||
long startTime = System.currentTimeMillis();
|
|
||||||
headers.setContentType(MediaType.APPLICATION_JSON);
|
|
||||||
HttpEntity<String> entity = new HttpEntity<>(JSON.toJSONString(functionReq), headers);
|
|
||||||
URI requestUrl = UriComponentsBuilder.fromHttpUrl(url).build().encode().toUri();
|
|
||||||
RestTemplate restTemplate = ContextUtils.getBean(RestTemplate.class);
|
|
||||||
try {
|
|
||||||
log.info("requestFunction functionReq:{}", JsonUtil.toString(functionReq));
|
|
||||||
ResponseEntity<FunctionResp> responseEntity = restTemplate.exchange(requestUrl, HttpMethod.POST, entity,
|
|
||||||
FunctionResp.class);
|
|
||||||
log.info("requestFunction responseEntity:{},cost:{}", responseEntity,
|
|
||||||
System.currentTimeMillis() - startTime);
|
|
||||||
return responseEntity.getBody();
|
|
||||||
} catch (Exception e) {
|
|
||||||
log.error("requestFunction error", e);
|
|
||||||
}
|
|
||||||
return null;
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -9,10 +9,14 @@ import com.tencent.supersonic.chat.api.component.SemanticParser;
|
|||||||
import com.tencent.supersonic.chat.api.component.SemanticQuery;
|
import com.tencent.supersonic.chat.api.component.SemanticQuery;
|
||||||
import com.tencent.supersonic.chat.api.pojo.ChatContext;
|
import com.tencent.supersonic.chat.api.pojo.ChatContext;
|
||||||
import com.tencent.supersonic.chat.api.pojo.QueryContext;
|
import com.tencent.supersonic.chat.api.pojo.QueryContext;
|
||||||
|
import com.tencent.supersonic.chat.query.QueryManager;
|
||||||
import com.tencent.supersonic.chat.service.AgentService;
|
import com.tencent.supersonic.chat.service.AgentService;
|
||||||
|
import com.tencent.supersonic.common.pojo.QueryType;
|
||||||
import com.tencent.supersonic.common.util.ContextUtils;
|
import com.tencent.supersonic.common.util.ContextUtils;
|
||||||
import lombok.extern.slf4j.Slf4j;
|
import lombok.extern.slf4j.Slf4j;
|
||||||
import org.apache.commons.collections.CollectionUtils;
|
import org.apache.commons.collections.CollectionUtils;
|
||||||
|
|
||||||
|
import java.util.HashSet;
|
||||||
import java.util.List;
|
import java.util.List;
|
||||||
import java.util.stream.Collectors;
|
import java.util.stream.Collectors;
|
||||||
|
|
||||||
@@ -40,13 +44,26 @@ public class AgentCheckParser implements SemanticParser {
|
|||||||
queries.stream().map(SemanticQuery::getQueryMode).collect(Collectors.toList()));
|
queries.stream().map(SemanticQuery::getQueryMode).collect(Collectors.toList()));
|
||||||
queries.removeIf(query -> {
|
queries.removeIf(query -> {
|
||||||
for (RuleQueryTool tool : queryTools) {
|
for (RuleQueryTool tool : queryTools) {
|
||||||
if (!tool.getQueryModes().contains(query.getQueryMode())) {
|
if (CollectionUtils.isNotEmpty(tool.getQueryModes())
|
||||||
|
&& !tool.getQueryModes().contains(query.getQueryMode())) {
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
if (CollectionUtils.isNotEmpty(tool.getQueryTypes())) {
|
||||||
|
if (QueryManager.isTagQuery(query.getQueryMode())) {
|
||||||
|
return !tool.getQueryTypes().contains(QueryType.TAG.name());
|
||||||
|
}
|
||||||
|
if (QueryManager.isMetricQuery(query.getQueryMode())) {
|
||||||
|
return !tool.getQueryTypes().contains(QueryType.METRIC.name());
|
||||||
|
}
|
||||||
|
}
|
||||||
if (CollectionUtils.isEmpty(tool.getModelIds())) {
|
if (CollectionUtils.isEmpty(tool.getModelIds())) {
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
if (tool.isContainsAllModel() || tool.getModelIds().contains(query.getParseInfo().getModelId())) {
|
if (tool.isContainsAllModel()) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
if (new HashSet<>(tool.getModelIds())
|
||||||
|
.containsAll(query.getParseInfo().getModel().getModelIds())) {
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -27,6 +27,12 @@ import lombok.AllArgsConstructor;
|
|||||||
import lombok.extern.slf4j.Slf4j;
|
import lombok.extern.slf4j.Slf4j;
|
||||||
import org.apache.commons.lang3.StringUtils;
|
import org.apache.commons.lang3.StringUtils;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* AggregateTypeParser extracts aggregation type specified in the user query
|
||||||
|
* based on keyword matching.
|
||||||
|
* Currently, it supports 7 types of aggregation: max, min, sum, avg, topN,
|
||||||
|
* distinct count, count.
|
||||||
|
*/
|
||||||
@Slf4j
|
@Slf4j
|
||||||
public class AggregateTypeParser implements SemanticParser {
|
public class AggregateTypeParser implements SemanticParser {
|
||||||
|
|
||||||
|
|||||||
@@ -1,33 +1,47 @@
|
|||||||
package com.tencent.supersonic.chat.parser.rule;
|
package com.tencent.supersonic.chat.parser.rule;
|
||||||
|
|
||||||
import static com.tencent.supersonic.chat.api.pojo.SchemaElementType.DIMENSION;
|
|
||||||
import static com.tencent.supersonic.chat.api.pojo.SchemaElementType.ENTITY;
|
|
||||||
import static com.tencent.supersonic.chat.api.pojo.SchemaElementType.ID;
|
|
||||||
import static com.tencent.supersonic.chat.api.pojo.SchemaElementType.METRIC;
|
|
||||||
import static com.tencent.supersonic.chat.api.pojo.SchemaElementType.MODEL;
|
|
||||||
import static com.tencent.supersonic.chat.api.pojo.SchemaElementType.VALUE;
|
|
||||||
|
|
||||||
import com.tencent.supersonic.chat.api.component.SemanticParser;
|
import com.tencent.supersonic.chat.api.component.SemanticParser;
|
||||||
import com.tencent.supersonic.chat.api.component.SemanticQuery;
|
import com.tencent.supersonic.chat.api.component.SemanticQuery;
|
||||||
import com.tencent.supersonic.chat.api.pojo.ChatContext;
|
import com.tencent.supersonic.chat.api.pojo.ChatContext;
|
||||||
import com.tencent.supersonic.chat.api.pojo.QueryContext;
|
import com.tencent.supersonic.chat.api.pojo.QueryContext;
|
||||||
import com.tencent.supersonic.chat.api.pojo.SchemaElementMatch;
|
import com.tencent.supersonic.chat.api.pojo.SchemaElementMatch;
|
||||||
import com.tencent.supersonic.chat.api.pojo.SchemaElementType;
|
import com.tencent.supersonic.chat.api.pojo.SchemaElementType;
|
||||||
|
import com.tencent.supersonic.chat.api.pojo.SemanticSchema;
|
||||||
import com.tencent.supersonic.chat.query.QueryManager;
|
import com.tencent.supersonic.chat.query.QueryManager;
|
||||||
import com.tencent.supersonic.chat.query.rule.RuleSemanticQuery;
|
import com.tencent.supersonic.chat.query.rule.RuleSemanticQuery;
|
||||||
import com.tencent.supersonic.chat.query.rule.metric.MetricEntityQuery;
|
|
||||||
import com.tencent.supersonic.chat.query.rule.metric.MetricModelQuery;
|
import com.tencent.supersonic.chat.query.rule.metric.MetricModelQuery;
|
||||||
import com.tencent.supersonic.chat.query.rule.metric.MetricSemanticQuery;
|
import com.tencent.supersonic.chat.query.rule.metric.MetricSemanticQuery;
|
||||||
|
import com.tencent.supersonic.chat.query.rule.metric.MetricTagQuery;
|
||||||
|
import com.tencent.supersonic.chat.service.SemanticService;
|
||||||
|
import com.tencent.supersonic.chat.utils.ModelClusterBuilder;
|
||||||
|
import com.tencent.supersonic.common.pojo.ModelCluster;
|
||||||
|
import com.tencent.supersonic.common.util.ContextUtils;
|
||||||
|
import lombok.extern.slf4j.Slf4j;
|
||||||
|
import org.apache.commons.lang3.StringUtils;
|
||||||
|
|
||||||
import java.util.AbstractMap;
|
import java.util.AbstractMap;
|
||||||
import java.util.ArrayList;
|
import java.util.ArrayList;
|
||||||
import java.util.Arrays;
|
import java.util.Arrays;
|
||||||
import java.util.List;
|
import java.util.List;
|
||||||
import java.util.Map;
|
import java.util.Map;
|
||||||
import java.util.Objects;
|
import java.util.Objects;
|
||||||
|
import java.util.Set;
|
||||||
import java.util.stream.Collectors;
|
import java.util.stream.Collectors;
|
||||||
import java.util.stream.Stream;
|
import java.util.stream.Stream;
|
||||||
import lombok.extern.slf4j.Slf4j;
|
|
||||||
|
|
||||||
|
import static com.tencent.supersonic.chat.api.pojo.SchemaElementType.DIMENSION;
|
||||||
|
import static com.tencent.supersonic.chat.api.pojo.SchemaElementType.ENTITY;
|
||||||
|
import static com.tencent.supersonic.chat.api.pojo.SchemaElementType.ID;
|
||||||
|
import static com.tencent.supersonic.chat.api.pojo.SchemaElementType.METRIC;
|
||||||
|
import static com.tencent.supersonic.chat.api.pojo.SchemaElementType.MODEL;
|
||||||
|
import static com.tencent.supersonic.chat.api.pojo.SchemaElementType.TAG;
|
||||||
|
import static com.tencent.supersonic.chat.api.pojo.SchemaElementType.VALUE;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* ContextInheritParser tries to inherit certain schema elements from context
|
||||||
|
* so that in multi-turn conversations users don't need to mention some keyword
|
||||||
|
* repeatedly.
|
||||||
|
*/
|
||||||
@Slf4j
|
@Slf4j
|
||||||
public class ContextInheritParser implements SemanticParser {
|
public class ContextInheritParser implements SemanticParser {
|
||||||
|
|
||||||
@@ -36,20 +50,22 @@ public class ContextInheritParser implements SemanticParser {
|
|||||||
new AbstractMap.SimpleEntry<>(DIMENSION, Arrays.asList(DIMENSION, VALUE)),
|
new AbstractMap.SimpleEntry<>(DIMENSION, Arrays.asList(DIMENSION, VALUE)),
|
||||||
new AbstractMap.SimpleEntry<>(VALUE, Arrays.asList(VALUE, DIMENSION)),
|
new AbstractMap.SimpleEntry<>(VALUE, Arrays.asList(VALUE, DIMENSION)),
|
||||||
new AbstractMap.SimpleEntry<>(ENTITY, Arrays.asList(ENTITY)),
|
new AbstractMap.SimpleEntry<>(ENTITY, Arrays.asList(ENTITY)),
|
||||||
|
new AbstractMap.SimpleEntry<>(TAG, Arrays.asList(TAG)),
|
||||||
new AbstractMap.SimpleEntry<>(MODEL, Arrays.asList(MODEL)),
|
new AbstractMap.SimpleEntry<>(MODEL, Arrays.asList(MODEL)),
|
||||||
new AbstractMap.SimpleEntry<>(ID, Arrays.asList(ID))
|
new AbstractMap.SimpleEntry<>(ID, Arrays.asList(ID))
|
||||||
).collect(Collectors.toMap(Map.Entry::getKey, Map.Entry::getValue));
|
).collect(Collectors.toMap(Map.Entry::getKey, Map.Entry::getValue));
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public void parse(QueryContext queryContext, ChatContext chatContext) {
|
public void parse(QueryContext queryContext, ChatContext chatContext) {
|
||||||
if (!shouldInherit(queryContext, chatContext)) {
|
if (!shouldInherit(queryContext)) {
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
ModelCluster modelCluster = getMatchedModelCluster(queryContext, chatContext);
|
||||||
Long modelId = chatContext.getParseInfo().getModelId();
|
if (modelCluster == null) {
|
||||||
List<SchemaElementMatch> elementMatches = queryContext.getMapInfo()
|
return;
|
||||||
.getMatchedElements(modelId);
|
}
|
||||||
|
List<SchemaElementMatch> elementMatches = queryContext.getModelClusterMapInfo()
|
||||||
|
.getMatchedElements(modelCluster.getKey());
|
||||||
List<SchemaElementMatch> matchesToInherit = new ArrayList<>();
|
List<SchemaElementMatch> matchesToInherit = new ArrayList<>();
|
||||||
for (SchemaElementMatch match : chatContext.getParseInfo().getElementMatches()) {
|
for (SchemaElementMatch match : chatContext.getParseInfo().getElementMatches()) {
|
||||||
SchemaElementType matchType = match.getElement().getType();
|
SchemaElementType matchType = match.getElement().getType();
|
||||||
@@ -64,18 +80,18 @@ public class ContextInheritParser implements SemanticParser {
|
|||||||
|
|
||||||
List<RuleSemanticQuery> queries = RuleSemanticQuery.resolve(elementMatches, queryContext);
|
List<RuleSemanticQuery> queries = RuleSemanticQuery.resolve(elementMatches, queryContext);
|
||||||
for (RuleSemanticQuery query : queries) {
|
for (RuleSemanticQuery query : queries) {
|
||||||
query.fillParseInfo(modelId, queryContext, chatContext);
|
query.fillParseInfo(chatContext);
|
||||||
if (existSameQuery(query.getParseInfo().getModelId(), query.getQueryMode(), queryContext)) {
|
if (existSameQuery(query.getParseInfo().getModelClusterKey(), query.getQueryMode(), queryContext)) {
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
queryContext.getCandidateQueries().add(query);
|
queryContext.getCandidateQueries().add(query);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
private boolean existSameQuery(Long modelId, String queryMode, QueryContext queryContext) {
|
private boolean existSameQuery(String modelClusterKey, String queryMode, QueryContext queryContext) {
|
||||||
for (SemanticQuery semanticQuery : queryContext.getCandidateQueries()) {
|
for (SemanticQuery semanticQuery : queryContext.getCandidateQueries()) {
|
||||||
if (semanticQuery.getQueryMode().equals(queryMode)
|
if (semanticQuery.getQueryMode().equals(queryMode)
|
||||||
&& semanticQuery.getParseInfo().getModelId().equals(modelId)) {
|
&& semanticQuery.getParseInfo().getModelClusterKey().equals(modelClusterKey)) {
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -89,30 +105,41 @@ public class ContextInheritParser implements SemanticParser {
|
|||||||
return matches.stream().anyMatch(m -> {
|
return matches.stream().anyMatch(m -> {
|
||||||
SchemaElementType type = m.getElement().getType();
|
SchemaElementType type = m.getElement().getType();
|
||||||
if (Objects.nonNull(ruleQuery) && ruleQuery instanceof MetricSemanticQuery
|
if (Objects.nonNull(ruleQuery) && ruleQuery instanceof MetricSemanticQuery
|
||||||
&& !(ruleQuery instanceof MetricEntityQuery)) {
|
&& !(ruleQuery instanceof MetricTagQuery)) {
|
||||||
return types.contains(type);
|
return types.contains(type);
|
||||||
}
|
}
|
||||||
return type.equals(matchType);
|
return type.equals(matchType);
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
protected boolean shouldInherit(QueryContext queryContext, ChatContext chatContext) {
|
protected boolean shouldInherit(QueryContext queryContext) {
|
||||||
Long contextModelId = chatContext.getParseInfo().getModelId();
|
|
||||||
// if map info doesn't contain the same Model of the context,
|
|
||||||
// no inheritance could be done
|
|
||||||
if (queryContext.getMapInfo().getMatchedElements(contextModelId) == null) {
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
|
|
||||||
// if candidates only have MetricModel mode, count in context
|
// if candidates only have MetricModel mode, count in context
|
||||||
List<SemanticQuery> metricModelQueries = queryContext.getCandidateQueries().stream()
|
List<SemanticQuery> metricModelQueries = queryContext.getCandidateQueries().stream()
|
||||||
.filter(query -> query instanceof MetricModelQuery).collect(
|
.filter(query -> query instanceof MetricModelQuery).collect(
|
||||||
Collectors.toList());
|
Collectors.toList());
|
||||||
if (metricModelQueries.size() == queryContext.getCandidateQueries().size()) {
|
return metricModelQueries.size() == queryContext.getCandidateQueries().size();
|
||||||
return true;
|
}
|
||||||
} else {
|
|
||||||
return queryContext.getCandidateQueries().size() == 0;
|
protected ModelCluster getMatchedModelCluster(QueryContext queryContext, ChatContext chatContext) {
|
||||||
|
String contextModelClusterKey = chatContext.getParseInfo().getModelClusterKey();
|
||||||
|
if (StringUtils.isBlank(contextModelClusterKey)) {
|
||||||
|
return null;
|
||||||
}
|
}
|
||||||
|
SemanticService semanticService = ContextUtils.getBean(SemanticService.class);
|
||||||
|
SemanticSchema semanticSchema = semanticService.getSemanticSchema();
|
||||||
|
List<ModelCluster> allModelClusters = ModelClusterBuilder.buildModelClusters(semanticSchema);
|
||||||
|
Set<String> queryModelClusters = queryContext.getModelClusterMapInfo().getMatchedModelClusters();
|
||||||
|
ModelCluster contextModelCluster = ModelCluster.build(contextModelClusterKey);
|
||||||
|
for (String cluster : queryModelClusters) {
|
||||||
|
ModelCluster queryModelCluster = ModelCluster.build(cluster);
|
||||||
|
for (ModelCluster modelCluster : allModelClusters) {
|
||||||
|
if (modelCluster.getModelIds().containsAll(contextModelCluster.getModelIds())
|
||||||
|
&& modelCluster.getModelIds().containsAll(queryModelCluster.getModelIds())) {
|
||||||
|
return queryModelCluster;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return null;
|
||||||
}
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,42 +1,34 @@
|
|||||||
package com.tencent.supersonic.chat.parser.rule;
|
package com.tencent.supersonic.chat.parser.rule;
|
||||||
|
|
||||||
import com.tencent.supersonic.chat.api.component.SemanticParser;
|
import com.tencent.supersonic.chat.api.component.SemanticParser;
|
||||||
import com.tencent.supersonic.chat.api.pojo.QueryContext;
|
|
||||||
import com.tencent.supersonic.chat.api.pojo.ChatContext;
|
import com.tencent.supersonic.chat.api.pojo.ChatContext;
|
||||||
import com.tencent.supersonic.chat.api.pojo.SchemaMapInfo;
|
import com.tencent.supersonic.chat.api.pojo.QueryContext;
|
||||||
import com.tencent.supersonic.chat.api.pojo.SchemaElementMatch;
|
import com.tencent.supersonic.chat.api.pojo.SchemaElementMatch;
|
||||||
|
import com.tencent.supersonic.chat.api.pojo.SchemaModelClusterMapInfo;
|
||||||
import com.tencent.supersonic.chat.query.rule.RuleSemanticQuery;
|
import com.tencent.supersonic.chat.query.rule.RuleSemanticQuery;
|
||||||
|
|
||||||
import java.util.HashMap;
|
|
||||||
import java.util.List;
|
|
||||||
import java.util.Map;
|
|
||||||
|
|
||||||
import lombok.extern.slf4j.Slf4j;
|
import lombok.extern.slf4j.Slf4j;
|
||||||
import org.apache.commons.collections.CollectionUtils;
|
|
||||||
|
|
||||||
|
import java.util.List;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* QueryModeParser resolves a specific query mode according to co-appearance
|
||||||
|
* of certain schema element types.
|
||||||
|
*/
|
||||||
@Slf4j
|
@Slf4j
|
||||||
public class QueryModeParser implements SemanticParser {
|
public class QueryModeParser implements SemanticParser {
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public void parse(QueryContext queryContext, ChatContext chatContext) {
|
public void parse(QueryContext queryContext, ChatContext chatContext) {
|
||||||
SchemaMapInfo mapInfo = queryContext.getMapInfo();
|
SchemaModelClusterMapInfo modelClusterMapInfo = queryContext.getModelClusterMapInfo();
|
||||||
// iterate all schemaElementMatches to resolve semantic query
|
// iterate all schemaElementMatches to resolve query mode
|
||||||
for (Long modelId : mapInfo.getMatchedModels()) {
|
for (String modelClusterKey : modelClusterMapInfo.getMatchedModelClusters()) {
|
||||||
List<SchemaElementMatch> elementMatches = mapInfo.getMatchedElements(modelId);
|
List<SchemaElementMatch> elementMatches = modelClusterMapInfo.getMatchedElements(modelClusterKey);
|
||||||
List<RuleSemanticQuery> queries = RuleSemanticQuery.resolve(elementMatches, queryContext);
|
List<RuleSemanticQuery> queries = RuleSemanticQuery.resolve(elementMatches, queryContext);
|
||||||
for (RuleSemanticQuery query : queries) {
|
for (RuleSemanticQuery query : queries) {
|
||||||
query.fillParseInfo(modelId, queryContext, chatContext);
|
query.fillParseInfo(chatContext);
|
||||||
queryContext.getCandidateQueries().add(query);
|
queryContext.getCandidateQueries().add(query);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
// if modelElementMatches id empty,so remove it.
|
|
||||||
Map<Long, List<SchemaElementMatch>> filterModelElementMatches = new HashMap<>();
|
|
||||||
for (Long modelId : queryContext.getMapInfo().getModelElementMatches().keySet()) {
|
|
||||||
if (!CollectionUtils.isEmpty(queryContext.getMapInfo().getModelElementMatches().get(modelId))) {
|
|
||||||
filterModelElementMatches.put(modelId, queryContext.getMapInfo().getModelElementMatches().get(modelId));
|
|
||||||
}
|
|
||||||
}
|
|
||||||
queryContext.getMapInfo().setModelElementMatches(filterModelElementMatches);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -0,0 +1,30 @@
|
|||||||
|
package com.tencent.supersonic.chat.parser.rule;
|
||||||
|
|
||||||
|
import com.tencent.supersonic.chat.api.component.SemanticParser;
|
||||||
|
import com.tencent.supersonic.chat.api.pojo.ChatContext;
|
||||||
|
import com.tencent.supersonic.chat.api.pojo.QueryContext;
|
||||||
|
import lombok.extern.slf4j.Slf4j;
|
||||||
|
|
||||||
|
import java.util.Arrays;
|
||||||
|
import java.util.List;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* RuleBasedParser acts as a container that incorporates a group of
|
||||||
|
* rule-based semantic parsers.
|
||||||
|
*/
|
||||||
|
@Slf4j
|
||||||
|
public class RuleBasedParser implements SemanticParser {
|
||||||
|
|
||||||
|
private static List<SemanticParser> ruleParsers = Arrays.asList(
|
||||||
|
new QueryModeParser(),
|
||||||
|
new ContextInheritParser(),
|
||||||
|
new AgentCheckParser(),
|
||||||
|
new TimeRangeParser(),
|
||||||
|
new AggregateTypeParser()
|
||||||
|
);
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public void parse(QueryContext queryContext, ChatContext chatContext) {
|
||||||
|
ruleParsers.stream().forEach(p -> p.parse(queryContext, chatContext));
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -24,6 +24,13 @@ import com.xkzhangsan.time.nlp.TimeNLPUtil;
|
|||||||
import lombok.extern.slf4j.Slf4j;
|
import lombok.extern.slf4j.Slf4j;
|
||||||
import org.apache.logging.log4j.util.Strings;
|
import org.apache.logging.log4j.util.Strings;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* TimeRangeParser extracts time range specified in the user query
|
||||||
|
* based on keyword matching.
|
||||||
|
* Currently, it supports two kinds of expression:
|
||||||
|
* 1. Recent unit: 近N天/周/月/年、过去N天/周/月/年
|
||||||
|
* 2. Concrete date: 2023年11月15日、20231115
|
||||||
|
*/
|
||||||
@Slf4j
|
@Slf4j
|
||||||
public class TimeRangeParser implements SemanticParser {
|
public class TimeRangeParser implements SemanticParser {
|
||||||
|
|
||||||
|
|||||||
@@ -25,6 +25,8 @@ public class ChatConfigDO {
|
|||||||
|
|
||||||
private Integer status;
|
private Integer status;
|
||||||
|
|
||||||
|
private String llmExamples;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* record info
|
* record info
|
||||||
*/
|
*/
|
||||||
|
|||||||
@@ -3,7 +3,9 @@ package com.tencent.supersonic.chat.persistence.dataobject;
|
|||||||
public enum CostType {
|
public enum CostType {
|
||||||
MAPPER(1, "mapper"),
|
MAPPER(1, "mapper"),
|
||||||
PARSER(2, "parser"),
|
PARSER(2, "parser"),
|
||||||
QUERY(3, "query");
|
QUERY(3, "query"),
|
||||||
|
PARSERRESPONDER(4, "responder"),
|
||||||
|
POSTPROCESSOR(5, "postprocessor");
|
||||||
|
|
||||||
private Integer type;
|
private Integer type;
|
||||||
private String name;
|
private String name;
|
||||||
|
|||||||
@@ -12,6 +12,10 @@ public interface ChatParseMapper {
|
|||||||
|
|
||||||
boolean batchSaveParseInfo(@Param("list") List<ChatParseDO> list);
|
boolean batchSaveParseInfo(@Param("list") List<ChatParseDO> list);
|
||||||
|
|
||||||
ChatParseDO getParseInfo(Long questionId, String userName, int parseId);
|
boolean updateParseInfo(ChatParseDO chatParseDO);
|
||||||
|
|
||||||
|
ChatParseDO getParseInfo(Long questionId, int parseId);
|
||||||
|
|
||||||
|
List<ChatParseDO> getParseInfoList(List<Long> questionIds);
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -7,6 +7,6 @@ import java.util.List;
|
|||||||
@Mapper
|
@Mapper
|
||||||
public interface ShowCaseCustomMapper {
|
public interface ShowCaseCustomMapper {
|
||||||
|
|
||||||
List<ChatQueryDO> queryShowCase(int start, int limit, int agentId);
|
List<ChatQueryDO> queryShowCase(int start, int limit, int agentId, String userName);
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -21,18 +21,21 @@ public interface ChatQueryRepository {
|
|||||||
|
|
||||||
void createChatQuery(QueryResult queryResult, ChatContext chatCtx);
|
void createChatQuery(QueryResult queryResult, ChatContext chatCtx);
|
||||||
|
|
||||||
|
void updateChatParseInfo(List<ChatParseDO> chatParseDOS);
|
||||||
|
|
||||||
ChatQueryDO getLastChatQuery(long chatId);
|
ChatQueryDO getLastChatQuery(long chatId);
|
||||||
|
|
||||||
int updateChatQuery(ChatQueryDO chatQueryDO);
|
int updateChatQuery(ChatQueryDO chatQueryDO);
|
||||||
|
|
||||||
Long createChatParse(ParseResp parseResult, ChatContext chatCtx, QueryReq queryReq);
|
Long createChatParse(ParseResp parseResult, ChatContext chatCtx, QueryReq queryReq);
|
||||||
|
|
||||||
Boolean batchSaveParseInfo(ChatContext chatCtx, QueryReq queryReq,
|
List<ChatParseDO> batchSaveParseInfo(ChatContext chatCtx, QueryReq queryReq,
|
||||||
ParseResp parseResult,
|
ParseResp parseResult,
|
||||||
List<SemanticParseInfo> candidateParses,
|
List<SemanticParseInfo> candidateParses);
|
||||||
List<SemanticParseInfo> selectedParses);
|
|
||||||
|
|
||||||
public ChatParseDO getParseInfo(Long questionId, String userName, int parseId);
|
public ChatParseDO getParseInfo(Long questionId, int parseId);
|
||||||
|
|
||||||
|
List<ChatParseDO> getParseInfoList(List<Long> questionIds);
|
||||||
|
|
||||||
Boolean deleteChatQuery(Long questionId);
|
Boolean deleteChatQuery(Long questionId);
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -7,5 +7,5 @@ import java.util.List;
|
|||||||
|
|
||||||
public interface StatisticsRepository {
|
public interface StatisticsRepository {
|
||||||
|
|
||||||
boolean batchSaveStatistics(List<StatisticsDO> list);
|
void batchSaveStatistics(List<StatisticsDO> list);
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -4,33 +4,31 @@ import com.github.pagehelper.PageHelper;
|
|||||||
import com.github.pagehelper.PageInfo;
|
import com.github.pagehelper.PageInfo;
|
||||||
import com.tencent.supersonic.chat.api.pojo.ChatContext;
|
import com.tencent.supersonic.chat.api.pojo.ChatContext;
|
||||||
import com.tencent.supersonic.chat.api.pojo.SemanticParseInfo;
|
import com.tencent.supersonic.chat.api.pojo.SemanticParseInfo;
|
||||||
|
import com.tencent.supersonic.chat.api.pojo.request.PageQueryInfoReq;
|
||||||
import com.tencent.supersonic.chat.api.pojo.request.QueryReq;
|
import com.tencent.supersonic.chat.api.pojo.request.QueryReq;
|
||||||
import com.tencent.supersonic.chat.api.pojo.response.ParseResp;
|
import com.tencent.supersonic.chat.api.pojo.response.ParseResp;
|
||||||
|
import com.tencent.supersonic.chat.api.pojo.response.QueryResp;
|
||||||
import com.tencent.supersonic.chat.api.pojo.response.QueryResult;
|
import com.tencent.supersonic.chat.api.pojo.response.QueryResult;
|
||||||
import com.tencent.supersonic.chat.persistence.dataobject.ChatParseDO;
|
import com.tencent.supersonic.chat.persistence.dataobject.ChatParseDO;
|
||||||
import com.tencent.supersonic.chat.persistence.dataobject.ChatQueryDO;
|
import com.tencent.supersonic.chat.persistence.dataobject.ChatQueryDO;
|
||||||
import com.tencent.supersonic.chat.persistence.dataobject.ChatQueryDOExample;
|
import com.tencent.supersonic.chat.persistence.dataobject.ChatQueryDOExample;
|
||||||
import com.tencent.supersonic.chat.persistence.dataobject.ChatQueryDOExample.Criteria;
|
import com.tencent.supersonic.chat.persistence.dataobject.ChatQueryDOExample.Criteria;
|
||||||
import com.tencent.supersonic.chat.api.pojo.response.QueryResp;
|
|
||||||
import com.tencent.supersonic.chat.api.pojo.request.PageQueryInfoReq;
|
|
||||||
import com.tencent.supersonic.chat.persistence.mapper.ChatParseMapper;
|
import com.tencent.supersonic.chat.persistence.mapper.ChatParseMapper;
|
||||||
import com.tencent.supersonic.chat.persistence.mapper.ChatQueryDOMapper;
|
import com.tencent.supersonic.chat.persistence.mapper.ChatQueryDOMapper;
|
||||||
import com.tencent.supersonic.chat.persistence.mapper.custom.ShowCaseCustomMapper;
|
import com.tencent.supersonic.chat.persistence.mapper.custom.ShowCaseCustomMapper;
|
||||||
import com.tencent.supersonic.chat.persistence.repository.ChatQueryRepository;
|
import com.tencent.supersonic.chat.persistence.repository.ChatQueryRepository;
|
||||||
import com.tencent.supersonic.common.util.JsonUtil;
|
import com.tencent.supersonic.common.util.JsonUtil;
|
||||||
import com.tencent.supersonic.common.util.PageUtils;
|
import com.tencent.supersonic.common.util.PageUtils;
|
||||||
|
|
||||||
import java.util.ArrayList;
|
|
||||||
import java.util.Comparator;
|
|
||||||
import java.util.List;
|
|
||||||
import java.util.stream.Collectors;
|
|
||||||
|
|
||||||
import lombok.extern.slf4j.Slf4j;
|
import lombok.extern.slf4j.Slf4j;
|
||||||
import org.apache.commons.lang3.StringUtils;
|
import org.apache.commons.lang3.StringUtils;
|
||||||
import org.springframework.beans.BeanUtils;
|
import org.springframework.beans.BeanUtils;
|
||||||
import org.springframework.context.annotation.Primary;
|
import org.springframework.context.annotation.Primary;
|
||||||
import org.springframework.stereotype.Repository;
|
import org.springframework.stereotype.Repository;
|
||||||
import org.springframework.util.CollectionUtils;
|
import org.springframework.util.CollectionUtils;
|
||||||
|
import java.util.ArrayList;
|
||||||
|
import java.util.Comparator;
|
||||||
|
import java.util.List;
|
||||||
|
import java.util.stream.Collectors;
|
||||||
|
|
||||||
@Repository
|
@Repository
|
||||||
@Primary
|
@Primary
|
||||||
@@ -78,9 +76,10 @@ public class ChatQueryRepositoryImpl implements ChatQueryRepository {
|
|||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public List<QueryResp> queryShowCase(PageQueryInfoReq pageQueryInfoCommend, int agentId) {
|
public List<QueryResp> queryShowCase(PageQueryInfoReq pageQueryInfoReq, int agentId) {
|
||||||
return showCaseCustomMapper.queryShowCase(pageQueryInfoCommend.getCurrent(),
|
return showCaseCustomMapper.queryShowCase(pageQueryInfoReq.getLimitStart(),
|
||||||
pageQueryInfoCommend.getPageSize(), agentId).stream().map(this::convertTo)
|
pageQueryInfoReq.getPageSize(), agentId, pageQueryInfoReq.getUserName())
|
||||||
|
.stream().map(this::convertTo)
|
||||||
.collect(Collectors.toList());
|
.collect(Collectors.toList());
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -129,30 +128,37 @@ public class ChatQueryRepositoryImpl implements ChatQueryRepository {
|
|||||||
return queryId;
|
return queryId;
|
||||||
}
|
}
|
||||||
|
|
||||||
public Boolean batchSaveParseInfo(ChatContext chatCtx, QueryReq queryReq,
|
@Override
|
||||||
|
public List<ChatParseDO> batchSaveParseInfo(ChatContext chatCtx, QueryReq queryReq,
|
||||||
ParseResp parseResult,
|
ParseResp parseResult,
|
||||||
List<SemanticParseInfo> candidateParses,
|
List<SemanticParseInfo> candidateParses) {
|
||||||
List<SemanticParseInfo> selectedParses) {
|
|
||||||
Long queryId = createChatParse(parseResult, chatCtx, queryReq);
|
Long queryId = createChatParse(parseResult, chatCtx, queryReq);
|
||||||
List<ChatParseDO> chatParseDOList = new ArrayList<>();
|
List<ChatParseDO> chatParseDOList = new ArrayList<>();
|
||||||
log.info("candidateParses size:{},selectedParses size:{}", candidateParses.size(), selectedParses.size());
|
getChatParseDO(chatCtx, queryReq, queryId, candidateParses, chatParseDOList);
|
||||||
getChatParseDO(chatCtx, queryReq, queryId, 0, 1, candidateParses, chatParseDOList);
|
chatParseMapper.batchSaveParseInfo(chatParseDOList);
|
||||||
getChatParseDO(chatCtx, queryReq, queryId, candidateParses.size(), 0, selectedParses, chatParseDOList);
|
return chatParseDOList;
|
||||||
Boolean save = chatParseMapper.batchSaveParseInfo(chatParseDOList);
|
|
||||||
return save;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
public void getChatParseDO(ChatContext chatCtx, QueryReq queryReq, Long queryId, int base, int isCandidate,
|
@Override
|
||||||
|
public void updateChatParseInfo(List<ChatParseDO> chatParseDOS) {
|
||||||
|
for (ChatParseDO chatParseDO : chatParseDOS) {
|
||||||
|
chatParseMapper.updateParseInfo(chatParseDO);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
public void getChatParseDO(ChatContext chatCtx, QueryReq queryReq, Long queryId,
|
||||||
List<SemanticParseInfo> parses, List<ChatParseDO> chatParseDOList) {
|
List<SemanticParseInfo> parses, List<ChatParseDO> chatParseDOList) {
|
||||||
for (int i = 0; i < parses.size(); i++) {
|
for (int i = 0; i < parses.size(); i++) {
|
||||||
ChatParseDO chatParseDO = new ChatParseDO();
|
ChatParseDO chatParseDO = new ChatParseDO();
|
||||||
parses.get(i).setId(base + i + 1);
|
|
||||||
chatParseDO.setChatId(Long.valueOf(chatCtx.getChatId()));
|
chatParseDO.setChatId(Long.valueOf(chatCtx.getChatId()));
|
||||||
chatParseDO.setQuestionId(queryId);
|
chatParseDO.setQuestionId(queryId);
|
||||||
chatParseDO.setQueryText(queryReq.getQueryText());
|
chatParseDO.setQueryText(queryReq.getQueryText());
|
||||||
chatParseDO.setParseInfo(JsonUtil.toString(parses.get(i)));
|
chatParseDO.setParseInfo(JsonUtil.toString(parses.get(i)));
|
||||||
chatParseDO.setIsCandidate(isCandidate);
|
chatParseDO.setIsCandidate(1);
|
||||||
chatParseDO.setParseId(base + i + 1);
|
if (i == 0) {
|
||||||
|
chatParseDO.setIsCandidate(0);
|
||||||
|
}
|
||||||
|
chatParseDO.setParseId(parses.get(i).getId());
|
||||||
chatParseDO.setCreateTime(new java.util.Date());
|
chatParseDO.setCreateTime(new java.util.Date());
|
||||||
chatParseDO.setUserName(queryReq.getUser().getName());
|
chatParseDO.setUserName(queryReq.getUser().getName());
|
||||||
chatParseDOList.add(chatParseDO);
|
chatParseDOList.add(chatParseDO);
|
||||||
@@ -179,8 +185,14 @@ public class ChatQueryRepositoryImpl implements ChatQueryRepository {
|
|||||||
return chatQueryDOMapper.updateByPrimaryKeyWithBLOBs(chatQueryDO);
|
return chatQueryDOMapper.updateByPrimaryKeyWithBLOBs(chatQueryDO);
|
||||||
}
|
}
|
||||||
|
|
||||||
public ChatParseDO getParseInfo(Long questionId, String userName, int parseId) {
|
|
||||||
return chatParseMapper.getParseInfo(questionId, userName, parseId);
|
public ChatParseDO getParseInfo(Long questionId, int parseId) {
|
||||||
|
return chatParseMapper.getParseInfo(questionId, parseId);
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public List<ChatParseDO> getParseInfoList(List<Long> questionIds) {
|
||||||
|
return chatParseMapper.getParseInfoList(questionIds);
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
|
|||||||
@@ -20,10 +20,9 @@ public class StatisticsRepositoryImpl implements StatisticsRepository {
|
|||||||
this.statisticsMapper = statisticsMapper;
|
this.statisticsMapper = statisticsMapper;
|
||||||
}
|
}
|
||||||
|
|
||||||
public boolean batchSaveStatistics(List<StatisticsDO> list) {
|
public void batchSaveStatistics(List<StatisticsDO> list) {
|
||||||
return statisticsMapper.batchSaveStatistics(list);
|
statisticsMapper.batchSaveStatistics(list);
|
||||||
}
|
}
|
||||||
|
|
||||||
;
|
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|||||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user