mirror of
https://github.com/tencentmusic/supersonic.git
synced 2025-12-10 11:07:06 +00:00
[improvement](supersonic) based on version 0.7.2 (#34)
Co-authored-by: zuopengge <hwzuopengge@tencent.com>
This commit is contained in:
@@ -10,8 +10,6 @@ import lombok.ToString;
|
||||
@ToString
|
||||
public class QueryAuthResReq {
|
||||
|
||||
private String user;
|
||||
|
||||
private List<String> departmentIds = new ArrayList<>();
|
||||
|
||||
private List<AuthRes> resources;
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
package com.tencent.supersonic.auth.api.authorization.service;
|
||||
|
||||
import com.tencent.supersonic.auth.api.authentication.pojo.User;
|
||||
import com.tencent.supersonic.auth.api.authorization.pojo.AuthGroup;
|
||||
import com.tencent.supersonic.auth.api.authorization.request.QueryAuthResReq;
|
||||
import com.tencent.supersonic.auth.api.authorization.response.AuthorizedResourceResp;
|
||||
@@ -14,5 +15,5 @@ public interface AuthService {
|
||||
|
||||
void removeAuthGroup(AuthGroup group);
|
||||
|
||||
AuthorizedResourceResp queryAuthorizedResources(QueryAuthResReq req, HttpServletRequest request);
|
||||
AuthorizedResourceResp queryAuthorizedResources(QueryAuthResReq req, User user);
|
||||
}
|
||||
|
||||
@@ -1,21 +0,0 @@
|
||||
package com.tencent.supersonic.auth.authentication.config;
|
||||
|
||||
|
||||
import lombok.Data;
|
||||
import org.springframework.beans.factory.annotation.Value;
|
||||
import org.springframework.context.annotation.Configuration;
|
||||
|
||||
@Data
|
||||
@Configuration
|
||||
public class TppConfig {
|
||||
|
||||
@Value(value = "${auth.app.secret:}")
|
||||
private String appSecret;
|
||||
|
||||
@Value(value = "${auth.app.key:}")
|
||||
private String appKey;
|
||||
|
||||
@Value(value = "${auth.oa.url:}")
|
||||
private String tppOaUrl;
|
||||
|
||||
}
|
||||
@@ -2,27 +2,24 @@ package com.tencent.supersonic.auth.authorization.application;
|
||||
|
||||
import com.google.common.base.Strings;
|
||||
import com.google.gson.Gson;
|
||||
import com.tencent.supersonic.auth.api.authentication.pojo.User;
|
||||
import com.tencent.supersonic.auth.api.authentication.service.UserService;
|
||||
import com.tencent.supersonic.auth.api.authorization.pojo.AuthGroup;
|
||||
import com.tencent.supersonic.auth.api.authorization.pojo.AuthRes;
|
||||
import com.tencent.supersonic.auth.api.authorization.pojo.AuthResGrp;
|
||||
import com.tencent.supersonic.auth.api.authorization.pojo.AuthRule;
|
||||
import com.tencent.supersonic.auth.api.authorization.pojo.DimensionFilter;
|
||||
import com.tencent.supersonic.auth.api.authorization.request.QueryAuthResReq;
|
||||
import com.tencent.supersonic.auth.api.authorization.response.AuthorizedResourceResp;
|
||||
import com.tencent.supersonic.auth.api.authorization.service.AuthService;
|
||||
import java.util.ArrayList;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
import java.util.Objects;
|
||||
import java.util.Set;
|
||||
import java.util.stream.Collectors;
|
||||
import javax.servlet.http.HttpServletRequest;
|
||||
import com.tencent.supersonic.auth.api.authorization.pojo.AuthGroup;
|
||||
import com.tencent.supersonic.auth.api.authorization.pojo.AuthRule;
|
||||
import com.tencent.supersonic.common.util.S2ThreadContext;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.apache.commons.lang3.StringUtils;
|
||||
import org.springframework.jdbc.core.JdbcTemplate;
|
||||
import org.springframework.stereotype.Service;
|
||||
import org.springframework.util.CollectionUtils;
|
||||
import java.util.*;
|
||||
import java.util.stream.Collectors;
|
||||
|
||||
@Service
|
||||
@Slf4j
|
||||
@@ -78,12 +75,12 @@ public class AuthServiceImpl implements AuthService {
|
||||
|
||||
|
||||
@Override
|
||||
public AuthorizedResourceResp queryAuthorizedResources(QueryAuthResReq req, HttpServletRequest request) {
|
||||
Set<String> userOrgIds = userService.getUserAllOrgId(req.getUser());
|
||||
public AuthorizedResourceResp queryAuthorizedResources(QueryAuthResReq req, User user) {
|
||||
Set<String> userOrgIds = userService.getUserAllOrgId(user.getName());
|
||||
if (!CollectionUtils.isEmpty(userOrgIds)) {
|
||||
req.setDepartmentIds(new ArrayList<>(userOrgIds));
|
||||
}
|
||||
List<AuthGroup> groups = getAuthGroups(req);
|
||||
List<AuthGroup> groups = getAuthGroups(req, user.getName());
|
||||
AuthorizedResourceResp resource = new AuthorizedResourceResp();
|
||||
Map<String, List<AuthGroup>> authGroupsByModelId = groups.stream()
|
||||
.collect(Collectors.groupingBy(AuthGroup::getModelId));
|
||||
@@ -130,14 +127,14 @@ public class AuthServiceImpl implements AuthService {
|
||||
return resource;
|
||||
}
|
||||
|
||||
private List<AuthGroup> getAuthGroups(QueryAuthResReq req) {
|
||||
private List<AuthGroup> getAuthGroups(QueryAuthResReq req, String userName) {
|
||||
List<AuthGroup> groups = load().stream()
|
||||
.filter(group -> {
|
||||
if (!Objects.equals(group.getModelId(), req.getModelId())) {
|
||||
return false;
|
||||
}
|
||||
if (!CollectionUtils.isEmpty(group.getAuthorizedUsers()) && group.getAuthorizedUsers()
|
||||
.contains(req.getUser())) {
|
||||
.contains(userName)) {
|
||||
return true;
|
||||
}
|
||||
for (String departmentId : req.getDepartmentIds()) {
|
||||
@@ -148,7 +145,7 @@ public class AuthServiceImpl implements AuthService {
|
||||
}
|
||||
return false;
|
||||
}).collect(Collectors.toList());
|
||||
log.info("user:{} department:{} authGroups:{}", req.getUser(), req.getDepartmentIds(), groups);
|
||||
log.info("user:{} department:{} authGroups:{}", userName, req.getDepartmentIds(), groups);
|
||||
return groups;
|
||||
}
|
||||
|
||||
|
||||
@@ -1,11 +1,14 @@
|
||||
package com.tencent.supersonic.auth.authorization.rest;
|
||||
|
||||
import com.tencent.supersonic.auth.api.authorization.pojo.AuthGroup;
|
||||
import com.tencent.supersonic.auth.api.authentication.pojo.User;
|
||||
import com.tencent.supersonic.auth.api.authentication.utils.UserHolder;
|
||||
import com.tencent.supersonic.auth.api.authorization.request.QueryAuthResReq;
|
||||
import com.tencent.supersonic.auth.api.authorization.response.AuthorizedResourceResp;
|
||||
import com.tencent.supersonic.auth.api.authorization.service.AuthService;
|
||||
import com.tencent.supersonic.auth.api.authorization.pojo.AuthGroup;
|
||||
import java.util.List;
|
||||
import javax.servlet.http.HttpServletRequest;
|
||||
import javax.servlet.http.HttpServletResponse;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.springframework.web.bind.annotation.GetMapping;
|
||||
import org.springframework.web.bind.annotation.PostMapping;
|
||||
@@ -62,12 +65,13 @@ public class AuthController {
|
||||
* 查询有权限访问的受限资源id
|
||||
*
|
||||
* @param req
|
||||
* @param request
|
||||
* @return
|
||||
*/
|
||||
@PostMapping("/queryAuthorizedRes")
|
||||
public AuthorizedResourceResp queryAuthorizedResources(@RequestBody QueryAuthResReq req,
|
||||
HttpServletRequest request) {
|
||||
return authService.queryAuthorizedResources(req, request);
|
||||
HttpServletRequest request,
|
||||
HttpServletResponse response) {
|
||||
User user = UserHolder.findUser(request, response);
|
||||
return authService.queryAuthorizedResources(req, user);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -0,0 +1,9 @@
|
||||
package com.tencent.supersonic.chat.api.component;
|
||||
|
||||
import com.tencent.supersonic.chat.api.pojo.CorrectionInfo;
|
||||
import net.sf.jsqlparser.JSQLParserException;
|
||||
|
||||
public interface DSLOptimizer {
|
||||
CorrectionInfo rewriter(CorrectionInfo correctionInfo) throws JSQLParserException;
|
||||
|
||||
}
|
||||
@@ -0,0 +1,21 @@
|
||||
package com.tencent.supersonic.chat.api.pojo;
|
||||
|
||||
import com.tencent.supersonic.chat.api.pojo.request.QueryFilters;
|
||||
import lombok.AllArgsConstructor;
|
||||
import lombok.Builder;
|
||||
import lombok.Data;
|
||||
import lombok.NoArgsConstructor;
|
||||
|
||||
@Data
|
||||
@AllArgsConstructor
|
||||
@NoArgsConstructor
|
||||
@Builder
|
||||
public class CorrectionInfo {
|
||||
|
||||
private QueryFilters queryFilters;
|
||||
|
||||
private SemanticParseInfo parseInfo;
|
||||
|
||||
private String sql;
|
||||
|
||||
}
|
||||
@@ -12,4 +12,5 @@ public class QueryReq {
|
||||
private User user;
|
||||
private QueryFilters queryFilters;
|
||||
private boolean saveAnswer = true;
|
||||
private Integer agentId;
|
||||
}
|
||||
|
||||
@@ -40,12 +40,6 @@
|
||||
<scope>compile</scope>
|
||||
</dependency>
|
||||
|
||||
<dependency>
|
||||
<groupId>com.github.plexpt</groupId>
|
||||
<artifactId>chatgpt</artifactId>
|
||||
<version>4.1.2</version>
|
||||
</dependency>
|
||||
|
||||
<dependency>
|
||||
<groupId>org.junit.jupiter</groupId>
|
||||
<artifactId>junit-jupiter</artifactId>
|
||||
|
||||
@@ -0,0 +1,43 @@
|
||||
package com.tencent.supersonic.chat.agent;
|
||||
|
||||
|
||||
import com.alibaba.fastjson.JSONObject;
|
||||
import com.google.common.collect.Lists;
|
||||
import com.tencent.supersonic.chat.agent.tool.AgentToolType;
|
||||
import com.tencent.supersonic.common.pojo.RecordInfo;
|
||||
import lombok.Data;
|
||||
import org.springframework.util.CollectionUtils;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
import java.util.stream.Collectors;
|
||||
|
||||
@Data
|
||||
public class Agent extends RecordInfo {
|
||||
|
||||
private Integer id;
|
||||
private Integer enableSearch;
|
||||
private String name;
|
||||
private String description;
|
||||
|
||||
//0 offline, 1 online
|
||||
private Integer status;
|
||||
private List<String> examples;
|
||||
private String agentConfig;
|
||||
|
||||
public List<String> getTools(AgentToolType type) {
|
||||
Map map = JSONObject.parseObject(agentConfig, Map.class);
|
||||
if (CollectionUtils.isEmpty(map) || map.get("tools") == null) {
|
||||
return Lists.newArrayList();
|
||||
}
|
||||
List<Map> toolList = (List) map.get("tools");
|
||||
return toolList.stream()
|
||||
.filter(tool -> type.name().equals(tool.get("type")))
|
||||
.map(JSONObject::toJSONString)
|
||||
.collect(Collectors.toList());
|
||||
}
|
||||
|
||||
public boolean enableSearch() {
|
||||
return enableSearch != null && enableSearch == 1;
|
||||
}
|
||||
|
||||
}
|
||||
@@ -0,0 +1,18 @@
|
||||
package com.tencent.supersonic.chat.agent;
|
||||
|
||||
import com.google.common.collect.Lists;
|
||||
import com.tencent.supersonic.chat.agent.tool.AgentTool;
|
||||
import lombok.AllArgsConstructor;
|
||||
import lombok.Data;
|
||||
import lombok.NoArgsConstructor;
|
||||
import java.util.List;
|
||||
|
||||
|
||||
@Data
|
||||
@NoArgsConstructor
|
||||
@AllArgsConstructor
|
||||
public class AgentConfig {
|
||||
|
||||
List<AgentTool> tools = Lists.newArrayList();
|
||||
|
||||
}
|
||||
@@ -0,0 +1,19 @@
|
||||
package com.tencent.supersonic.chat.agent.tool;
|
||||
|
||||
|
||||
import lombok.AllArgsConstructor;
|
||||
import lombok.Data;
|
||||
import lombok.NoArgsConstructor;
|
||||
|
||||
import java.util.List;
|
||||
|
||||
@Data
|
||||
@NoArgsConstructor
|
||||
@AllArgsConstructor
|
||||
public class AgentTool {
|
||||
|
||||
private String name;
|
||||
|
||||
private AgentToolType type;
|
||||
|
||||
}
|
||||
@@ -0,0 +1,8 @@
|
||||
package com.tencent.supersonic.chat.agent.tool;
|
||||
|
||||
public enum AgentToolType {
|
||||
RULE,
|
||||
DSL,
|
||||
PLUGIN,
|
||||
INTERPRET
|
||||
}
|
||||
@@ -0,0 +1,14 @@
|
||||
package com.tencent.supersonic.chat.agent.tool;
|
||||
|
||||
import lombok.Data;
|
||||
|
||||
import java.util.List;
|
||||
|
||||
@Data
|
||||
public class DslTool extends AgentTool {
|
||||
|
||||
private List<Long> modelIds;
|
||||
|
||||
private List<String> exampleQuestions;
|
||||
|
||||
}
|
||||
@@ -0,0 +1,16 @@
|
||||
package com.tencent.supersonic.chat.agent.tool;
|
||||
|
||||
import com.tencent.supersonic.chat.parser.llm.interpret.MetricOption;
|
||||
import lombok.Data;
|
||||
|
||||
import java.util.List;
|
||||
|
||||
|
||||
@Data
|
||||
public class MetricInterpretTool extends AgentTool {
|
||||
|
||||
private Long modelId;
|
||||
|
||||
private List<MetricOption> metricOptions;
|
||||
|
||||
}
|
||||
@@ -0,0 +1,13 @@
|
||||
package com.tencent.supersonic.chat.agent.tool;
|
||||
|
||||
|
||||
import lombok.Data;
|
||||
|
||||
import java.util.List;
|
||||
|
||||
@Data
|
||||
public class PluginTool extends AgentTool {
|
||||
|
||||
private List<Long> plugins;
|
||||
|
||||
}
|
||||
@@ -0,0 +1,13 @@
|
||||
package com.tencent.supersonic.chat.agent.tool;
|
||||
|
||||
|
||||
import lombok.Data;
|
||||
|
||||
import java.util.List;
|
||||
|
||||
@Data
|
||||
public class RuleQueryTool extends AgentTool {
|
||||
|
||||
private List<String> queryModes;
|
||||
|
||||
}
|
||||
@@ -19,7 +19,7 @@ import org.springframework.stereotype.Service;
|
||||
@Slf4j
|
||||
public class MapperHelper {
|
||||
|
||||
@Value("${one.detection.size:6}")
|
||||
@Value("${one.detection.size:8}")
|
||||
private Integer oneDetectionSize;
|
||||
@Value("${one.detection.max.size:20}")
|
||||
private Integer oneDetectionMaxSize;
|
||||
@@ -64,7 +64,7 @@ public class MapperHelper {
|
||||
*/
|
||||
public boolean existDimensionValues(List<String> natures) {
|
||||
for (String nature : natures) {
|
||||
if (NatureHelper.isDimensionValueClassId(nature)) {
|
||||
if (NatureHelper.isDimensionValueModelId(nature)) {
|
||||
return true;
|
||||
}
|
||||
}
|
||||
|
||||
@@ -32,7 +32,7 @@ public class QueryMatchStrategy implements MatchStrategy {
|
||||
private MapperHelper mapperHelper;
|
||||
|
||||
@Override
|
||||
public Map<MatchText, List<MapResult>> match(String text, List<Term> terms, Long detectmodelId) {
|
||||
public Map<MatchText, List<MapResult>> match(String text, List<Term> terms, Long detectModelId) {
|
||||
if (Objects.isNull(terms) || StringUtils.isEmpty(text)) {
|
||||
return null;
|
||||
}
|
||||
@@ -43,22 +43,18 @@ public class QueryMatchStrategy implements MatchStrategy {
|
||||
List<Integer> offsetList = terms.stream().sorted(Comparator.comparing(Term::getOffset))
|
||||
.map(term -> term.getOffset()).collect(Collectors.toList());
|
||||
|
||||
log.debug("retryCount:{},terms:{},regOffsetToLength:{},offsetList:{},detectmodelId:{}", terms,
|
||||
regOffsetToLength, offsetList, detectmodelId);
|
||||
log.debug("retryCount:{},terms:{},regOffsetToLength:{},offsetList:{},detectModelId:{}", terms,
|
||||
regOffsetToLength, offsetList, detectModelId);
|
||||
|
||||
List<MapResult> detects = detect(text, regOffsetToLength, offsetList, detectmodelId);
|
||||
List<MapResult> detects = detect(text, regOffsetToLength, offsetList, detectModelId);
|
||||
Map<MatchText, List<MapResult>> result = new HashMap<>();
|
||||
|
||||
MatchText matchText = MatchText.builder()
|
||||
.regText(text)
|
||||
.detectSegment(text)
|
||||
.build();
|
||||
result.put(matchText, detects);
|
||||
result.put(MatchText.builder().regText(text).detectSegment(text).build(), detects);
|
||||
return result;
|
||||
}
|
||||
|
||||
private List<MapResult> detect(String text, Map<Integer, Integer> regOffsetToLength, List<Integer> offsetList,
|
||||
Long detectmodelId) {
|
||||
Long detectModelId) {
|
||||
List<MapResult> results = Lists.newArrayList();
|
||||
|
||||
for (Integer index = 0; index <= text.length() - 1; ) {
|
||||
@@ -69,7 +65,7 @@ public class QueryMatchStrategy implements MatchStrategy {
|
||||
int offset = mapperHelper.getStepOffset(offsetList, index);
|
||||
i = mapperHelper.getStepIndex(regOffsetToLength, i);
|
||||
if (i <= text.length()) {
|
||||
List<MapResult> mapResults = detectByStep(text, detectmodelId, index, i, offset);
|
||||
List<MapResult> mapResults = detectByStep(text, detectModelId, index, i, offset);
|
||||
selectMapResultInOneRound(mapResultRowSet, mapResults);
|
||||
}
|
||||
}
|
||||
@@ -106,15 +102,15 @@ public class QueryMatchStrategy implements MatchStrategy {
|
||||
return a.getName() + Constants.UNDERLINE + String.join(Constants.UNDERLINE, a.getNatures());
|
||||
}
|
||||
|
||||
private List<MapResult> detectByStep(String text, Long detectmodelId, Integer index, Integer i, int offset) {
|
||||
private List<MapResult> detectByStep(String text, Long detectModelId, Integer index, Integer i, int offset) {
|
||||
String detectSegment = text.substring(index, i);
|
||||
Integer oneDetectionSize = mapperHelper.getOneDetectionSize();
|
||||
|
||||
// step1. pre search
|
||||
LinkedHashSet<MapResult> mapResults = SearchService.prefixSearch(detectSegment,
|
||||
mapperHelper.getOneDetectionMaxSize())
|
||||
Integer oneDetectionMaxSize = mapperHelper.getOneDetectionMaxSize();
|
||||
LinkedHashSet<MapResult> mapResults = SearchService.prefixSearch(detectSegment, oneDetectionMaxSize)
|
||||
.stream().collect(Collectors.toCollection(LinkedHashSet::new));
|
||||
// step2. suffix search
|
||||
LinkedHashSet<MapResult> suffixMapResults = SearchService.suffixSearch(detectSegment, oneDetectionSize)
|
||||
LinkedHashSet<MapResult> suffixMapResults = SearchService.suffixSearch(detectSegment, oneDetectionMaxSize)
|
||||
.stream().collect(Collectors.toCollection(LinkedHashSet::new));
|
||||
|
||||
mapResults.addAll(suffixMapResults);
|
||||
@@ -126,11 +122,11 @@ public class QueryMatchStrategy implements MatchStrategy {
|
||||
mapResults = mapResults.stream().sorted((a, b) -> -(b.getName().length() - a.getName().length()))
|
||||
.collect(Collectors.toCollection(LinkedHashSet::new));
|
||||
// step4. filter by classId
|
||||
if (Objects.nonNull(detectmodelId) && detectmodelId > 0) {
|
||||
log.debug("detectmodelId:{}, before parseResults:{}", mapResults);
|
||||
if (Objects.nonNull(detectModelId) && detectModelId > 0) {
|
||||
log.debug("detectModelId:{}, before parseResults:{}", mapResults);
|
||||
mapResults = mapResults.stream().map(entry -> {
|
||||
List<String> natures = entry.getNatures().stream().filter(
|
||||
nature -> nature.startsWith(DictWordType.NATURE_SPILT + detectmodelId) || (nature.startsWith(
|
||||
nature -> nature.startsWith(DictWordType.NATURE_SPILT + detectModelId) || (nature.startsWith(
|
||||
DictWordType.NATURE_SPILT))
|
||||
).collect(Collectors.toList());
|
||||
entry.setNatures(natures);
|
||||
@@ -145,8 +141,7 @@ public class QueryMatchStrategy implements MatchStrategy {
|
||||
.filter(term -> CollectionUtils.isNotEmpty(term.getNatures()))
|
||||
.collect(Collectors.toCollection(LinkedHashSet::new));
|
||||
|
||||
log.debug("metricDimensionThreshold:{},dimensionValueThreshold:{},after isSimilarity parseResults:{}",
|
||||
mapResults);
|
||||
log.debug("after isSimilarity parseResults:{}", mapResults);
|
||||
|
||||
mapResults = mapResults.stream().map(parseResult -> {
|
||||
parseResult.setOffset(offset);
|
||||
@@ -165,7 +160,7 @@ public class QueryMatchStrategy implements MatchStrategy {
|
||||
if (CollectionUtils.isNotEmpty(dimensionMetrics)) {
|
||||
return dimensionMetrics;
|
||||
} else {
|
||||
return mapResults.stream().limit(oneDetectionSize).collect(Collectors.toList());
|
||||
return mapResults.stream().limit(mapperHelper.getOneDetectionSize()).collect(Collectors.toList());
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -2,8 +2,8 @@ package com.tencent.supersonic.chat.parser;
|
||||
|
||||
|
||||
import com.tencent.supersonic.chat.api.component.SemanticQuery;
|
||||
import com.tencent.supersonic.chat.api.pojo.QueryContext;
|
||||
import com.tencent.supersonic.chat.api.pojo.SemanticParseInfo;
|
||||
import com.tencent.supersonic.chat.api.pojo.*;
|
||||
import com.tencent.supersonic.chat.query.dsl.DSLQuery;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
|
||||
/**
|
||||
@@ -21,6 +21,9 @@ public class SatisfactionChecker {
|
||||
// check all the parse info in candidate
|
||||
public static boolean check(QueryContext queryContext) {
|
||||
for (SemanticQuery query : queryContext.getCandidateQueries()) {
|
||||
if (query.getQueryMode().equals(DSLQuery.QUERY_MODE)) {
|
||||
continue;
|
||||
}
|
||||
if (checkThreshold(queryContext.getRequest().getQueryText(), query.getParseInfo())) {
|
||||
return true;
|
||||
}
|
||||
|
||||
@@ -2,14 +2,7 @@ package com.tencent.supersonic.chat.parser.embedding;
|
||||
|
||||
import com.google.common.collect.Lists;
|
||||
import com.tencent.supersonic.chat.api.component.SemanticParser;
|
||||
import com.tencent.supersonic.chat.api.pojo.ChatContext;
|
||||
import com.tencent.supersonic.chat.api.pojo.ModelSchema;
|
||||
import com.tencent.supersonic.chat.api.pojo.QueryContext;
|
||||
import com.tencent.supersonic.chat.api.pojo.SchemaElement;
|
||||
import com.tencent.supersonic.chat.api.pojo.SchemaElementMatch;
|
||||
import com.tencent.supersonic.chat.api.pojo.SchemaElementType;
|
||||
import com.tencent.supersonic.chat.api.pojo.SchemaMapInfo;
|
||||
import com.tencent.supersonic.chat.api.pojo.SemanticParseInfo;
|
||||
import com.tencent.supersonic.chat.api.pojo.*;
|
||||
import com.tencent.supersonic.chat.api.pojo.request.QueryFilter;
|
||||
import com.tencent.supersonic.chat.api.pojo.request.QueryReq;
|
||||
import com.tencent.supersonic.chat.parser.ParseMode;
|
||||
@@ -17,17 +10,14 @@ import com.tencent.supersonic.chat.plugin.Plugin;
|
||||
import com.tencent.supersonic.chat.plugin.PluginManager;
|
||||
import com.tencent.supersonic.chat.plugin.PluginParseResult;
|
||||
import com.tencent.supersonic.chat.query.QueryManager;
|
||||
import com.tencent.supersonic.chat.query.dsl.DSLQuery;
|
||||
import com.tencent.supersonic.chat.query.plugin.PluginSemanticQuery;
|
||||
import com.tencent.supersonic.chat.service.PluginService;
|
||||
import com.tencent.supersonic.chat.service.SemanticService;
|
||||
import com.tencent.supersonic.common.pojo.Constants;
|
||||
import com.tencent.supersonic.common.util.ContextUtils;
|
||||
import com.tencent.supersonic.semantic.api.query.enums.FilterOperatorEnum;
|
||||
import java.util.Comparator;
|
||||
import java.util.HashMap;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
import java.util.*;
|
||||
import java.util.stream.Collectors;
|
||||
import com.tencent.supersonic.semantic.api.query.enums.FilterOperatorEnum;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.apache.commons.lang3.StringUtils;
|
||||
import org.apache.commons.lang3.tuple.Pair;
|
||||
@@ -53,48 +43,32 @@ public class EmbeddingBasedParser implements SemanticParser {
|
||||
if (CollectionUtils.isEmpty(embeddingRetrievals)) {
|
||||
return;
|
||||
}
|
||||
PluginService pluginService = ContextUtils.getBean(PluginService.class);
|
||||
List<Plugin> plugins = pluginService.getPluginList();
|
||||
List<Plugin> plugins = getPluginList(queryContext);
|
||||
Map<Long, Plugin> pluginMap = plugins.stream().collect(Collectors.toMap(Plugin::getId, p -> p));
|
||||
for (RecallRetrieval embeddingRetrieval : embeddingRetrievals) {
|
||||
Plugin plugin = pluginMap.get(Long.parseLong(embeddingRetrieval.getId()));
|
||||
if (plugin == null) {
|
||||
if (plugin == null || DSLQuery.QUERY_MODE.equalsIgnoreCase(plugin.getType())) {
|
||||
continue;
|
||||
}
|
||||
Pair<Boolean, List<Long>> pair = PluginManager.resolve(plugin, queryContext);
|
||||
Pair<Boolean, Set<Long>> pair = PluginManager.resolve(plugin, queryContext);
|
||||
log.info("embedding plugin resolve: {}", pair);
|
||||
if (pair.getLeft()) {
|
||||
List<Long> modelList = pair.getRight();
|
||||
Set<Long> modelList = pair.getRight();
|
||||
if (CollectionUtils.isEmpty(modelList)) {
|
||||
return;
|
||||
}
|
||||
modelList = distinctModelList(plugin, queryContext.getMapInfo(), modelList);
|
||||
for (Long modelId : modelList) {
|
||||
buildQuery(plugin, Double.parseDouble(embeddingRetrieval.getDistance()), modelId, queryContext,
|
||||
queryContext.getMapInfo().getMatchedElements(modelId));
|
||||
if (plugin.isContainsAllModel()) {
|
||||
break;
|
||||
}
|
||||
}
|
||||
return;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
public List<Long> distinctModelList(Plugin plugin, SchemaMapInfo schemaMapInfo, List<Long> modelList) {
|
||||
if (!plugin.isContainsAllModel()) {
|
||||
return modelList;
|
||||
}
|
||||
boolean noElementMatch = true;
|
||||
for (Long model : modelList) {
|
||||
List<SchemaElementMatch> schemaElementMatches = schemaMapInfo.getMatchedElements(model);
|
||||
if (!CollectionUtils.isEmpty(schemaElementMatches)) {
|
||||
noElementMatch = false;
|
||||
}
|
||||
}
|
||||
if (noElementMatch) {
|
||||
return modelList.subList(0, 1);
|
||||
}
|
||||
return modelList;
|
||||
}
|
||||
|
||||
private void buildQuery(Plugin plugin, double distance, Long modelId,
|
||||
QueryContext queryContext, List<SchemaElementMatch> schemaElementMatches) {
|
||||
log.info("EmbeddingBasedParser Model: {} choose plugin: [{} {}]", modelId, plugin.getId(), plugin.getName());
|
||||
@@ -126,6 +100,8 @@ public class EmbeddingBasedParser implements SemanticParser {
|
||||
pluginParseResult.setRequest(queryReq);
|
||||
pluginParseResult.setDistance(distance);
|
||||
properties.put(Constants.CONTEXT, pluginParseResult);
|
||||
properties.put("type", "plugin");
|
||||
properties.put("name", plugin.getName());
|
||||
semanticParseInfo.setProperties(properties);
|
||||
semanticParseInfo.setScore(distance);
|
||||
fillSemanticParseInfo(semanticParseInfo);
|
||||
@@ -176,4 +152,8 @@ public class EmbeddingBasedParser implements SemanticParser {
|
||||
}
|
||||
}
|
||||
|
||||
protected List<Plugin> getPluginList(QueryContext queryContext) {
|
||||
return PluginManager.getPluginAgentCanSupport(queryContext.getRequest().getAgentId());
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
@@ -14,23 +14,16 @@ import com.tencent.supersonic.chat.plugin.PluginManager;
|
||||
import com.tencent.supersonic.chat.plugin.PluginParseConfig;
|
||||
import com.tencent.supersonic.chat.plugin.PluginParseResult;
|
||||
import com.tencent.supersonic.chat.query.QueryManager;
|
||||
import com.tencent.supersonic.chat.query.dsl.DSLQuery;
|
||||
import com.tencent.supersonic.chat.query.plugin.PluginSemanticQuery;
|
||||
import com.tencent.supersonic.chat.query.dsl.DSLQuery;
|
||||
import com.tencent.supersonic.chat.service.PluginService;
|
||||
import com.tencent.supersonic.chat.utils.ComponentFactory;
|
||||
import com.tencent.supersonic.common.pojo.Constants;
|
||||
import com.tencent.supersonic.common.util.ContextUtils;
|
||||
import com.tencent.supersonic.common.util.JsonUtil;
|
||||
import java.net.URI;
|
||||
import java.util.ArrayList;
|
||||
import java.util.HashMap;
|
||||
import java.util.HashSet;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
import java.util.Objects;
|
||||
import java.util.Optional;
|
||||
import java.util.Set;
|
||||
import java.util.*;
|
||||
import java.util.stream.Collectors;
|
||||
import com.tencent.supersonic.common.util.JsonUtil;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.apache.commons.lang3.StringUtils;
|
||||
import org.apache.commons.lang3.tuple.Pair;
|
||||
@@ -62,6 +55,10 @@ public class FunctionBasedParser implements SemanticParser {
|
||||
return;
|
||||
}
|
||||
List<PluginParseConfig> functionDOList = getFunctionDO(queryCtx.getRequest().getModelId(), queryCtx);
|
||||
if (CollectionUtils.isEmpty(functionDOList)) {
|
||||
log.info("function call parser, plugin is empty, skip");
|
||||
return;
|
||||
}
|
||||
FunctionReq functionReq = FunctionReq.builder()
|
||||
.queryText(queryCtx.getRequest().getQueryText())
|
||||
.pluginConfigs(functionDOList).build();
|
||||
@@ -85,7 +82,7 @@ public class FunctionBasedParser implements SemanticParser {
|
||||
PluginSemanticQuery semanticQuery = QueryManager.createPluginQuery(toolSelection);
|
||||
ModelResolver ModelResolver = ComponentFactory.getModelResolver();
|
||||
log.info("plugin ModelList:{}", plugin.getModelList());
|
||||
Pair<Boolean, List<Long>> pluginResolveResult = PluginManager.resolve(plugin, queryCtx);
|
||||
Pair<Boolean, Set<Long>> pluginResolveResult = PluginManager.resolve(plugin, queryCtx);
|
||||
Long modelId = ModelResolver.resolve(queryCtx, chatCtx, pluginResolveResult.getRight());
|
||||
log.info("FunctionBasedParser modelId:{}", modelId);
|
||||
if ((Objects.isNull(modelId) || modelId <= 0) && !plugin.isContainsAllModel()) {
|
||||
@@ -102,6 +99,8 @@ public class FunctionBasedParser implements SemanticParser {
|
||||
functionCallParseResult.setRequest(queryCtx.getRequest());
|
||||
Map<String, Object> properties = new HashMap<>();
|
||||
properties.put(Constants.CONTEXT, functionCallParseResult);
|
||||
properties.put("type", "plugin");
|
||||
properties.put("name", plugin.getName());
|
||||
parseInfo.setProperties(properties);
|
||||
parseInfo.setScore(FUNCTION_BONUS_THRESHOLD);
|
||||
parseInfo.setQueryMode(semanticQuery.getQueryMode());
|
||||
@@ -112,17 +111,6 @@ public class FunctionBasedParser implements SemanticParser {
|
||||
queryCtx.getCandidateQueries().add(semanticQuery);
|
||||
}
|
||||
|
||||
|
||||
private Set<Long> getMatchModels(QueryContext queryCtx) {
|
||||
Set<Long> result = new HashSet<>();
|
||||
Long modelId = queryCtx.getRequest().getModelId();
|
||||
if (Objects.nonNull(modelId) && modelId > 0) {
|
||||
result.add(modelId);
|
||||
return result;
|
||||
}
|
||||
return queryCtx.getMapInfo().getMatchedModels();
|
||||
}
|
||||
|
||||
private boolean skipFunction(QueryContext queryCtx, FunctionResp functionResp) {
|
||||
if (Objects.isNull(functionResp) || StringUtils.isBlank(functionResp.getToolSelection())) {
|
||||
return true;
|
||||
@@ -140,7 +128,7 @@ public class FunctionBasedParser implements SemanticParser {
|
||||
|
||||
private List<PluginParseConfig> getFunctionDO(Long modelId, QueryContext queryContext) {
|
||||
log.info("user decide Model:{}", modelId);
|
||||
List<Plugin> plugins = PluginManager.getPlugins();
|
||||
List<Plugin> plugins = getPluginList(queryContext);
|
||||
List<PluginParseConfig> functionDOList = plugins.stream().filter(plugin -> {
|
||||
if (DSLQuery.QUERY_MODE.equalsIgnoreCase(plugin.getType())) {
|
||||
return false;
|
||||
@@ -153,12 +141,12 @@ public class FunctionBasedParser implements SemanticParser {
|
||||
if (StringUtils.isBlank(pluginParseConfig.getName())) {
|
||||
return false;
|
||||
}
|
||||
Pair<Boolean, List<Long>> pluginResolverResult = PluginManager.resolve(plugin, queryContext);
|
||||
log.info("embedding plugin [{}-{}] resolve: {}", plugin.getId(), plugin.getName(), pluginResolverResult);
|
||||
Pair<Boolean, Set<Long>> pluginResolverResult = PluginManager.resolve(plugin, queryContext);
|
||||
log.info("plugin [{}-{}] resolve: {}", plugin.getId(), plugin.getName(), pluginResolverResult);
|
||||
if (!pluginResolverResult.getLeft()) {
|
||||
return false;
|
||||
} else {
|
||||
List<Long> resolveModel = pluginResolverResult.getRight();
|
||||
Set<Long> resolveModel = pluginResolverResult.getRight();
|
||||
if (modelId != null && modelId > 0) {
|
||||
if (plugin.isContainsAllModel()) {
|
||||
return true;
|
||||
@@ -172,20 +160,6 @@ public class FunctionBasedParser implements SemanticParser {
|
||||
return functionDOList;
|
||||
}
|
||||
|
||||
private List<String> getFunctionNames(Set<Long> matchedModels) {
|
||||
List<Plugin> plugins = PluginManager.getPlugins();
|
||||
Set<String> functionNames = plugins.stream()
|
||||
.filter(entry -> {
|
||||
if (!CollectionUtils.isEmpty(entry.getModelList()) && !CollectionUtils.isEmpty(matchedModels)) {
|
||||
return entry.getModelList().stream().anyMatch(matchedModels::contains);
|
||||
}
|
||||
return true;
|
||||
}
|
||||
).map(Plugin::getName).collect(Collectors.toSet());
|
||||
functionNames.add(DSLQuery.QUERY_MODE);
|
||||
return new ArrayList<>(functionNames);
|
||||
}
|
||||
|
||||
public FunctionResp requestFunction(String url, FunctionReq functionReq) {
|
||||
HttpHeaders headers = new HttpHeaders();
|
||||
long startTime = System.currentTimeMillis();
|
||||
@@ -205,4 +179,8 @@ public class FunctionBasedParser implements SemanticParser {
|
||||
}
|
||||
return null;
|
||||
}
|
||||
|
||||
protected List<Plugin> getPluginList(QueryContext queryContext) {
|
||||
return PluginManager.getPluginAgentCanSupport(queryContext.getRequest().getAgentId());
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,21 +1,12 @@
|
||||
package com.tencent.supersonic.chat.parser.function;
|
||||
|
||||
import com.tencent.supersonic.chat.api.component.SemanticQuery;
|
||||
import com.tencent.supersonic.chat.api.pojo.ChatContext;
|
||||
import com.tencent.supersonic.chat.api.pojo.QueryContext;
|
||||
import com.tencent.supersonic.chat.api.pojo.SchemaElementMatch;
|
||||
import com.tencent.supersonic.chat.api.pojo.SchemaElementType;
|
||||
import com.tencent.supersonic.chat.api.pojo.SchemaMapInfo;
|
||||
import com.tencent.supersonic.chat.api.pojo.SemanticParseInfo;
|
||||
import com.tencent.supersonic.chat.api.pojo.*;
|
||||
import com.tencent.supersonic.chat.api.pojo.request.QueryReq;
|
||||
import java.util.HashMap;
|
||||
import java.util.HashSet;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
import java.util.Objects;
|
||||
import java.util.Set;
|
||||
import java.util.stream.Collectors;
|
||||
import com.tencent.supersonic.chat.api.component.SemanticQuery;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
|
||||
import java.util.*;
|
||||
import java.util.stream.Collectors;
|
||||
import org.apache.commons.collections.CollectionUtils;
|
||||
|
||||
@Slf4j
|
||||
@@ -55,7 +46,7 @@ public class HeuristicModelResolver implements ModelResolver {
|
||||
* @return false will use context Model, true will use other Model , maybe include context Model
|
||||
*/
|
||||
protected static boolean isAllowSwitch(Map<Long, SemanticQuery> ModelQueryModes, SchemaMapInfo schemaMap,
|
||||
ChatContext chatCtx, QueryReq searchCtx, Long modelId, List<Long> restrictiveModels) {
|
||||
ChatContext chatCtx, QueryReq searchCtx, Long modelId, Set<Long> restrictiveModels) {
|
||||
if (!Objects.nonNull(modelId) || modelId <= 0) {
|
||||
return true;
|
||||
}
|
||||
@@ -80,8 +71,7 @@ public class HeuristicModelResolver implements ModelResolver {
|
||||
}
|
||||
if (searchCtx.getQueryText() != null && semanticParseInfo.getDateInfo() != null) {
|
||||
if (semanticParseInfo.getDateInfo().getDetectWord() != null) {
|
||||
if (semanticParseInfo.getDateInfo().getDetectWord()
|
||||
.equalsIgnoreCase(searchCtx.getQueryText())) {
|
||||
if (semanticParseInfo.getDateInfo().getDetectWord().equalsIgnoreCase(searchCtx.getQueryText())) {
|
||||
log.info("timeParseResults is not null , can not switch context , timeParseResults:{},",
|
||||
semanticParseInfo.getDateInfo());
|
||||
return false;
|
||||
@@ -131,7 +121,7 @@ public class HeuristicModelResolver implements ModelResolver {
|
||||
}
|
||||
|
||||
|
||||
public Long resolve(QueryContext queryContext, ChatContext chatCtx, List<Long> restrictiveModels) {
|
||||
public Long resolve(QueryContext queryContext, ChatContext chatCtx, Set<Long> restrictiveModels) {
|
||||
Long modelId = queryContext.getRequest().getModelId();
|
||||
if (Objects.nonNull(modelId) && modelId > 0) {
|
||||
if (CollectionUtils.isNotEmpty(restrictiveModels) && restrictiveModels.contains(modelId)) {
|
||||
@@ -151,17 +141,16 @@ public class HeuristicModelResolver implements ModelResolver {
|
||||
for (Long matchedModel : matchedModels) {
|
||||
ModelQueryModes.put(matchedModel, null);
|
||||
}
|
||||
if (ModelQueryModes.size() == 1) {
|
||||
if(ModelQueryModes.size()==1){
|
||||
return ModelQueryModes.keySet().stream().findFirst().get();
|
||||
}
|
||||
return resolve(ModelQueryModes, queryContext, chatCtx,
|
||||
queryContext.getMapInfo(), restrictiveModels);
|
||||
queryContext.getMapInfo(),restrictiveModels);
|
||||
}
|
||||
|
||||
public Long resolve(Map<Long, SemanticQuery> ModelQueryModes, QueryContext queryContext,
|
||||
ChatContext chatCtx, SchemaMapInfo schemaMap, List<Long> restrictiveModels) {
|
||||
Long selectModel = selectModel(ModelQueryModes, queryContext.getRequest(), chatCtx, schemaMap,
|
||||
restrictiveModels);
|
||||
ChatContext chatCtx, SchemaMapInfo schemaMap, Set<Long> restrictiveModels) {
|
||||
Long selectModel = selectModel(ModelQueryModes, queryContext.getRequest(), chatCtx, schemaMap,restrictiveModels);
|
||||
if (selectModel > 0) {
|
||||
log.info("selectModel {} ", selectModel);
|
||||
return selectModel;
|
||||
@@ -172,7 +161,7 @@ public class HeuristicModelResolver implements ModelResolver {
|
||||
|
||||
public Long selectModel(Map<Long, SemanticQuery> ModelQueryModes, QueryReq queryContext,
|
||||
ChatContext chatCtx,
|
||||
SchemaMapInfo schemaMap, List<Long> restrictiveModels) {
|
||||
SchemaMapInfo schemaMap, Set<Long> restrictiveModels) {
|
||||
// if QueryContext has modelId and in ModelQueryModes
|
||||
if (ModelQueryModes.containsKey(queryContext.getModelId())) {
|
||||
log.info("selectModel from QueryContext [{}]", queryContext.getModelId());
|
||||
@@ -181,7 +170,7 @@ public class HeuristicModelResolver implements ModelResolver {
|
||||
// if ChatContext has modelId and in ModelQueryModes
|
||||
if (chatCtx.getParseInfo().getModelId() > 0) {
|
||||
Long modelId = chatCtx.getParseInfo().getModelId();
|
||||
if (!isAllowSwitch(ModelQueryModes, schemaMap, chatCtx, queryContext, modelId, restrictiveModels)) {
|
||||
if (!isAllowSwitch(ModelQueryModes, schemaMap, chatCtx, queryContext, modelId,restrictiveModels)) {
|
||||
log.info("selectModel from ChatContext [{}]", modelId);
|
||||
return modelId;
|
||||
}
|
||||
|
||||
@@ -4,9 +4,10 @@ package com.tencent.supersonic.chat.parser.function;
|
||||
import com.tencent.supersonic.chat.api.pojo.ChatContext;
|
||||
import com.tencent.supersonic.chat.api.pojo.QueryContext;
|
||||
import java.util.List;
|
||||
import java.util.Set;
|
||||
|
||||
public interface ModelResolver {
|
||||
|
||||
Long resolve(QueryContext queryContext, ChatContext chatCtx, List<Long> restrictiveModels);
|
||||
Long resolve(QueryContext queryContext, ChatContext chatCtx, Set<Long> restrictiveModels);
|
||||
|
||||
}
|
||||
@@ -1,11 +0,0 @@
|
||||
package com.tencent.supersonic.chat.parser.llm;
|
||||
|
||||
import com.tencent.supersonic.chat.plugin.PluginParseResult;
|
||||
import com.tencent.supersonic.chat.query.dsl.LLMResp;
|
||||
import lombok.Data;
|
||||
|
||||
@Data
|
||||
public class DSLParseResult extends PluginParseResult {
|
||||
|
||||
private LLMResp llmResp;
|
||||
}
|
||||
@@ -0,0 +1,20 @@
|
||||
package com.tencent.supersonic.chat.parser.llm.dsl;
|
||||
|
||||
import com.tencent.supersonic.common.util.DateUtils;
|
||||
|
||||
public class DSLDateHelper {
|
||||
|
||||
public static String getCurrentDate(Long modelId) {
|
||||
return DateUtils.getBeforeDate(4);
|
||||
// ChatConfigFilter filter = new ChatConfigFilter();
|
||||
// filter.setModelId(modelId);
|
||||
//
|
||||
// List<ChatConfigResp> configResps = ContextUtils.getBean(ConfigService.class).search(filter, null);
|
||||
// if (CollectionUtils.isEmpty(configResps)) {
|
||||
// return
|
||||
// }
|
||||
// ChatConfigResp chatConfigResp = configResps.get(0);
|
||||
// chatConfigResp.getChatDetailConfig().getChatDefaultConfig().get
|
||||
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,17 @@
|
||||
package com.tencent.supersonic.chat.parser.llm.dsl;
|
||||
|
||||
import com.tencent.supersonic.chat.agent.tool.DslTool;
|
||||
import com.tencent.supersonic.chat.api.pojo.request.QueryReq;
|
||||
import com.tencent.supersonic.chat.plugin.PluginParseResult;
|
||||
import com.tencent.supersonic.chat.query.dsl.LLMResp;
|
||||
import lombok.Data;
|
||||
|
||||
@Data
|
||||
public class DSLParseResult {
|
||||
|
||||
private LLMResp llmResp;
|
||||
|
||||
private QueryReq request;
|
||||
|
||||
private DslTool dslTool;
|
||||
}
|
||||
@@ -1,5 +1,10 @@
|
||||
package com.tencent.supersonic.chat.parser.llm;
|
||||
package com.tencent.supersonic.chat.parser.llm.dsl;
|
||||
|
||||
import com.alibaba.fastjson.JSONObject;
|
||||
import com.google.common.collect.Lists;
|
||||
import com.tencent.supersonic.chat.agent.Agent;
|
||||
import com.tencent.supersonic.chat.agent.tool.AgentToolType;
|
||||
import com.tencent.supersonic.chat.agent.tool.DslTool;
|
||||
import com.tencent.supersonic.chat.api.component.SemanticParser;
|
||||
import com.tencent.supersonic.chat.api.pojo.ChatContext;
|
||||
import com.tencent.supersonic.chat.api.pojo.QueryContext;
|
||||
@@ -11,26 +16,26 @@ import com.tencent.supersonic.chat.api.pojo.SemanticSchema;
|
||||
import com.tencent.supersonic.chat.config.LLMConfig;
|
||||
import com.tencent.supersonic.chat.parser.SatisfactionChecker;
|
||||
import com.tencent.supersonic.chat.parser.function.ModelResolver;
|
||||
import com.tencent.supersonic.chat.plugin.Plugin;
|
||||
import com.tencent.supersonic.chat.plugin.PluginManager;
|
||||
import com.tencent.supersonic.chat.query.QueryManager;
|
||||
import com.tencent.supersonic.chat.query.dsl.DSLBuilder;
|
||||
import com.tencent.supersonic.chat.query.dsl.DSLQuery;
|
||||
import com.tencent.supersonic.chat.query.dsl.LLMReq;
|
||||
import com.tencent.supersonic.chat.query.dsl.LLMReq.ElementValue;
|
||||
import com.tencent.supersonic.chat.query.dsl.LLMResp;
|
||||
import com.tencent.supersonic.chat.query.dsl.optimizer.BaseDSLOptimizer;
|
||||
import com.tencent.supersonic.chat.query.plugin.PluginSemanticQuery;
|
||||
import com.tencent.supersonic.chat.service.AgentService;
|
||||
import com.tencent.supersonic.chat.utils.ComponentFactory;
|
||||
import com.tencent.supersonic.common.pojo.Constants;
|
||||
import com.tencent.supersonic.common.util.ContextUtils;
|
||||
import com.tencent.supersonic.common.util.DateUtils;
|
||||
import com.tencent.supersonic.common.util.JsonUtil;
|
||||
import com.tencent.supersonic.knowledge.service.SchemaService;
|
||||
import java.util.ArrayList;
|
||||
import java.util.Collection;
|
||||
import java.util.HashMap;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
import java.util.Objects;
|
||||
import java.util.Optional;
|
||||
import java.util.Set;
|
||||
import java.util.stream.Collectors;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
@@ -56,30 +61,30 @@ public class LLMDSLParser implements SemanticParser {
|
||||
queryCtx.getRequest().getQueryText());
|
||||
return;
|
||||
}
|
||||
List<Plugin> dslPlugins = PluginManager.getPlugins().stream()
|
||||
.filter(plugin -> DSLQuery.QUERY_MODE.equalsIgnoreCase(plugin.getType()))
|
||||
.collect(Collectors.toList());
|
||||
|
||||
if (CollectionUtils.isEmpty(dslPlugins)) {
|
||||
return;
|
||||
}
|
||||
Plugin plugin = dslPlugins.get(0);
|
||||
List<Long> dslModels = plugin.getModelList();
|
||||
|
||||
List<DslTool> dslTools = getDslTools(queryCtx.getRequest().getAgentId());
|
||||
Set<Long> distinctModelIds = dslTools.stream().map(DslTool::getModelIds)
|
||||
.flatMap(Collection::stream)
|
||||
.collect(Collectors.toSet());
|
||||
try {
|
||||
ModelResolver modelResolver = ComponentFactory.getModelResolver();
|
||||
Long modelId = modelResolver.resolve(queryCtx, chatCtx, dslModels);
|
||||
log.info("resolve modelId:{},dslModels:{}", modelId, dslModels);
|
||||
Long modelId = modelResolver.resolve(queryCtx, chatCtx, distinctModelIds);
|
||||
log.info("resolve modelId:{},dslModels:{}", modelId, distinctModelIds);
|
||||
|
||||
if (Objects.isNull(modelId)) {
|
||||
if (Objects.isNull(modelId) || modelId <= 0) {
|
||||
return;
|
||||
}
|
||||
|
||||
Optional<DslTool> dslToolOptional = dslTools.stream().filter(tool ->
|
||||
tool.getModelIds().contains(modelId)).findFirst();
|
||||
if (!dslToolOptional.isPresent()) {
|
||||
log.info("no dsl tool in this agent, skip dsl parser");
|
||||
return;
|
||||
}
|
||||
DslTool dslTool = dslToolOptional.get();
|
||||
LLMResp llmResp = requestLLM(queryCtx, modelId);
|
||||
if (Objects.isNull(llmResp)) {
|
||||
return;
|
||||
}
|
||||
|
||||
PluginSemanticQuery semanticQuery = QueryManager.createPluginQuery(DSLQuery.QUERY_MODE);
|
||||
|
||||
SemanticParseInfo parseInfo = semanticQuery.getParseInfo();
|
||||
@@ -89,10 +94,12 @@ public class LLMDSLParser implements SemanticParser {
|
||||
DSLParseResult dslParseResult = new DSLParseResult();
|
||||
dslParseResult.setRequest(queryCtx.getRequest());
|
||||
dslParseResult.setLlmResp(llmResp);
|
||||
dslParseResult.setPlugin(plugin);
|
||||
dslParseResult.setDslTool(dslToolOptional.get());
|
||||
|
||||
Map<String, Object> properties = new HashMap<>();
|
||||
properties.put(Constants.CONTEXT, dslParseResult);
|
||||
properties.put("type", "internal");
|
||||
properties.put("name", dslTool.getName());
|
||||
parseInfo.setProperties(properties);
|
||||
parseInfo.setScore(FUNCTION_BONUS_THRESHOLD);
|
||||
parseInfo.setQueryMode(semanticQuery.getQueryMode());
|
||||
@@ -126,13 +133,13 @@ public class LLMDSLParser implements SemanticParser {
|
||||
llmSchema.setModelName(modelIdToName.get(modelId));
|
||||
llmSchema.setDomainName(modelIdToName.get(modelId));
|
||||
List<String> fieldNameList = getFieldNameList(queryCtx, modelId, semanticSchema);
|
||||
fieldNameList.add(DSLBuilder.DATA_Field);
|
||||
fieldNameList.add(BaseDSLOptimizer.DATE_FIELD);
|
||||
llmSchema.setFieldNameList(fieldNameList);
|
||||
llmReq.setSchema(llmSchema);
|
||||
List<ElementValue> linking = new ArrayList<>();
|
||||
linking.addAll(getValueList(queryCtx, modelId, semanticSchema));
|
||||
llmReq.setLinking(linking);
|
||||
String currentDate = getCurrentDate(modelId);
|
||||
String currentDate = DSLDateHelper.getCurrentDate(modelId);
|
||||
llmReq.setCurrentDate(currentDate);
|
||||
|
||||
log.info("requestLLM request, modelId:{},llmReq:{}", modelId, llmReq);
|
||||
@@ -156,21 +163,6 @@ public class LLMDSLParser implements SemanticParser {
|
||||
return null;
|
||||
}
|
||||
|
||||
|
||||
private String getCurrentDate(Long modelId) {
|
||||
return DateUtils.getBeforeDate(4);
|
||||
// ChatConfigFilter filter = new ChatConfigFilter();
|
||||
// filter.setModelId(modelId);
|
||||
//
|
||||
// List<ChatConfigResp> configResps = ContextUtils.getBean(ConfigService.class).search(filter, null);
|
||||
// if (CollectionUtils.isEmpty(configResps)) {
|
||||
// return
|
||||
// }
|
||||
// ChatConfigResp chatConfigResp = configResps.get(0);
|
||||
// chatConfigResp.getChatDetailConfig().getChatDefaultConfig().get
|
||||
|
||||
}
|
||||
|
||||
private List<ElementValue> getValueList(QueryContext queryCtx, Long modelId, SemanticSchema semanticSchema) {
|
||||
Map<Long, String> itemIdToName = getItemIdToName(modelId, semanticSchema);
|
||||
|
||||
@@ -228,4 +220,18 @@ public class LLMDSLParser implements SemanticParser {
|
||||
.collect(Collectors.toMap(SchemaElement::getId, SchemaElement::getName, (value1, value2) -> value2));
|
||||
}
|
||||
|
||||
private List<DslTool> getDslTools(Integer agentId) {
|
||||
AgentService agentService = ContextUtils.getBean(AgentService.class);
|
||||
Agent agent = agentService.getAgent(agentId);
|
||||
if (agent == null) {
|
||||
return Lists.newArrayList();
|
||||
}
|
||||
List<String> tools = agent.getTools(AgentToolType.DSL);
|
||||
if (CollectionUtils.isEmpty(tools)) {
|
||||
return Lists.newArrayList();
|
||||
}
|
||||
return tools.stream().map(tool -> JSONObject.parseObject(tool, DslTool.class))
|
||||
.collect(Collectors.toList());
|
||||
}
|
||||
|
||||
}
|
||||
@@ -0,0 +1,144 @@
|
||||
package com.tencent.supersonic.chat.parser.llm.interpret;
|
||||
|
||||
import com.alibaba.fastjson.JSONObject;
|
||||
import com.google.common.collect.Sets;
|
||||
import com.tencent.supersonic.chat.agent.Agent;
|
||||
import com.tencent.supersonic.chat.agent.tool.AgentToolType;
|
||||
import com.tencent.supersonic.chat.agent.tool.MetricInterpretTool;
|
||||
import com.tencent.supersonic.chat.api.component.SemanticLayer;
|
||||
import com.tencent.supersonic.chat.api.component.SemanticParser;
|
||||
import com.tencent.supersonic.chat.api.pojo.*;
|
||||
import com.tencent.supersonic.chat.api.pojo.request.QueryFilter;
|
||||
import com.tencent.supersonic.chat.api.pojo.request.QueryReq;
|
||||
import com.tencent.supersonic.chat.parser.SatisfactionChecker;
|
||||
import com.tencent.supersonic.chat.query.metricInterpret.MetricInterpretQuery;
|
||||
import com.tencent.supersonic.chat.query.QueryManager;
|
||||
import com.tencent.supersonic.chat.query.plugin.PluginSemanticQuery;
|
||||
import com.tencent.supersonic.chat.service.AgentService;
|
||||
import com.tencent.supersonic.chat.utils.ComponentFactory;
|
||||
import com.tencent.supersonic.common.pojo.DateConf;
|
||||
import com.tencent.supersonic.common.util.ContextUtils;
|
||||
import com.tencent.supersonic.semantic.api.model.enums.TimeDimensionEnum;
|
||||
import com.tencent.supersonic.semantic.api.query.enums.FilterOperatorEnum;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.springframework.util.CollectionUtils;
|
||||
import java.util.*;
|
||||
import java.util.stream.Collectors;
|
||||
|
||||
@Slf4j
|
||||
public class MetricInterpretParser implements SemanticParser {
|
||||
|
||||
@Override
|
||||
public void parse(QueryContext queryContext, ChatContext chatContext) {
|
||||
if (SatisfactionChecker.check(queryContext)) {
|
||||
log.info("skip MetricInterpretParser");
|
||||
return;
|
||||
}
|
||||
Map<Long, MetricInterpretTool> metricInterpretToolMap = getMetricInterpretTools(queryContext.getRequest().getAgentId());
|
||||
log.info("metric interpret tool : {}", metricInterpretToolMap);
|
||||
if (CollectionUtils.isEmpty(metricInterpretToolMap)) {
|
||||
return;
|
||||
}
|
||||
Map<Long, List<SchemaElementMatch>> elementMatches = queryContext.getMapInfo().getModelElementMatches();
|
||||
for (Long modelId : elementMatches.keySet()) {
|
||||
MetricInterpretTool metricInterpretTool = metricInterpretToolMap.get(modelId);
|
||||
if (metricInterpretTool == null) {
|
||||
continue;
|
||||
}
|
||||
if (CollectionUtils.isEmpty(elementMatches.get(modelId))) {
|
||||
continue;
|
||||
}
|
||||
List<MetricOption> metricOptions = metricInterpretTool.getMetricOptions();
|
||||
if (!CollectionUtils.isEmpty(metricOptions)) {
|
||||
List<Long> metricIds = metricOptions.stream().map(MetricOption::getMetricId).collect(Collectors.toList());
|
||||
buildQuery(modelId, queryContext, metricIds, elementMatches.get(modelId), metricInterpretTool.getName());
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
private void buildQuery(Long modelId, QueryContext queryContext,
|
||||
List<Long> metricIds, List<SchemaElementMatch> schemaElementMatches, String toolName) {
|
||||
PluginSemanticQuery metricInterpretQuery = QueryManager.createPluginQuery(MetricInterpretQuery.QUERY_MODE);
|
||||
Set<SchemaElement> metrics = getMetrics(metricIds, modelId);
|
||||
SemanticParseInfo semanticParseInfo = buildSemanticParseInfo(modelId, queryContext.getRequest(),
|
||||
metrics, schemaElementMatches, toolName);
|
||||
semanticParseInfo.setQueryMode(metricInterpretQuery.getQueryMode());
|
||||
semanticParseInfo.getProperties().put("queryText", queryContext.getRequest().getQueryText());
|
||||
metricInterpretQuery.setParseInfo(semanticParseInfo);
|
||||
queryContext.getCandidateQueries().add(metricInterpretQuery);
|
||||
}
|
||||
|
||||
public Set<SchemaElement> getMetrics(List<Long> metricIds, Long modelId) {
|
||||
SemanticLayer semanticLayer = ComponentFactory.getSemanticLayer();
|
||||
ModelSchema modelSchema = semanticLayer.getModelSchema(modelId, true);
|
||||
Set<SchemaElement> metrics = modelSchema.getMetrics();
|
||||
return metrics.stream().filter(schemaElement -> metricIds.contains(schemaElement.getId()))
|
||||
.collect(Collectors.toSet());
|
||||
}
|
||||
|
||||
private Map<Long, MetricInterpretTool> getMetricInterpretTools(Integer agentId) {
|
||||
AgentService agentService = ContextUtils.getBean(AgentService.class);
|
||||
Agent agent = agentService.getAgent(agentId);
|
||||
if (agent == null) {
|
||||
return new HashMap<>();
|
||||
}
|
||||
List<String> tools= agent.getTools(AgentToolType.INTERPRET);
|
||||
if (CollectionUtils.isEmpty(tools)) {
|
||||
return new HashMap<>();
|
||||
}
|
||||
List<MetricInterpretTool> metricInterpretTools = tools.stream().map(tool ->
|
||||
JSONObject.parseObject(tool, MetricInterpretTool.class))
|
||||
.filter(tool -> !CollectionUtils.isEmpty(tool.getMetricOptions()))
|
||||
.collect(Collectors.toList());
|
||||
Map<Long, MetricInterpretTool> metricInterpretToolMap = new HashMap<>();
|
||||
for (MetricInterpretTool metricInterpretTool : metricInterpretTools) {
|
||||
metricInterpretToolMap.putIfAbsent(metricInterpretTool.getModelId(),
|
||||
metricInterpretTool);
|
||||
}
|
||||
return metricInterpretToolMap;
|
||||
}
|
||||
|
||||
private SemanticParseInfo buildSemanticParseInfo(Long modelId, QueryReq queryReq, Set<SchemaElement> metrics,
|
||||
List<SchemaElementMatch> schemaElementMatches, String toolName) {
|
||||
SchemaElement Model = new SchemaElement();
|
||||
Model.setModel(modelId);
|
||||
Model.setId(modelId);
|
||||
SemanticParseInfo semanticParseInfo = new SemanticParseInfo();
|
||||
semanticParseInfo.setMetrics(metrics);
|
||||
SchemaElement dimension = new SchemaElement();
|
||||
dimension.setBizName(TimeDimensionEnum.DAY.getName());
|
||||
semanticParseInfo.setDimensions(Sets.newHashSet(dimension));
|
||||
semanticParseInfo.setElementMatches(schemaElementMatches);
|
||||
semanticParseInfo.setModel(Model);
|
||||
semanticParseInfo.setScore(queryReq.getQueryText().length());
|
||||
DateConf dateConf = new DateConf();
|
||||
dateConf.setDateMode(DateConf.DateMode.RECENT);
|
||||
dateConf.setUnit(15);
|
||||
semanticParseInfo.setDateInfo(dateConf);
|
||||
Map<String, Object> properties = new HashMap<>();
|
||||
properties.put("type", "internal");
|
||||
properties.put("name", toolName);
|
||||
semanticParseInfo.setProperties(properties);
|
||||
fillSemanticParseInfo(semanticParseInfo);
|
||||
return semanticParseInfo;
|
||||
}
|
||||
|
||||
private void fillSemanticParseInfo(SemanticParseInfo semanticParseInfo) {
|
||||
List<SchemaElementMatch> schemaElementMatches = semanticParseInfo.getElementMatches();
|
||||
if (!CollectionUtils.isEmpty(schemaElementMatches)) {
|
||||
schemaElementMatches.stream().filter(schemaElementMatch ->
|
||||
SchemaElementType.VALUE.equals(schemaElementMatch.getElement().getType())
|
||||
|| SchemaElementType.ID.equals(schemaElementMatch.getElement().getType()))
|
||||
.forEach(schemaElementMatch -> {
|
||||
QueryFilter queryFilter = new QueryFilter();
|
||||
queryFilter.setValue(schemaElementMatch.getWord());
|
||||
queryFilter.setElementID(schemaElementMatch.getElement().getId());
|
||||
queryFilter.setName(schemaElementMatch.getElement().getName());
|
||||
queryFilter.setOperator(FilterOperatorEnum.EQUALS);
|
||||
queryFilter.setBizName(schemaElementMatch.getElement().getBizName());
|
||||
semanticParseInfo.getDimensionFilters().add(queryFilter);
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
@@ -0,0 +1,14 @@
|
||||
package com.tencent.supersonic.chat.parser.llm.interpret;
|
||||
|
||||
|
||||
import lombok.AllArgsConstructor;
|
||||
import lombok.Data;
|
||||
import lombok.NoArgsConstructor;
|
||||
|
||||
@Data
|
||||
@AllArgsConstructor
|
||||
@NoArgsConstructor
|
||||
public class MetricOption {
|
||||
|
||||
private Long metricId;
|
||||
}
|
||||
@@ -1,4 +1,4 @@
|
||||
package com.tencent.supersonic.chat.parser.llm;
|
||||
package com.tencent.supersonic.chat.parser.llm.time;
|
||||
|
||||
import com.alibaba.fastjson.JSON;
|
||||
import com.alibaba.fastjson.JSONObject;
|
||||
@@ -6,8 +6,8 @@ import com.tencent.supersonic.chat.api.component.SemanticParser;
|
||||
import com.tencent.supersonic.chat.api.component.SemanticQuery;
|
||||
import com.tencent.supersonic.chat.api.pojo.ChatContext;
|
||||
import com.tencent.supersonic.chat.api.pojo.QueryContext;
|
||||
import com.tencent.supersonic.chat.utils.ChatGptHelper;
|
||||
import com.tencent.supersonic.common.pojo.DateConf;
|
||||
import com.tencent.supersonic.common.util.ChatGptHelper;
|
||||
import com.tencent.supersonic.common.util.ContextUtils;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
|
||||
@@ -17,20 +17,20 @@ public class LLMTimeEnhancementParse implements SemanticParser {
|
||||
|
||||
@Override
|
||||
public void parse(QueryContext queryContext, ChatContext chatContext) {
|
||||
log.info("before queryContext:{},chatContext:{}", queryContext, chatContext);
|
||||
log.info("before queryContext:{},chatContext:{}",queryContext,chatContext);
|
||||
ChatGptHelper chatGptHelper = ContextUtils.getBean(ChatGptHelper.class);
|
||||
String inferredTime = chatGptHelper.inferredTime(queryContext.getRequest().getQueryText());
|
||||
try {
|
||||
String inferredTime = chatGptHelper.inferredTime(queryContext.getRequest().getQueryText());
|
||||
if (!queryContext.getCandidateQueries().isEmpty()) {
|
||||
for (SemanticQuery query : queryContext.getCandidateQueries()) {
|
||||
DateConf dateInfo = query.getParseInfo().getDateInfo();
|
||||
JSONObject jsonObject = JSON.parseObject(inferredTime);
|
||||
if (jsonObject.containsKey("date")) {
|
||||
if (jsonObject.containsKey("date")){
|
||||
dateInfo.setDateMode(DateConf.DateMode.BETWEEN);
|
||||
dateInfo.setStartDate(jsonObject.getString("date"));
|
||||
dateInfo.setEndDate(jsonObject.getString("date"));
|
||||
query.getParseInfo().setDateInfo(dateInfo);
|
||||
} else if (jsonObject.containsKey("start")) {
|
||||
}else if (jsonObject.containsKey("start")){
|
||||
dateInfo.setDateMode(DateConf.DateMode.BETWEEN);
|
||||
dateInfo.setStartDate(jsonObject.getString("start"));
|
||||
dateInfo.setEndDate(jsonObject.getString("end"));
|
||||
@@ -38,12 +38,11 @@ public class LLMTimeEnhancementParse implements SemanticParser {
|
||||
}
|
||||
}
|
||||
}
|
||||
} catch (Exception exception) {
|
||||
log.error("{} parse error,this reason is:{}", LLMTimeEnhancementParse.class.getSimpleName(),
|
||||
(Object) exception.getStackTrace());
|
||||
}catch (Exception exception){
|
||||
log.error("{} parse error,this reason is:{}",LLMTimeEnhancementParse.class.getSimpleName(), (Object) exception.getStackTrace());
|
||||
}
|
||||
|
||||
log.info("after queryContext:{},chatContext:{}", queryContext, chatContext);
|
||||
log.info("{} after queryContext:{},chatContext:{}",LLMTimeEnhancementParse.class.getSimpleName(),queryContext,chatContext);
|
||||
}
|
||||
|
||||
|
||||
@@ -0,0 +1,63 @@
|
||||
package com.tencent.supersonic.chat.parser.rule;
|
||||
|
||||
import com.alibaba.fastjson.JSONObject;
|
||||
import com.google.common.collect.Lists;
|
||||
import com.tencent.supersonic.chat.agent.Agent;
|
||||
import com.tencent.supersonic.chat.agent.tool.AgentToolType;
|
||||
import com.tencent.supersonic.chat.agent.tool.RuleQueryTool;
|
||||
import com.tencent.supersonic.chat.api.component.SemanticParser;
|
||||
import com.tencent.supersonic.chat.api.component.SemanticQuery;
|
||||
import com.tencent.supersonic.chat.api.pojo.ChatContext;
|
||||
import com.tencent.supersonic.chat.api.pojo.QueryContext;
|
||||
import com.tencent.supersonic.chat.service.AgentService;
|
||||
import com.tencent.supersonic.common.util.ContextUtils;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.apache.commons.collections.CollectionUtils;
|
||||
import java.util.Collection;
|
||||
import java.util.List;
|
||||
import java.util.stream.Collectors;
|
||||
|
||||
@Slf4j
|
||||
public class AgentCheckParser implements SemanticParser {
|
||||
|
||||
@Override
|
||||
public void parse(QueryContext queryContext, ChatContext chatContext) {
|
||||
List<SemanticQuery> queries = queryContext.getCandidateQueries();
|
||||
agentCanSupport(queryContext.getRequest().getAgentId(), queries);
|
||||
}
|
||||
|
||||
private void agentCanSupport(Integer agentId, List<SemanticQuery> queries) {
|
||||
AgentService agentService = ContextUtils.getBean(AgentService.class);
|
||||
Agent agent = agentService.getAgent(agentId);
|
||||
if (agent == null) {
|
||||
return;
|
||||
}
|
||||
List<String> queryModes = getRuleTools(agentId).stream().map(RuleQueryTool::getQueryModes)
|
||||
.flatMap(Collection::stream).collect(Collectors.toList());
|
||||
if (CollectionUtils.isEmpty(queries)) {
|
||||
queries.clear();
|
||||
return;
|
||||
}
|
||||
log.info("queries resolved:{} {}", agent.getName(),
|
||||
queries.stream().map(SemanticQuery::getQueryMode).collect(Collectors.toList()));
|
||||
queries.removeIf(query ->
|
||||
!queryModes.contains(query.getQueryMode()));
|
||||
log.info("rule queries witch can be supported by agent :{} {}", agent.getName(),
|
||||
queries.stream().map(SemanticQuery::getQueryMode).collect(Collectors.toList()));
|
||||
}
|
||||
|
||||
private static List<RuleQueryTool> getRuleTools(Integer agentId) {
|
||||
AgentService agentService = ContextUtils.getBean(AgentService.class);
|
||||
Agent agent = agentService.getAgent(agentId);
|
||||
if (agent == null) {
|
||||
return Lists.newArrayList();
|
||||
}
|
||||
List<String> tools = agent.getTools(AgentToolType.RULE);
|
||||
if (CollectionUtils.isEmpty(tools)) {
|
||||
return Lists.newArrayList();
|
||||
}
|
||||
return tools.stream().map(tool -> JSONObject.parseObject(tool, RuleQueryTool.class))
|
||||
.collect(Collectors.toList());
|
||||
}
|
||||
|
||||
}
|
||||
@@ -1,12 +1,9 @@
|
||||
package com.tencent.supersonic.chat.parser.rule;
|
||||
|
||||
import com.tencent.supersonic.chat.api.component.SemanticParser;
|
||||
import com.tencent.supersonic.chat.api.pojo.ChatContext;
|
||||
import com.tencent.supersonic.chat.api.pojo.QueryContext;
|
||||
import com.tencent.supersonic.chat.api.pojo.SchemaElementMatch;
|
||||
import com.tencent.supersonic.chat.api.pojo.SchemaMapInfo;
|
||||
import com.tencent.supersonic.chat.api.pojo.*;
|
||||
import com.tencent.supersonic.chat.query.rule.RuleSemanticQuery;
|
||||
import java.util.List;
|
||||
import java.util.*;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
|
||||
@Slf4j
|
||||
@@ -15,12 +12,10 @@ public class QueryModeParser implements SemanticParser {
|
||||
@Override
|
||||
public void parse(QueryContext queryContext, ChatContext chatContext) {
|
||||
SchemaMapInfo mapInfo = queryContext.getMapInfo();
|
||||
|
||||
// iterate all schemaElementMatches to resolve semantic query
|
||||
for (Long modelId : mapInfo.getMatchedModels()) {
|
||||
List<SchemaElementMatch> elementMatches = mapInfo.getMatchedElements(modelId);
|
||||
List<RuleSemanticQuery> queries = RuleSemanticQuery.resolve(elementMatches, queryContext);
|
||||
|
||||
for (RuleSemanticQuery query : queries) {
|
||||
query.fillParseInfo(modelId, queryContext, chatContext);
|
||||
queryContext.getCandidateQueries().add(query);
|
||||
|
||||
@@ -0,0 +1,236 @@
|
||||
package com.tencent.supersonic.chat.persistence.dataobject;
|
||||
|
||||
import java.util.Date;
|
||||
|
||||
public class AgentDO {
|
||||
/**
|
||||
*
|
||||
*/
|
||||
private Integer id;
|
||||
|
||||
/**
|
||||
*
|
||||
*/
|
||||
private String name;
|
||||
|
||||
/**
|
||||
*
|
||||
*/
|
||||
private String description;
|
||||
|
||||
/**
|
||||
* 0 offline, 1 online
|
||||
*/
|
||||
private Integer status;
|
||||
|
||||
/**
|
||||
*
|
||||
*/
|
||||
private String examples;
|
||||
|
||||
/**
|
||||
*
|
||||
*/
|
||||
private String config;
|
||||
|
||||
/**
|
||||
*
|
||||
*/
|
||||
private String createdBy;
|
||||
|
||||
/**
|
||||
*
|
||||
*/
|
||||
private Date createdAt;
|
||||
|
||||
/**
|
||||
*
|
||||
*/
|
||||
private String updatedBy;
|
||||
|
||||
/**
|
||||
*
|
||||
*/
|
||||
private Date updatedAt;
|
||||
|
||||
/**
|
||||
*
|
||||
*/
|
||||
private Integer enableSearch;
|
||||
|
||||
/**
|
||||
*
|
||||
* @return id
|
||||
*/
|
||||
public Integer getId() {
|
||||
return id;
|
||||
}
|
||||
|
||||
/**
|
||||
*
|
||||
* @param id
|
||||
*/
|
||||
public void setId(Integer id) {
|
||||
this.id = id;
|
||||
}
|
||||
|
||||
/**
|
||||
*
|
||||
* @return name
|
||||
*/
|
||||
public String getName() {
|
||||
return name;
|
||||
}
|
||||
|
||||
/**
|
||||
*
|
||||
* @param name
|
||||
*/
|
||||
public void setName(String name) {
|
||||
this.name = name == null ? null : name.trim();
|
||||
}
|
||||
|
||||
/**
|
||||
*
|
||||
* @return description
|
||||
*/
|
||||
public String getDescription() {
|
||||
return description;
|
||||
}
|
||||
|
||||
/**
|
||||
*
|
||||
* @param description
|
||||
*/
|
||||
public void setDescription(String description) {
|
||||
this.description = description == null ? null : description.trim();
|
||||
}
|
||||
|
||||
/**
|
||||
* 0 offline, 1 online
|
||||
* @return status 0 offline, 1 online
|
||||
*/
|
||||
public Integer getStatus() {
|
||||
return status;
|
||||
}
|
||||
|
||||
/**
|
||||
* 0 offline, 1 online
|
||||
* @param status 0 offline, 1 online
|
||||
*/
|
||||
public void setStatus(Integer status) {
|
||||
this.status = status;
|
||||
}
|
||||
|
||||
/**
|
||||
*
|
||||
* @return examples
|
||||
*/
|
||||
public String getExamples() {
|
||||
return examples;
|
||||
}
|
||||
|
||||
/**
|
||||
*
|
||||
* @param examples
|
||||
*/
|
||||
public void setExamples(String examples) {
|
||||
this.examples = examples == null ? null : examples.trim();
|
||||
}
|
||||
|
||||
/**
|
||||
*
|
||||
* @return config
|
||||
*/
|
||||
public String getConfig() {
|
||||
return config;
|
||||
}
|
||||
|
||||
/**
|
||||
*
|
||||
* @param config
|
||||
*/
|
||||
public void setConfig(String config) {
|
||||
this.config = config == null ? null : config.trim();
|
||||
}
|
||||
|
||||
/**
|
||||
*
|
||||
* @return created_by
|
||||
*/
|
||||
public String getCreatedBy() {
|
||||
return createdBy;
|
||||
}
|
||||
|
||||
/**
|
||||
*
|
||||
* @param createdBy
|
||||
*/
|
||||
public void setCreatedBy(String createdBy) {
|
||||
this.createdBy = createdBy == null ? null : createdBy.trim();
|
||||
}
|
||||
|
||||
/**
|
||||
*
|
||||
* @return created_at
|
||||
*/
|
||||
public Date getCreatedAt() {
|
||||
return createdAt;
|
||||
}
|
||||
|
||||
/**
|
||||
*
|
||||
* @param createdAt
|
||||
*/
|
||||
public void setCreatedAt(Date createdAt) {
|
||||
this.createdAt = createdAt;
|
||||
}
|
||||
|
||||
/**
|
||||
*
|
||||
* @return updated_by
|
||||
*/
|
||||
public String getUpdatedBy() {
|
||||
return updatedBy;
|
||||
}
|
||||
|
||||
/**
|
||||
*
|
||||
* @param updatedBy
|
||||
*/
|
||||
public void setUpdatedBy(String updatedBy) {
|
||||
this.updatedBy = updatedBy == null ? null : updatedBy.trim();
|
||||
}
|
||||
|
||||
/**
|
||||
*
|
||||
* @return updated_at
|
||||
*/
|
||||
public Date getUpdatedAt() {
|
||||
return updatedAt;
|
||||
}
|
||||
|
||||
/**
|
||||
*
|
||||
* @param updatedAt
|
||||
*/
|
||||
public void setUpdatedAt(Date updatedAt) {
|
||||
this.updatedAt = updatedAt;
|
||||
}
|
||||
|
||||
/**
|
||||
*
|
||||
* @return enable_search
|
||||
*/
|
||||
public Integer getEnableSearch() {
|
||||
return enableSearch;
|
||||
}
|
||||
|
||||
/**
|
||||
*
|
||||
* @param enableSearch
|
||||
*/
|
||||
public void setEnableSearch(Integer enableSearch) {
|
||||
this.enableSearch = enableSearch;
|
||||
}
|
||||
}
|
||||
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,71 @@
|
||||
package com.tencent.supersonic.chat.persistence.mapper;
|
||||
|
||||
import com.tencent.supersonic.chat.persistence.dataobject.AgentDO;
|
||||
import com.tencent.supersonic.chat.persistence.dataobject.AgentDOExample;
|
||||
import org.apache.ibatis.annotations.Mapper;
|
||||
import org.apache.ibatis.annotations.Param;
|
||||
|
||||
import java.util.List;
|
||||
|
||||
@Mapper
|
||||
public interface AgentDOMapper {
|
||||
/**
|
||||
*
|
||||
* @mbg.generated
|
||||
*/
|
||||
long countByExample(AgentDOExample example);
|
||||
|
||||
/**
|
||||
*
|
||||
* @mbg.generated
|
||||
*/
|
||||
int deleteByPrimaryKey(Integer id);
|
||||
|
||||
/**
|
||||
*
|
||||
* @mbg.generated
|
||||
*/
|
||||
int insert(AgentDO record);
|
||||
|
||||
/**
|
||||
*
|
||||
* @mbg.generated
|
||||
*/
|
||||
int insertSelective(AgentDO record);
|
||||
|
||||
/**
|
||||
*
|
||||
* @mbg.generated
|
||||
*/
|
||||
List<AgentDO> selectByExample(AgentDOExample example);
|
||||
|
||||
/**
|
||||
*
|
||||
* @mbg.generated
|
||||
*/
|
||||
AgentDO selectByPrimaryKey(Integer id);
|
||||
|
||||
/**
|
||||
*
|
||||
* @mbg.generated
|
||||
*/
|
||||
int updateByExampleSelective(@Param("record") AgentDO record, @Param("example") AgentDOExample example);
|
||||
|
||||
/**
|
||||
*
|
||||
* @mbg.generated
|
||||
*/
|
||||
int updateByExample(@Param("record") AgentDO record, @Param("example") AgentDOExample example);
|
||||
|
||||
/**
|
||||
*
|
||||
* @mbg.generated
|
||||
*/
|
||||
int updateByPrimaryKeySelective(AgentDO record);
|
||||
|
||||
/**
|
||||
*
|
||||
* @mbg.generated
|
||||
*/
|
||||
int updateByPrimaryKey(AgentDO record);
|
||||
}
|
||||
@@ -0,0 +1,18 @@
|
||||
package com.tencent.supersonic.chat.persistence.repository;
|
||||
|
||||
import com.tencent.supersonic.chat.persistence.dataobject.AgentDO;
|
||||
|
||||
import java.util.List;
|
||||
|
||||
public interface AgentRepository {
|
||||
|
||||
List<AgentDO> getAgents();
|
||||
|
||||
void createAgent(AgentDO agentDO);
|
||||
|
||||
void updateAgent(AgentDO agentDO);
|
||||
|
||||
AgentDO getAgent(Integer id);
|
||||
|
||||
void deleteAgent(Integer id);
|
||||
}
|
||||
@@ -0,0 +1,43 @@
|
||||
package com.tencent.supersonic.chat.persistence.repository.impl;
|
||||
|
||||
import com.tencent.supersonic.chat.persistence.dataobject.AgentDO;
|
||||
import com.tencent.supersonic.chat.persistence.dataobject.AgentDOExample;
|
||||
import com.tencent.supersonic.chat.persistence.mapper.AgentDOMapper;
|
||||
import com.tencent.supersonic.chat.persistence.repository.AgentRepository;
|
||||
import org.springframework.stereotype.Repository;
|
||||
import java.util.List;
|
||||
|
||||
@Repository
|
||||
public class AgentRepositoryImpl implements AgentRepository {
|
||||
|
||||
private AgentDOMapper agentDOMapper;
|
||||
|
||||
public AgentRepositoryImpl(AgentDOMapper agentDOMapper) {
|
||||
this.agentDOMapper = agentDOMapper;
|
||||
}
|
||||
|
||||
@Override
|
||||
public List<AgentDO> getAgents() {
|
||||
return agentDOMapper.selectByExample(new AgentDOExample());
|
||||
}
|
||||
|
||||
@Override
|
||||
public void createAgent(AgentDO agentDO) {
|
||||
agentDOMapper.insert(agentDO);
|
||||
}
|
||||
|
||||
@Override
|
||||
public void updateAgent(AgentDO agentDO) {
|
||||
agentDOMapper.updateByPrimaryKey(agentDO);
|
||||
}
|
||||
|
||||
@Override
|
||||
public AgentDO getAgent(Integer id) {
|
||||
return agentDOMapper.selectByPrimaryKey(id);
|
||||
}
|
||||
|
||||
@Override
|
||||
public void deleteAgent(Integer id) {
|
||||
agentDOMapper.deleteByPrimaryKey(id);
|
||||
}
|
||||
}
|
||||
@@ -7,12 +7,14 @@ import com.tencent.supersonic.chat.persistence.dataobject.ChatContextDO;
|
||||
import com.tencent.supersonic.chat.persistence.mapper.ChatContextMapper;
|
||||
import com.tencent.supersonic.chat.persistence.repository.ChatContextRepository;
|
||||
import com.tencent.supersonic.common.util.JsonUtil;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.springframework.beans.factory.annotation.Autowired;
|
||||
import org.springframework.context.annotation.Primary;
|
||||
import org.springframework.stereotype.Repository;
|
||||
|
||||
@Repository
|
||||
@Primary
|
||||
@Slf4j
|
||||
public class ChatContextRepositoryImpl implements ChatContextRepository {
|
||||
|
||||
@Autowired(required = false)
|
||||
@@ -50,8 +52,8 @@ public class ChatContextRepositoryImpl implements ChatContextRepository {
|
||||
chatContext.setUser(contextDO.getUser());
|
||||
chatContext.setQueryText(contextDO.getQueryText());
|
||||
if (contextDO.getSemanticParse() != null && !contextDO.getSemanticParse().isEmpty()) {
|
||||
SemanticParseInfo semanticParseInfo = JsonUtil.toObject(contextDO.getSemanticParse(),
|
||||
SemanticParseInfo.class);
|
||||
log.info("--->: {}",contextDO.getSemanticParse());
|
||||
SemanticParseInfo semanticParseInfo = JsonUtil.toObject(contextDO.getSemanticParse(), SemanticParseInfo.class);
|
||||
chatContext.setParseInfo(semanticParseInfo);
|
||||
}
|
||||
return chatContext;
|
||||
|
||||
@@ -3,11 +3,11 @@ package com.tencent.supersonic.chat.plugin;
|
||||
import com.alibaba.fastjson.JSONObject;
|
||||
import com.google.common.collect.Lists;
|
||||
import com.google.common.collect.Sets;
|
||||
import com.tencent.supersonic.chat.api.pojo.QueryContext;
|
||||
import com.tencent.supersonic.chat.api.pojo.SchemaElement;
|
||||
import com.tencent.supersonic.chat.api.pojo.SchemaElementMatch;
|
||||
import com.tencent.supersonic.chat.api.pojo.SchemaElementType;
|
||||
import com.tencent.supersonic.chat.api.pojo.SchemaMapInfo;
|
||||
import com.tencent.supersonic.chat.agent.tool.DslTool;
|
||||
import com.tencent.supersonic.chat.api.pojo.*;
|
||||
import com.tencent.supersonic.chat.agent.Agent;
|
||||
import com.tencent.supersonic.chat.agent.tool.AgentToolType;
|
||||
import com.tencent.supersonic.chat.agent.tool.PluginTool;
|
||||
import com.tencent.supersonic.chat.parser.ParseMode;
|
||||
import com.tencent.supersonic.chat.parser.embedding.EmbeddingConfig;
|
||||
import com.tencent.supersonic.chat.parser.embedding.EmbeddingResp;
|
||||
@@ -16,30 +16,22 @@ import com.tencent.supersonic.chat.plugin.event.PluginAddEvent;
|
||||
import com.tencent.supersonic.chat.plugin.event.PluginUpdateEvent;
|
||||
import com.tencent.supersonic.chat.query.plugin.ParamOption;
|
||||
import com.tencent.supersonic.chat.query.plugin.WebBase;
|
||||
import com.tencent.supersonic.chat.service.AgentService;
|
||||
import com.tencent.supersonic.chat.service.PluginService;
|
||||
import com.tencent.supersonic.common.util.ContextUtils;
|
||||
import java.net.URI;
|
||||
import java.util.ArrayList;
|
||||
import java.util.HashMap;
|
||||
import java.util.HashSet;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
import java.util.Objects;
|
||||
import java.util.Optional;
|
||||
import java.util.Set;
|
||||
import java.util.*;
|
||||
import java.util.concurrent.ConcurrentHashMap;
|
||||
import java.util.stream.Collectors;
|
||||
import java.util.stream.Stream;
|
||||
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.apache.commons.collections.CollectionUtils;
|
||||
import org.apache.commons.lang3.tuple.Pair;
|
||||
import org.apache.logging.log4j.util.Strings;
|
||||
import org.springframework.context.event.EventListener;
|
||||
import org.springframework.core.ParameterizedTypeReference;
|
||||
import org.springframework.http.HttpEntity;
|
||||
import org.springframework.http.HttpHeaders;
|
||||
import org.springframework.http.HttpMethod;
|
||||
import org.springframework.http.MediaType;
|
||||
import org.springframework.http.ResponseEntity;
|
||||
import org.springframework.http.*;
|
||||
import org.springframework.stereotype.Component;
|
||||
import org.springframework.web.client.RestTemplate;
|
||||
import org.springframework.web.util.UriComponentsBuilder;
|
||||
@@ -59,12 +51,40 @@ public class PluginManager {
|
||||
this.restTemplate = restTemplate;
|
||||
}
|
||||
|
||||
public static List<Plugin> getPlugins() {
|
||||
public static List<Plugin> getPluginAgentCanSupport(Integer agentId) {
|
||||
PluginService pluginService = ContextUtils.getBean(PluginService.class);
|
||||
List<Plugin> pluginList = pluginService.getPluginList().stream().filter(plugin ->
|
||||
CollectionUtils.isNotEmpty(plugin.getModelList())).collect(Collectors.toList());
|
||||
pluginList.addAll(internalPluginMap.values());
|
||||
return new ArrayList<>(pluginList);
|
||||
List<Plugin> plugins = pluginService.getPluginList();
|
||||
if (agentId == null) {
|
||||
return plugins;
|
||||
}
|
||||
Agent agent = ContextUtils.getBean(AgentService.class).getAgent(agentId);
|
||||
if (agent == null) {
|
||||
return plugins;
|
||||
}
|
||||
List<Long> pluginIds = getPluginTools(agentId).stream().map(PluginTool::getPlugins)
|
||||
.flatMap(Collection::stream).collect(Collectors.toList());
|
||||
if (CollectionUtils.isEmpty(pluginIds)) {
|
||||
return Lists.newArrayList();
|
||||
}
|
||||
plugins = plugins.stream().filter(plugin -> pluginIds.contains(plugin.getId()))
|
||||
.collect(Collectors.toList());
|
||||
log.info("plugins witch can be supported by cur agent :{} {}", agent.getName(),
|
||||
plugins.stream().map(Plugin::getName).collect(Collectors.toList()));
|
||||
return plugins;
|
||||
}
|
||||
|
||||
private static List<PluginTool> getPluginTools(Integer agentId) {
|
||||
AgentService agentService = ContextUtils.getBean(AgentService.class);
|
||||
Agent agent = agentService.getAgent(agentId);
|
||||
if (agent == null) {
|
||||
return Lists.newArrayList();
|
||||
}
|
||||
List<String> tools = agent.getTools(AgentToolType.PLUGIN);
|
||||
if (CollectionUtils.isEmpty(tools)) {
|
||||
return Lists.newArrayList();
|
||||
}
|
||||
return tools.stream().map(tool -> JSONObject.parseObject(tool, PluginTool.class))
|
||||
.collect(Collectors.toList());
|
||||
}
|
||||
|
||||
@EventListener
|
||||
@@ -201,17 +221,17 @@ public class PluginManager {
|
||||
return String.valueOf(Integer.parseInt(id) / 1000);
|
||||
}
|
||||
|
||||
public static Pair<Boolean, List<Long>> resolve(Plugin plugin, QueryContext queryContext) {
|
||||
public static Pair<Boolean, Set<Long>> resolve(Plugin plugin, QueryContext queryContext) {
|
||||
SchemaMapInfo schemaMapInfo = queryContext.getMapInfo();
|
||||
Set<Long> pluginMatchedModel = getPluginMatchedModel(plugin, queryContext);
|
||||
if (CollectionUtils.isEmpty(pluginMatchedModel) && !plugin.isContainsAllModel()) {
|
||||
return Pair.of(false, Lists.newArrayList());
|
||||
return Pair.of(false, Sets.newHashSet());
|
||||
}
|
||||
List<ParamOption> paramOptions = getSemanticOption(plugin);
|
||||
if (CollectionUtils.isEmpty(paramOptions)) {
|
||||
return Pair.of(true, new ArrayList<>(pluginMatchedModel));
|
||||
return Pair.of(true, Sets.newHashSet());
|
||||
}
|
||||
List<Long> matchedModel = Lists.newArrayList();
|
||||
Set<Long> matchedModel = Sets.newHashSet();
|
||||
Map<Long, List<ParamOption>> paramOptionMap = paramOptions.stream().
|
||||
collect(Collectors.groupingBy(ParamOption::getModelId));
|
||||
for (Long modelId : paramOptionMap.keySet()) {
|
||||
@@ -237,7 +257,7 @@ public class PluginManager {
|
||||
}
|
||||
}
|
||||
if (CollectionUtils.isEmpty(matchedModel)) {
|
||||
return Pair.of(false, Lists.newArrayList());
|
||||
return Pair.of(false, Sets.newHashSet());
|
||||
}
|
||||
return Pair.of(true, matchedModel);
|
||||
}
|
||||
|
||||
@@ -1,149 +0,0 @@
|
||||
package com.tencent.supersonic.chat.query.ContentInterpret;
|
||||
|
||||
import com.alibaba.fastjson.JSONObject;
|
||||
import com.google.common.collect.Lists;
|
||||
import com.tencent.supersonic.auth.api.authentication.pojo.User;
|
||||
import com.tencent.supersonic.chat.api.component.SemanticLayer;
|
||||
import com.tencent.supersonic.chat.api.pojo.ModelSchema;
|
||||
import com.tencent.supersonic.chat.api.pojo.SchemaElement;
|
||||
import com.tencent.supersonic.chat.api.pojo.request.QueryFilter;
|
||||
import com.tencent.supersonic.chat.api.pojo.response.QueryResult;
|
||||
import com.tencent.supersonic.chat.api.pojo.response.QueryState;
|
||||
import com.tencent.supersonic.chat.plugin.PluginManager;
|
||||
import com.tencent.supersonic.chat.plugin.PluginParseResult;
|
||||
import com.tencent.supersonic.chat.query.QueryManager;
|
||||
import com.tencent.supersonic.chat.query.plugin.PluginSemanticQuery;
|
||||
import com.tencent.supersonic.chat.utils.ComponentFactory;
|
||||
import com.tencent.supersonic.common.pojo.Aggregator;
|
||||
import com.tencent.supersonic.common.pojo.Constants;
|
||||
import com.tencent.supersonic.common.pojo.DateConf;
|
||||
import com.tencent.supersonic.common.pojo.QueryColumn;
|
||||
import com.tencent.supersonic.common.pojo.enums.AggOperatorEnum;
|
||||
import com.tencent.supersonic.common.util.ContextUtils;
|
||||
import com.tencent.supersonic.common.util.JsonUtil;
|
||||
import com.tencent.supersonic.semantic.api.model.enums.TimeDimensionEnum;
|
||||
import com.tencent.supersonic.semantic.api.model.response.QueryResultWithSchemaResp;
|
||||
import com.tencent.supersonic.semantic.api.query.pojo.Filter;
|
||||
import com.tencent.supersonic.semantic.api.query.request.QueryStructReq;
|
||||
import java.util.HashMap;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
import java.util.Set;
|
||||
import java.util.stream.Collectors;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.apache.calcite.sql.parser.SqlParseException;
|
||||
import org.springframework.beans.BeanUtils;
|
||||
import org.springframework.http.ResponseEntity;
|
||||
import org.springframework.stereotype.Component;
|
||||
import org.springframework.util.CollectionUtils;
|
||||
|
||||
@Slf4j
|
||||
@Component
|
||||
public class ContentInterpretQuery extends PluginSemanticQuery {
|
||||
|
||||
@Override
|
||||
public String getQueryMode() {
|
||||
return "CONTENT_INTERPRET";
|
||||
}
|
||||
|
||||
public ContentInterpretQuery() {
|
||||
QueryManager.register(this);
|
||||
}
|
||||
|
||||
@Override
|
||||
public QueryResult execute(User user) throws SqlParseException {
|
||||
QueryResultWithSchemaResp queryResultWithSchemaResp = queryMetric(user);
|
||||
String text = generateDataText(queryResultWithSchemaResp);
|
||||
Map<String, Object> properties = parseInfo.getProperties();
|
||||
PluginParseResult pluginParseResult = JsonUtil.toObject(JsonUtil.toString(properties.get(Constants.CONTEXT))
|
||||
, PluginParseResult.class);
|
||||
String answer = fetchInterpret(pluginParseResult.getRequest().getQueryText(), text);
|
||||
QueryResult queryResult = new QueryResult();
|
||||
List<QueryColumn> queryColumns = Lists.newArrayList(new QueryColumn("结果", "string", "answer"));
|
||||
Map<String, Object> result = new HashMap<>();
|
||||
result.put("answer", answer);
|
||||
List<Map<String, Object>> resultList = Lists.newArrayList();
|
||||
resultList.add(result);
|
||||
queryResultWithSchemaResp.setResultList(resultList);
|
||||
queryResultWithSchemaResp.setColumns(queryColumns);
|
||||
queryResult.setResponse(queryResultWithSchemaResp);
|
||||
queryResult.setQueryMode(getQueryMode());
|
||||
queryResult.setQueryState(QueryState.SUCCESS);
|
||||
return queryResult;
|
||||
}
|
||||
|
||||
|
||||
private QueryResultWithSchemaResp queryMetric(User user) {
|
||||
SemanticLayer semanticLayer = ComponentFactory.getSemanticLayer();
|
||||
QueryStructReq queryStructReq = new QueryStructReq();
|
||||
queryStructReq.setModelId(parseInfo.getModelId());
|
||||
queryStructReq.setGroups(Lists.newArrayList(TimeDimensionEnum.DAY.getName()));
|
||||
ModelSchema modelSchema = semanticLayer.getModelSchema(parseInfo.getModelId(), true);
|
||||
queryStructReq.setAggregators(buildAggregator(modelSchema));
|
||||
List<Filter> filterList = Lists.newArrayList();
|
||||
for (QueryFilter queryFilter : parseInfo.getDimensionFilters()) {
|
||||
Filter filter = new Filter();
|
||||
BeanUtils.copyProperties(queryFilter, filter);
|
||||
filterList.add(filter);
|
||||
}
|
||||
queryStructReq.setDimensionFilters(filterList);
|
||||
DateConf dateConf = new DateConf();
|
||||
dateConf.setDateMode(DateConf.DateMode.RECENT);
|
||||
dateConf.setUnit(7);
|
||||
queryStructReq.setDateInfo(dateConf);
|
||||
return semanticLayer.queryByStruct(queryStructReq, user);
|
||||
}
|
||||
|
||||
private List<Aggregator> buildAggregator(ModelSchema modelSchema) {
|
||||
List<Aggregator> aggregators = Lists.newArrayList();
|
||||
Set<SchemaElement> metrics = modelSchema.getMetrics();
|
||||
if (CollectionUtils.isEmpty(metrics)) {
|
||||
return aggregators;
|
||||
}
|
||||
for (SchemaElement schemaElement : metrics) {
|
||||
Aggregator aggregator = new Aggregator();
|
||||
aggregator.setColumn(schemaElement.getBizName());
|
||||
aggregator.setFunc(AggOperatorEnum.SUM);
|
||||
aggregator.setNameCh(schemaElement.getName());
|
||||
aggregators.add(aggregator);
|
||||
}
|
||||
return aggregators;
|
||||
}
|
||||
|
||||
|
||||
public String generateDataText(QueryResultWithSchemaResp queryResultWithSchemaResp) {
|
||||
Map<String, String> map = queryResultWithSchemaResp.getColumns().stream()
|
||||
.collect(Collectors.toMap(QueryColumn::getNameEn, QueryColumn::getName));
|
||||
StringBuilder stringBuilder = new StringBuilder();
|
||||
for (Map<String, Object> valueMap : queryResultWithSchemaResp.getResultList()) {
|
||||
for (String key : valueMap.keySet()) {
|
||||
String name = "";
|
||||
if (TimeDimensionEnum.getNameList().contains(key)) {
|
||||
name = "日期";
|
||||
} else {
|
||||
name = map.get(key);
|
||||
}
|
||||
String value = String.valueOf(valueMap.get(key));
|
||||
stringBuilder.append(name).append(":").append(value).append(" ");
|
||||
}
|
||||
}
|
||||
return stringBuilder.toString();
|
||||
}
|
||||
|
||||
|
||||
public String fetchInterpret(String queryText, String dataText) {
|
||||
PluginManager pluginManager = ContextUtils.getBean(PluginManager.class);
|
||||
LLmAnswerReq lLmAnswerReq = new LLmAnswerReq();
|
||||
lLmAnswerReq.setQueryText(queryText);
|
||||
lLmAnswerReq.setPluginOutput(dataText);
|
||||
ResponseEntity<String> responseEntity = pluginManager.doRequest("answer_with_plugin_call",
|
||||
JSONObject.toJSONString(lLmAnswerReq));
|
||||
LLmAnswerResp lLmAnswerResp = JSONObject.parseObject(responseEntity.getBody(), LLmAnswerResp.class);
|
||||
if (lLmAnswerResp != null) {
|
||||
return lLmAnswerResp.getAssistant_message();
|
||||
}
|
||||
return null;
|
||||
}
|
||||
|
||||
|
||||
}
|
||||
@@ -3,6 +3,7 @@ package com.tencent.supersonic.chat.query;
|
||||
import com.tencent.supersonic.chat.api.component.SemanticQuery;
|
||||
import com.tencent.supersonic.chat.api.pojo.SchemaElementMatch;
|
||||
import com.tencent.supersonic.chat.api.pojo.SemanticParseInfo;
|
||||
import com.tencent.supersonic.chat.api.pojo.request.QueryReq;
|
||||
import com.tencent.supersonic.chat.query.rule.RuleSemanticQuery;
|
||||
import com.tencent.supersonic.chat.query.rule.metric.MetricEntityQuery;
|
||||
import com.tencent.supersonic.chat.query.rule.metric.MetricModelQuery;
|
||||
@@ -18,7 +19,7 @@ public class HeuristicQuerySelector implements QuerySelector {
|
||||
private static final double CANDIDATE_THRESHOLD = 0.2;
|
||||
|
||||
@Override
|
||||
public List<SemanticQuery> select(List<SemanticQuery> candidateQueries) {
|
||||
public List<SemanticQuery> select(List<SemanticQuery> candidateQueries, QueryReq queryReq) {
|
||||
List<SemanticQuery> selectedQueries = new ArrayList<>();
|
||||
|
||||
if (CollectionUtils.isNotEmpty(candidateQueries) && candidateQueries.size() == 1) {
|
||||
|
||||
@@ -1,6 +1,8 @@
|
||||
package com.tencent.supersonic.chat.query;
|
||||
|
||||
import com.tencent.supersonic.chat.api.component.SemanticQuery;
|
||||
import com.tencent.supersonic.chat.api.pojo.request.QueryReq;
|
||||
|
||||
import java.util.List;
|
||||
|
||||
/**
|
||||
@@ -8,5 +10,5 @@ import java.util.List;
|
||||
**/
|
||||
public interface QuerySelector {
|
||||
|
||||
List<SemanticQuery> select(List<SemanticQuery> candidateQueries);
|
||||
List<SemanticQuery> select(List<SemanticQuery> candidateQueries, QueryReq queryReq);
|
||||
}
|
||||
|
||||
@@ -1,78 +0,0 @@
|
||||
package com.tencent.supersonic.chat.query.dsl;
|
||||
|
||||
import com.tencent.supersonic.chat.api.pojo.SchemaElement;
|
||||
import com.tencent.supersonic.chat.api.pojo.SemanticParseInfo;
|
||||
import com.tencent.supersonic.chat.api.pojo.SemanticSchema;
|
||||
import com.tencent.supersonic.chat.api.pojo.request.QueryFilter;
|
||||
import com.tencent.supersonic.chat.api.pojo.request.QueryFilters;
|
||||
import com.tencent.supersonic.common.pojo.Constants;
|
||||
import com.tencent.supersonic.common.util.ContextUtils;
|
||||
import com.tencent.supersonic.common.util.StringUtil;
|
||||
import com.tencent.supersonic.common.util.jsqlparser.CCJSqlParserUtils;
|
||||
import com.tencent.supersonic.knowledge.service.SchemaService;
|
||||
import com.tencent.supersonic.semantic.api.model.enums.TimeDimensionEnum;
|
||||
import java.util.ArrayList;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
import java.util.Objects;
|
||||
import java.util.stream.Collectors;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import net.sf.jsqlparser.expression.Expression;
|
||||
import net.sf.jsqlparser.parser.CCJSqlParserUtil;
|
||||
import org.apache.commons.collections.CollectionUtils;
|
||||
import org.apache.commons.lang3.StringUtils;
|
||||
|
||||
@Slf4j
|
||||
public class DSLBuilder {
|
||||
|
||||
public static final String DATA_Field = "数据日期";
|
||||
public static final String TABLE_PREFIX = "t_";
|
||||
|
||||
public String build(SemanticParseInfo parseInfo, QueryFilters queryFilters, LLMResp llmResp, Long modelId)
|
||||
throws Exception {
|
||||
|
||||
String sqlOutput = llmResp.getSqlOutput();
|
||||
SemanticSchema semanticSchema = ContextUtils.getBean(SchemaService.class).getSemanticSchema();
|
||||
List<SchemaElement> dbAllFields = new ArrayList<>();
|
||||
dbAllFields.addAll(semanticSchema.getMetrics());
|
||||
dbAllFields.addAll(semanticSchema.getDimensions());
|
||||
|
||||
Map<String, String> fieldToBizName = getMapInfo(modelId, dbAllFields);
|
||||
fieldToBizName.put(DATA_Field, TimeDimensionEnum.DAY.getName());
|
||||
|
||||
sqlOutput = CCJSqlParserUtils.replaceFields(sqlOutput, fieldToBizName);
|
||||
|
||||
sqlOutput = CCJSqlParserUtils.replaceTable(sqlOutput, TABLE_PREFIX + modelId);
|
||||
|
||||
String queryFilter = getQueryFilter(queryFilters);
|
||||
if (StringUtils.isNotEmpty(queryFilter)) {
|
||||
log.info("add queryFilter to sql :{}", queryFilter);
|
||||
Expression expression = CCJSqlParserUtil.parseCondExpression(queryFilter);
|
||||
CCJSqlParserUtils.addWhere(sqlOutput, expression);
|
||||
}
|
||||
|
||||
log.info("build sqlOutput:{}", sqlOutput);
|
||||
return sqlOutput;
|
||||
}
|
||||
|
||||
protected Map<String, String> getMapInfo(Long modelId, List<SchemaElement> metrics) {
|
||||
return metrics.stream().filter(entry -> entry.getModel().equals(modelId))
|
||||
.collect(Collectors.toMap(SchemaElement::getName, a -> a.getBizName(), (k1, k2) -> k1));
|
||||
}
|
||||
|
||||
private String getQueryFilter(QueryFilters queryFilters) {
|
||||
if (Objects.isNull(queryFilters) || CollectionUtils.isEmpty(queryFilters.getFilters())) {
|
||||
return "";
|
||||
}
|
||||
List<QueryFilter> filters = queryFilters.getFilters();
|
||||
|
||||
return filters.stream()
|
||||
.map(filter -> {
|
||||
String bizNameWrap = StringUtil.getSpaceWrap(filter.getBizName());
|
||||
String operatorWrap = StringUtil.getSpaceWrap(filter.getOperator().getValue());
|
||||
String valueWrap = StringUtil.getCommaWrap(filter.getValue().toString());
|
||||
return bizNameWrap + operatorWrap + valueWrap;
|
||||
})
|
||||
.collect(Collectors.joining(Constants.AND_UPPER));
|
||||
}
|
||||
}
|
||||
@@ -2,14 +2,14 @@ package com.tencent.supersonic.chat.query.dsl;
|
||||
|
||||
import com.tencent.supersonic.auth.api.authentication.pojo.User;
|
||||
import com.tencent.supersonic.chat.api.component.SemanticLayer;
|
||||
import com.tencent.supersonic.chat.api.pojo.SemanticParseInfo;
|
||||
import com.tencent.supersonic.chat.api.pojo.request.QueryFilters;
|
||||
import com.tencent.supersonic.chat.api.pojo.request.QueryReq;
|
||||
import com.tencent.supersonic.chat.api.pojo.response.EntityInfo;
|
||||
import com.tencent.supersonic.chat.api.pojo.response.QueryResult;
|
||||
import com.tencent.supersonic.chat.api.pojo.response.QueryState;
|
||||
import com.tencent.supersonic.chat.parser.llm.DSLParseResult;
|
||||
import com.tencent.supersonic.chat.parser.llm.dsl.DSLParseResult;
|
||||
import com.tencent.supersonic.chat.query.QueryManager;
|
||||
import com.tencent.supersonic.chat.api.pojo.CorrectionInfo;
|
||||
import com.tencent.supersonic.chat.api.component.DSLOptimizer;
|
||||
import com.tencent.supersonic.chat.query.plugin.PluginSemanticQuery;
|
||||
import com.tencent.supersonic.chat.service.SemanticService;
|
||||
import com.tencent.supersonic.chat.utils.ComponentFactory;
|
||||
@@ -32,7 +32,6 @@ import org.springframework.stereotype.Component;
|
||||
public class DSLQuery extends PluginSemanticQuery {
|
||||
|
||||
public static final String QUERY_MODE = "DSL";
|
||||
private DSLBuilder dslBuilder = new DSLBuilder();
|
||||
protected SemanticLayer semanticLayer = ComponentFactory.getSemanticLayer();
|
||||
|
||||
public DSLQuery() {
|
||||
@@ -51,12 +50,26 @@ public class DSLQuery extends PluginSemanticQuery {
|
||||
LLMResp llmResp = dslParseResult.getLlmResp();
|
||||
QueryReq queryReq = dslParseResult.getRequest();
|
||||
|
||||
Long modelId = parseInfo.getModelId();
|
||||
String querySql = convertToSql(queryReq.getQueryFilters(), llmResp, parseInfo, modelId);
|
||||
CorrectionInfo correctionInfo = CorrectionInfo.builder()
|
||||
.queryFilters(queryReq.getQueryFilters())
|
||||
.sql(llmResp.getSqlOutput())
|
||||
.parseInfo(parseInfo)
|
||||
.build();
|
||||
|
||||
List<DSLOptimizer> DSLCorrections = ComponentFactory.getSqlCorrections();
|
||||
|
||||
DSLCorrections.forEach(DSLCorrection -> {
|
||||
try {
|
||||
DSLCorrection.rewriter(correctionInfo);
|
||||
log.info("sqlCorrection:{} sql:{}", DSLCorrection.getClass().getSimpleName(), correctionInfo.getSql());
|
||||
} catch (Exception e) {
|
||||
log.error("sqlCorrection:{} execute error,correctionInfo:{}", DSLCorrection, correctionInfo, e);
|
||||
}
|
||||
});
|
||||
String querySql = correctionInfo.getSql();
|
||||
|
||||
long startTime = System.currentTimeMillis();
|
||||
|
||||
QueryDslReq queryDslReq = QueryReqBuilder.buildDslReq(querySql, modelId);
|
||||
QueryDslReq queryDslReq = QueryReqBuilder.buildDslReq(querySql, parseInfo.getModelId());
|
||||
QueryResultWithSchemaResp queryResp = semanticLayer.queryByDsl(queryDslReq, user);
|
||||
|
||||
log.info("queryByDsl cost:{},querySql:{}", System.currentTimeMillis() - startTime, querySql);
|
||||
@@ -80,17 +93,4 @@ public class DSLQuery extends PluginSemanticQuery {
|
||||
parseInfo.setProperties(null);
|
||||
return queryResult;
|
||||
}
|
||||
|
||||
|
||||
protected String convertToSql(QueryFilters queryFilters, LLMResp llmResp, SemanticParseInfo parseInfo,
|
||||
Long modelId) {
|
||||
try {
|
||||
return dslBuilder.build(parseInfo, queryFilters, llmResp, modelId);
|
||||
} catch (Exception e) {
|
||||
log.error("convertToSql error", e);
|
||||
}
|
||||
return null;
|
||||
}
|
||||
|
||||
|
||||
}
|
||||
|
||||
@@ -0,0 +1,33 @@
|
||||
package com.tencent.supersonic.chat.query.dsl.optimizer;
|
||||
|
||||
import com.tencent.supersonic.chat.api.component.DSLOptimizer;
|
||||
import com.tencent.supersonic.chat.api.pojo.SchemaElement;
|
||||
import com.tencent.supersonic.chat.api.pojo.SemanticSchema;
|
||||
import com.tencent.supersonic.common.util.ContextUtils;
|
||||
import com.tencent.supersonic.knowledge.service.SchemaService;
|
||||
import com.tencent.supersonic.semantic.api.model.enums.TimeDimensionEnum;
|
||||
import java.util.ArrayList;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
import java.util.stream.Collectors;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
|
||||
@Slf4j
|
||||
public abstract class BaseDSLOptimizer implements DSLOptimizer {
|
||||
public static final String DATE_FIELD = "数据日期";
|
||||
protected Map<String, String> getFieldToBizName(Long modelId) {
|
||||
|
||||
SemanticSchema semanticSchema = ContextUtils.getBean(SchemaService.class).getSemanticSchema();
|
||||
|
||||
List<SchemaElement> dbAllFields = new ArrayList<>();
|
||||
dbAllFields.addAll(semanticSchema.getMetrics());
|
||||
dbAllFields.addAll(semanticSchema.getDimensions());
|
||||
|
||||
Map<String, String> result = dbAllFields.stream()
|
||||
.filter(entry -> entry.getModel().equals(modelId))
|
||||
.collect(Collectors.toMap(SchemaElement::getName, a -> a.getBizName(), (k1, k2) -> k1));
|
||||
result.put(DATE_FIELD, TimeDimensionEnum.DAY.getName());
|
||||
return result;
|
||||
}
|
||||
|
||||
}
|
||||
@@ -0,0 +1,26 @@
|
||||
package com.tencent.supersonic.chat.query.dsl.optimizer;
|
||||
|
||||
import com.tencent.supersonic.chat.api.pojo.CorrectionInfo;
|
||||
import com.tencent.supersonic.chat.parser.llm.dsl.DSLDateHelper;
|
||||
import com.tencent.supersonic.common.util.jsqlparser.CCJSqlParserUtils;
|
||||
import java.util.List;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.springframework.util.CollectionUtils;
|
||||
|
||||
@Slf4j
|
||||
public class DateFieldCorrector extends BaseDSLOptimizer {
|
||||
|
||||
@Override
|
||||
public CorrectionInfo rewriter(CorrectionInfo correctionInfo) {
|
||||
|
||||
String sql = correctionInfo.getSql();
|
||||
List<String> whereFields = CCJSqlParserUtils.getWhereFields(sql);
|
||||
if (CollectionUtils.isEmpty(whereFields) || !whereFields.contains(BaseDSLOptimizer.DATE_FIELD)) {
|
||||
String currentDate = DSLDateHelper.getCurrentDate(correctionInfo.getParseInfo().getModelId());
|
||||
sql = CCJSqlParserUtils.addWhere(sql, BaseDSLOptimizer.DATE_FIELD, currentDate);
|
||||
}
|
||||
correctionInfo.setSql(sql);
|
||||
return correctionInfo;
|
||||
}
|
||||
|
||||
}
|
||||
@@ -0,0 +1,17 @@
|
||||
package com.tencent.supersonic.chat.query.dsl.optimizer;
|
||||
|
||||
import com.tencent.supersonic.chat.api.pojo.CorrectionInfo;
|
||||
import com.tencent.supersonic.common.util.jsqlparser.CCJSqlParserUtils;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
|
||||
@Slf4j
|
||||
public class FieldCorrector extends BaseDSLOptimizer {
|
||||
|
||||
@Override
|
||||
public CorrectionInfo rewriter(CorrectionInfo correctionInfo) {
|
||||
String replaceFields = CCJSqlParserUtils.replaceFields(correctionInfo.getSql(),
|
||||
getFieldToBizName(correctionInfo.getParseInfo().getModelId()));
|
||||
correctionInfo.setSql(replaceFields);
|
||||
return correctionInfo;
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,16 @@
|
||||
package com.tencent.supersonic.chat.query.dsl.optimizer;
|
||||
|
||||
import com.tencent.supersonic.chat.api.pojo.CorrectionInfo;
|
||||
import com.tencent.supersonic.common.util.jsqlparser.CCJSqlParserUtils;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
|
||||
@Slf4j
|
||||
public class FunctionCorrector extends BaseDSLOptimizer {
|
||||
|
||||
@Override
|
||||
public CorrectionInfo rewriter(CorrectionInfo correctionInfo) {
|
||||
String replaceFunction = CCJSqlParserUtils.replaceFunction(correctionInfo.getSql());
|
||||
correctionInfo.setSql(replaceFunction);
|
||||
return correctionInfo;
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,48 @@
|
||||
package com.tencent.supersonic.chat.query.dsl.optimizer;
|
||||
|
||||
import com.tencent.supersonic.chat.api.pojo.CorrectionInfo;
|
||||
import com.tencent.supersonic.chat.api.pojo.request.QueryFilters;
|
||||
import com.tencent.supersonic.common.pojo.Constants;
|
||||
import com.tencent.supersonic.common.util.StringUtil;
|
||||
import com.tencent.supersonic.common.util.jsqlparser.CCJSqlParserUtils;
|
||||
import java.util.Objects;
|
||||
import java.util.stream.Collectors;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import net.sf.jsqlparser.JSQLParserException;
|
||||
import net.sf.jsqlparser.expression.Expression;
|
||||
import net.sf.jsqlparser.parser.CCJSqlParserUtil;
|
||||
import org.apache.commons.collections.CollectionUtils;
|
||||
import org.apache.commons.lang3.StringUtils;
|
||||
|
||||
@Slf4j
|
||||
public class QueryFilterAppend extends BaseDSLOptimizer {
|
||||
|
||||
@Override
|
||||
public CorrectionInfo rewriter(CorrectionInfo correctionInfo) throws JSQLParserException {
|
||||
String queryFilter = getQueryFilter(correctionInfo.getQueryFilters());
|
||||
String sql = correctionInfo.getSql();
|
||||
|
||||
if (StringUtils.isNotEmpty(queryFilter)) {
|
||||
log.info("add queryFilter to sql :{}", queryFilter);
|
||||
Expression expression = CCJSqlParserUtil.parseCondExpression(queryFilter);
|
||||
sql = CCJSqlParserUtils.addWhere(sql, expression);
|
||||
}
|
||||
correctionInfo.setSql(sql);
|
||||
return correctionInfo;
|
||||
}
|
||||
|
||||
private String getQueryFilter(QueryFilters queryFilters) {
|
||||
if (Objects.isNull(queryFilters) || CollectionUtils.isEmpty(queryFilters.getFilters())) {
|
||||
return null;
|
||||
}
|
||||
return queryFilters.getFilters().stream()
|
||||
.map(filter -> {
|
||||
String bizNameWrap = StringUtil.getSpaceWrap(filter.getBizName());
|
||||
String operatorWrap = StringUtil.getSpaceWrap(filter.getOperator().getValue());
|
||||
String valueWrap = StringUtil.getCommaWrap(filter.getValue().toString());
|
||||
return bizNameWrap + operatorWrap + valueWrap;
|
||||
})
|
||||
.collect(Collectors.joining(Constants.AND_UPPER));
|
||||
}
|
||||
|
||||
}
|
||||
@@ -0,0 +1,35 @@
|
||||
package com.tencent.supersonic.chat.query.dsl.optimizer;
|
||||
|
||||
import com.tencent.supersonic.chat.api.pojo.CorrectionInfo;
|
||||
import com.tencent.supersonic.common.util.jsqlparser.CCJSqlParserUtils;
|
||||
import com.tencent.supersonic.semantic.api.model.enums.TimeDimensionEnum;
|
||||
import java.util.ArrayList;
|
||||
import java.util.HashSet;
|
||||
import java.util.Set;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.springframework.util.CollectionUtils;
|
||||
|
||||
@Slf4j
|
||||
public class SelectFieldAppendCorrector extends BaseDSLOptimizer {
|
||||
|
||||
@Override
|
||||
public CorrectionInfo rewriter(CorrectionInfo correctionInfo) {
|
||||
String sql = correctionInfo.getSql();
|
||||
if (CCJSqlParserUtils.hasAggregateFunction(sql)) {
|
||||
return correctionInfo;
|
||||
}
|
||||
Set<String> selectFields = new HashSet<>(CCJSqlParserUtils.getSelectFields(sql));
|
||||
Set<String> whereFields = new HashSet<>(CCJSqlParserUtils.getWhereFields(sql));
|
||||
if (CollectionUtils.isEmpty(selectFields) || CollectionUtils.isEmpty(whereFields)) {
|
||||
return correctionInfo;
|
||||
}
|
||||
|
||||
whereFields.removeAll(selectFields);
|
||||
whereFields.remove(TimeDimensionEnum.DAY.getName());
|
||||
whereFields.remove(TimeDimensionEnum.WEEK.getName());
|
||||
whereFields.remove(TimeDimensionEnum.MONTH.getName());
|
||||
String replaceFields = CCJSqlParserUtils.addFieldsToSelect(sql, new ArrayList<>(whereFields));
|
||||
correctionInfo.setSql(replaceFields);
|
||||
return correctionInfo;
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,21 @@
|
||||
package com.tencent.supersonic.chat.query.dsl.optimizer;
|
||||
|
||||
import com.tencent.supersonic.chat.api.pojo.CorrectionInfo;
|
||||
import com.tencent.supersonic.common.util.jsqlparser.CCJSqlParserUtils;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
|
||||
@Slf4j
|
||||
public class TableNameCorrector extends BaseDSLOptimizer {
|
||||
|
||||
public static final String TABLE_PREFIX = "t_";
|
||||
|
||||
@Override
|
||||
public CorrectionInfo rewriter(CorrectionInfo correctionInfo) {
|
||||
Long modelId = correctionInfo.getParseInfo().getModelId();
|
||||
String sqlOutput = correctionInfo.getSql();
|
||||
String replaceTable = CCJSqlParserUtils.replaceTable(sqlOutput, TABLE_PREFIX + modelId);
|
||||
correctionInfo.setSql(replaceTable);
|
||||
return correctionInfo;
|
||||
}
|
||||
|
||||
}
|
||||
@@ -1,4 +1,4 @@
|
||||
package com.tencent.supersonic.chat.query.ContentInterpret;
|
||||
package com.tencent.supersonic.chat.query.metricInterpret;
|
||||
|
||||
|
||||
import lombok.Data;
|
||||
@@ -1,4 +1,4 @@
|
||||
package com.tencent.supersonic.chat.query.ContentInterpret;
|
||||
package com.tencent.supersonic.chat.query.metricInterpret;
|
||||
|
||||
|
||||
import lombok.Data;
|
||||
@@ -0,0 +1,143 @@
|
||||
package com.tencent.supersonic.chat.query.metricInterpret;
|
||||
|
||||
import com.alibaba.fastjson.JSONObject;
|
||||
import com.google.common.collect.Lists;
|
||||
import com.tencent.supersonic.auth.api.authentication.pojo.User;
|
||||
import com.tencent.supersonic.chat.api.component.SemanticLayer;
|
||||
import com.tencent.supersonic.chat.api.pojo.SchemaElement;
|
||||
import com.tencent.supersonic.chat.api.pojo.SchemaElementMatch;
|
||||
import com.tencent.supersonic.chat.api.pojo.SchemaElementType;
|
||||
import com.tencent.supersonic.chat.api.pojo.response.QueryResult;
|
||||
import com.tencent.supersonic.chat.api.pojo.response.QueryState;
|
||||
import com.tencent.supersonic.chat.plugin.PluginManager;
|
||||
import com.tencent.supersonic.chat.query.QueryManager;
|
||||
import com.tencent.supersonic.chat.query.plugin.PluginSemanticQuery;
|
||||
import com.tencent.supersonic.chat.utils.ComponentFactory;
|
||||
import com.tencent.supersonic.chat.utils.QueryReqBuilder;
|
||||
import com.tencent.supersonic.common.pojo.Aggregator;
|
||||
import com.tencent.supersonic.common.pojo.QueryColumn;
|
||||
import com.tencent.supersonic.common.pojo.enums.AggOperatorEnum;
|
||||
import com.tencent.supersonic.common.util.ContextUtils;
|
||||
import com.tencent.supersonic.semantic.api.model.response.QueryResultWithSchemaResp;
|
||||
import com.tencent.supersonic.semantic.api.query.request.QueryStructReq;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.apache.calcite.sql.parser.SqlParseException;
|
||||
import org.apache.commons.lang3.StringUtils;
|
||||
import org.springframework.http.ResponseEntity;
|
||||
import org.springframework.stereotype.Component;
|
||||
import org.springframework.util.CollectionUtils;
|
||||
import java.util.*;
|
||||
import java.util.stream.Collectors;
|
||||
|
||||
@Slf4j
|
||||
@Component
|
||||
public class MetricInterpretQuery extends PluginSemanticQuery {
|
||||
|
||||
|
||||
public final static String QUERY_MODE = "METRIC_INTERPRET";
|
||||
|
||||
@Override
|
||||
public String getQueryMode() {
|
||||
return QUERY_MODE;
|
||||
}
|
||||
|
||||
public MetricInterpretQuery() {
|
||||
QueryManager.register(this);
|
||||
}
|
||||
|
||||
@Override
|
||||
public QueryResult execute(User user) throws SqlParseException {
|
||||
QueryStructReq queryStructReq = QueryReqBuilder.buildStructReq(parseInfo);
|
||||
fillAggregator(queryStructReq, parseInfo.getMetrics());
|
||||
queryStructReq.setNativeQuery(true);
|
||||
SemanticLayer semanticLayer = ComponentFactory.getSemanticLayer();
|
||||
QueryResultWithSchemaResp queryResultWithSchemaResp = semanticLayer.queryByStruct(queryStructReq, user);
|
||||
String text = generateTableText(queryResultWithSchemaResp);
|
||||
Map<String, Object> properties = parseInfo.getProperties();
|
||||
Map<String, String> replacedMap = new HashMap<>();
|
||||
String textReplaced = replaceText((String) properties.get("queryText"), parseInfo.getElementMatches(), replacedMap);
|
||||
String answer = replaceAnswer(fetchInterpret(textReplaced, text), replacedMap);
|
||||
QueryResult queryResult = new QueryResult();
|
||||
List<QueryColumn> queryColumns = Lists.newArrayList(new QueryColumn("结果","string","answer"));
|
||||
Map<String, Object> result = new HashMap<>();
|
||||
result.put("answer", answer);
|
||||
List<Map<String, Object>> resultList = Lists.newArrayList();
|
||||
resultList.add(result);
|
||||
queryResult.setQueryResults(resultList);
|
||||
queryResult.setQueryColumns(queryColumns);
|
||||
queryResult.setQueryMode(getQueryMode());
|
||||
queryResult.setQueryState(QueryState.SUCCESS);
|
||||
return queryResult;
|
||||
}
|
||||
|
||||
private String replaceText(String text, List<SchemaElementMatch> schemaElementMatches, Map<String, String> replacedMap) {
|
||||
if (CollectionUtils.isEmpty(schemaElementMatches)) {
|
||||
return text;
|
||||
}
|
||||
List<SchemaElementMatch> valueSchemaElementMatches = schemaElementMatches.stream()
|
||||
.filter(schemaElementMatch ->
|
||||
SchemaElementType.VALUE.equals(schemaElementMatch.getElement().getType())
|
||||
|| SchemaElementType.ID.equals(schemaElementMatch.getElement().getType()))
|
||||
.collect(Collectors.toList());
|
||||
for (SchemaElementMatch schemaElementMatch : valueSchemaElementMatches) {
|
||||
String detectWord = schemaElementMatch.getDetectWord();
|
||||
if (StringUtils.isBlank(detectWord)) {
|
||||
continue;
|
||||
}
|
||||
text = text.replace(detectWord, "xxx");
|
||||
replacedMap.put("xxx", detectWord);
|
||||
}
|
||||
return text;
|
||||
}
|
||||
|
||||
private void fillAggregator(QueryStructReq queryStructReq, Set<SchemaElement> schemaElements) {
|
||||
queryStructReq.getAggregators().clear();
|
||||
for (SchemaElement schemaElement : schemaElements) {
|
||||
Aggregator aggregator = new Aggregator();
|
||||
aggregator.setColumn(schemaElement.getBizName());
|
||||
aggregator.setFunc(AggOperatorEnum.SUM);
|
||||
aggregator.setNameCh(schemaElement.getName());
|
||||
queryStructReq.getAggregators().add(aggregator);
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
private String replaceAnswer(String text, Map<String, String> replacedMap) {
|
||||
for (String key : replacedMap.keySet()) {
|
||||
text = text.replaceAll(key, replacedMap.get(key));
|
||||
}
|
||||
return text;
|
||||
}
|
||||
|
||||
public static String generateTableText(QueryResultWithSchemaResp result) {
|
||||
StringBuilder tableBuilder = new StringBuilder();
|
||||
for (QueryColumn column : result.getColumns()) {
|
||||
tableBuilder.append(column.getName()).append("\t");
|
||||
}
|
||||
tableBuilder.append("\n");
|
||||
for (Map<String, Object> row : result.getResultList()) {
|
||||
for (QueryColumn column : result.getColumns()) {
|
||||
tableBuilder.append(row.get(column.getNameEn())).append("\t");
|
||||
}
|
||||
tableBuilder.append("\n");
|
||||
}
|
||||
return tableBuilder.toString();
|
||||
}
|
||||
|
||||
|
||||
public String fetchInterpret(String queryText, String dataText) {
|
||||
PluginManager pluginManager = ContextUtils.getBean(PluginManager.class);
|
||||
LLmAnswerReq lLmAnswerReq = new LLmAnswerReq();
|
||||
lLmAnswerReq.setQueryText(queryText);
|
||||
lLmAnswerReq.setPluginOutput(dataText);
|
||||
ResponseEntity<String> responseEntity = pluginManager.doRequest("answer_with_plugin_call",
|
||||
JSONObject.toJSONString(lLmAnswerReq));
|
||||
LLmAnswerResp lLmAnswerResp = JSONObject.parseObject(responseEntity.getBody(), LLmAnswerResp.class);
|
||||
if (lLmAnswerResp != null) {
|
||||
return lLmAnswerResp.getAssistant_message();
|
||||
}
|
||||
return null;
|
||||
}
|
||||
|
||||
|
||||
}
|
||||
@@ -2,15 +2,14 @@ package com.tencent.supersonic.chat.query.plugin.webpage;
|
||||
|
||||
import com.google.common.collect.Lists;
|
||||
import com.tencent.supersonic.auth.api.authentication.pojo.User;
|
||||
import com.tencent.supersonic.chat.api.pojo.ModelSchema;
|
||||
import com.tencent.supersonic.chat.api.pojo.SchemaElementMatch;
|
||||
import com.tencent.supersonic.chat.api.pojo.SchemaElementType;
|
||||
import com.tencent.supersonic.chat.api.pojo.*;
|
||||
import com.tencent.supersonic.chat.api.pojo.request.QueryFilter;
|
||||
import com.tencent.supersonic.chat.api.pojo.request.QueryFilters;
|
||||
import com.tencent.supersonic.chat.api.pojo.request.QueryReq;
|
||||
import com.tencent.supersonic.chat.api.pojo.response.EntityInfo;
|
||||
import com.tencent.supersonic.chat.api.pojo.response.QueryResult;
|
||||
import com.tencent.supersonic.chat.api.pojo.response.QueryState;
|
||||
import com.tencent.supersonic.chat.api.pojo.response.ChatConfigRichResp;
|
||||
import com.tencent.supersonic.chat.plugin.Plugin;
|
||||
import com.tencent.supersonic.chat.plugin.PluginParseResult;
|
||||
import com.tencent.supersonic.chat.query.QueryManager;
|
||||
@@ -18,18 +17,18 @@ import com.tencent.supersonic.chat.query.plugin.ParamOption;
|
||||
import com.tencent.supersonic.chat.query.plugin.PluginSemanticQuery;
|
||||
import com.tencent.supersonic.chat.query.plugin.WebBase;
|
||||
import com.tencent.supersonic.chat.query.plugin.WebBaseResult;
|
||||
import com.tencent.supersonic.chat.service.ConfigService;
|
||||
import com.tencent.supersonic.chat.service.SemanticService;
|
||||
import com.tencent.supersonic.common.pojo.Constants;
|
||||
import com.tencent.supersonic.common.util.ContextUtils;
|
||||
import com.tencent.supersonic.common.util.JsonUtil;
|
||||
import java.util.Comparator;
|
||||
import java.util.HashMap;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.springframework.stereotype.Component;
|
||||
import org.springframework.util.CollectionUtils;
|
||||
|
||||
import java.util.*;
|
||||
import java.util.stream.Collectors;
|
||||
|
||||
@Slf4j
|
||||
@Component
|
||||
public class WebPageQuery extends PluginSemanticQuery {
|
||||
@@ -107,17 +106,15 @@ public class WebPageQuery extends PluginSemanticQuery {
|
||||
.filter(schemaElementMatch ->
|
||||
SchemaElementType.VALUE.equals(schemaElementMatch.getElement().getType())
|
||||
|| SchemaElementType.ID.equals(schemaElementMatch.getElement().getType()))
|
||||
.sorted(Comparator.comparingDouble(SchemaElementMatch::getSimilarity))
|
||||
.filter(schemaElementMatch -> schemaElementMatch.getSimilarity() == 1.0)
|
||||
.forEach(schemaElementMatch -> {
|
||||
Object queryFilterValue = filterValueMap.get(schemaElementMatch.getElement().getId());
|
||||
if (queryFilterValue != null) {
|
||||
if (String.valueOf(queryFilterValue).equals(String.valueOf(schemaElementMatch.getWord()))) {
|
||||
elementValueMap.put(String.valueOf(schemaElementMatch.getElement().getId()),
|
||||
schemaElementMatch.getWord());
|
||||
elementValueMap.put(String.valueOf(schemaElementMatch.getElement().getId()), schemaElementMatch.getWord());
|
||||
}
|
||||
} else {
|
||||
elementValueMap.put(String.valueOf(schemaElementMatch.getElement().getId()),
|
||||
schemaElementMatch.getWord());
|
||||
elementValueMap.computeIfAbsent(String.valueOf(schemaElementMatch.getElement().getId()), k -> schemaElementMatch.getWord());
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
@@ -208,8 +208,7 @@ public abstract class RuleSemanticQuery implements SemanticQuery, Serializable {
|
||||
}
|
||||
|
||||
QueryResult queryResult = new QueryResult();
|
||||
QueryResultWithSchemaResp queryResp = semanticLayer.queryByStruct(
|
||||
convertQueryStruct(), user);
|
||||
QueryResultWithSchemaResp queryResp = semanticLayer.queryByStruct(convertQueryStruct(), user);
|
||||
|
||||
if (queryResp != null) {
|
||||
queryResult.setQueryAuthorization(queryResp.getQueryAuthorization());
|
||||
|
||||
@@ -1,9 +1,9 @@
|
||||
package com.tencent.supersonic.chat.query.rule.entity;
|
||||
|
||||
import static com.tencent.supersonic.chat.api.pojo.SchemaElementType.ID;
|
||||
import static com.tencent.supersonic.chat.api.pojo.SchemaElementType.VALUE;
|
||||
import static com.tencent.supersonic.chat.api.pojo.SchemaElementType.*;
|
||||
import static com.tencent.supersonic.chat.query.rule.QueryMatchOption.OptionType.OPTIONAL;
|
||||
import static com.tencent.supersonic.chat.query.rule.QueryMatchOption.RequireNumberType.AT_LEAST;
|
||||
import static com.tencent.supersonic.chat.query.rule.QueryMatchOption.OptionType.REQUIRED;
|
||||
import static com.tencent.supersonic.chat.query.rule.QueryMatchOption.RequireNumberType.*;
|
||||
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.springframework.stereotype.Component;
|
||||
@@ -17,8 +17,7 @@ public class EntityFilterQuery extends EntityListQuery {
|
||||
|
||||
public EntityFilterQuery() {
|
||||
super();
|
||||
queryMatcher.addOption(VALUE, OPTIONAL, AT_LEAST, 0);
|
||||
queryMatcher.addOption(ID, OPTIONAL, AT_LEAST, 0);
|
||||
queryMatcher.addOption(VALUE, REQUIRED, AT_LEAST, 1);
|
||||
}
|
||||
|
||||
@Override
|
||||
|
||||
@@ -0,0 +1,23 @@
|
||||
package com.tencent.supersonic.chat.query.rule.entity;
|
||||
|
||||
import org.springframework.stereotype.Component;
|
||||
import static com.tencent.supersonic.chat.api.pojo.SchemaElementType.ID;
|
||||
import static com.tencent.supersonic.chat.query.rule.QueryMatchOption.OptionType.REQUIRED;
|
||||
import static com.tencent.supersonic.chat.query.rule.QueryMatchOption.RequireNumberType.AT_LEAST;
|
||||
|
||||
@Component
|
||||
public class EntityIdQuery extends EntityListQuery {
|
||||
|
||||
public static final String QUERY_MODE = "ENTITY_ID";
|
||||
|
||||
public EntityIdQuery() {
|
||||
super();
|
||||
queryMatcher.addOption(ID, REQUIRED, AT_LEAST, 1);
|
||||
}
|
||||
|
||||
@Override
|
||||
public String getQueryMode() {
|
||||
return QUERY_MODE;
|
||||
}
|
||||
|
||||
}
|
||||
@@ -0,0 +1,51 @@
|
||||
package com.tencent.supersonic.chat.rest;
|
||||
|
||||
import com.tencent.supersonic.auth.api.authentication.pojo.User;
|
||||
import com.tencent.supersonic.auth.api.authentication.utils.UserHolder;
|
||||
import com.tencent.supersonic.chat.agent.Agent;
|
||||
import com.tencent.supersonic.chat.service.AgentService;
|
||||
import org.springframework.web.bind.annotation.*;
|
||||
import javax.servlet.http.HttpServletRequest;
|
||||
import javax.servlet.http.HttpServletResponse;
|
||||
import java.util.List;
|
||||
|
||||
@RestController
|
||||
@RequestMapping("/api/chat/agent")
|
||||
public class AgentController {
|
||||
|
||||
private AgentService agentService;
|
||||
|
||||
public AgentController(AgentService agentService) {
|
||||
this.agentService = agentService;
|
||||
}
|
||||
|
||||
@PostMapping
|
||||
public boolean createAgent(@RequestBody Agent agent,
|
||||
HttpServletRequest httpServletRequest,
|
||||
HttpServletResponse httpServletResponse) {
|
||||
User user = UserHolder.findUser(httpServletRequest, httpServletResponse);
|
||||
agentService.createAgent(agent, user);
|
||||
return true;
|
||||
}
|
||||
|
||||
@PutMapping
|
||||
public boolean updateAgent(@RequestBody Agent agent,
|
||||
HttpServletRequest httpServletRequest,
|
||||
HttpServletResponse httpServletResponse) {
|
||||
User user = UserHolder.findUser(httpServletRequest, httpServletResponse);
|
||||
agentService.updateAgent(agent, user);
|
||||
return true;
|
||||
}
|
||||
|
||||
@DeleteMapping("/{id}")
|
||||
public boolean deleteAgent(@PathVariable("id") Integer id) {
|
||||
agentService.deleteAgent(id);
|
||||
return true;
|
||||
}
|
||||
|
||||
@RequestMapping("/getAgentList")
|
||||
public List<Agent> getAgentList() {
|
||||
return agentService.getAgents();
|
||||
}
|
||||
|
||||
}
|
||||
@@ -32,7 +32,7 @@ import org.springframework.web.bind.annotation.RestController;
|
||||
|
||||
|
||||
@RestController
|
||||
@RequestMapping("/api/chat/conf")
|
||||
@RequestMapping({"/api/chat/conf", "/openapi/chat/conf"})
|
||||
public class ChatConfigController {
|
||||
|
||||
@Autowired
|
||||
|
||||
@@ -18,7 +18,7 @@ import org.springframework.web.bind.annotation.RequestParam;
|
||||
import org.springframework.web.bind.annotation.RestController;
|
||||
|
||||
@RestController
|
||||
@RequestMapping("/api/chat/manage")
|
||||
@RequestMapping({"/api/chat/manage", "/openapi/chat/manage"})
|
||||
public class ChatController {
|
||||
|
||||
private final ChatService chatService;
|
||||
|
||||
@@ -20,7 +20,7 @@ import org.springframework.web.bind.annotation.RestController;
|
||||
* query controller
|
||||
*/
|
||||
@RestController
|
||||
@RequestMapping("/api/chat/query")
|
||||
@RequestMapping({"/api/chat/query", "/openapi/chat/query"})
|
||||
public class ChatQueryController {
|
||||
|
||||
@Autowired
|
||||
|
||||
@@ -19,7 +19,7 @@ import org.springframework.web.bind.annotation.RestController;
|
||||
* recommend controller
|
||||
*/
|
||||
@RestController
|
||||
@RequestMapping("/api/chat/")
|
||||
@RequestMapping({"/api/chat/", "/openapi/chat/"})
|
||||
public class RecommendController {
|
||||
|
||||
@Autowired
|
||||
|
||||
@@ -0,0 +1,19 @@
|
||||
package com.tencent.supersonic.chat.service;
|
||||
|
||||
import com.tencent.supersonic.auth.api.authentication.pojo.User;
|
||||
import com.tencent.supersonic.chat.agent.Agent;
|
||||
import java.util.List;
|
||||
|
||||
public interface AgentService {
|
||||
|
||||
List<Agent> getAgents();
|
||||
|
||||
void createAgent(Agent agent, User user);
|
||||
|
||||
void updateAgent(Agent agent, User user);
|
||||
|
||||
Agent getAgent(Integer id);
|
||||
|
||||
void deleteAgent(Integer id);
|
||||
|
||||
}
|
||||
@@ -0,0 +1,82 @@
|
||||
package com.tencent.supersonic.chat.service.impl;
|
||||
|
||||
import com.alibaba.fastjson.JSONObject;
|
||||
import com.tencent.supersonic.auth.api.authentication.pojo.User;
|
||||
import com.tencent.supersonic.chat.agent.Agent;
|
||||
import com.tencent.supersonic.chat.persistence.dataobject.AgentDO;
|
||||
import com.tencent.supersonic.chat.persistence.repository.AgentRepository;
|
||||
import com.tencent.supersonic.chat.service.AgentService;
|
||||
import org.springframework.beans.BeanUtils;
|
||||
import org.springframework.stereotype.Service;
|
||||
import java.util.Date;
|
||||
import java.util.List;
|
||||
import java.util.stream.Collectors;
|
||||
|
||||
@Service
|
||||
public class AgentServiceImpl implements AgentService {
|
||||
|
||||
private AgentRepository agentRepository;
|
||||
|
||||
public AgentServiceImpl(AgentRepository agentRepository) {
|
||||
this.agentRepository = agentRepository;
|
||||
}
|
||||
|
||||
@Override
|
||||
public List<Agent> getAgents() {
|
||||
return getAgentDOList().stream()
|
||||
.map(this::convert).collect(Collectors.toList());
|
||||
}
|
||||
|
||||
@Override
|
||||
public void createAgent(Agent agent, User user) {
|
||||
agentRepository.createAgent(convert(agent, user));
|
||||
}
|
||||
|
||||
@Override
|
||||
public void updateAgent(Agent agent, User user) {
|
||||
agentRepository.updateAgent(convert(agent, user));
|
||||
}
|
||||
|
||||
@Override
|
||||
public Agent getAgent(Integer id) {
|
||||
if (id == null) {
|
||||
return null;
|
||||
}
|
||||
return convert(agentRepository.getAgent(id));
|
||||
}
|
||||
|
||||
@Override
|
||||
public void deleteAgent(Integer id) {
|
||||
agentRepository.deleteAgent(id);
|
||||
}
|
||||
|
||||
private List<AgentDO> getAgentDOList() {
|
||||
return agentRepository.getAgents();
|
||||
}
|
||||
|
||||
private Agent convert(AgentDO agentDO){
|
||||
if (agentDO == null ) {
|
||||
return null;
|
||||
}
|
||||
Agent agent = new Agent();
|
||||
BeanUtils.copyProperties(agentDO,agent);
|
||||
agent.setAgentConfig(agentDO.getConfig());
|
||||
agent.setExamples(JSONObject.parseArray(agentDO.getExamples(), String.class));
|
||||
return agent;
|
||||
}
|
||||
|
||||
private AgentDO convert(Agent agent, User user){
|
||||
AgentDO agentDO = new AgentDO();
|
||||
BeanUtils.copyProperties(agent, agentDO);
|
||||
agentDO.setConfig(agent.getAgentConfig());
|
||||
agentDO.setExamples(JSONObject.toJSONString(agent.getExamples()));
|
||||
agentDO.setCreatedAt(new Date());
|
||||
agentDO.setCreatedBy(user.getName());
|
||||
agentDO.setUpdatedAt(new Date());
|
||||
agentDO.setUpdatedBy(user.getName());
|
||||
if (agentDO.getStatus() == null) {
|
||||
agentDO.setStatus(1);
|
||||
}
|
||||
return agentDO;
|
||||
}
|
||||
}
|
||||
@@ -2,26 +2,25 @@ package com.tencent.supersonic.chat.service.impl;
|
||||
|
||||
|
||||
import com.tencent.supersonic.auth.api.authentication.pojo.User;
|
||||
import com.tencent.supersonic.chat.api.component.SchemaMapper;
|
||||
import com.tencent.supersonic.chat.api.component.SemanticParser;
|
||||
import com.tencent.supersonic.chat.api.component.SemanticQuery;
|
||||
import com.tencent.supersonic.chat.api.component.*;
|
||||
import com.tencent.supersonic.chat.api.pojo.ChatContext;
|
||||
import com.tencent.supersonic.chat.api.pojo.QueryContext;
|
||||
import com.tencent.supersonic.chat.api.pojo.SemanticParseInfo;
|
||||
import com.tencent.supersonic.chat.api.pojo.request.ExecuteQueryReq;
|
||||
import com.tencent.supersonic.chat.api.pojo.request.QueryDataReq;
|
||||
import com.tencent.supersonic.chat.api.pojo.request.QueryReq;
|
||||
import com.tencent.supersonic.chat.api.pojo.response.ParseResp;
|
||||
import com.tencent.supersonic.chat.api.pojo.response.QueryResult;
|
||||
import com.tencent.supersonic.chat.api.pojo.response.QueryState;
|
||||
import com.tencent.supersonic.chat.query.QueryManager;
|
||||
import com.tencent.supersonic.chat.query.QuerySelector;
|
||||
import com.tencent.supersonic.chat.api.pojo.request.QueryDataReq;
|
||||
import com.tencent.supersonic.chat.query.QueryManager;
|
||||
import com.tencent.supersonic.chat.service.ChatService;
|
||||
import com.tencent.supersonic.chat.service.QueryService;
|
||||
import com.tencent.supersonic.chat.utils.ComponentFactory;
|
||||
import com.tencent.supersonic.common.util.JsonUtil;
|
||||
import java.util.Comparator;
|
||||
import java.util.List;
|
||||
import java.util.stream.Collectors;
|
||||
import com.tencent.supersonic.common.util.JsonUtil;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.apache.calcite.sql.parser.SqlParseException;
|
||||
import org.springframework.beans.BeanUtils;
|
||||
@@ -63,12 +62,14 @@ public class QueryServiceImpl implements QueryService {
|
||||
if (queryCtx.getCandidateQueries().size() > 0) {
|
||||
log.debug("pick before [{}]", queryCtx.getCandidateQueries().stream().collect(
|
||||
Collectors.toList()));
|
||||
List<SemanticQuery> selectedQueries = querySelector.select(queryCtx.getCandidateQueries());
|
||||
List<SemanticQuery> selectedQueries = querySelector.select(queryCtx.getCandidateQueries(), queryReq);
|
||||
log.debug("pick after [{}]", selectedQueries.stream().collect(
|
||||
Collectors.toList()));
|
||||
|
||||
List<SemanticParseInfo> selectedParses = selectedQueries.stream()
|
||||
.map(SemanticQuery::getParseInfo).collect(Collectors.toList());
|
||||
.map(SemanticQuery::getParseInfo)
|
||||
.sorted(Comparator.comparingDouble(SemanticParseInfo::getScore).reversed())
|
||||
.collect(Collectors.toList());
|
||||
List<SemanticParseInfo> candidateParses = queryCtx.getCandidateQueries().stream()
|
||||
.map(SemanticQuery::getParseInfo).collect(Collectors.toList());
|
||||
|
||||
@@ -138,7 +139,7 @@ public class QueryServiceImpl implements QueryService {
|
||||
if (queryCtx.getCandidateQueries().size() > 0) {
|
||||
log.info("pick before [{}]", queryCtx.getCandidateQueries().stream().collect(
|
||||
Collectors.toList()));
|
||||
List<SemanticQuery> selectedQueries = querySelector.select(queryCtx.getCandidateQueries());
|
||||
List<SemanticQuery> selectedQueries = querySelector.select(queryCtx.getCandidateQueries(), queryReq);
|
||||
log.info("pick after [{}]", selectedQueries.stream().collect(
|
||||
Collectors.toList()));
|
||||
|
||||
|
||||
@@ -57,6 +57,7 @@ public class RecommendServiceImpl implements RecommendService {
|
||||
item.setName(dimSchemaDesc.getName());
|
||||
item.setBizName(dimSchemaDesc.getBizName());
|
||||
item.setId(dimSchemaDesc.getId());
|
||||
item.setAlias(dimSchemaDesc.getAlias());
|
||||
return item;
|
||||
}).collect(Collectors.toList());
|
||||
|
||||
@@ -70,6 +71,7 @@ public class RecommendServiceImpl implements RecommendService {
|
||||
item.setName(metricSchemaDesc.getName());
|
||||
item.setBizName(metricSchemaDesc.getBizName());
|
||||
item.setId(metricSchemaDesc.getId());
|
||||
item.setAlias(metricSchemaDesc.getAlias());
|
||||
return item;
|
||||
}).collect(Collectors.toList());
|
||||
|
||||
|
||||
@@ -2,6 +2,7 @@ package com.tencent.supersonic.chat.service.impl;
|
||||
|
||||
import com.google.common.collect.Lists;
|
||||
import com.hankcs.hanlp.seg.common.Term;
|
||||
import com.tencent.supersonic.chat.agent.Agent;
|
||||
import com.tencent.supersonic.chat.api.pojo.SchemaElement;
|
||||
import com.tencent.supersonic.chat.api.pojo.SchemaElementType;
|
||||
import com.tencent.supersonic.chat.api.pojo.SemanticSchema;
|
||||
@@ -13,6 +14,7 @@ import com.tencent.supersonic.chat.mapper.MatchText;
|
||||
import com.tencent.supersonic.chat.mapper.ModelInfoStat;
|
||||
import com.tencent.supersonic.chat.mapper.ModelWithSemanticType;
|
||||
import com.tencent.supersonic.chat.mapper.SearchMatchStrategy;
|
||||
import com.tencent.supersonic.chat.service.AgentService;
|
||||
import com.tencent.supersonic.chat.service.ChatService;
|
||||
import com.tencent.supersonic.chat.service.SearchService;
|
||||
import com.tencent.supersonic.chat.utils.NatureHelper;
|
||||
@@ -53,22 +55,34 @@ public class SearchServiceImpl implements SearchService {
|
||||
private ChatService chatService;
|
||||
@Autowired
|
||||
private SearchMatchStrategy searchMatchStrategy;
|
||||
@Autowired
|
||||
private AgentService agentService;
|
||||
|
||||
@Override
|
||||
public List<SearchResult> search(QueryReq queryCtx) {
|
||||
|
||||
// 1. check search enable
|
||||
Integer agentId = queryCtx.getAgentId();
|
||||
if (agentId != null) {
|
||||
Agent agent = agentService.getAgent(agentId);
|
||||
if (!agent.enableSearch()) {
|
||||
return Lists.newArrayList();
|
||||
}
|
||||
}
|
||||
|
||||
String queryText = queryCtx.getQueryText();
|
||||
// 1.get meta info
|
||||
// 2.get meta info
|
||||
SemanticSchema semanticSchemaDb = schemaService.getSemanticSchema();
|
||||
List<SchemaElement> metricsDb = semanticSchemaDb.getMetrics();
|
||||
final Map<Long, String> modelToName = semanticSchemaDb.getModelIdToName();
|
||||
|
||||
// 2.detect by segment
|
||||
// 3.detect by segment
|
||||
List<Term> originals = HanlpHelper.getTerms(queryText);
|
||||
Map<MatchText, List<MapResult>> regTextMap = searchMatchStrategy.match(queryText, originals,
|
||||
queryCtx.getModelId());
|
||||
regTextMap.entrySet().stream().forEach(m -> HanlpHelper.transLetterOriginal(m.getValue()));
|
||||
|
||||
// 3.get the most matching data
|
||||
// 4.get the most matching data
|
||||
Optional<Entry<MatchText, List<MapResult>>> mostSimilarSearchResult = regTextMap.entrySet()
|
||||
.stream()
|
||||
.filter(entry -> CollectionUtils.isNotEmpty(entry.getValue()))
|
||||
@@ -77,7 +91,7 @@ public class SearchServiceImpl implements SearchService {
|
||||
? entry1 : entry2);
|
||||
log.debug("mostSimilarSearchResult:{}", mostSimilarSearchResult);
|
||||
|
||||
// 4.optimize the results after the query
|
||||
// 5.optimize the results after the query
|
||||
if (!mostSimilarSearchResult.isPresent()) {
|
||||
return Lists.newArrayList();
|
||||
}
|
||||
@@ -89,11 +103,11 @@ public class SearchServiceImpl implements SearchService {
|
||||
|
||||
List<Long> possibleModels = getPossibleModels(queryCtx, originals, modelStat, queryCtx.getModelId());
|
||||
|
||||
// 4.1 priority dimension metric
|
||||
// 5.1 priority dimension metric
|
||||
boolean existMetricAndDimension = searchMetricAndDimension(new HashSet<>(possibleModels), modelToName,
|
||||
searchTextEntry, searchResults);
|
||||
|
||||
// 4.2 process based on dimension values
|
||||
// 5.2 process based on dimension values
|
||||
MatchText matchText = searchTextEntry.getKey();
|
||||
Map<String, String> natureToNameMap = getNatureToNameMap(searchTextEntry, new HashSet<>(possibleModels));
|
||||
log.debug("possibleModels:{},natureToNameMap:{}", possibleModels, natureToNameMap);
|
||||
|
||||
@@ -1,77 +0,0 @@
|
||||
package com.tencent.supersonic.chat.utils;
|
||||
|
||||
|
||||
import com.plexpt.chatgpt.ChatGPT;
|
||||
import com.plexpt.chatgpt.entity.chat.ChatCompletion;
|
||||
import com.plexpt.chatgpt.entity.chat.ChatCompletionResponse;
|
||||
import com.plexpt.chatgpt.entity.chat.Message;
|
||||
import com.plexpt.chatgpt.util.Proxys;
|
||||
import java.net.Proxy;
|
||||
import java.text.SimpleDateFormat;
|
||||
import java.util.Arrays;
|
||||
import java.util.Date;
|
||||
import org.springframework.beans.factory.annotation.Value;
|
||||
import org.springframework.stereotype.Component;
|
||||
|
||||
|
||||
@Component
|
||||
public class ChatGptHelper {
|
||||
|
||||
@Value("${llm.chatgpt.apikey:xx-xxxx}")
|
||||
private String apiKey;
|
||||
|
||||
@Value("${llm.chatgpt.apiHost:https://api.openai.com/}")
|
||||
private String apiHost;
|
||||
|
||||
@Value("${llm.chatgpt.proxyIp:default}")
|
||||
private String proxyIp;
|
||||
|
||||
@Value("${llm.chatgpt.proxyPort:8080}")
|
||||
private Integer proxyPort;
|
||||
|
||||
|
||||
public ChatGPT getChatGPT() {
|
||||
Proxy proxy = null;
|
||||
if (!"default".equals(proxyIp)) {
|
||||
proxy = Proxys.http(proxyIp, proxyPort);
|
||||
}
|
||||
return ChatGPT.builder()
|
||||
.apiKey(apiKey)
|
||||
.proxy(proxy)
|
||||
.timeout(900)
|
||||
.apiHost(apiHost) //反向代理地址
|
||||
.build()
|
||||
.init();
|
||||
}
|
||||
|
||||
public String inferredTime(String queryText) {
|
||||
long nowTime = System.currentTimeMillis();
|
||||
Date date = new Date(nowTime);
|
||||
SimpleDateFormat sdf = new SimpleDateFormat("yyyy-MM-dd");
|
||||
String formattedDate = sdf.format(date);
|
||||
Message system = Message.ofSystem("现在时间 " + formattedDate + ",你是一个专业的数据分析师,你的任务是基于数据,专业的解答用户的问题。"
|
||||
+ "你需要遵守以下规则:\n"
|
||||
+ "1.返回规范的数据格式,json,如: 输入:近 10 天的日活跃数,输出:{\"start\":\"2023-07-21\",\"end\":\"2023-07-31\"}"
|
||||
+ "2.你对时间数据要求规范,能从近 10 天,国庆节,端午节,获取到相应的时间,填写到 json 中。\n"
|
||||
+ "3.你的数据时间,只有当前及之前时间即可,超过则回复去年\n"
|
||||
+ "4.只需要解析出时间,时间可以是时间月和年或日、日历采用公历\n"
|
||||
+ "5.时间给出要是绝对正确,不能瞎编\n"
|
||||
);
|
||||
Message message = Message.of("输入:" + queryText + ",输出:");
|
||||
ChatCompletion chatCompletion = ChatCompletion.builder()
|
||||
.model(ChatCompletion.Model.GPT_3_5_TURBO_16K.getName())
|
||||
.messages(Arrays.asList(system, message))
|
||||
.maxTokens(10000)
|
||||
.temperature(0.9)
|
||||
.build();
|
||||
ChatCompletionResponse response = getChatGPT().chatCompletion(chatCompletion);
|
||||
Message res = response.getChoices().get(0).getMessage();
|
||||
return res.getContent();
|
||||
}
|
||||
|
||||
public static void main(String[] args) {
|
||||
|
||||
}
|
||||
|
||||
|
||||
}
|
||||
@@ -3,11 +3,14 @@ package com.tencent.supersonic.chat.utils;
|
||||
import com.tencent.supersonic.chat.api.component.SchemaMapper;
|
||||
import com.tencent.supersonic.chat.api.component.SemanticLayer;
|
||||
import com.tencent.supersonic.chat.api.component.SemanticParser;
|
||||
import com.tencent.supersonic.chat.parser.function.ModelResolver;
|
||||
import com.tencent.supersonic.chat.query.QuerySelector;
|
||||
|
||||
import com.tencent.supersonic.chat.api.component.DSLOptimizer;
|
||||
import java.util.ArrayList;
|
||||
import java.util.List;
|
||||
import java.util.Objects;
|
||||
|
||||
import com.tencent.supersonic.chat.parser.function.ModelResolver;
|
||||
import com.tencent.supersonic.chat.query.QuerySelector;
|
||||
import org.apache.commons.collections.CollectionUtils;
|
||||
import org.springframework.core.io.support.SpringFactoriesLoader;
|
||||
|
||||
@@ -15,10 +18,11 @@ public class ComponentFactory {
|
||||
|
||||
private static List<SchemaMapper> schemaMappers = new ArrayList<>();
|
||||
private static List<SemanticParser> semanticParsers = new ArrayList<>();
|
||||
|
||||
private static List<DSLOptimizer> dslCorrections = new ArrayList<>();
|
||||
private static SemanticLayer semanticLayer;
|
||||
private static QuerySelector querySelector;
|
||||
private static ModelResolver modelResolver;
|
||||
|
||||
public static List<SchemaMapper> getSchemaMappers() {
|
||||
return CollectionUtils.isEmpty(schemaMappers) ? init(SchemaMapper.class, schemaMappers) : schemaMappers;
|
||||
}
|
||||
@@ -27,6 +31,11 @@ public class ComponentFactory {
|
||||
return CollectionUtils.isEmpty(semanticParsers) ? init(SemanticParser.class, semanticParsers) : semanticParsers;
|
||||
}
|
||||
|
||||
public static List<DSLOptimizer> getSqlCorrections() {
|
||||
return CollectionUtils.isEmpty(dslCorrections) ? init(DSLOptimizer.class, dslCorrections) : dslCorrections;
|
||||
}
|
||||
|
||||
|
||||
public static SemanticLayer getSemanticLayer() {
|
||||
if (Objects.isNull(semanticLayer)) {
|
||||
semanticLayer = init(SemanticLayer.class);
|
||||
|
||||
@@ -74,7 +74,7 @@ public class NatureHelper {
|
||||
return null;
|
||||
}
|
||||
|
||||
public static boolean isDimensionValueClassId(String nature) {
|
||||
public static boolean isDimensionValueModelId(String nature) {
|
||||
if (StringUtils.isEmpty(nature)) {
|
||||
return false;
|
||||
}
|
||||
@@ -104,7 +104,7 @@ public class NatureHelper {
|
||||
}
|
||||
|
||||
private static long getDimensionValueCount(List<Term> terms) {
|
||||
return terms.stream().filter(term -> isDimensionValueClassId(term.nature.toString())).count();
|
||||
return terms.stream().filter(term -> isDimensionValueModelId(term.nature.toString())).count();
|
||||
}
|
||||
|
||||
private static long getDimensionCount(List<Term> terms) {
|
||||
|
||||
@@ -1,4 +1,3 @@
|
||||
#!/usr/bin/env bash
|
||||
# python path
|
||||
export python_path="/usr/local/bin/python3.9"
|
||||
# pip path
|
||||
|
||||
303
chat/core/src/main/resources/mapper/AgentDOMapper.xml
Normal file
303
chat/core/src/main/resources/mapper/AgentDOMapper.xml
Normal file
@@ -0,0 +1,303 @@
|
||||
<?xml version="1.0" encoding="UTF-8"?>
|
||||
<!DOCTYPE mapper PUBLIC "-//mybatis.org//DTD Mapper 3.0//EN" "http://mybatis.org/dtd/mybatis-3-mapper.dtd">
|
||||
<mapper namespace="com.tencent.supersonic.chat.persistence.mapper.AgentDOMapper">
|
||||
<resultMap id="BaseResultMap" type="com.tencent.supersonic.chat.persistence.dataobject.AgentDO">
|
||||
<id column="id" jdbcType="INTEGER" property="id" />
|
||||
<result column="name" jdbcType="VARCHAR" property="name" />
|
||||
<result column="description" jdbcType="VARCHAR" property="description" />
|
||||
<result column="status" jdbcType="INTEGER" property="status" />
|
||||
<result column="examples" jdbcType="VARCHAR" property="examples" />
|
||||
<result column="config" jdbcType="VARCHAR" property="config" />
|
||||
<result column="created_by" jdbcType="VARCHAR" property="createdBy" />
|
||||
<result column="created_at" jdbcType="TIMESTAMP" property="createdAt" />
|
||||
<result column="updated_by" jdbcType="VARCHAR" property="updatedBy" />
|
||||
<result column="updated_at" jdbcType="TIMESTAMP" property="updatedAt" />
|
||||
<result column="enable_search" jdbcType="INTEGER" property="enableSearch" />
|
||||
</resultMap>
|
||||
<sql id="Example_Where_Clause">
|
||||
<where>
|
||||
<foreach collection="oredCriteria" item="criteria" separator="or">
|
||||
<if test="criteria.valid">
|
||||
<trim prefix="(" prefixOverrides="and" suffix=")">
|
||||
<foreach collection="criteria.criteria" item="criterion">
|
||||
<choose>
|
||||
<when test="criterion.noValue">
|
||||
and ${criterion.condition}
|
||||
</when>
|
||||
<when test="criterion.singleValue">
|
||||
and ${criterion.condition} #{criterion.value}
|
||||
</when>
|
||||
<when test="criterion.betweenValue">
|
||||
and ${criterion.condition} #{criterion.value} and #{criterion.secondValue}
|
||||
</when>
|
||||
<when test="criterion.listValue">
|
||||
and ${criterion.condition}
|
||||
<foreach close=")" collection="criterion.value" item="listItem" open="(" separator=",">
|
||||
#{listItem}
|
||||
</foreach>
|
||||
</when>
|
||||
</choose>
|
||||
</foreach>
|
||||
</trim>
|
||||
</if>
|
||||
</foreach>
|
||||
</where>
|
||||
</sql>
|
||||
<sql id="Update_By_Example_Where_Clause">
|
||||
<where>
|
||||
<foreach collection="example.oredCriteria" item="criteria" separator="or">
|
||||
<if test="criteria.valid">
|
||||
<trim prefix="(" prefixOverrides="and" suffix=")">
|
||||
<foreach collection="criteria.criteria" item="criterion">
|
||||
<choose>
|
||||
<when test="criterion.noValue">
|
||||
and ${criterion.condition}
|
||||
</when>
|
||||
<when test="criterion.singleValue">
|
||||
and ${criterion.condition} #{criterion.value}
|
||||
</when>
|
||||
<when test="criterion.betweenValue">
|
||||
and ${criterion.condition} #{criterion.value} and #{criterion.secondValue}
|
||||
</when>
|
||||
<when test="criterion.listValue">
|
||||
and ${criterion.condition}
|
||||
<foreach close=")" collection="criterion.value" item="listItem" open="(" separator=",">
|
||||
#{listItem}
|
||||
</foreach>
|
||||
</when>
|
||||
</choose>
|
||||
</foreach>
|
||||
</trim>
|
||||
</if>
|
||||
</foreach>
|
||||
</where>
|
||||
</sql>
|
||||
<sql id="Base_Column_List">
|
||||
id, name, description, status, examples, config, created_by, created_at, updated_by,
|
||||
updated_at, enable_search
|
||||
</sql>
|
||||
<select id="selectByExample" parameterType="com.tencent.supersonic.chat.persistence.dataobject.AgentDOExample" resultMap="BaseResultMap">
|
||||
select
|
||||
<if test="distinct">
|
||||
distinct
|
||||
</if>
|
||||
<include refid="Base_Column_List" />
|
||||
from s2_agent
|
||||
<if test="_parameter != null">
|
||||
<include refid="Example_Where_Clause" />
|
||||
</if>
|
||||
<if test="orderByClause != null">
|
||||
order by ${orderByClause}
|
||||
</if>
|
||||
<if test="limitStart != null and limitStart>=0">
|
||||
limit #{limitStart} , #{limitEnd}
|
||||
</if>
|
||||
</select>
|
||||
<select id="selectByPrimaryKey" parameterType="java.lang.Integer" resultMap="BaseResultMap">
|
||||
select
|
||||
<include refid="Base_Column_List" />
|
||||
from s2_agent
|
||||
where id = #{id,jdbcType=INTEGER}
|
||||
</select>
|
||||
<delete id="deleteByPrimaryKey" parameterType="java.lang.Integer">
|
||||
delete from s2_agent
|
||||
where id = #{id,jdbcType=INTEGER}
|
||||
</delete>
|
||||
<insert id="insert" parameterType="com.tencent.supersonic.chat.persistence.dataobject.AgentDO">
|
||||
insert into s2_agent (id, name, description,
|
||||
status, examples, config,
|
||||
created_by, created_at, updated_by,
|
||||
updated_at, enable_search)
|
||||
values (#{id,jdbcType=INTEGER}, #{name,jdbcType=VARCHAR}, #{description,jdbcType=VARCHAR},
|
||||
#{status,jdbcType=INTEGER}, #{examples,jdbcType=VARCHAR}, #{config,jdbcType=VARCHAR},
|
||||
#{createdBy,jdbcType=VARCHAR}, #{createdAt,jdbcType=TIMESTAMP}, #{updatedBy,jdbcType=VARCHAR},
|
||||
#{updatedAt,jdbcType=TIMESTAMP}, #{enableSearch,jdbcType=INTEGER})
|
||||
</insert>
|
||||
<insert id="insertSelective" parameterType="com.tencent.supersonic.chat.persistence.dataobject.AgentDO">
|
||||
insert into s2_agent
|
||||
<trim prefix="(" suffix=")" suffixOverrides=",">
|
||||
<if test="id != null">
|
||||
id,
|
||||
</if>
|
||||
<if test="name != null">
|
||||
name,
|
||||
</if>
|
||||
<if test="description != null">
|
||||
description,
|
||||
</if>
|
||||
<if test="status != null">
|
||||
status,
|
||||
</if>
|
||||
<if test="examples != null">
|
||||
examples,
|
||||
</if>
|
||||
<if test="config != null">
|
||||
config,
|
||||
</if>
|
||||
<if test="createdBy != null">
|
||||
created_by,
|
||||
</if>
|
||||
<if test="createdAt != null">
|
||||
created_at,
|
||||
</if>
|
||||
<if test="updatedBy != null">
|
||||
updated_by,
|
||||
</if>
|
||||
<if test="updatedAt != null">
|
||||
updated_at,
|
||||
</if>
|
||||
<if test="enableSearch != null">
|
||||
enable_search,
|
||||
</if>
|
||||
</trim>
|
||||
<trim prefix="values (" suffix=")" suffixOverrides=",">
|
||||
<if test="id != null">
|
||||
#{id,jdbcType=INTEGER},
|
||||
</if>
|
||||
<if test="name != null">
|
||||
#{name,jdbcType=VARCHAR},
|
||||
</if>
|
||||
<if test="description != null">
|
||||
#{description,jdbcType=VARCHAR},
|
||||
</if>
|
||||
<if test="status != null">
|
||||
#{status,jdbcType=INTEGER},
|
||||
</if>
|
||||
<if test="examples != null">
|
||||
#{examples,jdbcType=VARCHAR},
|
||||
</if>
|
||||
<if test="config != null">
|
||||
#{config,jdbcType=VARCHAR},
|
||||
</if>
|
||||
<if test="createdBy != null">
|
||||
#{createdBy,jdbcType=VARCHAR},
|
||||
</if>
|
||||
<if test="createdAt != null">
|
||||
#{createdAt,jdbcType=TIMESTAMP},
|
||||
</if>
|
||||
<if test="updatedBy != null">
|
||||
#{updatedBy,jdbcType=VARCHAR},
|
||||
</if>
|
||||
<if test="updatedAt != null">
|
||||
#{updatedAt,jdbcType=TIMESTAMP},
|
||||
</if>
|
||||
<if test="enableSearch != null">
|
||||
#{enableSearch,jdbcType=INTEGER},
|
||||
</if>
|
||||
</trim>
|
||||
</insert>
|
||||
<select id="countByExample" parameterType="com.tencent.supersonic.chat.persistence.dataobject.AgentDOExample" resultType="java.lang.Long">
|
||||
select count(*) from s2_agent
|
||||
<if test="_parameter != null">
|
||||
<include refid="Example_Where_Clause" />
|
||||
</if>
|
||||
</select>
|
||||
<update id="updateByExampleSelective" parameterType="map">
|
||||
update s2_agent
|
||||
<set>
|
||||
<if test="record.id != null">
|
||||
id = #{record.id,jdbcType=INTEGER},
|
||||
</if>
|
||||
<if test="record.name != null">
|
||||
name = #{record.name,jdbcType=VARCHAR},
|
||||
</if>
|
||||
<if test="record.description != null">
|
||||
description = #{record.description,jdbcType=VARCHAR},
|
||||
</if>
|
||||
<if test="record.status != null">
|
||||
status = #{record.status,jdbcType=INTEGER},
|
||||
</if>
|
||||
<if test="record.examples != null">
|
||||
examples = #{record.examples,jdbcType=VARCHAR},
|
||||
</if>
|
||||
<if test="record.config != null">
|
||||
config = #{record.config,jdbcType=VARCHAR},
|
||||
</if>
|
||||
<if test="record.createdBy != null">
|
||||
created_by = #{record.createdBy,jdbcType=VARCHAR},
|
||||
</if>
|
||||
<if test="record.createdAt != null">
|
||||
created_at = #{record.createdAt,jdbcType=TIMESTAMP},
|
||||
</if>
|
||||
<if test="record.updatedBy != null">
|
||||
updated_by = #{record.updatedBy,jdbcType=VARCHAR},
|
||||
</if>
|
||||
<if test="record.updatedAt != null">
|
||||
updated_at = #{record.updatedAt,jdbcType=TIMESTAMP},
|
||||
</if>
|
||||
<if test="record.enableSearch != null">
|
||||
enable_search = #{record.enableSearch,jdbcType=INTEGER},
|
||||
</if>
|
||||
</set>
|
||||
<if test="_parameter != null">
|
||||
<include refid="Update_By_Example_Where_Clause" />
|
||||
</if>
|
||||
</update>
|
||||
<update id="updateByExample" parameterType="map">
|
||||
update s2_agent
|
||||
set id = #{record.id,jdbcType=INTEGER},
|
||||
name = #{record.name,jdbcType=VARCHAR},
|
||||
description = #{record.description,jdbcType=VARCHAR},
|
||||
status = #{record.status,jdbcType=INTEGER},
|
||||
examples = #{record.examples,jdbcType=VARCHAR},
|
||||
config = #{record.config,jdbcType=VARCHAR},
|
||||
created_by = #{record.createdBy,jdbcType=VARCHAR},
|
||||
created_at = #{record.createdAt,jdbcType=TIMESTAMP},
|
||||
updated_by = #{record.updatedBy,jdbcType=VARCHAR},
|
||||
updated_at = #{record.updatedAt,jdbcType=TIMESTAMP},
|
||||
enable_search = #{record.enableSearch,jdbcType=INTEGER}
|
||||
<if test="_parameter != null">
|
||||
<include refid="Update_By_Example_Where_Clause" />
|
||||
</if>
|
||||
</update>
|
||||
<update id="updateByPrimaryKeySelective" parameterType="com.tencent.supersonic.chat.persistence.dataobject.AgentDO">
|
||||
update s2_agent
|
||||
<set>
|
||||
<if test="name != null">
|
||||
name = #{name,jdbcType=VARCHAR},
|
||||
</if>
|
||||
<if test="description != null">
|
||||
description = #{description,jdbcType=VARCHAR},
|
||||
</if>
|
||||
<if test="status != null">
|
||||
status = #{status,jdbcType=INTEGER},
|
||||
</if>
|
||||
<if test="examples != null">
|
||||
examples = #{examples,jdbcType=VARCHAR},
|
||||
</if>
|
||||
<if test="config != null">
|
||||
config = #{config,jdbcType=VARCHAR},
|
||||
</if>
|
||||
<if test="createdBy != null">
|
||||
created_by = #{createdBy,jdbcType=VARCHAR},
|
||||
</if>
|
||||
<if test="createdAt != null">
|
||||
created_at = #{createdAt,jdbcType=TIMESTAMP},
|
||||
</if>
|
||||
<if test="updatedBy != null">
|
||||
updated_by = #{updatedBy,jdbcType=VARCHAR},
|
||||
</if>
|
||||
<if test="updatedAt != null">
|
||||
updated_at = #{updatedAt,jdbcType=TIMESTAMP},
|
||||
</if>
|
||||
<if test="enableSearch != null">
|
||||
enable_search = #{enableSearch,jdbcType=INTEGER},
|
||||
</if>
|
||||
</set>
|
||||
where id = #{id,jdbcType=INTEGER}
|
||||
</update>
|
||||
<update id="updateByPrimaryKey" parameterType="com.tencent.supersonic.chat.persistence.dataobject.AgentDO">
|
||||
update s2_agent
|
||||
set name = #{name,jdbcType=VARCHAR},
|
||||
description = #{description,jdbcType=VARCHAR},
|
||||
status = #{status,jdbcType=INTEGER},
|
||||
examples = #{examples,jdbcType=VARCHAR},
|
||||
config = #{config,jdbcType=VARCHAR},
|
||||
created_by = #{createdBy,jdbcType=VARCHAR},
|
||||
created_at = #{createdAt,jdbcType=TIMESTAMP},
|
||||
updated_by = #{updatedBy,jdbcType=VARCHAR},
|
||||
updated_at = #{updatedAt,jdbcType=TIMESTAMP},
|
||||
enable_search = #{enableSearch,jdbcType=INTEGER}
|
||||
where id = #{id,jdbcType=INTEGER}
|
||||
</update>
|
||||
</mapper>
|
||||
@@ -60,3 +60,17 @@ CREATE TABLE `chat_query`
|
||||
KEY `common1` (`user_name`),
|
||||
KEY `common2` (`chat_id`)
|
||||
) ENGINE=InnoDB AUTO_INCREMENT=1 DEFAULT CHARSET=utf8;
|
||||
|
||||
|
||||
CREATE TABLE `chat`
|
||||
(
|
||||
`chat_id` bigint(8) NOT NULL AUTO_INCREMENT,
|
||||
`chat_name` varchar(100) DEFAULT NULL,
|
||||
`create_time` datetime DEFAULT NULL,
|
||||
`last_time` datetime DEFAULT NULL,
|
||||
`creator` varchar(30) DEFAULT NULL,
|
||||
`last_question` varchar(200) DEFAULT NULL,
|
||||
`is_delete` int(2) DEFAULT '0' COMMENT 'is deleted',
|
||||
`is_top` int(2) DEFAULT '0' COMMENT 'is top',
|
||||
PRIMARY KEY (`chat_id`)
|
||||
) ENGINE=InnoDB AUTO_INCREMENT=1 DEFAULT CHARSET=utf8;
|
||||
|
||||
@@ -0,0 +1,37 @@
|
||||
package com.tencent.supersonic.chat.query.dsl.optimizer;
|
||||
|
||||
import com.tencent.supersonic.chat.api.pojo.CorrectionInfo;
|
||||
import com.tencent.supersonic.chat.api.pojo.SchemaElement;
|
||||
import com.tencent.supersonic.chat.api.pojo.SemanticParseInfo;
|
||||
import org.junit.Assert;
|
||||
import org.junit.jupiter.api.Test;
|
||||
|
||||
class DateFieldCorrectorTest {
|
||||
|
||||
@Test
|
||||
void rewriter() {
|
||||
DateFieldCorrector dateFieldCorrector = new DateFieldCorrector();
|
||||
SemanticParseInfo parseInfo = new SemanticParseInfo();
|
||||
SchemaElement model = new SchemaElement();
|
||||
model.setId(2L);
|
||||
parseInfo.setModel(model);
|
||||
CorrectionInfo correctionInfo = CorrectionInfo.builder()
|
||||
.sql("select count(歌曲名) from 歌曲库 ")
|
||||
.parseInfo(parseInfo)
|
||||
.build();
|
||||
|
||||
CorrectionInfo rewriter = dateFieldCorrector.rewriter(correctionInfo);
|
||||
|
||||
Assert.assertEquals("SELECT count(歌曲名) FROM 歌曲库 WHERE 数据日期 = '2023-08-14'", rewriter.getSql());
|
||||
|
||||
correctionInfo = CorrectionInfo.builder()
|
||||
.sql("select count(歌曲名) from 歌曲库 where 数据日期 = '2023-08-14'")
|
||||
.parseInfo(parseInfo)
|
||||
.build();
|
||||
|
||||
rewriter = dateFieldCorrector.rewriter(correctionInfo);
|
||||
|
||||
Assert.assertEquals("select count(歌曲名) from 歌曲库 where 数据日期 = '2023-08-14'", rewriter.getSql());
|
||||
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,25 @@
|
||||
package com.tencent.supersonic.chat.query.dsl.optimizer;
|
||||
|
||||
import static org.junit.jupiter.api.Assertions.*;
|
||||
|
||||
import com.tencent.supersonic.chat.api.pojo.CorrectionInfo;
|
||||
import org.junit.Assert;
|
||||
import org.junit.jupiter.api.Test;
|
||||
|
||||
class SelectFieldAppendCorrectorTest {
|
||||
|
||||
@Test
|
||||
void rewriter() {
|
||||
SelectFieldAppendCorrector corrector = new SelectFieldAppendCorrector();
|
||||
CorrectionInfo correctionInfo = CorrectionInfo.builder()
|
||||
.sql("select 歌曲名 from 歌曲库 where datediff('day', 发布日期, '2023-08-09') <= 1 and 歌手名 = '邓紫棋' and sys_imp_date = '2023-08-09' and 歌曲发布时 = '2023-08-01' order by 播放量 desc limit 11")
|
||||
.build();
|
||||
|
||||
CorrectionInfo rewriter = corrector.rewriter(correctionInfo);
|
||||
|
||||
Assert.assertEquals(
|
||||
"SELECT 歌曲名, 歌手名, 歌曲发布时, 发布日期 FROM 歌曲库 WHERE datediff('day', 发布日期, '2023-08-09') <= 1 AND 歌手名 = '邓紫棋' AND sys_imp_date = '2023-08-09' AND 歌曲发布时 = '2023-08-01' ORDER BY 播放量 DESC LIMIT 11",
|
||||
rewriter.getSql());
|
||||
|
||||
}
|
||||
}
|
||||
@@ -1,13 +1,17 @@
|
||||
package com.tencent.supersonic.knowledge.dictionary.builder;
|
||||
|
||||
import com.google.common.collect.Lists;
|
||||
|
||||
import java.util.ArrayList;
|
||||
import java.util.List;
|
||||
|
||||
import com.tencent.supersonic.chat.api.pojo.SchemaElement;
|
||||
import com.tencent.supersonic.knowledge.dictionary.DictWord;
|
||||
import com.tencent.supersonic.knowledge.dictionary.DictWordType;
|
||||
import java.util.List;
|
||||
import org.apache.commons.lang3.StringUtils;
|
||||
import org.springframework.beans.factory.annotation.Value;
|
||||
import org.springframework.stereotype.Service;
|
||||
import org.springframework.util.CollectionUtils;
|
||||
|
||||
/**
|
||||
* dimension word nature
|
||||
@@ -23,6 +27,7 @@ public class DimensionWordBuilder extends BaseWordBuilder {
|
||||
public List<DictWord> doGet(String word, SchemaElement schemaElement) {
|
||||
List<DictWord> result = Lists.newArrayList();
|
||||
result.add(getOnwWordNature(word, schemaElement, false));
|
||||
result.addAll(getOnwWordNatureAlias(schemaElement, false));
|
||||
if (nlpDimensionUseSuffix) {
|
||||
String reverseWord = StringUtils.reverse(word);
|
||||
if (StringUtils.isNotEmpty(word) && !word.equalsIgnoreCase(reverseWord)) {
|
||||
@@ -46,4 +51,16 @@ public class DimensionWordBuilder extends BaseWordBuilder {
|
||||
return dictWord;
|
||||
}
|
||||
|
||||
private List<DictWord> getOnwWordNatureAlias(SchemaElement schemaElement, boolean isSuffix) {
|
||||
List<DictWord> dictWords = new ArrayList<>();
|
||||
if (CollectionUtils.isEmpty(schemaElement.getAlias())) {
|
||||
return dictWords;
|
||||
}
|
||||
|
||||
for (String alias : schemaElement.getAlias()) {
|
||||
dictWords.add(getOnwWordNature(alias, schemaElement, false));
|
||||
}
|
||||
return dictWords;
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
@@ -1,13 +1,17 @@
|
||||
package com.tencent.supersonic.knowledge.dictionary.builder;
|
||||
|
||||
import com.google.common.collect.Lists;
|
||||
|
||||
import java.util.ArrayList;
|
||||
import java.util.List;
|
||||
|
||||
import com.tencent.supersonic.chat.api.pojo.SchemaElement;
|
||||
import com.tencent.supersonic.knowledge.dictionary.DictWord;
|
||||
import com.tencent.supersonic.knowledge.dictionary.DictWordType;
|
||||
import java.util.List;
|
||||
import org.apache.commons.lang3.StringUtils;
|
||||
import org.springframework.beans.factory.annotation.Value;
|
||||
import org.springframework.stereotype.Service;
|
||||
import org.springframework.util.CollectionUtils;
|
||||
|
||||
/**
|
||||
* Metric DictWord
|
||||
@@ -22,6 +26,7 @@ public class MetricWordBuilder extends BaseWordBuilder {
|
||||
public List<DictWord> doGet(String word, SchemaElement schemaElement) {
|
||||
List<DictWord> result = Lists.newArrayList();
|
||||
result.add(getOnwWordNature(word, schemaElement, false));
|
||||
result.addAll(getOnwWordNatureAlias(schemaElement, false));
|
||||
if (nlpMetricUseSuffix) {
|
||||
String reverseWord = StringUtils.reverse(word);
|
||||
if (!word.equalsIgnoreCase(reverseWord)) {
|
||||
@@ -45,4 +50,16 @@ public class MetricWordBuilder extends BaseWordBuilder {
|
||||
return dictWord;
|
||||
}
|
||||
|
||||
private List<DictWord> getOnwWordNatureAlias(SchemaElement schemaElement, boolean isSuffix) {
|
||||
List<DictWord> dictWords = new ArrayList<>();
|
||||
if (CollectionUtils.isEmpty(schemaElement.getAlias())) {
|
||||
return dictWords;
|
||||
}
|
||||
|
||||
for (String alias : schemaElement.getAlias()) {
|
||||
dictWords.add(getOnwWordNature(alias, schemaElement, false));
|
||||
}
|
||||
return dictWords;
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
@@ -2,51 +2,38 @@ package com.tencent.supersonic.knowledge.semantic;
|
||||
|
||||
import com.github.pagehelper.PageInfo;
|
||||
import com.tencent.supersonic.auth.api.authentication.pojo.User;
|
||||
import com.tencent.supersonic.common.pojo.enums.AuthType;
|
||||
import com.tencent.supersonic.common.util.ContextUtils;
|
||||
import com.tencent.supersonic.common.util.JsonUtil;
|
||||
import com.tencent.supersonic.common.util.S2ThreadContext;
|
||||
import com.tencent.supersonic.common.pojo.enums.AuthType;
|
||||
import com.tencent.supersonic.semantic.api.model.request.ModelSchemaFilterReq;
|
||||
import com.tencent.supersonic.semantic.api.model.request.PageDimensionReq;
|
||||
import com.tencent.supersonic.semantic.api.model.request.PageMetricReq;
|
||||
import com.tencent.supersonic.semantic.api.model.response.DimensionResp;
|
||||
import com.tencent.supersonic.semantic.api.model.response.DomainResp;
|
||||
import com.tencent.supersonic.semantic.api.model.response.MetricResp;
|
||||
import com.tencent.supersonic.semantic.api.model.response.ModelResp;
|
||||
import com.tencent.supersonic.semantic.api.model.response.ModelSchemaResp;
|
||||
import com.tencent.supersonic.semantic.api.model.response.QueryResultWithSchemaResp;
|
||||
import com.tencent.supersonic.semantic.api.model.response.*;
|
||||
import com.tencent.supersonic.semantic.api.query.request.QueryDslReq;
|
||||
import com.tencent.supersonic.semantic.api.query.request.QueryMultiStructReq;
|
||||
import com.tencent.supersonic.semantic.api.query.request.QueryStructReq;
|
||||
import com.tencent.supersonic.semantic.model.domain.DimensionService;
|
||||
import com.tencent.supersonic.semantic.model.domain.DomainService;
|
||||
import com.tencent.supersonic.semantic.model.domain.MetricService;
|
||||
import com.tencent.supersonic.semantic.model.domain.ModelService;
|
||||
import com.tencent.supersonic.semantic.query.service.QueryService;
|
||||
import com.tencent.supersonic.semantic.query.service.SchemaService;
|
||||
import java.util.List;
|
||||
import lombok.SneakyThrows;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
|
||||
@Slf4j
|
||||
public class LocalSemanticLayer extends BaseSemanticLayer {
|
||||
|
||||
private SchemaService schemaService;
|
||||
private S2ThreadContext s2ThreadContext;
|
||||
private DomainService domainService;
|
||||
private ModelService modelService;
|
||||
private DimensionService dimensionService;
|
||||
private MetricService metricService;
|
||||
|
||||
@SneakyThrows
|
||||
@Override
|
||||
public QueryResultWithSchemaResp queryByStruct(QueryStructReq queryStructReq, User user) {
|
||||
try {
|
||||
public QueryResultWithSchemaResp queryByStruct(QueryStructReq queryStructReq, User user){
|
||||
QueryService queryService = ContextUtils.getBean(QueryService.class);
|
||||
QueryResultWithSchemaResp queryResultWithSchemaResp = queryService.queryByStruct(queryStructReq, user);
|
||||
return queryResultWithSchemaResp;
|
||||
} catch (Exception e) {
|
||||
log.info("queryByStruct has an exception:{}", e.toString());
|
||||
}
|
||||
return null;
|
||||
return queryService.queryByStructWithAuth(queryStructReq, user);
|
||||
}
|
||||
|
||||
@Override
|
||||
|
||||
@@ -8,21 +8,17 @@ import com.tencent.supersonic.semantic.api.model.pojo.Entity;
|
||||
import com.tencent.supersonic.semantic.api.model.response.DimSchemaResp;
|
||||
import com.tencent.supersonic.semantic.api.model.response.MetricSchemaResp;
|
||||
import com.tencent.supersonic.semantic.api.model.response.ModelSchemaResp;
|
||||
import java.util.ArrayList;
|
||||
import java.util.Arrays;
|
||||
import java.util.HashSet;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
import java.util.Objects;
|
||||
import java.util.Set;
|
||||
import java.util.stream.Collectors;
|
||||
import org.apache.commons.lang3.StringUtils;
|
||||
import org.apache.logging.log4j.util.Strings;
|
||||
import org.springframework.beans.BeanUtils;
|
||||
import org.springframework.util.CollectionUtils;
|
||||
|
||||
import java.util.*;
|
||||
import java.util.stream.Collectors;
|
||||
|
||||
public class ModelSchemaBuilder {
|
||||
|
||||
private static String aliasSplit = ",";
|
||||
|
||||
public static ModelSchema build(ModelSchemaResp resp) {
|
||||
ModelSchema domainSchema = new ModelSchema();
|
||||
|
||||
@@ -37,6 +33,13 @@ public class ModelSchemaBuilder {
|
||||
|
||||
Set<SchemaElement> metrics = new HashSet<>();
|
||||
for (MetricSchemaResp metric : resp.getMetrics()) {
|
||||
|
||||
List<String> alias = new ArrayList<>();
|
||||
String aliasStr = metric.getAlias();
|
||||
if (Strings.isNotEmpty(aliasStr)) {
|
||||
alias = Arrays.asList(aliasStr.split(aliasSplit));
|
||||
}
|
||||
|
||||
SchemaElement metricToAdd = SchemaElement.builder()
|
||||
.model(resp.getId())
|
||||
.id(metric.getId())
|
||||
@@ -44,16 +47,10 @@ public class ModelSchemaBuilder {
|
||||
.bizName(metric.getBizName())
|
||||
.type(SchemaElementType.METRIC)
|
||||
.useCnt(metric.getUseCnt())
|
||||
.alias(alias)
|
||||
.build();
|
||||
metrics.add(metricToAdd);
|
||||
|
||||
String alias = metric.getAlias();
|
||||
if (StringUtils.isNotEmpty(alias)) {
|
||||
SchemaElement alisMetricToAdd = new SchemaElement();
|
||||
BeanUtils.copyProperties(metricToAdd, alisMetricToAdd);
|
||||
alisMetricToAdd.setName(alias);
|
||||
metrics.add(alisMetricToAdd);
|
||||
}
|
||||
}
|
||||
domainSchema.getMetrics().addAll(metrics);
|
||||
|
||||
@@ -74,6 +71,11 @@ public class ModelSchemaBuilder {
|
||||
}
|
||||
}
|
||||
|
||||
List<String> alias = new ArrayList<>();
|
||||
String aliasStr = dim.getAlias();
|
||||
if (Strings.isNotEmpty(aliasStr)) {
|
||||
alias = Arrays.asList(aliasStr.split(aliasSplit));
|
||||
}
|
||||
SchemaElement dimToAdd = SchemaElement.builder()
|
||||
.model(resp.getId())
|
||||
.id(dim.getId())
|
||||
@@ -81,17 +83,10 @@ public class ModelSchemaBuilder {
|
||||
.bizName(dim.getBizName())
|
||||
.type(SchemaElementType.DIMENSION)
|
||||
.useCnt(dim.getUseCnt())
|
||||
.alias(alias)
|
||||
.build();
|
||||
dimensions.add(dimToAdd);
|
||||
|
||||
String alias = dim.getAlias();
|
||||
if (StringUtils.isNotEmpty(alias)) {
|
||||
SchemaElement alisDimToAdd = new SchemaElement();
|
||||
BeanUtils.copyProperties(dimToAdd, alisDimToAdd);
|
||||
alisDimToAdd.setName(alias);
|
||||
dimensions.add(alisDimToAdd);
|
||||
}
|
||||
|
||||
SchemaElement dimValueToAdd = SchemaElement.builder()
|
||||
.model(resp.getId())
|
||||
.id(dim.getId())
|
||||
@@ -115,7 +110,7 @@ public class ModelSchemaBuilder {
|
||||
.collect(
|
||||
Collectors.toMap(SchemaElement::getId, schemaElement -> schemaElement, (k1, k2) -> k2));
|
||||
if (idAndDimPair.containsKey(entity.getEntityId())) {
|
||||
entityElement = idAndDimPair.get(entity.getEntityId());
|
||||
BeanUtils.copyProperties(idAndDimPair.get(entity.getEntityId()), entityElement);
|
||||
entityElement.setType(SchemaElementType.ENTITY);
|
||||
}
|
||||
entityElement.setAlias(entity.getNames());
|
||||
|
||||
@@ -118,6 +118,12 @@
|
||||
</dependency>
|
||||
|
||||
|
||||
<dependency>
|
||||
<groupId>com.github.plexpt</groupId>
|
||||
<artifactId>chatgpt</artifactId>
|
||||
<version>4.1.2</version>
|
||||
</dependency>
|
||||
|
||||
<dependency>
|
||||
<groupId>com.github.pagehelper</groupId>
|
||||
<artifactId>pagehelper</artifactId>
|
||||
|
||||
@@ -0,0 +1,130 @@
|
||||
package com.tencent.supersonic.common.util;
|
||||
|
||||
|
||||
import com.plexpt.chatgpt.ChatGPT;
|
||||
import com.plexpt.chatgpt.entity.chat.ChatCompletion;
|
||||
import com.plexpt.chatgpt.entity.chat.ChatCompletionResponse;
|
||||
import com.plexpt.chatgpt.entity.chat.Message;
|
||||
import com.plexpt.chatgpt.util.Proxys;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.springframework.beans.factory.annotation.Value;
|
||||
import org.springframework.stereotype.Component;
|
||||
|
||||
import java.net.Proxy;
|
||||
import java.text.SimpleDateFormat;
|
||||
import java.util.Arrays;
|
||||
import java.util.Date;
|
||||
|
||||
|
||||
@Component
|
||||
@Slf4j
|
||||
public class ChatGptHelper {
|
||||
|
||||
@Value("${llm.chatgpt.apikey:}")
|
||||
private String apiKey;
|
||||
|
||||
@Value("${llm.chatgpt.apiHost:}")
|
||||
private String apiHost;
|
||||
|
||||
@Value("${llm.chatgpt.proxyIp:}")
|
||||
private String proxyIp;
|
||||
|
||||
@Value("${llm.chatgpt.proxyPort:}")
|
||||
private Integer proxyPort;
|
||||
|
||||
|
||||
public ChatGPT getChatGPT(){
|
||||
Proxy proxy = null;
|
||||
if (!"default".equals(proxyIp)){
|
||||
proxy = Proxys.http(proxyIp, proxyPort);
|
||||
}
|
||||
return ChatGPT.builder()
|
||||
.apiKey(apiKey)
|
||||
.proxy(proxy)
|
||||
.timeout(900)
|
||||
.apiHost(apiHost) //反向代理地址
|
||||
.build()
|
||||
.init();
|
||||
}
|
||||
|
||||
public Message getChatCompletion(Message system,Message message){
|
||||
ChatCompletion chatCompletion = ChatCompletion.builder()
|
||||
.model(ChatCompletion.Model.GPT_3_5_TURBO_16K.getName())
|
||||
.messages(Arrays.asList(system, message))
|
||||
.maxTokens(10000)
|
||||
.temperature(0.9)
|
||||
.build();
|
||||
ChatCompletionResponse response = getChatGPT().chatCompletion(chatCompletion);
|
||||
return response.getChoices().get(0).getMessage();
|
||||
}
|
||||
|
||||
public String inferredTime(String queryText){
|
||||
long nowTime = System.currentTimeMillis();
|
||||
Date date = new Date(nowTime);
|
||||
SimpleDateFormat sdf = new SimpleDateFormat("yyyy-MM-dd");
|
||||
String formattedDate = sdf.format(date);
|
||||
Message system = Message.ofSystem("现在时间 "+formattedDate+",你是一个专业的数据分析师,你的任务是基于数据,专业的解答用户的问题。" +
|
||||
"你需要遵守以下规则:\n" +
|
||||
"1.返回规范的数据格式,json,如: 输入:近 10 天的日活跃数,输出:{\"start\":\"2023-07-21\",\"end\":\"2023-07-31\"}" +
|
||||
"2.你对时间数据要求规范,能从近 10 天,国庆节,端午节,获取到相应的时间,填写到 json 中。\n"+
|
||||
"3.你的数据时间,只有当前及之前时间即可,超过则回复去年\n" +
|
||||
"4.只需要解析出时间,时间可以是时间月和年或日、日历采用公历\n"+
|
||||
"5.时间给出要是绝对正确,不能瞎编\n"
|
||||
);
|
||||
Message message = Message.of("输入:"+queryText+",输出:");
|
||||
Message res = getChatCompletion(system, message);
|
||||
return res.getContent();
|
||||
}
|
||||
|
||||
|
||||
public String mockAlias(String mockType,String name,String bizName,String table,String desc,Boolean isPercentage){
|
||||
String msg = "Assuming you are a professional data analyst specializing in indicators, you have a vast amount of data analysis indicator content. You are familiar with the basic format of the content,Now, Construct your answer Based on the following json-schema.\n" +
|
||||
"{\n" +
|
||||
"\"$schema\": \"http://json-schema.org/draft-07/schema#\",\n" +
|
||||
"\"type\": \"array\",\n" +
|
||||
"\"minItems\": 2,\n" +
|
||||
"\"maxItems\": 4,\n" +
|
||||
"\"items\": {\n" +
|
||||
"\"type\": \"string\",\n" +
|
||||
"\"description\": \"Assuming you are a data analyst and give a defined "+mockType+" name: " +name+","+
|
||||
"this "+mockType+" is from database and table: "+table+ ",This "+mockType+" calculates the field source: "+bizName+", The description of this indicator is: "+desc+", provide some aliases for this,please take chinese or english,but more chinese and Not repeating,\"\n" +
|
||||
"},\n" +
|
||||
"\"additionalProperties\":false}\n" +
|
||||
"Please double-check whether the answer conforms to the format described in the JSON-schema.\n" +
|
||||
"ANSWER JSON:";
|
||||
log.info("msg:{}",msg);
|
||||
Message system = Message.ofSystem("");
|
||||
Message message = Message.of(msg);
|
||||
Message res = getChatCompletion(system, message);
|
||||
return res.getContent();
|
||||
}
|
||||
|
||||
|
||||
|
||||
public String mockDimensionValueAlias(String json){
|
||||
String msg = "Assuming you are a professional data analyst specializing in indicators,for you a json list," +
|
||||
"the required content to follow is as follows: " +
|
||||
"1. The format of JSON," +
|
||||
"2. Only return in JSON format," +
|
||||
"3. the array item > 1 and < 5,more alias," +
|
||||
"for example:input:[\"qq_music\",\"kugou_music\"],out:{\"tran\":[\"qq音乐\",\"酷狗音乐\"],\"alias\":{\"qq_music\":[\"q音\",\"qq音乐\"],\"kugou_music\":[\"kugou\",\"酷狗\"]}}," +
|
||||
"now input: " + json + ","+
|
||||
"answer json:";
|
||||
log.info("msg:{}",msg);
|
||||
Message system = Message.ofSystem("");
|
||||
Message message = Message.of(msg);
|
||||
Message res = getChatCompletion(system, message);
|
||||
return res.getContent();
|
||||
}
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
public static void main(String[] args) {
|
||||
ChatGptHelper chatGptHelper = new ChatGptHelper();
|
||||
System.out.println(chatGptHelper.mockAlias("","","","","",false));
|
||||
}
|
||||
|
||||
|
||||
}
|
||||
@@ -1,11 +1,13 @@
|
||||
package com.tencent.supersonic.common.util;
|
||||
|
||||
import java.text.ParseException;
|
||||
import java.text.SimpleDateFormat;
|
||||
import java.time.LocalDate;
|
||||
import java.time.LocalDateTime;
|
||||
import java.time.format.DateTimeFormatter;
|
||||
import java.time.temporal.TemporalAdjusters;
|
||||
import java.util.Calendar;
|
||||
import java.util.Date;
|
||||
import java.util.Objects;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
|
||||
@Slf4j
|
||||
@@ -56,30 +58,37 @@ public class DateUtils {
|
||||
return dateFormat.format(calendar.getTime());
|
||||
}
|
||||
|
||||
|
||||
public static String getBeforeDate(String date, int intervalDay, DatePeriodEnum datePeriodEnum) {
|
||||
Calendar calendar = Calendar.getInstance();
|
||||
SimpleDateFormat dateFormat = new SimpleDateFormat(DATE_FORMAT_DOT);
|
||||
try {
|
||||
calendar.setTime(dateFormat.parse(date));
|
||||
} catch (ParseException e) {
|
||||
log.error("parse error");
|
||||
}
|
||||
DateTimeFormatter dateTimeFormatter = DateTimeFormatter.ofPattern(DATE_FORMAT_DOT);
|
||||
LocalDate currentDate = LocalDate.parse(date, dateTimeFormatter);
|
||||
LocalDate result = null;
|
||||
switch (datePeriodEnum) {
|
||||
case DAY:
|
||||
calendar.set(Calendar.DATE, calendar.get(Calendar.DATE) - intervalDay);
|
||||
result = currentDate.minusDays(intervalDay);
|
||||
break;
|
||||
case WEEK:
|
||||
calendar.set(Calendar.DATE, calendar.get(Calendar.DATE) - intervalDay * 7);
|
||||
result = currentDate.minusWeeks(intervalDay);
|
||||
if (intervalDay == 0) {
|
||||
result = result.with(TemporalAdjusters.previousOrSame(java.time.DayOfWeek.MONDAY));
|
||||
}
|
||||
break;
|
||||
case MONTH:
|
||||
calendar.set(Calendar.MONTH, calendar.get(Calendar.MONTH) - intervalDay);
|
||||
result = currentDate.minusMonths(intervalDay);
|
||||
if (intervalDay == 0) {
|
||||
result = result.with(TemporalAdjusters.firstDayOfMonth());
|
||||
}
|
||||
break;
|
||||
case YEAR:
|
||||
calendar.set(Calendar.YEAR, calendar.get(Calendar.YEAR) - intervalDay);
|
||||
result = currentDate.minusYears(intervalDay);
|
||||
if (intervalDay == 0) {
|
||||
result = result.with(TemporalAdjusters.firstDayOfYear());
|
||||
}
|
||||
break;
|
||||
default:
|
||||
}
|
||||
return dateFormat.format(calendar.getTime());
|
||||
if (Objects.nonNull(result)) {
|
||||
return result.format(DateTimeFormatter.ofPattern(DATE_FORMAT_DOT));
|
||||
}
|
||||
return null;
|
||||
}
|
||||
}
|
||||
|
||||
@@ -9,7 +9,6 @@ public class StringUtil {
|
||||
public static final String COMMA_WRAPPER = "'%s'";
|
||||
public static final String SPACE_WRAPPER = " %s ";
|
||||
|
||||
|
||||
public static String getCommaWrap(String value) {
|
||||
return String.format(COMMA_WRAPPER, value);
|
||||
}
|
||||
|
||||
@@ -41,7 +41,9 @@ public class CCJSqlParserUtils {
|
||||
}
|
||||
Set<String> result = new HashSet<>();
|
||||
Expression where = plainSelect.getWhere();
|
||||
if (Objects.nonNull(where)) {
|
||||
where.accept(new FieldAcquireVisitor(result));
|
||||
}
|
||||
return new ArrayList<>(result);
|
||||
}
|
||||
|
||||
@@ -166,7 +168,24 @@ public class CCJSqlParserUtils {
|
||||
if (Objects.nonNull(groupByElement)) {
|
||||
groupByElement.accept(new GroupByReplaceVisitor(fieldToBizName));
|
||||
}
|
||||
//5. add Waiting Expression
|
||||
return selectStatement.toString();
|
||||
}
|
||||
|
||||
|
||||
public static String replaceFunction(String sql) {
|
||||
Select selectStatement = getSelect(sql);
|
||||
SelectBody selectBody = selectStatement.getSelectBody();
|
||||
if (!(selectBody instanceof PlainSelect)) {
|
||||
return sql;
|
||||
}
|
||||
PlainSelect plainSelect = (PlainSelect) selectBody;
|
||||
//1. replace where dataDiff function
|
||||
Expression where = plainSelect.getWhere();
|
||||
FunctionReplaceVisitor visitor = new FunctionReplaceVisitor();
|
||||
if (Objects.nonNull(where)) {
|
||||
where.accept(visitor);
|
||||
}
|
||||
//2. add Waiting Expression
|
||||
List<Expression> waitingForAdds = visitor.getWaitingForAdds();
|
||||
addWaitingExpression(plainSelect, where, waitingForAdds);
|
||||
return selectStatement.toString();
|
||||
@@ -181,9 +200,10 @@ public class CCJSqlParserUtils {
|
||||
if (where == null) {
|
||||
plainSelect.setWhere(expression);
|
||||
} else {
|
||||
plainSelect.setWhere(new AndExpression(where, expression));
|
||||
where = new AndExpression(where, expression);
|
||||
}
|
||||
}
|
||||
plainSelect.setWhere(where);
|
||||
}
|
||||
|
||||
|
||||
|
||||
@@ -1,25 +1,15 @@
|
||||
package com.tencent.supersonic.common.util.jsqlparser;
|
||||
|
||||
import java.util.ArrayList;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
import java.util.Objects;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import net.sf.jsqlparser.expression.Expression;
|
||||
import net.sf.jsqlparser.expression.ExpressionVisitorAdapter;
|
||||
import net.sf.jsqlparser.expression.operators.relational.GreaterThan;
|
||||
import net.sf.jsqlparser.expression.operators.relational.GreaterThanEquals;
|
||||
import net.sf.jsqlparser.expression.operators.relational.MinorThan;
|
||||
import net.sf.jsqlparser.expression.operators.relational.MinorThanEquals;
|
||||
import net.sf.jsqlparser.schema.Column;
|
||||
|
||||
@Slf4j
|
||||
public class FieldReplaceVisitor extends ExpressionVisitorAdapter {
|
||||
|
||||
ParseVisitorHelper parseVisitorHelper = new ParseVisitorHelper();
|
||||
|
||||
private Map<String, String> fieldToBizName;
|
||||
private List<Expression> waitingForAdds = new ArrayList<>();
|
||||
|
||||
public FieldReplaceVisitor(Map<String, String> fieldToBizName) {
|
||||
this.fieldToBizName = fieldToBizName;
|
||||
@@ -29,42 +19,4 @@ public class FieldReplaceVisitor extends ExpressionVisitorAdapter {
|
||||
public void visit(Column column) {
|
||||
parseVisitorHelper.replaceColumn(column, fieldToBizName);
|
||||
}
|
||||
|
||||
@Override
|
||||
public void visit(MinorThan expr) {
|
||||
Expression expression = parseVisitorHelper.reparseDate(expr, fieldToBizName, ">");
|
||||
if (Objects.nonNull(expression)) {
|
||||
waitingForAdds.add(expression);
|
||||
}
|
||||
}
|
||||
|
||||
@Override
|
||||
public void visit(MinorThanEquals expr) {
|
||||
Expression expression = parseVisitorHelper.reparseDate(expr, fieldToBizName, ">=");
|
||||
if (Objects.nonNull(expression)) {
|
||||
waitingForAdds.add(expression);
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@Override
|
||||
public void visit(GreaterThan expr) {
|
||||
Expression expression = parseVisitorHelper.reparseDate(expr, fieldToBizName, "<");
|
||||
if (Objects.nonNull(expression)) {
|
||||
waitingForAdds.add(expression);
|
||||
}
|
||||
}
|
||||
|
||||
@Override
|
||||
public void visit(GreaterThanEquals expr) {
|
||||
Expression expression = parseVisitorHelper.reparseDate(expr, fieldToBizName, "<=");
|
||||
if (Objects.nonNull(expression)) {
|
||||
waitingForAdds.add(expression);
|
||||
}
|
||||
}
|
||||
|
||||
public List<Expression> getWaitingForAdds() {
|
||||
return waitingForAdds;
|
||||
}
|
||||
|
||||
}
|
||||
@@ -0,0 +1,171 @@
|
||||
package com.tencent.supersonic.common.util.jsqlparser;
|
||||
|
||||
import com.tencent.supersonic.common.util.DatePeriodEnum;
|
||||
import com.tencent.supersonic.common.util.DateUtils;
|
||||
import com.tencent.supersonic.common.util.StringUtil;
|
||||
import java.util.ArrayList;
|
||||
import java.util.List;
|
||||
import java.util.Objects;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import net.sf.jsqlparser.JSQLParserException;
|
||||
import net.sf.jsqlparser.expression.DoubleValue;
|
||||
import net.sf.jsqlparser.expression.Expression;
|
||||
import net.sf.jsqlparser.expression.ExpressionVisitorAdapter;
|
||||
import net.sf.jsqlparser.expression.Function;
|
||||
import net.sf.jsqlparser.expression.LongValue;
|
||||
import net.sf.jsqlparser.expression.StringValue;
|
||||
import net.sf.jsqlparser.expression.operators.relational.ComparisonOperator;
|
||||
import net.sf.jsqlparser.expression.operators.relational.EqualsTo;
|
||||
import net.sf.jsqlparser.expression.operators.relational.GreaterThan;
|
||||
import net.sf.jsqlparser.expression.operators.relational.GreaterThanEquals;
|
||||
import net.sf.jsqlparser.expression.operators.relational.MinorThan;
|
||||
import net.sf.jsqlparser.expression.operators.relational.MinorThanEquals;
|
||||
import net.sf.jsqlparser.parser.CCJSqlParserUtil;
|
||||
import net.sf.jsqlparser.schema.Column;
|
||||
import org.apache.commons.collections.CollectionUtils;
|
||||
|
||||
@Slf4j
|
||||
public class FunctionReplaceVisitor extends ExpressionVisitorAdapter {
|
||||
|
||||
public static final String DATE_FUNCTION = "datediff";
|
||||
public static final double HALF_YEAR = 0.5d;
|
||||
public static final int SIX_MONTH = 6;
|
||||
public static final String EQUAL = "=";
|
||||
private List<Expression> waitingForAdds = new ArrayList<>();
|
||||
|
||||
@Override
|
||||
public void visit(MinorThan expr) {
|
||||
List<Expression> expressions = reparseDate(expr, ">");
|
||||
if (Objects.nonNull(expressions)) {
|
||||
waitingForAdds.addAll(expressions);
|
||||
}
|
||||
}
|
||||
|
||||
@Override
|
||||
public void visit(EqualsTo expr) {
|
||||
List<Expression> expressions = reparseDate(expr, ">=");
|
||||
if (Objects.nonNull(expressions)) {
|
||||
waitingForAdds.addAll(expressions);
|
||||
}
|
||||
}
|
||||
|
||||
@Override
|
||||
public void visit(MinorThanEquals expr) {
|
||||
List<Expression> expressions = reparseDate(expr, ">=");
|
||||
if (Objects.nonNull(expressions)) {
|
||||
waitingForAdds.addAll(expressions);
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@Override
|
||||
public void visit(GreaterThan expr) {
|
||||
List<Expression> expressions = reparseDate(expr, "<");
|
||||
if (Objects.nonNull(expressions)) {
|
||||
waitingForAdds.addAll(expressions);
|
||||
}
|
||||
}
|
||||
|
||||
@Override
|
||||
public void visit(GreaterThanEquals expr) {
|
||||
List<Expression> expressions = reparseDate(expr, "<=");
|
||||
if (Objects.nonNull(expressions)) {
|
||||
waitingForAdds.addAll(expressions);
|
||||
}
|
||||
}
|
||||
|
||||
public List<Expression> getWaitingForAdds() {
|
||||
return waitingForAdds;
|
||||
}
|
||||
|
||||
|
||||
public List<Expression> reparseDate(ComparisonOperator comparisonOperator, String startDateOperator) {
|
||||
List<Expression> result = new ArrayList<>();
|
||||
Expression leftExpression = comparisonOperator.getLeftExpression();
|
||||
if (!(leftExpression instanceof Function)) {
|
||||
return result;
|
||||
}
|
||||
Function leftExpressionFunction = (Function) leftExpression;
|
||||
if (!leftExpressionFunction.toString().contains(DATE_FUNCTION)) {
|
||||
return result;
|
||||
}
|
||||
List<Expression> leftExpressions = leftExpressionFunction.getParameters().getExpressions();
|
||||
if (CollectionUtils.isEmpty(leftExpressions) || leftExpressions.size() < 3) {
|
||||
return result;
|
||||
}
|
||||
Column field = (Column) leftExpressions.get(1);
|
||||
String columnName = field.getColumnName();
|
||||
try {
|
||||
String startDateValue = getStartDateStr(comparisonOperator, leftExpressions);
|
||||
String endDateValue = getEndDateValue(leftExpressions);
|
||||
String endDateOperator = comparisonOperator.getStringExpression();
|
||||
String condExpr =
|
||||
columnName + StringUtil.getSpaceWrap(getEndDateOperator(comparisonOperator))
|
||||
+ StringUtil.getCommaWrap(endDateValue);
|
||||
ComparisonOperator expression = (ComparisonOperator) CCJSqlParserUtil.parseCondExpression(condExpr);
|
||||
|
||||
String startDataCondExpr =
|
||||
columnName + StringUtil.getSpaceWrap(startDateOperator) + StringUtil.getCommaWrap(startDateValue);
|
||||
|
||||
if (EQUAL.equalsIgnoreCase(endDateOperator)) {
|
||||
result.add(CCJSqlParserUtil.parseCondExpression(condExpr));
|
||||
expression = (ComparisonOperator) CCJSqlParserUtil.parseCondExpression(" 1 = 1 ");
|
||||
}
|
||||
comparisonOperator.setLeftExpression(null);
|
||||
comparisonOperator.setRightExpression(null);
|
||||
comparisonOperator.setASTNode(null);
|
||||
|
||||
comparisonOperator.setLeftExpression(expression.getLeftExpression());
|
||||
comparisonOperator.setRightExpression(expression.getRightExpression());
|
||||
comparisonOperator.setASTNode(expression.getASTNode());
|
||||
|
||||
result.add(CCJSqlParserUtil.parseCondExpression(startDataCondExpr));
|
||||
return result;
|
||||
} catch (JSQLParserException e) {
|
||||
log.error("JSQLParserException", e);
|
||||
}
|
||||
|
||||
return null;
|
||||
}
|
||||
|
||||
private String getStartDateStr(ComparisonOperator minorThanEquals, List<Expression> expressions) {
|
||||
String unitValue = getUnit(expressions);
|
||||
String dateValue = getEndDateValue(expressions);
|
||||
String dateStr = "";
|
||||
Expression rightExpression = minorThanEquals.getRightExpression();
|
||||
DatePeriodEnum datePeriodEnum = DatePeriodEnum.get(unitValue);
|
||||
if (rightExpression instanceof DoubleValue) {
|
||||
DoubleValue value = (DoubleValue) rightExpression;
|
||||
double doubleValue = value.getValue();
|
||||
if (DatePeriodEnum.YEAR.equals(datePeriodEnum) && doubleValue == HALF_YEAR) {
|
||||
datePeriodEnum = DatePeriodEnum.MONTH;
|
||||
dateStr = DateUtils.getBeforeDate(dateValue, SIX_MONTH, datePeriodEnum);
|
||||
}
|
||||
} else if (rightExpression instanceof LongValue) {
|
||||
LongValue value = (LongValue) rightExpression;
|
||||
long doubleValue = value.getValue();
|
||||
dateStr = DateUtils.getBeforeDate(dateValue, (int) doubleValue, datePeriodEnum);
|
||||
}
|
||||
return dateStr;
|
||||
}
|
||||
|
||||
private String getEndDateOperator(ComparisonOperator comparisonOperator) {
|
||||
String operator = comparisonOperator.getStringExpression();
|
||||
if (EQUAL.equalsIgnoreCase(operator)) {
|
||||
operator = "<=";
|
||||
}
|
||||
return operator;
|
||||
}
|
||||
|
||||
private String getEndDateValue(List<Expression> leftExpressions) {
|
||||
StringValue date = (StringValue) leftExpressions.get(2);
|
||||
return date.getValue();
|
||||
}
|
||||
|
||||
private String getUnit(List<Expression> expressions) {
|
||||
StringValue unit = (StringValue) expressions.get(0);
|
||||
return unit.getValue();
|
||||
}
|
||||
|
||||
|
||||
}
|
||||
@@ -2,13 +2,18 @@ package com.tencent.supersonic.common.util.jsqlparser;
|
||||
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import net.sf.jsqlparser.JSQLParserException;
|
||||
import net.sf.jsqlparser.expression.Expression;
|
||||
import net.sf.jsqlparser.expression.Function;
|
||||
import net.sf.jsqlparser.expression.operators.relational.ExpressionList;
|
||||
import net.sf.jsqlparser.parser.CCJSqlParserUtil;
|
||||
import net.sf.jsqlparser.schema.Column;
|
||||
import net.sf.jsqlparser.statement.select.GroupByElement;
|
||||
import net.sf.jsqlparser.statement.select.GroupByVisitor;
|
||||
import org.apache.commons.lang3.StringUtils;
|
||||
|
||||
@Slf4j
|
||||
public class GroupByReplaceVisitor implements GroupByVisitor {
|
||||
|
||||
ParseVisitorHelper parseVisitorHelper = new ParseVisitorHelper();
|
||||
@@ -29,8 +34,18 @@ public class GroupByReplaceVisitor implements GroupByVisitor {
|
||||
|
||||
String replaceColumn = parseVisitorHelper.getReplaceColumn(expression.toString(), fieldToBizName);
|
||||
if (StringUtils.isNotEmpty(replaceColumn)) {
|
||||
if (expression instanceof Column) {
|
||||
groupByExpressions.set(i, new Column(replaceColumn));
|
||||
}
|
||||
if (expression instanceof Function) {
|
||||
try {
|
||||
Expression element = CCJSqlParserUtil.parseExpression(replaceColumn);
|
||||
((Function) expression).getParameters().getExpressions().set(0, element);
|
||||
} catch (JSQLParserException e) {
|
||||
log.error("e", e);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -1,33 +1,16 @@
|
||||
package com.tencent.supersonic.common.util.jsqlparser;
|
||||
|
||||
import com.tencent.supersonic.common.util.DatePeriodEnum;
|
||||
import com.tencent.supersonic.common.util.DateUtils;
|
||||
import com.tencent.supersonic.common.util.StringUtil;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
import java.util.Map.Entry;
|
||||
import java.util.Optional;
|
||||
import java.util.stream.Collectors;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import net.sf.jsqlparser.JSQLParserException;
|
||||
import net.sf.jsqlparser.expression.DoubleValue;
|
||||
import net.sf.jsqlparser.expression.Expression;
|
||||
import net.sf.jsqlparser.expression.Function;
|
||||
import net.sf.jsqlparser.expression.LongValue;
|
||||
import net.sf.jsqlparser.expression.StringValue;
|
||||
import net.sf.jsqlparser.expression.operators.relational.ComparisonOperator;
|
||||
import net.sf.jsqlparser.parser.CCJSqlParserUtil;
|
||||
import net.sf.jsqlparser.schema.Column;
|
||||
import org.apache.commons.collections.CollectionUtils;
|
||||
import org.apache.commons.lang3.StringUtils;
|
||||
|
||||
@Slf4j
|
||||
public class ParseVisitorHelper {
|
||||
|
||||
public static final double HALF_YEAR = 0.5d;
|
||||
public static final int SIX_MONTH = 6;
|
||||
public static final String DATE_FUNCTION = "datediff";
|
||||
|
||||
public void replaceColumn(Column column, Map<String, String> fieldToBizName) {
|
||||
String columnName = column.getColumnName();
|
||||
column.setColumnName(getReplaceColumn(columnName, fieldToBizName));
|
||||
@@ -53,88 +36,6 @@ public class ParseVisitorHelper {
|
||||
return columnName;
|
||||
}
|
||||
|
||||
public Expression reparseDate(ComparisonOperator comparisonOperator, Map<String, String> fieldToBizName,
|
||||
String startDateOperator) {
|
||||
Expression leftExpression = comparisonOperator.getLeftExpression();
|
||||
if (leftExpression instanceof Column) {
|
||||
Column leftExpressionColumn = (Column) leftExpression;
|
||||
replaceColumn(leftExpressionColumn, fieldToBizName);
|
||||
return null;
|
||||
}
|
||||
|
||||
if (!(leftExpression instanceof Function)) {
|
||||
return null;
|
||||
}
|
||||
Function leftExpressionFunction = (Function) leftExpression;
|
||||
if (!leftExpressionFunction.toString().contains(DATE_FUNCTION)) {
|
||||
return null;
|
||||
}
|
||||
List<Expression> leftExpressions = leftExpressionFunction.getParameters().getExpressions();
|
||||
if (CollectionUtils.isEmpty(leftExpressions) || leftExpressions.size() < 3) {
|
||||
return null;
|
||||
}
|
||||
Column field = (Column) leftExpressions.get(1);
|
||||
String columnName = field.getColumnName();
|
||||
String startDateValue = getStartDateStr(comparisonOperator, leftExpressions);
|
||||
String fieldBizName = fieldToBizName.get(columnName);
|
||||
try {
|
||||
String endDateValue = getEndDateValue(leftExpressions);
|
||||
String stringExpression = comparisonOperator.getStringExpression();
|
||||
|
||||
String condExpr =
|
||||
fieldBizName + StringUtil.getSpaceWrap(stringExpression) + StringUtil.getCommaWrap(endDateValue);
|
||||
ComparisonOperator expression = (ComparisonOperator) CCJSqlParserUtil.parseCondExpression(condExpr);
|
||||
|
||||
comparisonOperator.setLeftExpression(null);
|
||||
comparisonOperator.setRightExpression(null);
|
||||
comparisonOperator.setASTNode(null);
|
||||
|
||||
comparisonOperator.setLeftExpression(expression.getLeftExpression());
|
||||
comparisonOperator.setRightExpression(expression.getRightExpression());
|
||||
comparisonOperator.setASTNode(expression.getASTNode());
|
||||
|
||||
String startDataCondExpr =
|
||||
fieldBizName + StringUtil.getSpaceWrap(startDateOperator) + StringUtil.getCommaWrap(startDateValue);
|
||||
return CCJSqlParserUtil.parseCondExpression(startDataCondExpr);
|
||||
} catch (JSQLParserException e) {
|
||||
log.error("JSQLParserException", e);
|
||||
}
|
||||
|
||||
return null;
|
||||
}
|
||||
|
||||
private String getEndDateValue(List<Expression> leftExpressions) {
|
||||
StringValue date = (StringValue) leftExpressions.get(2);
|
||||
return date.getValue();
|
||||
}
|
||||
|
||||
private String getStartDateStr(ComparisonOperator minorThanEquals, List<Expression> expressions) {
|
||||
String unitValue = getUnit(expressions);
|
||||
String dateValue = getEndDateValue(expressions);
|
||||
String dateStr = "";
|
||||
Expression rightExpression = minorThanEquals.getRightExpression();
|
||||
DatePeriodEnum datePeriodEnum = DatePeriodEnum.get(unitValue);
|
||||
if (rightExpression instanceof DoubleValue) {
|
||||
DoubleValue value = (DoubleValue) rightExpression;
|
||||
double doubleValue = value.getValue();
|
||||
if (DatePeriodEnum.YEAR.equals(datePeriodEnum) && doubleValue == HALF_YEAR) {
|
||||
datePeriodEnum = DatePeriodEnum.MONTH;
|
||||
dateStr = DateUtils.getBeforeDate(dateValue, SIX_MONTH, datePeriodEnum);
|
||||
}
|
||||
} else if (rightExpression instanceof LongValue) {
|
||||
LongValue value = (LongValue) rightExpression;
|
||||
long doubleValue = value.getValue();
|
||||
dateStr = DateUtils.getBeforeDate(dateValue, (int) doubleValue, datePeriodEnum);
|
||||
}
|
||||
return dateStr;
|
||||
}
|
||||
|
||||
private String getUnit(List<Expression> expressions) {
|
||||
StringValue unit = (StringValue) expressions.get(0);
|
||||
return unit.getValue();
|
||||
}
|
||||
|
||||
|
||||
public static int editDistance(String word1, String word2) {
|
||||
final int m = word1.length();
|
||||
final int n = word2.length();
|
||||
|
||||
@@ -14,13 +14,25 @@ class DateUtilsTest {
|
||||
dateStr = DateUtils.getBeforeDate("2023-08-10", 8, DatePeriodEnum.DAY);
|
||||
Assert.assertEquals(dateStr, "2023-08-02");
|
||||
|
||||
dateStr = DateUtils.getBeforeDate("2023-08-10", 0, DatePeriodEnum.DAY);
|
||||
Assert.assertEquals(dateStr, "2023-08-10");
|
||||
|
||||
dateStr = DateUtils.getBeforeDate("2023-08-10", 1, DatePeriodEnum.WEEK);
|
||||
Assert.assertEquals(dateStr, "2023-08-03");
|
||||
|
||||
dateStr = DateUtils.getBeforeDate("2023-08-10", 0, DatePeriodEnum.WEEK);
|
||||
Assert.assertEquals(dateStr, "2023-08-07");
|
||||
|
||||
dateStr = DateUtils.getBeforeDate("2023-08-01", 1, DatePeriodEnum.MONTH);
|
||||
Assert.assertEquals(dateStr, "2023-07-01");
|
||||
|
||||
dateStr = DateUtils.getBeforeDate("2023-08-10", 0, DatePeriodEnum.MONTH);
|
||||
Assert.assertEquals(dateStr, "2023-08-01");
|
||||
|
||||
dateStr = DateUtils.getBeforeDate("2023-08-01", 1, DatePeriodEnum.YEAR);
|
||||
Assert.assertEquals(dateStr, "2022-08-01");
|
||||
|
||||
dateStr = DateUtils.getBeforeDate("2023-08-10", 0, DatePeriodEnum.YEAR);
|
||||
Assert.assertEquals(dateStr, "2023-01-01");
|
||||
}
|
||||
}
|
||||
@@ -22,37 +22,104 @@ class CCJSqlParserUtilsTest {
|
||||
String replaceSql = "select 歌曲名 from 歌曲库 where datediff('day', 发布日期, '2023-08-09') <= 1 and 歌手名 = '邓紫棋' and 数据日期 = '2023-08-09' and 歌曲发布时 = '2023-08-01' order by 播放量 desc limit 11";
|
||||
|
||||
replaceSql = CCJSqlParserUtils.replaceFields(replaceSql, fieldToBizName);
|
||||
replaceSql = CCJSqlParserUtils.replaceFunction(replaceSql);
|
||||
|
||||
Assert.assertEquals(
|
||||
"SELECT song_name FROM 歌曲库 WHERE publish_date <= '2023-08-09' AND singer_name = '邓紫棋' AND sys_imp_date = '2023-08-09' AND song_publis_date = '2023-08-01' AND publish_date >= '2023-08-08' ORDER BY play_count DESC LIMIT 11"
|
||||
, replaceSql);
|
||||
|
||||
replaceSql = "select YEAR(发行日期), count(歌曲名) from 歌曲库 where YEAR(发行日期) in (2022, 2023) and 数据日期 = '2023-08-14' group by YEAR(发行日期)";
|
||||
|
||||
replaceSql = CCJSqlParserUtils.replaceFields(replaceSql, fieldToBizName);
|
||||
replaceSql = CCJSqlParserUtils.replaceFunction(replaceSql);
|
||||
|
||||
Assert.assertEquals(
|
||||
"SELECT YEAR(publish_date), count(song_name) FROM 歌曲库 WHERE YEAR(publish_date) IN (2022, 2023) AND sys_imp_date = '2023-08-14' GROUP BY YEAR(publish_date)",
|
||||
replaceSql);
|
||||
|
||||
replaceSql = "select YEAR(发行日期), count(歌曲名) from 歌曲库 where YEAR(发行日期) in (2022, 2023) and 数据日期 = '2023-08-14' group by 发行日期";
|
||||
|
||||
replaceSql = CCJSqlParserUtils.replaceFields(replaceSql, fieldToBizName);
|
||||
replaceSql = CCJSqlParserUtils.replaceFunction(replaceSql);
|
||||
|
||||
Assert.assertEquals(
|
||||
"SELECT YEAR(publish_date), count(song_name) FROM 歌曲库 WHERE YEAR(publish_date) IN (2022, 2023) AND sys_imp_date = '2023-08-14' GROUP BY publish_date",
|
||||
replaceSql);
|
||||
|
||||
replaceSql = CCJSqlParserUtils.replaceFields(
|
||||
"select 歌曲名 from 歌曲库 where datediff('year', 发布日期, '2023-08-11') <= 1 and 结算播放量 > 1000000 and datediff('day', 数据日期, '2023-08-11') <= 30",
|
||||
fieldToBizName);
|
||||
replaceSql = CCJSqlParserUtils.replaceFunction(replaceSql);
|
||||
|
||||
Assert.assertEquals(
|
||||
"SELECT song_name FROM 歌曲库 WHERE publish_date <= '2023-08-11' AND play_count > 1000000 AND sys_imp_date <= '2023-08-11' AND publish_date >= '2022-08-11' AND sys_imp_date >= '2023-07-12'"
|
||||
, replaceSql);
|
||||
|
||||
replaceSql = CCJSqlParserUtils.replaceFields(
|
||||
"select 歌曲名 from 歌曲库 where datediff('day', 发布日期, '2023-08-09') <= 1 and 歌手名 = '邓紫棋' and 数据日期 = '2023-08-09' order by 播放量 desc limit 11",
|
||||
fieldToBizName);
|
||||
replaceSql = CCJSqlParserUtils.replaceFunction(replaceSql);
|
||||
|
||||
Assert.assertEquals(
|
||||
"SELECT song_name FROM 歌曲库 WHERE publish_date <= '2023-08-09' AND singer_name = '邓紫棋' AND sys_imp_date = '2023-08-09' AND publish_date >= '2023-08-08' ORDER BY play_count DESC LIMIT 11"
|
||||
, replaceSql);
|
||||
|
||||
replaceSql = CCJSqlParserUtils.replaceFields(
|
||||
"select 歌曲名 from 歌曲库 where datediff('year', 发布日期, '2023-08-09') = 0 and 歌手名 = '邓紫棋' and 数据日期 = '2023-08-09' order by 播放量 desc limit 11",
|
||||
fieldToBizName);
|
||||
replaceSql = CCJSqlParserUtils.replaceFunction(replaceSql);
|
||||
|
||||
Assert.assertEquals(
|
||||
"SELECT song_name FROM 歌曲库 WHERE 1 = 1 AND singer_name = '邓紫棋' AND sys_imp_date = '2023-08-09' AND publish_date <= '2023-08-09' AND publish_date >= '2023-01-01' ORDER BY play_count DESC LIMIT 11"
|
||||
, replaceSql);
|
||||
|
||||
replaceSql = CCJSqlParserUtils.replaceFields(
|
||||
"select 歌曲名 from 歌曲库 where datediff('year', 发布日期, '2023-08-09') <= 0.5 and 歌手名 = '邓紫棋' and 数据日期 = '2023-08-09' order by 播放量 desc limit 11",
|
||||
fieldToBizName);
|
||||
replaceSql = CCJSqlParserUtils.replaceFunction(replaceSql);
|
||||
|
||||
Assert.assertEquals(
|
||||
"SELECT song_name FROM 歌曲库 WHERE publish_date <= '2023-08-09' AND singer_name = '邓紫棋' AND sys_imp_date = '2023-08-09' AND publish_date >= '2023-02-09' ORDER BY play_count DESC LIMIT 11"
|
||||
, replaceSql);
|
||||
|
||||
replaceSql = CCJSqlParserUtils.replaceFields(
|
||||
"select 歌曲名 from 歌曲库 where datediff('year', 发布日期, '2023-08-09') >= 0.5 and 歌手名 = '邓紫棋' and 数据日期 = '2023-08-09' order by 播放量 desc limit 11",
|
||||
fieldToBizName);
|
||||
replaceSql = CCJSqlParserUtils.replaceFunction(replaceSql);
|
||||
|
||||
Assert.assertEquals(
|
||||
"SELECT song_name FROM 歌曲库 WHERE publish_date >= '2023-08-09' AND singer_name = '邓紫棋' AND sys_imp_date = '2023-08-09' AND publish_date <= '2023-02-09' ORDER BY play_count DESC LIMIT 11"
|
||||
, replaceSql);
|
||||
|
||||
replaceSql = CCJSqlParserUtils.replaceFields(
|
||||
"select 部门,用户 from 超音数 where 数据日期 = '2023-08-08' and 用户 =alice and 发布日期 ='11' order by 访问次数 desc limit 1",
|
||||
"select 部门,用户 from 超音数 where 数据日期 = '2023-08-08' and 用户 ='alice' and 发布日期 ='11' order by 访问次数 desc limit 1",
|
||||
fieldToBizName);
|
||||
replaceSql = CCJSqlParserUtils.replaceFunction(replaceSql);
|
||||
|
||||
Assert.assertEquals(
|
||||
"SELECT department, user_id FROM 超音数 WHERE sys_imp_date = '2023-08-08' AND user_id = 'alice' AND publish_date = '11' ORDER BY pv DESC LIMIT 1"
|
||||
, replaceSql);
|
||||
|
||||
replaceSql = CCJSqlParserUtils.replaceTable(replaceSql, "s2");
|
||||
|
||||
replaceSql = CCJSqlParserUtils.addFieldsToSelect(replaceSql, Collections.singletonList("field_a"));
|
||||
|
||||
replaceSql = CCJSqlParserUtils.replaceFields(
|
||||
"select 部门,sum (访问次数) from 超音数 where 数据日期 = '2023-08-08' and 用户 =alice and 发布日期 ='11' group by 部门 limit 1",
|
||||
"select 部门,sum (访问次数) from 超音数 where 数据日期 = '2023-08-08' and 用户 ='alice' and 发布日期 ='11' group by 部门 limit 1",
|
||||
fieldToBizName);
|
||||
Assert.assertEquals(replaceSql,
|
||||
"SELECT department, sum(pv) FROM 超音数 WHERE sys_imp_date = '2023-08-08' AND user_id = user_id AND publish_date = '11' GROUP BY department LIMIT 1");
|
||||
replaceSql = CCJSqlParserUtils.replaceFunction(replaceSql);
|
||||
|
||||
Assert.assertEquals(
|
||||
"SELECT department, sum(pv) FROM 超音数 WHERE sys_imp_date = '2023-08-08' AND user_id = 'alice' AND publish_date = '11' GROUP BY department LIMIT 1",
|
||||
replaceSql);
|
||||
|
||||
replaceSql = "select sum(访问次数) from 超音数 where 数据日期 >= '2023-08-06' and 数据日期 <= '2023-08-06' and 部门 = 'hr'";
|
||||
replaceSql = CCJSqlParserUtils.replaceFields(replaceSql, fieldToBizName);
|
||||
replaceSql = CCJSqlParserUtils.replaceFunction(replaceSql);
|
||||
|
||||
System.out.println(replaceSql);
|
||||
Assert.assertEquals(
|
||||
"SELECT sum(pv) FROM 超音数 WHERE sys_imp_date >= '2023-08-06' AND sys_imp_date <= '2023-08-06' AND department = 'hr'",
|
||||
replaceSql);
|
||||
}
|
||||
|
||||
|
||||
|
||||
@@ -6,9 +6,10 @@ com.tencent.supersonic.chat.api.component.SchemaMapper=\
|
||||
com.tencent.supersonic.chat.api.component.SemanticParser=\
|
||||
com.tencent.supersonic.chat.parser.rule.QueryModeParser, \
|
||||
com.tencent.supersonic.chat.parser.rule.ContextInheritParser, \
|
||||
com.tencent.supersonic.chat.parser.rule.AgentCheckParser, \
|
||||
com.tencent.supersonic.chat.parser.rule.TimeRangeParser, \
|
||||
com.tencent.supersonic.chat.parser.rule.AggregateTypeParser, \
|
||||
com.tencent.supersonic.chat.parser.llm.LLMDSLParser, \
|
||||
com.tencent.supersonic.chat.parser.llm.dsl.LLMDSLParser, \
|
||||
com.tencent.supersonic.chat.parser.function.FunctionBasedParser
|
||||
com.tencent.supersonic.chat.api.component.SemanticLayer=\
|
||||
com.tencent.supersonic.knowledge.semantic.RemoteSemanticLayer
|
||||
@@ -20,3 +21,12 @@ com.tencent.supersonic.auth.authentication.interceptor.AuthenticationInterceptor
|
||||
com.tencent.supersonic.auth.authentication.interceptor.DefaultAuthenticationInterceptor
|
||||
com.tencent.supersonic.auth.api.authentication.adaptor.UserAdaptor=\
|
||||
com.tencent.supersonic.auth.authentication.adaptor.DefaultUserAdaptor
|
||||
|
||||
|
||||
com.tencent.supersonic.chat.api.component.DSLOptimizer=\
|
||||
com.tencent.supersonic.chat.query.dsl.optimizer.DateFieldCorrector, \
|
||||
com.tencent.supersonic.chat.query.dsl.optimizer.FieldCorrector, \
|
||||
com.tencent.supersonic.chat.query.dsl.optimizer.FunctionCorrector, \
|
||||
com.tencent.supersonic.chat.query.dsl.optimizer.TableNameCorrector, \
|
||||
com.tencent.supersonic.chat.query.dsl.optimizer.QueryFilterAppend, \
|
||||
com.tencent.supersonic.chat.query.dsl.optimizer.SelectFieldAppendCorrector
|
||||
@@ -1,6 +1,12 @@
|
||||
package com.tencent.supersonic;
|
||||
|
||||
import com.alibaba.fastjson.JSONObject;
|
||||
import com.google.common.collect.Lists;
|
||||
import com.tencent.supersonic.auth.api.authentication.pojo.User;
|
||||
import com.tencent.supersonic.chat.agent.Agent;
|
||||
import com.tencent.supersonic.chat.agent.AgentConfig;
|
||||
import com.tencent.supersonic.chat.agent.tool.AgentToolType;
|
||||
import com.tencent.supersonic.chat.agent.tool.RuleQueryTool;
|
||||
import com.tencent.supersonic.chat.api.pojo.request.ChatAggConfigReq;
|
||||
import com.tencent.supersonic.chat.api.pojo.request.ChatConfigBaseReq;
|
||||
import com.tencent.supersonic.chat.api.pojo.request.ChatDefaultConfigReq;
|
||||
@@ -14,16 +20,14 @@ import com.tencent.supersonic.chat.plugin.Plugin;
|
||||
import com.tencent.supersonic.chat.plugin.PluginParseConfig;
|
||||
import com.tencent.supersonic.chat.query.plugin.ParamOption;
|
||||
import com.tencent.supersonic.chat.query.plugin.WebBase;
|
||||
import com.tencent.supersonic.chat.service.ChatService;
|
||||
import com.tencent.supersonic.chat.service.ConfigService;
|
||||
import com.tencent.supersonic.chat.service.PluginService;
|
||||
import com.tencent.supersonic.chat.service.QueryService;
|
||||
import com.tencent.supersonic.chat.service.*;
|
||||
import com.tencent.supersonic.common.util.JsonUtil;
|
||||
import java.util.ArrayList;
|
||||
import java.util.Arrays;
|
||||
import java.util.List;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.springframework.beans.factory.annotation.Autowired;
|
||||
import org.springframework.beans.factory.annotation.Qualifier;
|
||||
import org.springframework.boot.context.event.ApplicationReadyEvent;
|
||||
import org.springframework.context.ApplicationListener;
|
||||
import org.springframework.stereotype.Component;
|
||||
@@ -32,6 +36,7 @@ import org.springframework.stereotype.Component;
|
||||
@Slf4j
|
||||
public class ConfigureDemo implements ApplicationListener<ApplicationReadyEvent> {
|
||||
|
||||
@Qualifier("chatQueryService")
|
||||
@Autowired
|
||||
private QueryService queryService;
|
||||
@Autowired
|
||||
@@ -40,6 +45,8 @@ public class ConfigureDemo implements ApplicationListener<ApplicationReadyEvent>
|
||||
protected ConfigService configService;
|
||||
@Autowired
|
||||
private PluginService pluginService;
|
||||
@Autowired
|
||||
private AgentService agentService;
|
||||
|
||||
private User user = User.getFakeUser();
|
||||
|
||||
@@ -175,43 +182,25 @@ public class ConfigureDemo implements ApplicationListener<ApplicationReadyEvent>
|
||||
pluginService.createPlugin(plugin_1, user);
|
||||
}
|
||||
|
||||
private void addPlugin_2() {
|
||||
Plugin plugin_2 = new Plugin();
|
||||
plugin_2.setType("DSL");
|
||||
plugin_2.setModelList(Arrays.asList(1L, 2L));
|
||||
plugin_2.setPattern("");
|
||||
plugin_2.setParseModeConfig(null);
|
||||
plugin_2.setName("大模型语义解析");
|
||||
List<String> examples = new ArrayList<>();
|
||||
examples.add("超音数访问次数最高的部门是哪个");
|
||||
examples.add("超音数访问人数最高的部门是哪个");
|
||||
|
||||
PluginParseConfig parseConfig = PluginParseConfig.builder()
|
||||
.name("DSL")
|
||||
.description("这个工具能够将用户的自然语言查询转化为SQL语句,从而从数据库中的查询具体的数据。用于处理数据查询的问题,提供基于事实的数据")
|
||||
.examples(examples)
|
||||
.build();
|
||||
plugin_2.setParseModeConfig(JsonUtil.toString(parseConfig));
|
||||
pluginService.createPlugin(plugin_2, user);
|
||||
}
|
||||
|
||||
private void addPlugin_3() {
|
||||
Plugin plugin_2 = new Plugin();
|
||||
plugin_2.setType("CONTENT_INTERPRET");
|
||||
plugin_2.setModelList(Arrays.asList(1L));
|
||||
plugin_2.setPattern("超音数最近访问情况怎么样");
|
||||
plugin_2.setParseModeConfig(null);
|
||||
plugin_2.setName("内容解读");
|
||||
List<String> examples = new ArrayList<>();
|
||||
examples.add("超音数最近访问情况怎么样");
|
||||
examples.add("超音数最近访问情况如何");
|
||||
PluginParseConfig parseConfig = PluginParseConfig.builder()
|
||||
.name("supersonic_content_interpret")
|
||||
.description("这个工具能够先查询到相关的数据并交给大模型进行解读, 最后返回解读结果")
|
||||
.examples(examples)
|
||||
.build();
|
||||
plugin_2.setParseModeConfig(JsonUtil.toString(parseConfig));
|
||||
pluginService.createPlugin(plugin_2, user);
|
||||
private void addAgent() {
|
||||
Agent agent = new Agent();
|
||||
agent.setId(1);
|
||||
agent.setName("查信息");
|
||||
agent.setDescription("查信息");
|
||||
agent.setStatus(1);
|
||||
agent.setEnableSearch(1);
|
||||
agent.setExamples(Lists.newArrayList("超音数访问次数", "超音数访问人数", "alice 停留时长"));
|
||||
AgentConfig agentConfig = new AgentConfig();
|
||||
RuleQueryTool ruleQueryTool = new RuleQueryTool();
|
||||
ruleQueryTool.setType(AgentToolType.RULE);
|
||||
ruleQueryTool.setQueryModes(Lists.newArrayList(
|
||||
"ENTITY_DETAIL", "ENTITY_LIST_FILTER", "ENTITY_ID", "METRIC_ENTITY",
|
||||
"METRIC_FILTER", "METRIC_GROUPBY", "METRIC_MODEL", "METRIC_ORDERBY"
|
||||
));
|
||||
agentConfig.getTools().add(ruleQueryTool);
|
||||
agent.setAgentConfig(JSONObject.toJSONString(agentConfig));
|
||||
agentService.createAgent(agent, User.getFakeUser());
|
||||
}
|
||||
|
||||
@Override
|
||||
@@ -220,8 +209,7 @@ public class ConfigureDemo implements ApplicationListener<ApplicationReadyEvent>
|
||||
addDemoChatConfig_1();
|
||||
addDemoChatConfig_2();
|
||||
addPlugin_1();
|
||||
addPlugin_2();
|
||||
addPlugin_3();
|
||||
addAgent();
|
||||
addSampleChats();
|
||||
addSampleChats2();
|
||||
} catch (Exception e) {
|
||||
|
||||
@@ -6,9 +6,10 @@ com.tencent.supersonic.chat.api.component.SchemaMapper=\
|
||||
com.tencent.supersonic.chat.api.component.SemanticParser=\
|
||||
com.tencent.supersonic.chat.parser.rule.QueryModeParser, \
|
||||
com.tencent.supersonic.chat.parser.rule.ContextInheritParser, \
|
||||
com.tencent.supersonic.chat.parser.rule.AgentCheckParser, \
|
||||
com.tencent.supersonic.chat.parser.rule.TimeRangeParser, \
|
||||
com.tencent.supersonic.chat.parser.rule.AggregateTypeParser, \
|
||||
com.tencent.supersonic.chat.parser.llm.LLMDSLParser, \
|
||||
com.tencent.supersonic.chat.parser.llm.dsl.LLMDSLParser, \
|
||||
com.tencent.supersonic.chat.parser.embedding.EmbeddingBasedParser, \
|
||||
com.tencent.supersonic.chat.parser.function.FunctionBasedParser
|
||||
com.tencent.supersonic.chat.api.component.SemanticLayer=\
|
||||
@@ -21,3 +22,11 @@ com.tencent.supersonic.auth.authentication.interceptor.AuthenticationInterceptor
|
||||
com.tencent.supersonic.auth.authentication.interceptor.DefaultAuthenticationInterceptor
|
||||
com.tencent.supersonic.auth.api.authentication.adaptor.UserAdaptor=\
|
||||
com.tencent.supersonic.auth.authentication.adaptor.DefaultUserAdaptor
|
||||
|
||||
com.tencent.supersonic.chat.api.component.DSLOptimizer=\
|
||||
com.tencent.supersonic.chat.query.dsl.optimizer.DateFieldCorrector, \
|
||||
com.tencent.supersonic.chat.query.dsl.optimizer.FieldCorrector, \
|
||||
com.tencent.supersonic.chat.query.dsl.optimizer.FunctionCorrector, \
|
||||
com.tencent.supersonic.chat.query.dsl.optimizer.TableNameCorrector, \
|
||||
com.tencent.supersonic.chat.query.dsl.optimizer.QueryFilterAppend, \
|
||||
com.tencent.supersonic.chat.query.dsl.optimizer.SelectFieldAppendCorrector
|
||||
@@ -648,6 +648,22 @@ CREATE TABLE IF NOT EXISTS `s2_plugin`
|
||||
COMMENT
|
||||
ON TABLE s2_plugin IS 'plugin information table';
|
||||
|
||||
CREATE TABLE IF NOT EXISTS s2_agent
|
||||
(
|
||||
id int AUTO_INCREMENT,
|
||||
name varchar(100) null,
|
||||
description varchar(500) null,
|
||||
status int null,
|
||||
examples varchar(500) null,
|
||||
config varchar(2000) null,
|
||||
created_by varchar(100) null,
|
||||
created_at TIMESTAMP null,
|
||||
updated_by varchar(100) null,
|
||||
updated_at TIMESTAMP null,
|
||||
enable_search int null,
|
||||
PRIMARY KEY (`id`)
|
||||
); COMMENT ON TABLE s2_agent IS 'assistant information table';
|
||||
|
||||
|
||||
-------demo for semantic and chat
|
||||
CREATE TABLE IF NOT EXISTS `s2_user_department`
|
||||
|
||||
@@ -11,6 +11,7 @@ import com.tencent.supersonic.chat.api.pojo.request.QueryReq;
|
||||
import com.tencent.supersonic.chat.api.pojo.response.ParseResp;
|
||||
import com.tencent.supersonic.chat.api.pojo.response.QueryResult;
|
||||
import com.tencent.supersonic.chat.api.pojo.response.QueryState;
|
||||
import com.tencent.supersonic.chat.service.AgentService;
|
||||
import com.tencent.supersonic.chat.service.ChatService;
|
||||
import com.tencent.supersonic.chat.service.ConfigService;
|
||||
import com.tencent.supersonic.chat.service.QueryService;
|
||||
@@ -22,6 +23,7 @@ import org.junit.runner.RunWith;
|
||||
import org.springframework.beans.factory.annotation.Autowired;
|
||||
import org.springframework.beans.factory.annotation.Qualifier;
|
||||
import org.springframework.boot.test.context.SpringBootTest;
|
||||
import org.springframework.boot.test.mock.mockito.MockBean;
|
||||
import org.springframework.test.context.ActiveProfiles;
|
||||
import org.springframework.test.context.junit4.SpringRunner;
|
||||
|
||||
@@ -42,6 +44,8 @@ public class BaseQueryTest {
|
||||
protected ChatService chatService;
|
||||
@Autowired
|
||||
protected ConfigService configService;
|
||||
@MockBean
|
||||
protected AgentService agentService;
|
||||
|
||||
protected QueryResult submitMultiTurnChat(String queryText) throws Exception {
|
||||
ParseResp parseResp = submitParse(queryText);
|
||||
@@ -78,6 +82,11 @@ public class BaseQueryTest {
|
||||
return queryService.performParsing(queryContextReq);
|
||||
}
|
||||
|
||||
protected ParseResp submitParseWithAgent(String queryText, Integer agentId) {
|
||||
QueryReq queryContextReq = DataUtils.getQueryReqWithAgent(10, queryText, agentId);
|
||||
return queryService.performParsing(queryContextReq);
|
||||
}
|
||||
|
||||
protected void assertSchemaElements(Set<SchemaElement> expected, Set<SchemaElement> actual) {
|
||||
Set<String> expectedNames = expected.stream().map(s -> s.getName())
|
||||
.filter(s -> s != null).collect(Collectors.toSet());
|
||||
|
||||
@@ -0,0 +1,56 @@
|
||||
package com.tencent.supersonic.integration;
|
||||
|
||||
import com.alibaba.fastjson.JSONObject;
|
||||
import com.tencent.supersonic.StandaloneLauncher;
|
||||
import com.tencent.supersonic.chat.api.pojo.request.QueryReq;
|
||||
import com.tencent.supersonic.chat.api.pojo.response.QueryResult;
|
||||
import com.tencent.supersonic.chat.parser.embedding.EmbeddingConfig;
|
||||
import com.tencent.supersonic.chat.plugin.PluginManager;
|
||||
import com.tencent.supersonic.chat.query.metricInterpret.LLmAnswerResp;
|
||||
import com.tencent.supersonic.chat.service.AgentService;
|
||||
import com.tencent.supersonic.chat.service.QueryService;
|
||||
import com.tencent.supersonic.util.DataUtils;
|
||||
import org.junit.Assert;
|
||||
import org.junit.Test;
|
||||
import org.junit.runner.RunWith;
|
||||
import org.springframework.beans.factory.annotation.Autowired;
|
||||
import org.springframework.beans.factory.annotation.Qualifier;
|
||||
import org.springframework.boot.test.context.SpringBootTest;
|
||||
import org.springframework.boot.test.mock.mockito.MockBean;
|
||||
import org.springframework.http.ResponseEntity;
|
||||
import org.springframework.test.context.ActiveProfiles;
|
||||
import org.springframework.test.context.junit4.SpringRunner;
|
||||
|
||||
@RunWith(SpringRunner.class)
|
||||
@SpringBootTest(classes = StandaloneLauncher.class)
|
||||
@ActiveProfiles("local")
|
||||
public class MetricInterpretTest {
|
||||
|
||||
@MockBean
|
||||
private AgentService agentService;
|
||||
|
||||
@MockBean
|
||||
private PluginManager pluginManager;
|
||||
|
||||
@MockBean
|
||||
private EmbeddingConfig embeddingConfig;
|
||||
|
||||
@Autowired
|
||||
@Qualifier("chatQueryService")
|
||||
private QueryService queryService;
|
||||
|
||||
@Test
|
||||
public void testMetricInterpret() throws Exception {
|
||||
MockConfiguration.mockAgent(agentService);
|
||||
MockConfiguration.mockEmbeddingUrl(embeddingConfig);
|
||||
LLmAnswerResp lLmAnswerResp = new LLmAnswerResp();
|
||||
lLmAnswerResp.setAssistant_message("alice最近在超音数的访问情况有增多");
|
||||
MockConfiguration.mockPluginManagerDoRequest(pluginManager, "answer_with_plugin_call",
|
||||
ResponseEntity.ok(JSONObject.toJSONString(lLmAnswerResp)));
|
||||
QueryReq queryReq = DataUtils.getQueryReqWithAgent(1000, "能不能帮我解读分析下最近alice在超音数的访问情况",
|
||||
DataUtils.getAgent().getId());
|
||||
QueryResult queryResult = queryService.executeQuery(queryReq);
|
||||
Assert.assertEquals(queryResult.getQueryResults().get(0).get("answer"), lLmAnswerResp.getAssistant_message());
|
||||
}
|
||||
|
||||
}
|
||||
@@ -1,30 +1,31 @@
|
||||
package com.tencent.supersonic.integration;
|
||||
|
||||
import static com.tencent.supersonic.common.pojo.enums.AggregateTypeEnum.NONE;
|
||||
import static com.tencent.supersonic.common.pojo.enums.AggregateTypeEnum.SUM;
|
||||
|
||||
import com.tencent.supersonic.auth.api.authentication.pojo.User;
|
||||
import com.tencent.supersonic.chat.api.pojo.SemanticParseInfo;
|
||||
import com.tencent.supersonic.chat.api.pojo.request.ChatConfigEditReqReq;
|
||||
import com.tencent.supersonic.chat.api.pojo.request.ItemVisibility;
|
||||
import com.tencent.supersonic.chat.api.pojo.request.QueryFilter;
|
||||
import com.tencent.supersonic.chat.api.pojo.response.ChatConfigResp;
|
||||
import com.tencent.supersonic.chat.api.pojo.response.ParseResp;
|
||||
import com.tencent.supersonic.chat.api.pojo.response.QueryResult;
|
||||
import com.tencent.supersonic.chat.api.pojo.request.ChatConfigEditReqReq;
|
||||
import com.tencent.supersonic.chat.api.pojo.response.ChatConfigResp;
|
||||
import com.tencent.supersonic.chat.api.pojo.request.ItemVisibility;
|
||||
import com.tencent.supersonic.chat.query.rule.metric.MetricModelQuery;
|
||||
import com.tencent.supersonic.chat.query.rule.metric.MetricFilterQuery;
|
||||
import com.tencent.supersonic.chat.query.rule.metric.MetricGroupByQuery;
|
||||
import com.tencent.supersonic.chat.query.rule.metric.MetricModelQuery;
|
||||
import com.tencent.supersonic.chat.query.rule.metric.MetricTopNQuery;
|
||||
import com.tencent.supersonic.common.pojo.DateConf;
|
||||
import com.tencent.supersonic.semantic.api.query.enums.FilterOperatorEnum;
|
||||
import com.tencent.supersonic.util.DataUtils;
|
||||
import java.text.DateFormat;
|
||||
import java.text.SimpleDateFormat;
|
||||
import java.util.ArrayList;
|
||||
import java.util.Arrays;
|
||||
import java.util.List;
|
||||
import org.junit.Assert;
|
||||
import org.junit.Test;
|
||||
import org.springframework.beans.BeanUtils;
|
||||
|
||||
import java.text.DateFormat;
|
||||
import java.text.SimpleDateFormat;
|
||||
import java.util.*;
|
||||
import java.util.stream.Collectors;
|
||||
|
||||
import static com.tencent.supersonic.common.pojo.enums.AggregateTypeEnum.*;
|
||||
|
||||
|
||||
public class MetricQueryTest extends BaseQueryTest {
|
||||
|
||||
@@ -50,6 +51,17 @@ public class MetricQueryTest extends BaseQueryTest {
|
||||
assertQueryResult(expectedResult, actualResult);
|
||||
}
|
||||
|
||||
@Test
|
||||
public void queryTest_METRIC_FILTER_with_agent() {
|
||||
//agent only support METRIC_ENTITY, METRIC_FILTER
|
||||
MockConfiguration.mockAgent(agentService);
|
||||
ParseResp parseResp = submitParseWithAgent("alice的访问次数", DataUtils.getAgent().getId());
|
||||
Assert.assertNotNull(parseResp.getSelectedParses());
|
||||
List<String> queryModes = parseResp.getSelectedParses().stream()
|
||||
.map(SemanticParseInfo::getQueryMode).collect(Collectors.toList());
|
||||
Assert.assertTrue(queryModes.contains("METRIC_FILTER"));
|
||||
}
|
||||
|
||||
@Test
|
||||
public void queryTest_METRIC_DOMAIN() throws Exception {
|
||||
QueryResult actualResult = submitNewChat("超音数的访问次数");
|
||||
@@ -69,6 +81,16 @@ public class MetricQueryTest extends BaseQueryTest {
|
||||
assertQueryResult(expectedResult, actualResult);
|
||||
}
|
||||
|
||||
@Test
|
||||
public void queryTest_METRIC_MODEL_with_agent() {
|
||||
//agent only support METRIC_ENTITY, METRIC_FILTER
|
||||
MockConfiguration.mockAgent(agentService);
|
||||
ParseResp parseResp = submitParseWithAgent("超音数的访问次数", DataUtils.getAgent().getId());
|
||||
List<String> queryModes = parseResp.getSelectedParses().stream()
|
||||
.map(SemanticParseInfo::getQueryMode).collect(Collectors.toList());
|
||||
Assert.assertTrue(queryModes.contains("METRIC_MODEL"));
|
||||
}
|
||||
|
||||
@Test
|
||||
public void queryTest_METRIC_GROUPBY() throws Exception {
|
||||
QueryResult actualResult = submitNewChat("超音数各部门的访问次数");
|
||||
|
||||
@@ -1,22 +1,23 @@
|
||||
package com.tencent.supersonic.integration.plugin;
|
||||
package com.tencent.supersonic.integration;
|
||||
|
||||
|
||||
import static org.mockito.ArgumentMatchers.eq;
|
||||
import static org.mockito.ArgumentMatchers.notNull;
|
||||
import static org.mockito.Mockito.when;
|
||||
|
||||
import com.google.common.collect.Lists;
|
||||
import com.tencent.supersonic.chat.parser.embedding.EmbeddingConfig;
|
||||
import com.tencent.supersonic.chat.parser.embedding.EmbeddingResp;
|
||||
import com.tencent.supersonic.chat.parser.embedding.RecallRetrieval;
|
||||
import com.tencent.supersonic.chat.plugin.PluginManager;
|
||||
import com.tencent.supersonic.chat.service.AgentService;
|
||||
import com.tencent.supersonic.util.DataUtils;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.springframework.context.annotation.Configuration;
|
||||
import org.springframework.http.ResponseEntity;
|
||||
import static org.mockito.ArgumentMatchers.eq;
|
||||
import static org.mockito.ArgumentMatchers.notNull;
|
||||
import static org.mockito.Mockito.when;
|
||||
|
||||
@Configuration
|
||||
@Slf4j
|
||||
public class PluginMockConfiguration {
|
||||
public class MockConfiguration {
|
||||
|
||||
public static void mockEmbeddingRecognize(PluginManager pluginManager, String text, String id) {
|
||||
EmbeddingResp embeddingResp = new EmbeddingResp();
|
||||
@@ -33,9 +34,12 @@ public class PluginMockConfiguration {
|
||||
when(embeddingConfig.getUrl()).thenReturn("test");
|
||||
}
|
||||
|
||||
public static void mockPluginManagerDoRequest(PluginManager pluginManager, String path,
|
||||
ResponseEntity<String> responseEntity) {
|
||||
public static void mockPluginManagerDoRequest(PluginManager pluginManager, String path, ResponseEntity<String> responseEntity) {
|
||||
when(pluginManager.doRequest(eq(path), notNull(String.class))).thenReturn(responseEntity);
|
||||
}
|
||||
|
||||
public static void mockAgent(AgentService agentService) {
|
||||
when(agentService.getAgent(1)).thenReturn(DataUtils.getAgent());
|
||||
}
|
||||
|
||||
}
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user