[improvement](supersonic) based on version 0.7.2 (#34)

Co-authored-by: zuopengge <hwzuopengge@tencent.com>
This commit is contained in:
mainmain
2023-08-20 17:30:35 +08:00
committed by GitHub
parent c93e60ced7
commit cf1b5336c3
122 changed files with 4045 additions and 1075 deletions

View File

@@ -10,8 +10,6 @@ import lombok.ToString;
@ToString
public class QueryAuthResReq {
private String user;
private List<String> departmentIds = new ArrayList<>();
private List<AuthRes> resources;

View File

@@ -1,5 +1,6 @@
package com.tencent.supersonic.auth.api.authorization.service;
import com.tencent.supersonic.auth.api.authentication.pojo.User;
import com.tencent.supersonic.auth.api.authorization.pojo.AuthGroup;
import com.tencent.supersonic.auth.api.authorization.request.QueryAuthResReq;
import com.tencent.supersonic.auth.api.authorization.response.AuthorizedResourceResp;
@@ -14,5 +15,5 @@ public interface AuthService {
void removeAuthGroup(AuthGroup group);
AuthorizedResourceResp queryAuthorizedResources(QueryAuthResReq req, HttpServletRequest request);
AuthorizedResourceResp queryAuthorizedResources(QueryAuthResReq req, User user);
}

View File

@@ -1,21 +0,0 @@
package com.tencent.supersonic.auth.authentication.config;
import lombok.Data;
import org.springframework.beans.factory.annotation.Value;
import org.springframework.context.annotation.Configuration;
@Data
@Configuration
public class TppConfig {
@Value(value = "${auth.app.secret:}")
private String appSecret;
@Value(value = "${auth.app.key:}")
private String appKey;
@Value(value = "${auth.oa.url:}")
private String tppOaUrl;
}

View File

@@ -2,27 +2,24 @@ package com.tencent.supersonic.auth.authorization.application;
import com.google.common.base.Strings;
import com.google.gson.Gson;
import com.tencent.supersonic.auth.api.authentication.pojo.User;
import com.tencent.supersonic.auth.api.authentication.service.UserService;
import com.tencent.supersonic.auth.api.authorization.pojo.AuthGroup;
import com.tencent.supersonic.auth.api.authorization.pojo.AuthRes;
import com.tencent.supersonic.auth.api.authorization.pojo.AuthResGrp;
import com.tencent.supersonic.auth.api.authorization.pojo.AuthRule;
import com.tencent.supersonic.auth.api.authorization.pojo.DimensionFilter;
import com.tencent.supersonic.auth.api.authorization.request.QueryAuthResReq;
import com.tencent.supersonic.auth.api.authorization.response.AuthorizedResourceResp;
import com.tencent.supersonic.auth.api.authorization.service.AuthService;
import java.util.ArrayList;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.Set;
import java.util.stream.Collectors;
import javax.servlet.http.HttpServletRequest;
import com.tencent.supersonic.auth.api.authorization.pojo.AuthGroup;
import com.tencent.supersonic.auth.api.authorization.pojo.AuthRule;
import com.tencent.supersonic.common.util.S2ThreadContext;
import lombok.extern.slf4j.Slf4j;
import org.apache.commons.lang3.StringUtils;
import org.springframework.jdbc.core.JdbcTemplate;
import org.springframework.stereotype.Service;
import org.springframework.util.CollectionUtils;
import java.util.*;
import java.util.stream.Collectors;
@Service
@Slf4j
@@ -78,12 +75,12 @@ public class AuthServiceImpl implements AuthService {
@Override
public AuthorizedResourceResp queryAuthorizedResources(QueryAuthResReq req, HttpServletRequest request) {
Set<String> userOrgIds = userService.getUserAllOrgId(req.getUser());
public AuthorizedResourceResp queryAuthorizedResources(QueryAuthResReq req, User user) {
Set<String> userOrgIds = userService.getUserAllOrgId(user.getName());
if (!CollectionUtils.isEmpty(userOrgIds)) {
req.setDepartmentIds(new ArrayList<>(userOrgIds));
}
List<AuthGroup> groups = getAuthGroups(req);
List<AuthGroup> groups = getAuthGroups(req, user.getName());
AuthorizedResourceResp resource = new AuthorizedResourceResp();
Map<String, List<AuthGroup>> authGroupsByModelId = groups.stream()
.collect(Collectors.groupingBy(AuthGroup::getModelId));
@@ -130,14 +127,14 @@ public class AuthServiceImpl implements AuthService {
return resource;
}
private List<AuthGroup> getAuthGroups(QueryAuthResReq req) {
private List<AuthGroup> getAuthGroups(QueryAuthResReq req, String userName) {
List<AuthGroup> groups = load().stream()
.filter(group -> {
if (!Objects.equals(group.getModelId(), req.getModelId())) {
return false;
}
if (!CollectionUtils.isEmpty(group.getAuthorizedUsers()) && group.getAuthorizedUsers()
.contains(req.getUser())) {
.contains(userName)) {
return true;
}
for (String departmentId : req.getDepartmentIds()) {
@@ -148,7 +145,7 @@ public class AuthServiceImpl implements AuthService {
}
return false;
}).collect(Collectors.toList());
log.info("user:{} department:{} authGroups:{}", req.getUser(), req.getDepartmentIds(), groups);
log.info("user:{} department:{} authGroups:{}", userName, req.getDepartmentIds(), groups);
return groups;
}

View File

@@ -1,11 +1,14 @@
package com.tencent.supersonic.auth.authorization.rest;
import com.tencent.supersonic.auth.api.authorization.pojo.AuthGroup;
import com.tencent.supersonic.auth.api.authentication.pojo.User;
import com.tencent.supersonic.auth.api.authentication.utils.UserHolder;
import com.tencent.supersonic.auth.api.authorization.request.QueryAuthResReq;
import com.tencent.supersonic.auth.api.authorization.response.AuthorizedResourceResp;
import com.tencent.supersonic.auth.api.authorization.service.AuthService;
import com.tencent.supersonic.auth.api.authorization.pojo.AuthGroup;
import java.util.List;
import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletResponse;
import lombok.extern.slf4j.Slf4j;
import org.springframework.web.bind.annotation.GetMapping;
import org.springframework.web.bind.annotation.PostMapping;
@@ -62,12 +65,13 @@ public class AuthController {
* 查询有权限访问的受限资源id
*
* @param req
* @param request
* @return
*/
@PostMapping("/queryAuthorizedRes")
public AuthorizedResourceResp queryAuthorizedResources(@RequestBody QueryAuthResReq req,
HttpServletRequest request) {
return authService.queryAuthorizedResources(req, request);
HttpServletRequest request,
HttpServletResponse response) {
User user = UserHolder.findUser(request, response);
return authService.queryAuthorizedResources(req, user);
}
}

View File

@@ -0,0 +1,9 @@
package com.tencent.supersonic.chat.api.component;
import com.tencent.supersonic.chat.api.pojo.CorrectionInfo;
import net.sf.jsqlparser.JSQLParserException;
public interface DSLOptimizer {
CorrectionInfo rewriter(CorrectionInfo correctionInfo) throws JSQLParserException;
}

View File

@@ -0,0 +1,21 @@
package com.tencent.supersonic.chat.api.pojo;
import com.tencent.supersonic.chat.api.pojo.request.QueryFilters;
import lombok.AllArgsConstructor;
import lombok.Builder;
import lombok.Data;
import lombok.NoArgsConstructor;
@Data
@AllArgsConstructor
@NoArgsConstructor
@Builder
public class CorrectionInfo {
private QueryFilters queryFilters;
private SemanticParseInfo parseInfo;
private String sql;
}

View File

@@ -12,4 +12,5 @@ public class QueryReq {
private User user;
private QueryFilters queryFilters;
private boolean saveAnswer = true;
private Integer agentId;
}

View File

@@ -40,12 +40,6 @@
<scope>compile</scope>
</dependency>
<dependency>
<groupId>com.github.plexpt</groupId>
<artifactId>chatgpt</artifactId>
<version>4.1.2</version>
</dependency>
<dependency>
<groupId>org.junit.jupiter</groupId>
<artifactId>junit-jupiter</artifactId>

View File

@@ -0,0 +1,43 @@
package com.tencent.supersonic.chat.agent;
import com.alibaba.fastjson.JSONObject;
import com.google.common.collect.Lists;
import com.tencent.supersonic.chat.agent.tool.AgentToolType;
import com.tencent.supersonic.common.pojo.RecordInfo;
import lombok.Data;
import org.springframework.util.CollectionUtils;
import java.util.List;
import java.util.Map;
import java.util.stream.Collectors;
@Data
public class Agent extends RecordInfo {
private Integer id;
private Integer enableSearch;
private String name;
private String description;
//0 offline, 1 online
private Integer status;
private List<String> examples;
private String agentConfig;
public List<String> getTools(AgentToolType type) {
Map map = JSONObject.parseObject(agentConfig, Map.class);
if (CollectionUtils.isEmpty(map) || map.get("tools") == null) {
return Lists.newArrayList();
}
List<Map> toolList = (List) map.get("tools");
return toolList.stream()
.filter(tool -> type.name().equals(tool.get("type")))
.map(JSONObject::toJSONString)
.collect(Collectors.toList());
}
public boolean enableSearch() {
return enableSearch != null && enableSearch == 1;
}
}

View File

@@ -0,0 +1,18 @@
package com.tencent.supersonic.chat.agent;
import com.google.common.collect.Lists;
import com.tencent.supersonic.chat.agent.tool.AgentTool;
import lombok.AllArgsConstructor;
import lombok.Data;
import lombok.NoArgsConstructor;
import java.util.List;
@Data
@NoArgsConstructor
@AllArgsConstructor
public class AgentConfig {
List<AgentTool> tools = Lists.newArrayList();
}

View File

@@ -0,0 +1,19 @@
package com.tencent.supersonic.chat.agent.tool;
import lombok.AllArgsConstructor;
import lombok.Data;
import lombok.NoArgsConstructor;
import java.util.List;
@Data
@NoArgsConstructor
@AllArgsConstructor
public class AgentTool {
private String name;
private AgentToolType type;
}

View File

@@ -0,0 +1,8 @@
package com.tencent.supersonic.chat.agent.tool;
public enum AgentToolType {
RULE,
DSL,
PLUGIN,
INTERPRET
}

View File

@@ -0,0 +1,14 @@
package com.tencent.supersonic.chat.agent.tool;
import lombok.Data;
import java.util.List;
@Data
public class DslTool extends AgentTool {
private List<Long> modelIds;
private List<String> exampleQuestions;
}

View File

@@ -0,0 +1,16 @@
package com.tencent.supersonic.chat.agent.tool;
import com.tencent.supersonic.chat.parser.llm.interpret.MetricOption;
import lombok.Data;
import java.util.List;
@Data
public class MetricInterpretTool extends AgentTool {
private Long modelId;
private List<MetricOption> metricOptions;
}

View File

@@ -0,0 +1,13 @@
package com.tencent.supersonic.chat.agent.tool;
import lombok.Data;
import java.util.List;
@Data
public class PluginTool extends AgentTool {
private List<Long> plugins;
}

View File

@@ -0,0 +1,13 @@
package com.tencent.supersonic.chat.agent.tool;
import lombok.Data;
import java.util.List;
@Data
public class RuleQueryTool extends AgentTool {
private List<String> queryModes;
}

View File

@@ -19,7 +19,7 @@ import org.springframework.stereotype.Service;
@Slf4j
public class MapperHelper {
@Value("${one.detection.size:6}")
@Value("${one.detection.size:8}")
private Integer oneDetectionSize;
@Value("${one.detection.max.size:20}")
private Integer oneDetectionMaxSize;
@@ -64,7 +64,7 @@ public class MapperHelper {
*/
public boolean existDimensionValues(List<String> natures) {
for (String nature : natures) {
if (NatureHelper.isDimensionValueClassId(nature)) {
if (NatureHelper.isDimensionValueModelId(nature)) {
return true;
}
}

View File

@@ -32,7 +32,7 @@ public class QueryMatchStrategy implements MatchStrategy {
private MapperHelper mapperHelper;
@Override
public Map<MatchText, List<MapResult>> match(String text, List<Term> terms, Long detectmodelId) {
public Map<MatchText, List<MapResult>> match(String text, List<Term> terms, Long detectModelId) {
if (Objects.isNull(terms) || StringUtils.isEmpty(text)) {
return null;
}
@@ -43,22 +43,18 @@ public class QueryMatchStrategy implements MatchStrategy {
List<Integer> offsetList = terms.stream().sorted(Comparator.comparing(Term::getOffset))
.map(term -> term.getOffset()).collect(Collectors.toList());
log.debug("retryCount:{},terms:{},regOffsetToLength:{},offsetList:{},detectmodelId:{}", terms,
regOffsetToLength, offsetList, detectmodelId);
log.debug("retryCount:{},terms:{},regOffsetToLength:{},offsetList:{},detectModelId:{}", terms,
regOffsetToLength, offsetList, detectModelId);
List<MapResult> detects = detect(text, regOffsetToLength, offsetList, detectmodelId);
List<MapResult> detects = detect(text, regOffsetToLength, offsetList, detectModelId);
Map<MatchText, List<MapResult>> result = new HashMap<>();
MatchText matchText = MatchText.builder()
.regText(text)
.detectSegment(text)
.build();
result.put(matchText, detects);
result.put(MatchText.builder().regText(text).detectSegment(text).build(), detects);
return result;
}
private List<MapResult> detect(String text, Map<Integer, Integer> regOffsetToLength, List<Integer> offsetList,
Long detectmodelId) {
Long detectModelId) {
List<MapResult> results = Lists.newArrayList();
for (Integer index = 0; index <= text.length() - 1; ) {
@@ -69,7 +65,7 @@ public class QueryMatchStrategy implements MatchStrategy {
int offset = mapperHelper.getStepOffset(offsetList, index);
i = mapperHelper.getStepIndex(regOffsetToLength, i);
if (i <= text.length()) {
List<MapResult> mapResults = detectByStep(text, detectmodelId, index, i, offset);
List<MapResult> mapResults = detectByStep(text, detectModelId, index, i, offset);
selectMapResultInOneRound(mapResultRowSet, mapResults);
}
}
@@ -106,15 +102,15 @@ public class QueryMatchStrategy implements MatchStrategy {
return a.getName() + Constants.UNDERLINE + String.join(Constants.UNDERLINE, a.getNatures());
}
private List<MapResult> detectByStep(String text, Long detectmodelId, Integer index, Integer i, int offset) {
private List<MapResult> detectByStep(String text, Long detectModelId, Integer index, Integer i, int offset) {
String detectSegment = text.substring(index, i);
Integer oneDetectionSize = mapperHelper.getOneDetectionSize();
// step1. pre search
LinkedHashSet<MapResult> mapResults = SearchService.prefixSearch(detectSegment,
mapperHelper.getOneDetectionMaxSize())
Integer oneDetectionMaxSize = mapperHelper.getOneDetectionMaxSize();
LinkedHashSet<MapResult> mapResults = SearchService.prefixSearch(detectSegment, oneDetectionMaxSize)
.stream().collect(Collectors.toCollection(LinkedHashSet::new));
// step2. suffix search
LinkedHashSet<MapResult> suffixMapResults = SearchService.suffixSearch(detectSegment, oneDetectionSize)
LinkedHashSet<MapResult> suffixMapResults = SearchService.suffixSearch(detectSegment, oneDetectionMaxSize)
.stream().collect(Collectors.toCollection(LinkedHashSet::new));
mapResults.addAll(suffixMapResults);
@@ -126,11 +122,11 @@ public class QueryMatchStrategy implements MatchStrategy {
mapResults = mapResults.stream().sorted((a, b) -> -(b.getName().length() - a.getName().length()))
.collect(Collectors.toCollection(LinkedHashSet::new));
// step4. filter by classId
if (Objects.nonNull(detectmodelId) && detectmodelId > 0) {
log.debug("detectmodelId:{}, before parseResults:{}", mapResults);
if (Objects.nonNull(detectModelId) && detectModelId > 0) {
log.debug("detectModelId:{}, before parseResults:{}", mapResults);
mapResults = mapResults.stream().map(entry -> {
List<String> natures = entry.getNatures().stream().filter(
nature -> nature.startsWith(DictWordType.NATURE_SPILT + detectmodelId) || (nature.startsWith(
nature -> nature.startsWith(DictWordType.NATURE_SPILT + detectModelId) || (nature.startsWith(
DictWordType.NATURE_SPILT))
).collect(Collectors.toList());
entry.setNatures(natures);
@@ -145,8 +141,7 @@ public class QueryMatchStrategy implements MatchStrategy {
.filter(term -> CollectionUtils.isNotEmpty(term.getNatures()))
.collect(Collectors.toCollection(LinkedHashSet::new));
log.debug("metricDimensionThreshold:{},dimensionValueThreshold:{},after isSimilarity parseResults:{}",
mapResults);
log.debug("after isSimilarity parseResults:{}", mapResults);
mapResults = mapResults.stream().map(parseResult -> {
parseResult.setOffset(offset);
@@ -165,7 +160,7 @@ public class QueryMatchStrategy implements MatchStrategy {
if (CollectionUtils.isNotEmpty(dimensionMetrics)) {
return dimensionMetrics;
} else {
return mapResults.stream().limit(oneDetectionSize).collect(Collectors.toList());
return mapResults.stream().limit(mapperHelper.getOneDetectionSize()).collect(Collectors.toList());
}
}
}

View File

@@ -2,8 +2,8 @@ package com.tencent.supersonic.chat.parser;
import com.tencent.supersonic.chat.api.component.SemanticQuery;
import com.tencent.supersonic.chat.api.pojo.QueryContext;
import com.tencent.supersonic.chat.api.pojo.SemanticParseInfo;
import com.tencent.supersonic.chat.api.pojo.*;
import com.tencent.supersonic.chat.query.dsl.DSLQuery;
import lombok.extern.slf4j.Slf4j;
/**
@@ -21,6 +21,9 @@ public class SatisfactionChecker {
// check all the parse info in candidate
public static boolean check(QueryContext queryContext) {
for (SemanticQuery query : queryContext.getCandidateQueries()) {
if (query.getQueryMode().equals(DSLQuery.QUERY_MODE)) {
continue;
}
if (checkThreshold(queryContext.getRequest().getQueryText(), query.getParseInfo())) {
return true;
}

View File

@@ -2,14 +2,7 @@ package com.tencent.supersonic.chat.parser.embedding;
import com.google.common.collect.Lists;
import com.tencent.supersonic.chat.api.component.SemanticParser;
import com.tencent.supersonic.chat.api.pojo.ChatContext;
import com.tencent.supersonic.chat.api.pojo.ModelSchema;
import com.tencent.supersonic.chat.api.pojo.QueryContext;
import com.tencent.supersonic.chat.api.pojo.SchemaElement;
import com.tencent.supersonic.chat.api.pojo.SchemaElementMatch;
import com.tencent.supersonic.chat.api.pojo.SchemaElementType;
import com.tencent.supersonic.chat.api.pojo.SchemaMapInfo;
import com.tencent.supersonic.chat.api.pojo.SemanticParseInfo;
import com.tencent.supersonic.chat.api.pojo.*;
import com.tencent.supersonic.chat.api.pojo.request.QueryFilter;
import com.tencent.supersonic.chat.api.pojo.request.QueryReq;
import com.tencent.supersonic.chat.parser.ParseMode;
@@ -17,17 +10,14 @@ import com.tencent.supersonic.chat.plugin.Plugin;
import com.tencent.supersonic.chat.plugin.PluginManager;
import com.tencent.supersonic.chat.plugin.PluginParseResult;
import com.tencent.supersonic.chat.query.QueryManager;
import com.tencent.supersonic.chat.query.dsl.DSLQuery;
import com.tencent.supersonic.chat.query.plugin.PluginSemanticQuery;
import com.tencent.supersonic.chat.service.PluginService;
import com.tencent.supersonic.chat.service.SemanticService;
import com.tencent.supersonic.common.pojo.Constants;
import com.tencent.supersonic.common.util.ContextUtils;
import com.tencent.supersonic.semantic.api.query.enums.FilterOperatorEnum;
import java.util.Comparator;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.*;
import java.util.stream.Collectors;
import com.tencent.supersonic.semantic.api.query.enums.FilterOperatorEnum;
import lombok.extern.slf4j.Slf4j;
import org.apache.commons.lang3.StringUtils;
import org.apache.commons.lang3.tuple.Pair;
@@ -53,48 +43,32 @@ public class EmbeddingBasedParser implements SemanticParser {
if (CollectionUtils.isEmpty(embeddingRetrievals)) {
return;
}
PluginService pluginService = ContextUtils.getBean(PluginService.class);
List<Plugin> plugins = pluginService.getPluginList();
List<Plugin> plugins = getPluginList(queryContext);
Map<Long, Plugin> pluginMap = plugins.stream().collect(Collectors.toMap(Plugin::getId, p -> p));
for (RecallRetrieval embeddingRetrieval : embeddingRetrievals) {
Plugin plugin = pluginMap.get(Long.parseLong(embeddingRetrieval.getId()));
if (plugin == null) {
if (plugin == null || DSLQuery.QUERY_MODE.equalsIgnoreCase(plugin.getType())) {
continue;
}
Pair<Boolean, List<Long>> pair = PluginManager.resolve(plugin, queryContext);
Pair<Boolean, Set<Long>> pair = PluginManager.resolve(plugin, queryContext);
log.info("embedding plugin resolve: {}", pair);
if (pair.getLeft()) {
List<Long> modelList = pair.getRight();
Set<Long> modelList = pair.getRight();
if (CollectionUtils.isEmpty(modelList)) {
return;
}
modelList = distinctModelList(plugin, queryContext.getMapInfo(), modelList);
for (Long modelId : modelList) {
buildQuery(plugin, Double.parseDouble(embeddingRetrieval.getDistance()), modelId, queryContext,
queryContext.getMapInfo().getMatchedElements(modelId));
if (plugin.isContainsAllModel()) {
break;
}
}
return;
}
}
}
public List<Long> distinctModelList(Plugin plugin, SchemaMapInfo schemaMapInfo, List<Long> modelList) {
if (!plugin.isContainsAllModel()) {
return modelList;
}
boolean noElementMatch = true;
for (Long model : modelList) {
List<SchemaElementMatch> schemaElementMatches = schemaMapInfo.getMatchedElements(model);
if (!CollectionUtils.isEmpty(schemaElementMatches)) {
noElementMatch = false;
}
}
if (noElementMatch) {
return modelList.subList(0, 1);
}
return modelList;
}
private void buildQuery(Plugin plugin, double distance, Long modelId,
QueryContext queryContext, List<SchemaElementMatch> schemaElementMatches) {
log.info("EmbeddingBasedParser Model: {} choose plugin: [{} {}]", modelId, plugin.getId(), plugin.getName());
@@ -126,6 +100,8 @@ public class EmbeddingBasedParser implements SemanticParser {
pluginParseResult.setRequest(queryReq);
pluginParseResult.setDistance(distance);
properties.put(Constants.CONTEXT, pluginParseResult);
properties.put("type", "plugin");
properties.put("name", plugin.getName());
semanticParseInfo.setProperties(properties);
semanticParseInfo.setScore(distance);
fillSemanticParseInfo(semanticParseInfo);
@@ -176,4 +152,8 @@ public class EmbeddingBasedParser implements SemanticParser {
}
}
protected List<Plugin> getPluginList(QueryContext queryContext) {
return PluginManager.getPluginAgentCanSupport(queryContext.getRequest().getAgentId());
}
}

View File

@@ -14,23 +14,16 @@ import com.tencent.supersonic.chat.plugin.PluginManager;
import com.tencent.supersonic.chat.plugin.PluginParseConfig;
import com.tencent.supersonic.chat.plugin.PluginParseResult;
import com.tencent.supersonic.chat.query.QueryManager;
import com.tencent.supersonic.chat.query.dsl.DSLQuery;
import com.tencent.supersonic.chat.query.plugin.PluginSemanticQuery;
import com.tencent.supersonic.chat.query.dsl.DSLQuery;
import com.tencent.supersonic.chat.service.PluginService;
import com.tencent.supersonic.chat.utils.ComponentFactory;
import com.tencent.supersonic.common.pojo.Constants;
import com.tencent.supersonic.common.util.ContextUtils;
import com.tencent.supersonic.common.util.JsonUtil;
import java.net.URI;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.Optional;
import java.util.Set;
import java.util.*;
import java.util.stream.Collectors;
import com.tencent.supersonic.common.util.JsonUtil;
import lombok.extern.slf4j.Slf4j;
import org.apache.commons.lang3.StringUtils;
import org.apache.commons.lang3.tuple.Pair;
@@ -62,6 +55,10 @@ public class FunctionBasedParser implements SemanticParser {
return;
}
List<PluginParseConfig> functionDOList = getFunctionDO(queryCtx.getRequest().getModelId(), queryCtx);
if (CollectionUtils.isEmpty(functionDOList)) {
log.info("function call parser, plugin is empty, skip");
return;
}
FunctionReq functionReq = FunctionReq.builder()
.queryText(queryCtx.getRequest().getQueryText())
.pluginConfigs(functionDOList).build();
@@ -85,7 +82,7 @@ public class FunctionBasedParser implements SemanticParser {
PluginSemanticQuery semanticQuery = QueryManager.createPluginQuery(toolSelection);
ModelResolver ModelResolver = ComponentFactory.getModelResolver();
log.info("plugin ModelList:{}", plugin.getModelList());
Pair<Boolean, List<Long>> pluginResolveResult = PluginManager.resolve(plugin, queryCtx);
Pair<Boolean, Set<Long>> pluginResolveResult = PluginManager.resolve(plugin, queryCtx);
Long modelId = ModelResolver.resolve(queryCtx, chatCtx, pluginResolveResult.getRight());
log.info("FunctionBasedParser modelId:{}", modelId);
if ((Objects.isNull(modelId) || modelId <= 0) && !plugin.isContainsAllModel()) {
@@ -102,6 +99,8 @@ public class FunctionBasedParser implements SemanticParser {
functionCallParseResult.setRequest(queryCtx.getRequest());
Map<String, Object> properties = new HashMap<>();
properties.put(Constants.CONTEXT, functionCallParseResult);
properties.put("type", "plugin");
properties.put("name", plugin.getName());
parseInfo.setProperties(properties);
parseInfo.setScore(FUNCTION_BONUS_THRESHOLD);
parseInfo.setQueryMode(semanticQuery.getQueryMode());
@@ -112,17 +111,6 @@ public class FunctionBasedParser implements SemanticParser {
queryCtx.getCandidateQueries().add(semanticQuery);
}
private Set<Long> getMatchModels(QueryContext queryCtx) {
Set<Long> result = new HashSet<>();
Long modelId = queryCtx.getRequest().getModelId();
if (Objects.nonNull(modelId) && modelId > 0) {
result.add(modelId);
return result;
}
return queryCtx.getMapInfo().getMatchedModels();
}
private boolean skipFunction(QueryContext queryCtx, FunctionResp functionResp) {
if (Objects.isNull(functionResp) || StringUtils.isBlank(functionResp.getToolSelection())) {
return true;
@@ -140,7 +128,7 @@ public class FunctionBasedParser implements SemanticParser {
private List<PluginParseConfig> getFunctionDO(Long modelId, QueryContext queryContext) {
log.info("user decide Model:{}", modelId);
List<Plugin> plugins = PluginManager.getPlugins();
List<Plugin> plugins = getPluginList(queryContext);
List<PluginParseConfig> functionDOList = plugins.stream().filter(plugin -> {
if (DSLQuery.QUERY_MODE.equalsIgnoreCase(plugin.getType())) {
return false;
@@ -153,12 +141,12 @@ public class FunctionBasedParser implements SemanticParser {
if (StringUtils.isBlank(pluginParseConfig.getName())) {
return false;
}
Pair<Boolean, List<Long>> pluginResolverResult = PluginManager.resolve(plugin, queryContext);
log.info("embedding plugin [{}-{}] resolve: {}", plugin.getId(), plugin.getName(), pluginResolverResult);
Pair<Boolean, Set<Long>> pluginResolverResult = PluginManager.resolve(plugin, queryContext);
log.info("plugin [{}-{}] resolve: {}", plugin.getId(), plugin.getName(), pluginResolverResult);
if (!pluginResolverResult.getLeft()) {
return false;
} else {
List<Long> resolveModel = pluginResolverResult.getRight();
Set<Long> resolveModel = pluginResolverResult.getRight();
if (modelId != null && modelId > 0) {
if (plugin.isContainsAllModel()) {
return true;
@@ -172,20 +160,6 @@ public class FunctionBasedParser implements SemanticParser {
return functionDOList;
}
private List<String> getFunctionNames(Set<Long> matchedModels) {
List<Plugin> plugins = PluginManager.getPlugins();
Set<String> functionNames = plugins.stream()
.filter(entry -> {
if (!CollectionUtils.isEmpty(entry.getModelList()) && !CollectionUtils.isEmpty(matchedModels)) {
return entry.getModelList().stream().anyMatch(matchedModels::contains);
}
return true;
}
).map(Plugin::getName).collect(Collectors.toSet());
functionNames.add(DSLQuery.QUERY_MODE);
return new ArrayList<>(functionNames);
}
public FunctionResp requestFunction(String url, FunctionReq functionReq) {
HttpHeaders headers = new HttpHeaders();
long startTime = System.currentTimeMillis();
@@ -205,4 +179,8 @@ public class FunctionBasedParser implements SemanticParser {
}
return null;
}
protected List<Plugin> getPluginList(QueryContext queryContext) {
return PluginManager.getPluginAgentCanSupport(queryContext.getRequest().getAgentId());
}
}

View File

@@ -1,21 +1,12 @@
package com.tencent.supersonic.chat.parser.function;
import com.tencent.supersonic.chat.api.component.SemanticQuery;
import com.tencent.supersonic.chat.api.pojo.ChatContext;
import com.tencent.supersonic.chat.api.pojo.QueryContext;
import com.tencent.supersonic.chat.api.pojo.SchemaElementMatch;
import com.tencent.supersonic.chat.api.pojo.SchemaElementType;
import com.tencent.supersonic.chat.api.pojo.SchemaMapInfo;
import com.tencent.supersonic.chat.api.pojo.SemanticParseInfo;
import com.tencent.supersonic.chat.api.pojo.*;
import com.tencent.supersonic.chat.api.pojo.request.QueryReq;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.Set;
import java.util.stream.Collectors;
import com.tencent.supersonic.chat.api.component.SemanticQuery;
import lombok.extern.slf4j.Slf4j;
import java.util.*;
import java.util.stream.Collectors;
import org.apache.commons.collections.CollectionUtils;
@Slf4j
@@ -55,7 +46,7 @@ public class HeuristicModelResolver implements ModelResolver {
* @return false will use context Model, true will use other Model , maybe include context Model
*/
protected static boolean isAllowSwitch(Map<Long, SemanticQuery> ModelQueryModes, SchemaMapInfo schemaMap,
ChatContext chatCtx, QueryReq searchCtx, Long modelId, List<Long> restrictiveModels) {
ChatContext chatCtx, QueryReq searchCtx, Long modelId, Set<Long> restrictiveModels) {
if (!Objects.nonNull(modelId) || modelId <= 0) {
return true;
}
@@ -80,8 +71,7 @@ public class HeuristicModelResolver implements ModelResolver {
}
if (searchCtx.getQueryText() != null && semanticParseInfo.getDateInfo() != null) {
if (semanticParseInfo.getDateInfo().getDetectWord() != null) {
if (semanticParseInfo.getDateInfo().getDetectWord()
.equalsIgnoreCase(searchCtx.getQueryText())) {
if (semanticParseInfo.getDateInfo().getDetectWord().equalsIgnoreCase(searchCtx.getQueryText())) {
log.info("timeParseResults is not null , can not switch context , timeParseResults:{},",
semanticParseInfo.getDateInfo());
return false;
@@ -131,7 +121,7 @@ public class HeuristicModelResolver implements ModelResolver {
}
public Long resolve(QueryContext queryContext, ChatContext chatCtx, List<Long> restrictiveModels) {
public Long resolve(QueryContext queryContext, ChatContext chatCtx, Set<Long> restrictiveModels) {
Long modelId = queryContext.getRequest().getModelId();
if (Objects.nonNull(modelId) && modelId > 0) {
if (CollectionUtils.isNotEmpty(restrictiveModels) && restrictiveModels.contains(modelId)) {
@@ -151,17 +141,16 @@ public class HeuristicModelResolver implements ModelResolver {
for (Long matchedModel : matchedModels) {
ModelQueryModes.put(matchedModel, null);
}
if (ModelQueryModes.size() == 1) {
if(ModelQueryModes.size()==1){
return ModelQueryModes.keySet().stream().findFirst().get();
}
return resolve(ModelQueryModes, queryContext, chatCtx,
queryContext.getMapInfo(), restrictiveModels);
queryContext.getMapInfo(),restrictiveModels);
}
public Long resolve(Map<Long, SemanticQuery> ModelQueryModes, QueryContext queryContext,
ChatContext chatCtx, SchemaMapInfo schemaMap, List<Long> restrictiveModels) {
Long selectModel = selectModel(ModelQueryModes, queryContext.getRequest(), chatCtx, schemaMap,
restrictiveModels);
ChatContext chatCtx, SchemaMapInfo schemaMap, Set<Long> restrictiveModels) {
Long selectModel = selectModel(ModelQueryModes, queryContext.getRequest(), chatCtx, schemaMap,restrictiveModels);
if (selectModel > 0) {
log.info("selectModel {} ", selectModel);
return selectModel;
@@ -172,7 +161,7 @@ public class HeuristicModelResolver implements ModelResolver {
public Long selectModel(Map<Long, SemanticQuery> ModelQueryModes, QueryReq queryContext,
ChatContext chatCtx,
SchemaMapInfo schemaMap, List<Long> restrictiveModels) {
SchemaMapInfo schemaMap, Set<Long> restrictiveModels) {
// if QueryContext has modelId and in ModelQueryModes
if (ModelQueryModes.containsKey(queryContext.getModelId())) {
log.info("selectModel from QueryContext [{}]", queryContext.getModelId());
@@ -181,7 +170,7 @@ public class HeuristicModelResolver implements ModelResolver {
// if ChatContext has modelId and in ModelQueryModes
if (chatCtx.getParseInfo().getModelId() > 0) {
Long modelId = chatCtx.getParseInfo().getModelId();
if (!isAllowSwitch(ModelQueryModes, schemaMap, chatCtx, queryContext, modelId, restrictiveModels)) {
if (!isAllowSwitch(ModelQueryModes, schemaMap, chatCtx, queryContext, modelId,restrictiveModels)) {
log.info("selectModel from ChatContext [{}]", modelId);
return modelId;
}

View File

@@ -4,9 +4,10 @@ package com.tencent.supersonic.chat.parser.function;
import com.tencent.supersonic.chat.api.pojo.ChatContext;
import com.tencent.supersonic.chat.api.pojo.QueryContext;
import java.util.List;
import java.util.Set;
public interface ModelResolver {
Long resolve(QueryContext queryContext, ChatContext chatCtx, List<Long> restrictiveModels);
Long resolve(QueryContext queryContext, ChatContext chatCtx, Set<Long> restrictiveModels);
}

View File

@@ -1,11 +0,0 @@
package com.tencent.supersonic.chat.parser.llm;
import com.tencent.supersonic.chat.plugin.PluginParseResult;
import com.tencent.supersonic.chat.query.dsl.LLMResp;
import lombok.Data;
@Data
public class DSLParseResult extends PluginParseResult {
private LLMResp llmResp;
}

View File

@@ -0,0 +1,20 @@
package com.tencent.supersonic.chat.parser.llm.dsl;
import com.tencent.supersonic.common.util.DateUtils;
public class DSLDateHelper {
public static String getCurrentDate(Long modelId) {
return DateUtils.getBeforeDate(4);
// ChatConfigFilter filter = new ChatConfigFilter();
// filter.setModelId(modelId);
//
// List<ChatConfigResp> configResps = ContextUtils.getBean(ConfigService.class).search(filter, null);
// if (CollectionUtils.isEmpty(configResps)) {
// return
// }
// ChatConfigResp chatConfigResp = configResps.get(0);
// chatConfigResp.getChatDetailConfig().getChatDefaultConfig().get
}
}

View File

@@ -0,0 +1,17 @@
package com.tencent.supersonic.chat.parser.llm.dsl;
import com.tencent.supersonic.chat.agent.tool.DslTool;
import com.tencent.supersonic.chat.api.pojo.request.QueryReq;
import com.tencent.supersonic.chat.plugin.PluginParseResult;
import com.tencent.supersonic.chat.query.dsl.LLMResp;
import lombok.Data;
@Data
public class DSLParseResult {
private LLMResp llmResp;
private QueryReq request;
private DslTool dslTool;
}

View File

@@ -1,5 +1,10 @@
package com.tencent.supersonic.chat.parser.llm;
package com.tencent.supersonic.chat.parser.llm.dsl;
import com.alibaba.fastjson.JSONObject;
import com.google.common.collect.Lists;
import com.tencent.supersonic.chat.agent.Agent;
import com.tencent.supersonic.chat.agent.tool.AgentToolType;
import com.tencent.supersonic.chat.agent.tool.DslTool;
import com.tencent.supersonic.chat.api.component.SemanticParser;
import com.tencent.supersonic.chat.api.pojo.ChatContext;
import com.tencent.supersonic.chat.api.pojo.QueryContext;
@@ -11,26 +16,26 @@ import com.tencent.supersonic.chat.api.pojo.SemanticSchema;
import com.tencent.supersonic.chat.config.LLMConfig;
import com.tencent.supersonic.chat.parser.SatisfactionChecker;
import com.tencent.supersonic.chat.parser.function.ModelResolver;
import com.tencent.supersonic.chat.plugin.Plugin;
import com.tencent.supersonic.chat.plugin.PluginManager;
import com.tencent.supersonic.chat.query.QueryManager;
import com.tencent.supersonic.chat.query.dsl.DSLBuilder;
import com.tencent.supersonic.chat.query.dsl.DSLQuery;
import com.tencent.supersonic.chat.query.dsl.LLMReq;
import com.tencent.supersonic.chat.query.dsl.LLMReq.ElementValue;
import com.tencent.supersonic.chat.query.dsl.LLMResp;
import com.tencent.supersonic.chat.query.dsl.optimizer.BaseDSLOptimizer;
import com.tencent.supersonic.chat.query.plugin.PluginSemanticQuery;
import com.tencent.supersonic.chat.service.AgentService;
import com.tencent.supersonic.chat.utils.ComponentFactory;
import com.tencent.supersonic.common.pojo.Constants;
import com.tencent.supersonic.common.util.ContextUtils;
import com.tencent.supersonic.common.util.DateUtils;
import com.tencent.supersonic.common.util.JsonUtil;
import com.tencent.supersonic.knowledge.service.SchemaService;
import java.util.ArrayList;
import java.util.Collection;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.Optional;
import java.util.Set;
import java.util.stream.Collectors;
import lombok.extern.slf4j.Slf4j;
@@ -56,30 +61,30 @@ public class LLMDSLParser implements SemanticParser {
queryCtx.getRequest().getQueryText());
return;
}
List<Plugin> dslPlugins = PluginManager.getPlugins().stream()
.filter(plugin -> DSLQuery.QUERY_MODE.equalsIgnoreCase(plugin.getType()))
.collect(Collectors.toList());
if (CollectionUtils.isEmpty(dslPlugins)) {
return;
}
Plugin plugin = dslPlugins.get(0);
List<Long> dslModels = plugin.getModelList();
List<DslTool> dslTools = getDslTools(queryCtx.getRequest().getAgentId());
Set<Long> distinctModelIds = dslTools.stream().map(DslTool::getModelIds)
.flatMap(Collection::stream)
.collect(Collectors.toSet());
try {
ModelResolver modelResolver = ComponentFactory.getModelResolver();
Long modelId = modelResolver.resolve(queryCtx, chatCtx, dslModels);
log.info("resolve modelId:{},dslModels:{}", modelId, dslModels);
Long modelId = modelResolver.resolve(queryCtx, chatCtx, distinctModelIds);
log.info("resolve modelId:{},dslModels:{}", modelId, distinctModelIds);
if (Objects.isNull(modelId)) {
if (Objects.isNull(modelId) || modelId <= 0) {
return;
}
Optional<DslTool> dslToolOptional = dslTools.stream().filter(tool ->
tool.getModelIds().contains(modelId)).findFirst();
if (!dslToolOptional.isPresent()) {
log.info("no dsl tool in this agent, skip dsl parser");
return;
}
DslTool dslTool = dslToolOptional.get();
LLMResp llmResp = requestLLM(queryCtx, modelId);
if (Objects.isNull(llmResp)) {
return;
}
PluginSemanticQuery semanticQuery = QueryManager.createPluginQuery(DSLQuery.QUERY_MODE);
SemanticParseInfo parseInfo = semanticQuery.getParseInfo();
@@ -89,10 +94,12 @@ public class LLMDSLParser implements SemanticParser {
DSLParseResult dslParseResult = new DSLParseResult();
dslParseResult.setRequest(queryCtx.getRequest());
dslParseResult.setLlmResp(llmResp);
dslParseResult.setPlugin(plugin);
dslParseResult.setDslTool(dslToolOptional.get());
Map<String, Object> properties = new HashMap<>();
properties.put(Constants.CONTEXT, dslParseResult);
properties.put("type", "internal");
properties.put("name", dslTool.getName());
parseInfo.setProperties(properties);
parseInfo.setScore(FUNCTION_BONUS_THRESHOLD);
parseInfo.setQueryMode(semanticQuery.getQueryMode());
@@ -126,13 +133,13 @@ public class LLMDSLParser implements SemanticParser {
llmSchema.setModelName(modelIdToName.get(modelId));
llmSchema.setDomainName(modelIdToName.get(modelId));
List<String> fieldNameList = getFieldNameList(queryCtx, modelId, semanticSchema);
fieldNameList.add(DSLBuilder.DATA_Field);
fieldNameList.add(BaseDSLOptimizer.DATE_FIELD);
llmSchema.setFieldNameList(fieldNameList);
llmReq.setSchema(llmSchema);
List<ElementValue> linking = new ArrayList<>();
linking.addAll(getValueList(queryCtx, modelId, semanticSchema));
llmReq.setLinking(linking);
String currentDate = getCurrentDate(modelId);
String currentDate = DSLDateHelper.getCurrentDate(modelId);
llmReq.setCurrentDate(currentDate);
log.info("requestLLM request, modelId:{},llmReq:{}", modelId, llmReq);
@@ -156,21 +163,6 @@ public class LLMDSLParser implements SemanticParser {
return null;
}
private String getCurrentDate(Long modelId) {
return DateUtils.getBeforeDate(4);
// ChatConfigFilter filter = new ChatConfigFilter();
// filter.setModelId(modelId);
//
// List<ChatConfigResp> configResps = ContextUtils.getBean(ConfigService.class).search(filter, null);
// if (CollectionUtils.isEmpty(configResps)) {
// return
// }
// ChatConfigResp chatConfigResp = configResps.get(0);
// chatConfigResp.getChatDetailConfig().getChatDefaultConfig().get
}
private List<ElementValue> getValueList(QueryContext queryCtx, Long modelId, SemanticSchema semanticSchema) {
Map<Long, String> itemIdToName = getItemIdToName(modelId, semanticSchema);
@@ -228,4 +220,18 @@ public class LLMDSLParser implements SemanticParser {
.collect(Collectors.toMap(SchemaElement::getId, SchemaElement::getName, (value1, value2) -> value2));
}
private List<DslTool> getDslTools(Integer agentId) {
AgentService agentService = ContextUtils.getBean(AgentService.class);
Agent agent = agentService.getAgent(agentId);
if (agent == null) {
return Lists.newArrayList();
}
List<String> tools = agent.getTools(AgentToolType.DSL);
if (CollectionUtils.isEmpty(tools)) {
return Lists.newArrayList();
}
return tools.stream().map(tool -> JSONObject.parseObject(tool, DslTool.class))
.collect(Collectors.toList());
}
}

View File

@@ -0,0 +1,144 @@
package com.tencent.supersonic.chat.parser.llm.interpret;
import com.alibaba.fastjson.JSONObject;
import com.google.common.collect.Sets;
import com.tencent.supersonic.chat.agent.Agent;
import com.tencent.supersonic.chat.agent.tool.AgentToolType;
import com.tencent.supersonic.chat.agent.tool.MetricInterpretTool;
import com.tencent.supersonic.chat.api.component.SemanticLayer;
import com.tencent.supersonic.chat.api.component.SemanticParser;
import com.tencent.supersonic.chat.api.pojo.*;
import com.tencent.supersonic.chat.api.pojo.request.QueryFilter;
import com.tencent.supersonic.chat.api.pojo.request.QueryReq;
import com.tencent.supersonic.chat.parser.SatisfactionChecker;
import com.tencent.supersonic.chat.query.metricInterpret.MetricInterpretQuery;
import com.tencent.supersonic.chat.query.QueryManager;
import com.tencent.supersonic.chat.query.plugin.PluginSemanticQuery;
import com.tencent.supersonic.chat.service.AgentService;
import com.tencent.supersonic.chat.utils.ComponentFactory;
import com.tencent.supersonic.common.pojo.DateConf;
import com.tencent.supersonic.common.util.ContextUtils;
import com.tencent.supersonic.semantic.api.model.enums.TimeDimensionEnum;
import com.tencent.supersonic.semantic.api.query.enums.FilterOperatorEnum;
import lombok.extern.slf4j.Slf4j;
import org.springframework.util.CollectionUtils;
import java.util.*;
import java.util.stream.Collectors;
@Slf4j
public class MetricInterpretParser implements SemanticParser {
@Override
public void parse(QueryContext queryContext, ChatContext chatContext) {
if (SatisfactionChecker.check(queryContext)) {
log.info("skip MetricInterpretParser");
return;
}
Map<Long, MetricInterpretTool> metricInterpretToolMap = getMetricInterpretTools(queryContext.getRequest().getAgentId());
log.info("metric interpret tool : {}", metricInterpretToolMap);
if (CollectionUtils.isEmpty(metricInterpretToolMap)) {
return;
}
Map<Long, List<SchemaElementMatch>> elementMatches = queryContext.getMapInfo().getModelElementMatches();
for (Long modelId : elementMatches.keySet()) {
MetricInterpretTool metricInterpretTool = metricInterpretToolMap.get(modelId);
if (metricInterpretTool == null) {
continue;
}
if (CollectionUtils.isEmpty(elementMatches.get(modelId))) {
continue;
}
List<MetricOption> metricOptions = metricInterpretTool.getMetricOptions();
if (!CollectionUtils.isEmpty(metricOptions)) {
List<Long> metricIds = metricOptions.stream().map(MetricOption::getMetricId).collect(Collectors.toList());
buildQuery(modelId, queryContext, metricIds, elementMatches.get(modelId), metricInterpretTool.getName());
}
}
}
private void buildQuery(Long modelId, QueryContext queryContext,
List<Long> metricIds, List<SchemaElementMatch> schemaElementMatches, String toolName) {
PluginSemanticQuery metricInterpretQuery = QueryManager.createPluginQuery(MetricInterpretQuery.QUERY_MODE);
Set<SchemaElement> metrics = getMetrics(metricIds, modelId);
SemanticParseInfo semanticParseInfo = buildSemanticParseInfo(modelId, queryContext.getRequest(),
metrics, schemaElementMatches, toolName);
semanticParseInfo.setQueryMode(metricInterpretQuery.getQueryMode());
semanticParseInfo.getProperties().put("queryText", queryContext.getRequest().getQueryText());
metricInterpretQuery.setParseInfo(semanticParseInfo);
queryContext.getCandidateQueries().add(metricInterpretQuery);
}
public Set<SchemaElement> getMetrics(List<Long> metricIds, Long modelId) {
SemanticLayer semanticLayer = ComponentFactory.getSemanticLayer();
ModelSchema modelSchema = semanticLayer.getModelSchema(modelId, true);
Set<SchemaElement> metrics = modelSchema.getMetrics();
return metrics.stream().filter(schemaElement -> metricIds.contains(schemaElement.getId()))
.collect(Collectors.toSet());
}
private Map<Long, MetricInterpretTool> getMetricInterpretTools(Integer agentId) {
AgentService agentService = ContextUtils.getBean(AgentService.class);
Agent agent = agentService.getAgent(agentId);
if (agent == null) {
return new HashMap<>();
}
List<String> tools= agent.getTools(AgentToolType.INTERPRET);
if (CollectionUtils.isEmpty(tools)) {
return new HashMap<>();
}
List<MetricInterpretTool> metricInterpretTools = tools.stream().map(tool ->
JSONObject.parseObject(tool, MetricInterpretTool.class))
.filter(tool -> !CollectionUtils.isEmpty(tool.getMetricOptions()))
.collect(Collectors.toList());
Map<Long, MetricInterpretTool> metricInterpretToolMap = new HashMap<>();
for (MetricInterpretTool metricInterpretTool : metricInterpretTools) {
metricInterpretToolMap.putIfAbsent(metricInterpretTool.getModelId(),
metricInterpretTool);
}
return metricInterpretToolMap;
}
private SemanticParseInfo buildSemanticParseInfo(Long modelId, QueryReq queryReq, Set<SchemaElement> metrics,
List<SchemaElementMatch> schemaElementMatches, String toolName) {
SchemaElement Model = new SchemaElement();
Model.setModel(modelId);
Model.setId(modelId);
SemanticParseInfo semanticParseInfo = new SemanticParseInfo();
semanticParseInfo.setMetrics(metrics);
SchemaElement dimension = new SchemaElement();
dimension.setBizName(TimeDimensionEnum.DAY.getName());
semanticParseInfo.setDimensions(Sets.newHashSet(dimension));
semanticParseInfo.setElementMatches(schemaElementMatches);
semanticParseInfo.setModel(Model);
semanticParseInfo.setScore(queryReq.getQueryText().length());
DateConf dateConf = new DateConf();
dateConf.setDateMode(DateConf.DateMode.RECENT);
dateConf.setUnit(15);
semanticParseInfo.setDateInfo(dateConf);
Map<String, Object> properties = new HashMap<>();
properties.put("type", "internal");
properties.put("name", toolName);
semanticParseInfo.setProperties(properties);
fillSemanticParseInfo(semanticParseInfo);
return semanticParseInfo;
}
private void fillSemanticParseInfo(SemanticParseInfo semanticParseInfo) {
List<SchemaElementMatch> schemaElementMatches = semanticParseInfo.getElementMatches();
if (!CollectionUtils.isEmpty(schemaElementMatches)) {
schemaElementMatches.stream().filter(schemaElementMatch ->
SchemaElementType.VALUE.equals(schemaElementMatch.getElement().getType())
|| SchemaElementType.ID.equals(schemaElementMatch.getElement().getType()))
.forEach(schemaElementMatch -> {
QueryFilter queryFilter = new QueryFilter();
queryFilter.setValue(schemaElementMatch.getWord());
queryFilter.setElementID(schemaElementMatch.getElement().getId());
queryFilter.setName(schemaElementMatch.getElement().getName());
queryFilter.setOperator(FilterOperatorEnum.EQUALS);
queryFilter.setBizName(schemaElementMatch.getElement().getBizName());
semanticParseInfo.getDimensionFilters().add(queryFilter);
});
}
}
}

View File

@@ -0,0 +1,14 @@
package com.tencent.supersonic.chat.parser.llm.interpret;
import lombok.AllArgsConstructor;
import lombok.Data;
import lombok.NoArgsConstructor;
@Data
@AllArgsConstructor
@NoArgsConstructor
public class MetricOption {
private Long metricId;
}

View File

@@ -1,4 +1,4 @@
package com.tencent.supersonic.chat.parser.llm;
package com.tencent.supersonic.chat.parser.llm.time;
import com.alibaba.fastjson.JSON;
import com.alibaba.fastjson.JSONObject;
@@ -6,8 +6,8 @@ import com.tencent.supersonic.chat.api.component.SemanticParser;
import com.tencent.supersonic.chat.api.component.SemanticQuery;
import com.tencent.supersonic.chat.api.pojo.ChatContext;
import com.tencent.supersonic.chat.api.pojo.QueryContext;
import com.tencent.supersonic.chat.utils.ChatGptHelper;
import com.tencent.supersonic.common.pojo.DateConf;
import com.tencent.supersonic.common.util.ChatGptHelper;
import com.tencent.supersonic.common.util.ContextUtils;
import lombok.extern.slf4j.Slf4j;
@@ -17,20 +17,20 @@ public class LLMTimeEnhancementParse implements SemanticParser {
@Override
public void parse(QueryContext queryContext, ChatContext chatContext) {
log.info("before queryContext:{},chatContext:{}", queryContext, chatContext);
log.info("before queryContext:{},chatContext:{}",queryContext,chatContext);
ChatGptHelper chatGptHelper = ContextUtils.getBean(ChatGptHelper.class);
String inferredTime = chatGptHelper.inferredTime(queryContext.getRequest().getQueryText());
try {
String inferredTime = chatGptHelper.inferredTime(queryContext.getRequest().getQueryText());
if (!queryContext.getCandidateQueries().isEmpty()) {
for (SemanticQuery query : queryContext.getCandidateQueries()) {
DateConf dateInfo = query.getParseInfo().getDateInfo();
JSONObject jsonObject = JSON.parseObject(inferredTime);
if (jsonObject.containsKey("date")) {
if (jsonObject.containsKey("date")){
dateInfo.setDateMode(DateConf.DateMode.BETWEEN);
dateInfo.setStartDate(jsonObject.getString("date"));
dateInfo.setEndDate(jsonObject.getString("date"));
query.getParseInfo().setDateInfo(dateInfo);
} else if (jsonObject.containsKey("start")) {
}else if (jsonObject.containsKey("start")){
dateInfo.setDateMode(DateConf.DateMode.BETWEEN);
dateInfo.setStartDate(jsonObject.getString("start"));
dateInfo.setEndDate(jsonObject.getString("end"));
@@ -38,12 +38,11 @@ public class LLMTimeEnhancementParse implements SemanticParser {
}
}
}
} catch (Exception exception) {
log.error("{} parse error,this reason is:{}", LLMTimeEnhancementParse.class.getSimpleName(),
(Object) exception.getStackTrace());
}catch (Exception exception){
log.error("{} parse error,this reason is:{}",LLMTimeEnhancementParse.class.getSimpleName(), (Object) exception.getStackTrace());
}
log.info("after queryContext:{},chatContext:{}", queryContext, chatContext);
log.info("{} after queryContext:{},chatContext:{}",LLMTimeEnhancementParse.class.getSimpleName(),queryContext,chatContext);
}

View File

@@ -0,0 +1,63 @@
package com.tencent.supersonic.chat.parser.rule;
import com.alibaba.fastjson.JSONObject;
import com.google.common.collect.Lists;
import com.tencent.supersonic.chat.agent.Agent;
import com.tencent.supersonic.chat.agent.tool.AgentToolType;
import com.tencent.supersonic.chat.agent.tool.RuleQueryTool;
import com.tencent.supersonic.chat.api.component.SemanticParser;
import com.tencent.supersonic.chat.api.component.SemanticQuery;
import com.tencent.supersonic.chat.api.pojo.ChatContext;
import com.tencent.supersonic.chat.api.pojo.QueryContext;
import com.tencent.supersonic.chat.service.AgentService;
import com.tencent.supersonic.common.util.ContextUtils;
import lombok.extern.slf4j.Slf4j;
import org.apache.commons.collections.CollectionUtils;
import java.util.Collection;
import java.util.List;
import java.util.stream.Collectors;
@Slf4j
public class AgentCheckParser implements SemanticParser {
@Override
public void parse(QueryContext queryContext, ChatContext chatContext) {
List<SemanticQuery> queries = queryContext.getCandidateQueries();
agentCanSupport(queryContext.getRequest().getAgentId(), queries);
}
private void agentCanSupport(Integer agentId, List<SemanticQuery> queries) {
AgentService agentService = ContextUtils.getBean(AgentService.class);
Agent agent = agentService.getAgent(agentId);
if (agent == null) {
return;
}
List<String> queryModes = getRuleTools(agentId).stream().map(RuleQueryTool::getQueryModes)
.flatMap(Collection::stream).collect(Collectors.toList());
if (CollectionUtils.isEmpty(queries)) {
queries.clear();
return;
}
log.info("queries resolved:{} {}", agent.getName(),
queries.stream().map(SemanticQuery::getQueryMode).collect(Collectors.toList()));
queries.removeIf(query ->
!queryModes.contains(query.getQueryMode()));
log.info("rule queries witch can be supported by agent :{} {}", agent.getName(),
queries.stream().map(SemanticQuery::getQueryMode).collect(Collectors.toList()));
}
private static List<RuleQueryTool> getRuleTools(Integer agentId) {
AgentService agentService = ContextUtils.getBean(AgentService.class);
Agent agent = agentService.getAgent(agentId);
if (agent == null) {
return Lists.newArrayList();
}
List<String> tools = agent.getTools(AgentToolType.RULE);
if (CollectionUtils.isEmpty(tools)) {
return Lists.newArrayList();
}
return tools.stream().map(tool -> JSONObject.parseObject(tool, RuleQueryTool.class))
.collect(Collectors.toList());
}
}

View File

@@ -1,12 +1,9 @@
package com.tencent.supersonic.chat.parser.rule;
import com.tencent.supersonic.chat.api.component.SemanticParser;
import com.tencent.supersonic.chat.api.pojo.ChatContext;
import com.tencent.supersonic.chat.api.pojo.QueryContext;
import com.tencent.supersonic.chat.api.pojo.SchemaElementMatch;
import com.tencent.supersonic.chat.api.pojo.SchemaMapInfo;
import com.tencent.supersonic.chat.api.pojo.*;
import com.tencent.supersonic.chat.query.rule.RuleSemanticQuery;
import java.util.List;
import java.util.*;
import lombok.extern.slf4j.Slf4j;
@Slf4j
@@ -15,12 +12,10 @@ public class QueryModeParser implements SemanticParser {
@Override
public void parse(QueryContext queryContext, ChatContext chatContext) {
SchemaMapInfo mapInfo = queryContext.getMapInfo();
// iterate all schemaElementMatches to resolve semantic query
for (Long modelId : mapInfo.getMatchedModels()) {
List<SchemaElementMatch> elementMatches = mapInfo.getMatchedElements(modelId);
List<RuleSemanticQuery> queries = RuleSemanticQuery.resolve(elementMatches, queryContext);
for (RuleSemanticQuery query : queries) {
query.fillParseInfo(modelId, queryContext, chatContext);
queryContext.getCandidateQueries().add(query);

View File

@@ -0,0 +1,236 @@
package com.tencent.supersonic.chat.persistence.dataobject;
import java.util.Date;
public class AgentDO {
/**
*
*/
private Integer id;
/**
*
*/
private String name;
/**
*
*/
private String description;
/**
* 0 offline, 1 online
*/
private Integer status;
/**
*
*/
private String examples;
/**
*
*/
private String config;
/**
*
*/
private String createdBy;
/**
*
*/
private Date createdAt;
/**
*
*/
private String updatedBy;
/**
*
*/
private Date updatedAt;
/**
*
*/
private Integer enableSearch;
/**
*
* @return id
*/
public Integer getId() {
return id;
}
/**
*
* @param id
*/
public void setId(Integer id) {
this.id = id;
}
/**
*
* @return name
*/
public String getName() {
return name;
}
/**
*
* @param name
*/
public void setName(String name) {
this.name = name == null ? null : name.trim();
}
/**
*
* @return description
*/
public String getDescription() {
return description;
}
/**
*
* @param description
*/
public void setDescription(String description) {
this.description = description == null ? null : description.trim();
}
/**
* 0 offline, 1 online
* @return status 0 offline, 1 online
*/
public Integer getStatus() {
return status;
}
/**
* 0 offline, 1 online
* @param status 0 offline, 1 online
*/
public void setStatus(Integer status) {
this.status = status;
}
/**
*
* @return examples
*/
public String getExamples() {
return examples;
}
/**
*
* @param examples
*/
public void setExamples(String examples) {
this.examples = examples == null ? null : examples.trim();
}
/**
*
* @return config
*/
public String getConfig() {
return config;
}
/**
*
* @param config
*/
public void setConfig(String config) {
this.config = config == null ? null : config.trim();
}
/**
*
* @return created_by
*/
public String getCreatedBy() {
return createdBy;
}
/**
*
* @param createdBy
*/
public void setCreatedBy(String createdBy) {
this.createdBy = createdBy == null ? null : createdBy.trim();
}
/**
*
* @return created_at
*/
public Date getCreatedAt() {
return createdAt;
}
/**
*
* @param createdAt
*/
public void setCreatedAt(Date createdAt) {
this.createdAt = createdAt;
}
/**
*
* @return updated_by
*/
public String getUpdatedBy() {
return updatedBy;
}
/**
*
* @param updatedBy
*/
public void setUpdatedBy(String updatedBy) {
this.updatedBy = updatedBy == null ? null : updatedBy.trim();
}
/**
*
* @return updated_at
*/
public Date getUpdatedAt() {
return updatedAt;
}
/**
*
* @param updatedAt
*/
public void setUpdatedAt(Date updatedAt) {
this.updatedAt = updatedAt;
}
/**
*
* @return enable_search
*/
public Integer getEnableSearch() {
return enableSearch;
}
/**
*
* @param enableSearch
*/
public void setEnableSearch(Integer enableSearch) {
this.enableSearch = enableSearch;
}
}

View File

@@ -0,0 +1,71 @@
package com.tencent.supersonic.chat.persistence.mapper;
import com.tencent.supersonic.chat.persistence.dataobject.AgentDO;
import com.tencent.supersonic.chat.persistence.dataobject.AgentDOExample;
import org.apache.ibatis.annotations.Mapper;
import org.apache.ibatis.annotations.Param;
import java.util.List;
@Mapper
public interface AgentDOMapper {
/**
*
* @mbg.generated
*/
long countByExample(AgentDOExample example);
/**
*
* @mbg.generated
*/
int deleteByPrimaryKey(Integer id);
/**
*
* @mbg.generated
*/
int insert(AgentDO record);
/**
*
* @mbg.generated
*/
int insertSelective(AgentDO record);
/**
*
* @mbg.generated
*/
List<AgentDO> selectByExample(AgentDOExample example);
/**
*
* @mbg.generated
*/
AgentDO selectByPrimaryKey(Integer id);
/**
*
* @mbg.generated
*/
int updateByExampleSelective(@Param("record") AgentDO record, @Param("example") AgentDOExample example);
/**
*
* @mbg.generated
*/
int updateByExample(@Param("record") AgentDO record, @Param("example") AgentDOExample example);
/**
*
* @mbg.generated
*/
int updateByPrimaryKeySelective(AgentDO record);
/**
*
* @mbg.generated
*/
int updateByPrimaryKey(AgentDO record);
}

View File

@@ -0,0 +1,18 @@
package com.tencent.supersonic.chat.persistence.repository;
import com.tencent.supersonic.chat.persistence.dataobject.AgentDO;
import java.util.List;
public interface AgentRepository {
List<AgentDO> getAgents();
void createAgent(AgentDO agentDO);
void updateAgent(AgentDO agentDO);
AgentDO getAgent(Integer id);
void deleteAgent(Integer id);
}

View File

@@ -0,0 +1,43 @@
package com.tencent.supersonic.chat.persistence.repository.impl;
import com.tencent.supersonic.chat.persistence.dataobject.AgentDO;
import com.tencent.supersonic.chat.persistence.dataobject.AgentDOExample;
import com.tencent.supersonic.chat.persistence.mapper.AgentDOMapper;
import com.tencent.supersonic.chat.persistence.repository.AgentRepository;
import org.springframework.stereotype.Repository;
import java.util.List;
@Repository
public class AgentRepositoryImpl implements AgentRepository {
private AgentDOMapper agentDOMapper;
public AgentRepositoryImpl(AgentDOMapper agentDOMapper) {
this.agentDOMapper = agentDOMapper;
}
@Override
public List<AgentDO> getAgents() {
return agentDOMapper.selectByExample(new AgentDOExample());
}
@Override
public void createAgent(AgentDO agentDO) {
agentDOMapper.insert(agentDO);
}
@Override
public void updateAgent(AgentDO agentDO) {
agentDOMapper.updateByPrimaryKey(agentDO);
}
@Override
public AgentDO getAgent(Integer id) {
return agentDOMapper.selectByPrimaryKey(id);
}
@Override
public void deleteAgent(Integer id) {
agentDOMapper.deleteByPrimaryKey(id);
}
}

View File

@@ -7,12 +7,14 @@ import com.tencent.supersonic.chat.persistence.dataobject.ChatContextDO;
import com.tencent.supersonic.chat.persistence.mapper.ChatContextMapper;
import com.tencent.supersonic.chat.persistence.repository.ChatContextRepository;
import com.tencent.supersonic.common.util.JsonUtil;
import lombok.extern.slf4j.Slf4j;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.context.annotation.Primary;
import org.springframework.stereotype.Repository;
@Repository
@Primary
@Slf4j
public class ChatContextRepositoryImpl implements ChatContextRepository {
@Autowired(required = false)
@@ -50,8 +52,8 @@ public class ChatContextRepositoryImpl implements ChatContextRepository {
chatContext.setUser(contextDO.getUser());
chatContext.setQueryText(contextDO.getQueryText());
if (contextDO.getSemanticParse() != null && !contextDO.getSemanticParse().isEmpty()) {
SemanticParseInfo semanticParseInfo = JsonUtil.toObject(contextDO.getSemanticParse(),
SemanticParseInfo.class);
log.info("--->: {}",contextDO.getSemanticParse());
SemanticParseInfo semanticParseInfo = JsonUtil.toObject(contextDO.getSemanticParse(), SemanticParseInfo.class);
chatContext.setParseInfo(semanticParseInfo);
}
return chatContext;

View File

@@ -3,11 +3,11 @@ package com.tencent.supersonic.chat.plugin;
import com.alibaba.fastjson.JSONObject;
import com.google.common.collect.Lists;
import com.google.common.collect.Sets;
import com.tencent.supersonic.chat.api.pojo.QueryContext;
import com.tencent.supersonic.chat.api.pojo.SchemaElement;
import com.tencent.supersonic.chat.api.pojo.SchemaElementMatch;
import com.tencent.supersonic.chat.api.pojo.SchemaElementType;
import com.tencent.supersonic.chat.api.pojo.SchemaMapInfo;
import com.tencent.supersonic.chat.agent.tool.DslTool;
import com.tencent.supersonic.chat.api.pojo.*;
import com.tencent.supersonic.chat.agent.Agent;
import com.tencent.supersonic.chat.agent.tool.AgentToolType;
import com.tencent.supersonic.chat.agent.tool.PluginTool;
import com.tencent.supersonic.chat.parser.ParseMode;
import com.tencent.supersonic.chat.parser.embedding.EmbeddingConfig;
import com.tencent.supersonic.chat.parser.embedding.EmbeddingResp;
@@ -16,30 +16,22 @@ import com.tencent.supersonic.chat.plugin.event.PluginAddEvent;
import com.tencent.supersonic.chat.plugin.event.PluginUpdateEvent;
import com.tencent.supersonic.chat.query.plugin.ParamOption;
import com.tencent.supersonic.chat.query.plugin.WebBase;
import com.tencent.supersonic.chat.service.AgentService;
import com.tencent.supersonic.chat.service.PluginService;
import com.tencent.supersonic.common.util.ContextUtils;
import java.net.URI;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.Optional;
import java.util.Set;
import java.util.*;
import java.util.concurrent.ConcurrentHashMap;
import java.util.stream.Collectors;
import java.util.stream.Stream;
import lombok.extern.slf4j.Slf4j;
import org.apache.commons.collections.CollectionUtils;
import org.apache.commons.lang3.tuple.Pair;
import org.apache.logging.log4j.util.Strings;
import org.springframework.context.event.EventListener;
import org.springframework.core.ParameterizedTypeReference;
import org.springframework.http.HttpEntity;
import org.springframework.http.HttpHeaders;
import org.springframework.http.HttpMethod;
import org.springframework.http.MediaType;
import org.springframework.http.ResponseEntity;
import org.springframework.http.*;
import org.springframework.stereotype.Component;
import org.springframework.web.client.RestTemplate;
import org.springframework.web.util.UriComponentsBuilder;
@@ -59,12 +51,40 @@ public class PluginManager {
this.restTemplate = restTemplate;
}
public static List<Plugin> getPlugins() {
public static List<Plugin> getPluginAgentCanSupport(Integer agentId) {
PluginService pluginService = ContextUtils.getBean(PluginService.class);
List<Plugin> pluginList = pluginService.getPluginList().stream().filter(plugin ->
CollectionUtils.isNotEmpty(plugin.getModelList())).collect(Collectors.toList());
pluginList.addAll(internalPluginMap.values());
return new ArrayList<>(pluginList);
List<Plugin> plugins = pluginService.getPluginList();
if (agentId == null) {
return plugins;
}
Agent agent = ContextUtils.getBean(AgentService.class).getAgent(agentId);
if (agent == null) {
return plugins;
}
List<Long> pluginIds = getPluginTools(agentId).stream().map(PluginTool::getPlugins)
.flatMap(Collection::stream).collect(Collectors.toList());
if (CollectionUtils.isEmpty(pluginIds)) {
return Lists.newArrayList();
}
plugins = plugins.stream().filter(plugin -> pluginIds.contains(plugin.getId()))
.collect(Collectors.toList());
log.info("plugins witch can be supported by cur agent :{} {}", agent.getName(),
plugins.stream().map(Plugin::getName).collect(Collectors.toList()));
return plugins;
}
private static List<PluginTool> getPluginTools(Integer agentId) {
AgentService agentService = ContextUtils.getBean(AgentService.class);
Agent agent = agentService.getAgent(agentId);
if (agent == null) {
return Lists.newArrayList();
}
List<String> tools = agent.getTools(AgentToolType.PLUGIN);
if (CollectionUtils.isEmpty(tools)) {
return Lists.newArrayList();
}
return tools.stream().map(tool -> JSONObject.parseObject(tool, PluginTool.class))
.collect(Collectors.toList());
}
@EventListener
@@ -201,17 +221,17 @@ public class PluginManager {
return String.valueOf(Integer.parseInt(id) / 1000);
}
public static Pair<Boolean, List<Long>> resolve(Plugin plugin, QueryContext queryContext) {
public static Pair<Boolean, Set<Long>> resolve(Plugin plugin, QueryContext queryContext) {
SchemaMapInfo schemaMapInfo = queryContext.getMapInfo();
Set<Long> pluginMatchedModel = getPluginMatchedModel(plugin, queryContext);
if (CollectionUtils.isEmpty(pluginMatchedModel) && !plugin.isContainsAllModel()) {
return Pair.of(false, Lists.newArrayList());
return Pair.of(false, Sets.newHashSet());
}
List<ParamOption> paramOptions = getSemanticOption(plugin);
if (CollectionUtils.isEmpty(paramOptions)) {
return Pair.of(true, new ArrayList<>(pluginMatchedModel));
return Pair.of(true, Sets.newHashSet());
}
List<Long> matchedModel = Lists.newArrayList();
Set<Long> matchedModel = Sets.newHashSet();
Map<Long, List<ParamOption>> paramOptionMap = paramOptions.stream().
collect(Collectors.groupingBy(ParamOption::getModelId));
for (Long modelId : paramOptionMap.keySet()) {
@@ -237,7 +257,7 @@ public class PluginManager {
}
}
if (CollectionUtils.isEmpty(matchedModel)) {
return Pair.of(false, Lists.newArrayList());
return Pair.of(false, Sets.newHashSet());
}
return Pair.of(true, matchedModel);
}

View File

@@ -1,149 +0,0 @@
package com.tencent.supersonic.chat.query.ContentInterpret;
import com.alibaba.fastjson.JSONObject;
import com.google.common.collect.Lists;
import com.tencent.supersonic.auth.api.authentication.pojo.User;
import com.tencent.supersonic.chat.api.component.SemanticLayer;
import com.tencent.supersonic.chat.api.pojo.ModelSchema;
import com.tencent.supersonic.chat.api.pojo.SchemaElement;
import com.tencent.supersonic.chat.api.pojo.request.QueryFilter;
import com.tencent.supersonic.chat.api.pojo.response.QueryResult;
import com.tencent.supersonic.chat.api.pojo.response.QueryState;
import com.tencent.supersonic.chat.plugin.PluginManager;
import com.tencent.supersonic.chat.plugin.PluginParseResult;
import com.tencent.supersonic.chat.query.QueryManager;
import com.tencent.supersonic.chat.query.plugin.PluginSemanticQuery;
import com.tencent.supersonic.chat.utils.ComponentFactory;
import com.tencent.supersonic.common.pojo.Aggregator;
import com.tencent.supersonic.common.pojo.Constants;
import com.tencent.supersonic.common.pojo.DateConf;
import com.tencent.supersonic.common.pojo.QueryColumn;
import com.tencent.supersonic.common.pojo.enums.AggOperatorEnum;
import com.tencent.supersonic.common.util.ContextUtils;
import com.tencent.supersonic.common.util.JsonUtil;
import com.tencent.supersonic.semantic.api.model.enums.TimeDimensionEnum;
import com.tencent.supersonic.semantic.api.model.response.QueryResultWithSchemaResp;
import com.tencent.supersonic.semantic.api.query.pojo.Filter;
import com.tencent.supersonic.semantic.api.query.request.QueryStructReq;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.stream.Collectors;
import lombok.extern.slf4j.Slf4j;
import org.apache.calcite.sql.parser.SqlParseException;
import org.springframework.beans.BeanUtils;
import org.springframework.http.ResponseEntity;
import org.springframework.stereotype.Component;
import org.springframework.util.CollectionUtils;
@Slf4j
@Component
public class ContentInterpretQuery extends PluginSemanticQuery {
@Override
public String getQueryMode() {
return "CONTENT_INTERPRET";
}
public ContentInterpretQuery() {
QueryManager.register(this);
}
@Override
public QueryResult execute(User user) throws SqlParseException {
QueryResultWithSchemaResp queryResultWithSchemaResp = queryMetric(user);
String text = generateDataText(queryResultWithSchemaResp);
Map<String, Object> properties = parseInfo.getProperties();
PluginParseResult pluginParseResult = JsonUtil.toObject(JsonUtil.toString(properties.get(Constants.CONTEXT))
, PluginParseResult.class);
String answer = fetchInterpret(pluginParseResult.getRequest().getQueryText(), text);
QueryResult queryResult = new QueryResult();
List<QueryColumn> queryColumns = Lists.newArrayList(new QueryColumn("结果", "string", "answer"));
Map<String, Object> result = new HashMap<>();
result.put("answer", answer);
List<Map<String, Object>> resultList = Lists.newArrayList();
resultList.add(result);
queryResultWithSchemaResp.setResultList(resultList);
queryResultWithSchemaResp.setColumns(queryColumns);
queryResult.setResponse(queryResultWithSchemaResp);
queryResult.setQueryMode(getQueryMode());
queryResult.setQueryState(QueryState.SUCCESS);
return queryResult;
}
private QueryResultWithSchemaResp queryMetric(User user) {
SemanticLayer semanticLayer = ComponentFactory.getSemanticLayer();
QueryStructReq queryStructReq = new QueryStructReq();
queryStructReq.setModelId(parseInfo.getModelId());
queryStructReq.setGroups(Lists.newArrayList(TimeDimensionEnum.DAY.getName()));
ModelSchema modelSchema = semanticLayer.getModelSchema(parseInfo.getModelId(), true);
queryStructReq.setAggregators(buildAggregator(modelSchema));
List<Filter> filterList = Lists.newArrayList();
for (QueryFilter queryFilter : parseInfo.getDimensionFilters()) {
Filter filter = new Filter();
BeanUtils.copyProperties(queryFilter, filter);
filterList.add(filter);
}
queryStructReq.setDimensionFilters(filterList);
DateConf dateConf = new DateConf();
dateConf.setDateMode(DateConf.DateMode.RECENT);
dateConf.setUnit(7);
queryStructReq.setDateInfo(dateConf);
return semanticLayer.queryByStruct(queryStructReq, user);
}
private List<Aggregator> buildAggregator(ModelSchema modelSchema) {
List<Aggregator> aggregators = Lists.newArrayList();
Set<SchemaElement> metrics = modelSchema.getMetrics();
if (CollectionUtils.isEmpty(metrics)) {
return aggregators;
}
for (SchemaElement schemaElement : metrics) {
Aggregator aggregator = new Aggregator();
aggregator.setColumn(schemaElement.getBizName());
aggregator.setFunc(AggOperatorEnum.SUM);
aggregator.setNameCh(schemaElement.getName());
aggregators.add(aggregator);
}
return aggregators;
}
public String generateDataText(QueryResultWithSchemaResp queryResultWithSchemaResp) {
Map<String, String> map = queryResultWithSchemaResp.getColumns().stream()
.collect(Collectors.toMap(QueryColumn::getNameEn, QueryColumn::getName));
StringBuilder stringBuilder = new StringBuilder();
for (Map<String, Object> valueMap : queryResultWithSchemaResp.getResultList()) {
for (String key : valueMap.keySet()) {
String name = "";
if (TimeDimensionEnum.getNameList().contains(key)) {
name = "日期";
} else {
name = map.get(key);
}
String value = String.valueOf(valueMap.get(key));
stringBuilder.append(name).append(":").append(value).append(" ");
}
}
return stringBuilder.toString();
}
public String fetchInterpret(String queryText, String dataText) {
PluginManager pluginManager = ContextUtils.getBean(PluginManager.class);
LLmAnswerReq lLmAnswerReq = new LLmAnswerReq();
lLmAnswerReq.setQueryText(queryText);
lLmAnswerReq.setPluginOutput(dataText);
ResponseEntity<String> responseEntity = pluginManager.doRequest("answer_with_plugin_call",
JSONObject.toJSONString(lLmAnswerReq));
LLmAnswerResp lLmAnswerResp = JSONObject.parseObject(responseEntity.getBody(), LLmAnswerResp.class);
if (lLmAnswerResp != null) {
return lLmAnswerResp.getAssistant_message();
}
return null;
}
}

View File

@@ -3,6 +3,7 @@ package com.tencent.supersonic.chat.query;
import com.tencent.supersonic.chat.api.component.SemanticQuery;
import com.tencent.supersonic.chat.api.pojo.SchemaElementMatch;
import com.tencent.supersonic.chat.api.pojo.SemanticParseInfo;
import com.tencent.supersonic.chat.api.pojo.request.QueryReq;
import com.tencent.supersonic.chat.query.rule.RuleSemanticQuery;
import com.tencent.supersonic.chat.query.rule.metric.MetricEntityQuery;
import com.tencent.supersonic.chat.query.rule.metric.MetricModelQuery;
@@ -18,7 +19,7 @@ public class HeuristicQuerySelector implements QuerySelector {
private static final double CANDIDATE_THRESHOLD = 0.2;
@Override
public List<SemanticQuery> select(List<SemanticQuery> candidateQueries) {
public List<SemanticQuery> select(List<SemanticQuery> candidateQueries, QueryReq queryReq) {
List<SemanticQuery> selectedQueries = new ArrayList<>();
if (CollectionUtils.isNotEmpty(candidateQueries) && candidateQueries.size() == 1) {

View File

@@ -1,6 +1,8 @@
package com.tencent.supersonic.chat.query;
import com.tencent.supersonic.chat.api.component.SemanticQuery;
import com.tencent.supersonic.chat.api.pojo.request.QueryReq;
import java.util.List;
/**
@@ -8,5 +10,5 @@ import java.util.List;
**/
public interface QuerySelector {
List<SemanticQuery> select(List<SemanticQuery> candidateQueries);
List<SemanticQuery> select(List<SemanticQuery> candidateQueries, QueryReq queryReq);
}

View File

@@ -1,78 +0,0 @@
package com.tencent.supersonic.chat.query.dsl;
import com.tencent.supersonic.chat.api.pojo.SchemaElement;
import com.tencent.supersonic.chat.api.pojo.SemanticParseInfo;
import com.tencent.supersonic.chat.api.pojo.SemanticSchema;
import com.tencent.supersonic.chat.api.pojo.request.QueryFilter;
import com.tencent.supersonic.chat.api.pojo.request.QueryFilters;
import com.tencent.supersonic.common.pojo.Constants;
import com.tencent.supersonic.common.util.ContextUtils;
import com.tencent.supersonic.common.util.StringUtil;
import com.tencent.supersonic.common.util.jsqlparser.CCJSqlParserUtils;
import com.tencent.supersonic.knowledge.service.SchemaService;
import com.tencent.supersonic.semantic.api.model.enums.TimeDimensionEnum;
import java.util.ArrayList;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.stream.Collectors;
import lombok.extern.slf4j.Slf4j;
import net.sf.jsqlparser.expression.Expression;
import net.sf.jsqlparser.parser.CCJSqlParserUtil;
import org.apache.commons.collections.CollectionUtils;
import org.apache.commons.lang3.StringUtils;
@Slf4j
public class DSLBuilder {
public static final String DATA_Field = "数据日期";
public static final String TABLE_PREFIX = "t_";
public String build(SemanticParseInfo parseInfo, QueryFilters queryFilters, LLMResp llmResp, Long modelId)
throws Exception {
String sqlOutput = llmResp.getSqlOutput();
SemanticSchema semanticSchema = ContextUtils.getBean(SchemaService.class).getSemanticSchema();
List<SchemaElement> dbAllFields = new ArrayList<>();
dbAllFields.addAll(semanticSchema.getMetrics());
dbAllFields.addAll(semanticSchema.getDimensions());
Map<String, String> fieldToBizName = getMapInfo(modelId, dbAllFields);
fieldToBizName.put(DATA_Field, TimeDimensionEnum.DAY.getName());
sqlOutput = CCJSqlParserUtils.replaceFields(sqlOutput, fieldToBizName);
sqlOutput = CCJSqlParserUtils.replaceTable(sqlOutput, TABLE_PREFIX + modelId);
String queryFilter = getQueryFilter(queryFilters);
if (StringUtils.isNotEmpty(queryFilter)) {
log.info("add queryFilter to sql :{}", queryFilter);
Expression expression = CCJSqlParserUtil.parseCondExpression(queryFilter);
CCJSqlParserUtils.addWhere(sqlOutput, expression);
}
log.info("build sqlOutput:{}", sqlOutput);
return sqlOutput;
}
protected Map<String, String> getMapInfo(Long modelId, List<SchemaElement> metrics) {
return metrics.stream().filter(entry -> entry.getModel().equals(modelId))
.collect(Collectors.toMap(SchemaElement::getName, a -> a.getBizName(), (k1, k2) -> k1));
}
private String getQueryFilter(QueryFilters queryFilters) {
if (Objects.isNull(queryFilters) || CollectionUtils.isEmpty(queryFilters.getFilters())) {
return "";
}
List<QueryFilter> filters = queryFilters.getFilters();
return filters.stream()
.map(filter -> {
String bizNameWrap = StringUtil.getSpaceWrap(filter.getBizName());
String operatorWrap = StringUtil.getSpaceWrap(filter.getOperator().getValue());
String valueWrap = StringUtil.getCommaWrap(filter.getValue().toString());
return bizNameWrap + operatorWrap + valueWrap;
})
.collect(Collectors.joining(Constants.AND_UPPER));
}
}

View File

@@ -2,14 +2,14 @@ package com.tencent.supersonic.chat.query.dsl;
import com.tencent.supersonic.auth.api.authentication.pojo.User;
import com.tencent.supersonic.chat.api.component.SemanticLayer;
import com.tencent.supersonic.chat.api.pojo.SemanticParseInfo;
import com.tencent.supersonic.chat.api.pojo.request.QueryFilters;
import com.tencent.supersonic.chat.api.pojo.request.QueryReq;
import com.tencent.supersonic.chat.api.pojo.response.EntityInfo;
import com.tencent.supersonic.chat.api.pojo.response.QueryResult;
import com.tencent.supersonic.chat.api.pojo.response.QueryState;
import com.tencent.supersonic.chat.parser.llm.DSLParseResult;
import com.tencent.supersonic.chat.parser.llm.dsl.DSLParseResult;
import com.tencent.supersonic.chat.query.QueryManager;
import com.tencent.supersonic.chat.api.pojo.CorrectionInfo;
import com.tencent.supersonic.chat.api.component.DSLOptimizer;
import com.tencent.supersonic.chat.query.plugin.PluginSemanticQuery;
import com.tencent.supersonic.chat.service.SemanticService;
import com.tencent.supersonic.chat.utils.ComponentFactory;
@@ -32,7 +32,6 @@ import org.springframework.stereotype.Component;
public class DSLQuery extends PluginSemanticQuery {
public static final String QUERY_MODE = "DSL";
private DSLBuilder dslBuilder = new DSLBuilder();
protected SemanticLayer semanticLayer = ComponentFactory.getSemanticLayer();
public DSLQuery() {
@@ -51,12 +50,26 @@ public class DSLQuery extends PluginSemanticQuery {
LLMResp llmResp = dslParseResult.getLlmResp();
QueryReq queryReq = dslParseResult.getRequest();
Long modelId = parseInfo.getModelId();
String querySql = convertToSql(queryReq.getQueryFilters(), llmResp, parseInfo, modelId);
CorrectionInfo correctionInfo = CorrectionInfo.builder()
.queryFilters(queryReq.getQueryFilters())
.sql(llmResp.getSqlOutput())
.parseInfo(parseInfo)
.build();
List<DSLOptimizer> DSLCorrections = ComponentFactory.getSqlCorrections();
DSLCorrections.forEach(DSLCorrection -> {
try {
DSLCorrection.rewriter(correctionInfo);
log.info("sqlCorrection:{} sql:{}", DSLCorrection.getClass().getSimpleName(), correctionInfo.getSql());
} catch (Exception e) {
log.error("sqlCorrection:{} execute error,correctionInfo:{}", DSLCorrection, correctionInfo, e);
}
});
String querySql = correctionInfo.getSql();
long startTime = System.currentTimeMillis();
QueryDslReq queryDslReq = QueryReqBuilder.buildDslReq(querySql, modelId);
QueryDslReq queryDslReq = QueryReqBuilder.buildDslReq(querySql, parseInfo.getModelId());
QueryResultWithSchemaResp queryResp = semanticLayer.queryByDsl(queryDslReq, user);
log.info("queryByDsl cost:{},querySql:{}", System.currentTimeMillis() - startTime, querySql);
@@ -80,17 +93,4 @@ public class DSLQuery extends PluginSemanticQuery {
parseInfo.setProperties(null);
return queryResult;
}
protected String convertToSql(QueryFilters queryFilters, LLMResp llmResp, SemanticParseInfo parseInfo,
Long modelId) {
try {
return dslBuilder.build(parseInfo, queryFilters, llmResp, modelId);
} catch (Exception e) {
log.error("convertToSql error", e);
}
return null;
}
}

View File

@@ -0,0 +1,33 @@
package com.tencent.supersonic.chat.query.dsl.optimizer;
import com.tencent.supersonic.chat.api.component.DSLOptimizer;
import com.tencent.supersonic.chat.api.pojo.SchemaElement;
import com.tencent.supersonic.chat.api.pojo.SemanticSchema;
import com.tencent.supersonic.common.util.ContextUtils;
import com.tencent.supersonic.knowledge.service.SchemaService;
import com.tencent.supersonic.semantic.api.model.enums.TimeDimensionEnum;
import java.util.ArrayList;
import java.util.List;
import java.util.Map;
import java.util.stream.Collectors;
import lombok.extern.slf4j.Slf4j;
@Slf4j
public abstract class BaseDSLOptimizer implements DSLOptimizer {
public static final String DATE_FIELD = "数据日期";
protected Map<String, String> getFieldToBizName(Long modelId) {
SemanticSchema semanticSchema = ContextUtils.getBean(SchemaService.class).getSemanticSchema();
List<SchemaElement> dbAllFields = new ArrayList<>();
dbAllFields.addAll(semanticSchema.getMetrics());
dbAllFields.addAll(semanticSchema.getDimensions());
Map<String, String> result = dbAllFields.stream()
.filter(entry -> entry.getModel().equals(modelId))
.collect(Collectors.toMap(SchemaElement::getName, a -> a.getBizName(), (k1, k2) -> k1));
result.put(DATE_FIELD, TimeDimensionEnum.DAY.getName());
return result;
}
}

View File

@@ -0,0 +1,26 @@
package com.tencent.supersonic.chat.query.dsl.optimizer;
import com.tencent.supersonic.chat.api.pojo.CorrectionInfo;
import com.tencent.supersonic.chat.parser.llm.dsl.DSLDateHelper;
import com.tencent.supersonic.common.util.jsqlparser.CCJSqlParserUtils;
import java.util.List;
import lombok.extern.slf4j.Slf4j;
import org.springframework.util.CollectionUtils;
@Slf4j
public class DateFieldCorrector extends BaseDSLOptimizer {
@Override
public CorrectionInfo rewriter(CorrectionInfo correctionInfo) {
String sql = correctionInfo.getSql();
List<String> whereFields = CCJSqlParserUtils.getWhereFields(sql);
if (CollectionUtils.isEmpty(whereFields) || !whereFields.contains(BaseDSLOptimizer.DATE_FIELD)) {
String currentDate = DSLDateHelper.getCurrentDate(correctionInfo.getParseInfo().getModelId());
sql = CCJSqlParserUtils.addWhere(sql, BaseDSLOptimizer.DATE_FIELD, currentDate);
}
correctionInfo.setSql(sql);
return correctionInfo;
}
}

View File

@@ -0,0 +1,17 @@
package com.tencent.supersonic.chat.query.dsl.optimizer;
import com.tencent.supersonic.chat.api.pojo.CorrectionInfo;
import com.tencent.supersonic.common.util.jsqlparser.CCJSqlParserUtils;
import lombok.extern.slf4j.Slf4j;
@Slf4j
public class FieldCorrector extends BaseDSLOptimizer {
@Override
public CorrectionInfo rewriter(CorrectionInfo correctionInfo) {
String replaceFields = CCJSqlParserUtils.replaceFields(correctionInfo.getSql(),
getFieldToBizName(correctionInfo.getParseInfo().getModelId()));
correctionInfo.setSql(replaceFields);
return correctionInfo;
}
}

View File

@@ -0,0 +1,16 @@
package com.tencent.supersonic.chat.query.dsl.optimizer;
import com.tencent.supersonic.chat.api.pojo.CorrectionInfo;
import com.tencent.supersonic.common.util.jsqlparser.CCJSqlParserUtils;
import lombok.extern.slf4j.Slf4j;
@Slf4j
public class FunctionCorrector extends BaseDSLOptimizer {
@Override
public CorrectionInfo rewriter(CorrectionInfo correctionInfo) {
String replaceFunction = CCJSqlParserUtils.replaceFunction(correctionInfo.getSql());
correctionInfo.setSql(replaceFunction);
return correctionInfo;
}
}

View File

@@ -0,0 +1,48 @@
package com.tencent.supersonic.chat.query.dsl.optimizer;
import com.tencent.supersonic.chat.api.pojo.CorrectionInfo;
import com.tencent.supersonic.chat.api.pojo.request.QueryFilters;
import com.tencent.supersonic.common.pojo.Constants;
import com.tencent.supersonic.common.util.StringUtil;
import com.tencent.supersonic.common.util.jsqlparser.CCJSqlParserUtils;
import java.util.Objects;
import java.util.stream.Collectors;
import lombok.extern.slf4j.Slf4j;
import net.sf.jsqlparser.JSQLParserException;
import net.sf.jsqlparser.expression.Expression;
import net.sf.jsqlparser.parser.CCJSqlParserUtil;
import org.apache.commons.collections.CollectionUtils;
import org.apache.commons.lang3.StringUtils;
@Slf4j
public class QueryFilterAppend extends BaseDSLOptimizer {
@Override
public CorrectionInfo rewriter(CorrectionInfo correctionInfo) throws JSQLParserException {
String queryFilter = getQueryFilter(correctionInfo.getQueryFilters());
String sql = correctionInfo.getSql();
if (StringUtils.isNotEmpty(queryFilter)) {
log.info("add queryFilter to sql :{}", queryFilter);
Expression expression = CCJSqlParserUtil.parseCondExpression(queryFilter);
sql = CCJSqlParserUtils.addWhere(sql, expression);
}
correctionInfo.setSql(sql);
return correctionInfo;
}
private String getQueryFilter(QueryFilters queryFilters) {
if (Objects.isNull(queryFilters) || CollectionUtils.isEmpty(queryFilters.getFilters())) {
return null;
}
return queryFilters.getFilters().stream()
.map(filter -> {
String bizNameWrap = StringUtil.getSpaceWrap(filter.getBizName());
String operatorWrap = StringUtil.getSpaceWrap(filter.getOperator().getValue());
String valueWrap = StringUtil.getCommaWrap(filter.getValue().toString());
return bizNameWrap + operatorWrap + valueWrap;
})
.collect(Collectors.joining(Constants.AND_UPPER));
}
}

View File

@@ -0,0 +1,35 @@
package com.tencent.supersonic.chat.query.dsl.optimizer;
import com.tencent.supersonic.chat.api.pojo.CorrectionInfo;
import com.tencent.supersonic.common.util.jsqlparser.CCJSqlParserUtils;
import com.tencent.supersonic.semantic.api.model.enums.TimeDimensionEnum;
import java.util.ArrayList;
import java.util.HashSet;
import java.util.Set;
import lombok.extern.slf4j.Slf4j;
import org.springframework.util.CollectionUtils;
@Slf4j
public class SelectFieldAppendCorrector extends BaseDSLOptimizer {
@Override
public CorrectionInfo rewriter(CorrectionInfo correctionInfo) {
String sql = correctionInfo.getSql();
if (CCJSqlParserUtils.hasAggregateFunction(sql)) {
return correctionInfo;
}
Set<String> selectFields = new HashSet<>(CCJSqlParserUtils.getSelectFields(sql));
Set<String> whereFields = new HashSet<>(CCJSqlParserUtils.getWhereFields(sql));
if (CollectionUtils.isEmpty(selectFields) || CollectionUtils.isEmpty(whereFields)) {
return correctionInfo;
}
whereFields.removeAll(selectFields);
whereFields.remove(TimeDimensionEnum.DAY.getName());
whereFields.remove(TimeDimensionEnum.WEEK.getName());
whereFields.remove(TimeDimensionEnum.MONTH.getName());
String replaceFields = CCJSqlParserUtils.addFieldsToSelect(sql, new ArrayList<>(whereFields));
correctionInfo.setSql(replaceFields);
return correctionInfo;
}
}

View File

@@ -0,0 +1,21 @@
package com.tencent.supersonic.chat.query.dsl.optimizer;
import com.tencent.supersonic.chat.api.pojo.CorrectionInfo;
import com.tencent.supersonic.common.util.jsqlparser.CCJSqlParserUtils;
import lombok.extern.slf4j.Slf4j;
@Slf4j
public class TableNameCorrector extends BaseDSLOptimizer {
public static final String TABLE_PREFIX = "t_";
@Override
public CorrectionInfo rewriter(CorrectionInfo correctionInfo) {
Long modelId = correctionInfo.getParseInfo().getModelId();
String sqlOutput = correctionInfo.getSql();
String replaceTable = CCJSqlParserUtils.replaceTable(sqlOutput, TABLE_PREFIX + modelId);
correctionInfo.setSql(replaceTable);
return correctionInfo;
}
}

View File

@@ -1,4 +1,4 @@
package com.tencent.supersonic.chat.query.ContentInterpret;
package com.tencent.supersonic.chat.query.metricInterpret;
import lombok.Data;

View File

@@ -1,4 +1,4 @@
package com.tencent.supersonic.chat.query.ContentInterpret;
package com.tencent.supersonic.chat.query.metricInterpret;
import lombok.Data;

View File

@@ -0,0 +1,143 @@
package com.tencent.supersonic.chat.query.metricInterpret;
import com.alibaba.fastjson.JSONObject;
import com.google.common.collect.Lists;
import com.tencent.supersonic.auth.api.authentication.pojo.User;
import com.tencent.supersonic.chat.api.component.SemanticLayer;
import com.tencent.supersonic.chat.api.pojo.SchemaElement;
import com.tencent.supersonic.chat.api.pojo.SchemaElementMatch;
import com.tencent.supersonic.chat.api.pojo.SchemaElementType;
import com.tencent.supersonic.chat.api.pojo.response.QueryResult;
import com.tencent.supersonic.chat.api.pojo.response.QueryState;
import com.tencent.supersonic.chat.plugin.PluginManager;
import com.tencent.supersonic.chat.query.QueryManager;
import com.tencent.supersonic.chat.query.plugin.PluginSemanticQuery;
import com.tencent.supersonic.chat.utils.ComponentFactory;
import com.tencent.supersonic.chat.utils.QueryReqBuilder;
import com.tencent.supersonic.common.pojo.Aggregator;
import com.tencent.supersonic.common.pojo.QueryColumn;
import com.tencent.supersonic.common.pojo.enums.AggOperatorEnum;
import com.tencent.supersonic.common.util.ContextUtils;
import com.tencent.supersonic.semantic.api.model.response.QueryResultWithSchemaResp;
import com.tencent.supersonic.semantic.api.query.request.QueryStructReq;
import lombok.extern.slf4j.Slf4j;
import org.apache.calcite.sql.parser.SqlParseException;
import org.apache.commons.lang3.StringUtils;
import org.springframework.http.ResponseEntity;
import org.springframework.stereotype.Component;
import org.springframework.util.CollectionUtils;
import java.util.*;
import java.util.stream.Collectors;
@Slf4j
@Component
public class MetricInterpretQuery extends PluginSemanticQuery {
public final static String QUERY_MODE = "METRIC_INTERPRET";
@Override
public String getQueryMode() {
return QUERY_MODE;
}
public MetricInterpretQuery() {
QueryManager.register(this);
}
@Override
public QueryResult execute(User user) throws SqlParseException {
QueryStructReq queryStructReq = QueryReqBuilder.buildStructReq(parseInfo);
fillAggregator(queryStructReq, parseInfo.getMetrics());
queryStructReq.setNativeQuery(true);
SemanticLayer semanticLayer = ComponentFactory.getSemanticLayer();
QueryResultWithSchemaResp queryResultWithSchemaResp = semanticLayer.queryByStruct(queryStructReq, user);
String text = generateTableText(queryResultWithSchemaResp);
Map<String, Object> properties = parseInfo.getProperties();
Map<String, String> replacedMap = new HashMap<>();
String textReplaced = replaceText((String) properties.get("queryText"), parseInfo.getElementMatches(), replacedMap);
String answer = replaceAnswer(fetchInterpret(textReplaced, text), replacedMap);
QueryResult queryResult = new QueryResult();
List<QueryColumn> queryColumns = Lists.newArrayList(new QueryColumn("结果","string","answer"));
Map<String, Object> result = new HashMap<>();
result.put("answer", answer);
List<Map<String, Object>> resultList = Lists.newArrayList();
resultList.add(result);
queryResult.setQueryResults(resultList);
queryResult.setQueryColumns(queryColumns);
queryResult.setQueryMode(getQueryMode());
queryResult.setQueryState(QueryState.SUCCESS);
return queryResult;
}
private String replaceText(String text, List<SchemaElementMatch> schemaElementMatches, Map<String, String> replacedMap) {
if (CollectionUtils.isEmpty(schemaElementMatches)) {
return text;
}
List<SchemaElementMatch> valueSchemaElementMatches = schemaElementMatches.stream()
.filter(schemaElementMatch ->
SchemaElementType.VALUE.equals(schemaElementMatch.getElement().getType())
|| SchemaElementType.ID.equals(schemaElementMatch.getElement().getType()))
.collect(Collectors.toList());
for (SchemaElementMatch schemaElementMatch : valueSchemaElementMatches) {
String detectWord = schemaElementMatch.getDetectWord();
if (StringUtils.isBlank(detectWord)) {
continue;
}
text = text.replace(detectWord, "xxx");
replacedMap.put("xxx", detectWord);
}
return text;
}
private void fillAggregator(QueryStructReq queryStructReq, Set<SchemaElement> schemaElements) {
queryStructReq.getAggregators().clear();
for (SchemaElement schemaElement : schemaElements) {
Aggregator aggregator = new Aggregator();
aggregator.setColumn(schemaElement.getBizName());
aggregator.setFunc(AggOperatorEnum.SUM);
aggregator.setNameCh(schemaElement.getName());
queryStructReq.getAggregators().add(aggregator);
}
}
private String replaceAnswer(String text, Map<String, String> replacedMap) {
for (String key : replacedMap.keySet()) {
text = text.replaceAll(key, replacedMap.get(key));
}
return text;
}
public static String generateTableText(QueryResultWithSchemaResp result) {
StringBuilder tableBuilder = new StringBuilder();
for (QueryColumn column : result.getColumns()) {
tableBuilder.append(column.getName()).append("\t");
}
tableBuilder.append("\n");
for (Map<String, Object> row : result.getResultList()) {
for (QueryColumn column : result.getColumns()) {
tableBuilder.append(row.get(column.getNameEn())).append("\t");
}
tableBuilder.append("\n");
}
return tableBuilder.toString();
}
public String fetchInterpret(String queryText, String dataText) {
PluginManager pluginManager = ContextUtils.getBean(PluginManager.class);
LLmAnswerReq lLmAnswerReq = new LLmAnswerReq();
lLmAnswerReq.setQueryText(queryText);
lLmAnswerReq.setPluginOutput(dataText);
ResponseEntity<String> responseEntity = pluginManager.doRequest("answer_with_plugin_call",
JSONObject.toJSONString(lLmAnswerReq));
LLmAnswerResp lLmAnswerResp = JSONObject.parseObject(responseEntity.getBody(), LLmAnswerResp.class);
if (lLmAnswerResp != null) {
return lLmAnswerResp.getAssistant_message();
}
return null;
}
}

View File

@@ -2,15 +2,14 @@ package com.tencent.supersonic.chat.query.plugin.webpage;
import com.google.common.collect.Lists;
import com.tencent.supersonic.auth.api.authentication.pojo.User;
import com.tencent.supersonic.chat.api.pojo.ModelSchema;
import com.tencent.supersonic.chat.api.pojo.SchemaElementMatch;
import com.tencent.supersonic.chat.api.pojo.SchemaElementType;
import com.tencent.supersonic.chat.api.pojo.*;
import com.tencent.supersonic.chat.api.pojo.request.QueryFilter;
import com.tencent.supersonic.chat.api.pojo.request.QueryFilters;
import com.tencent.supersonic.chat.api.pojo.request.QueryReq;
import com.tencent.supersonic.chat.api.pojo.response.EntityInfo;
import com.tencent.supersonic.chat.api.pojo.response.QueryResult;
import com.tencent.supersonic.chat.api.pojo.response.QueryState;
import com.tencent.supersonic.chat.api.pojo.response.ChatConfigRichResp;
import com.tencent.supersonic.chat.plugin.Plugin;
import com.tencent.supersonic.chat.plugin.PluginParseResult;
import com.tencent.supersonic.chat.query.QueryManager;
@@ -18,18 +17,18 @@ import com.tencent.supersonic.chat.query.plugin.ParamOption;
import com.tencent.supersonic.chat.query.plugin.PluginSemanticQuery;
import com.tencent.supersonic.chat.query.plugin.WebBase;
import com.tencent.supersonic.chat.query.plugin.WebBaseResult;
import com.tencent.supersonic.chat.service.ConfigService;
import com.tencent.supersonic.chat.service.SemanticService;
import com.tencent.supersonic.common.pojo.Constants;
import com.tencent.supersonic.common.util.ContextUtils;
import com.tencent.supersonic.common.util.JsonUtil;
import java.util.Comparator;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import lombok.extern.slf4j.Slf4j;
import org.springframework.stereotype.Component;
import org.springframework.util.CollectionUtils;
import java.util.*;
import java.util.stream.Collectors;
@Slf4j
@Component
public class WebPageQuery extends PluginSemanticQuery {
@@ -107,17 +106,15 @@ public class WebPageQuery extends PluginSemanticQuery {
.filter(schemaElementMatch ->
SchemaElementType.VALUE.equals(schemaElementMatch.getElement().getType())
|| SchemaElementType.ID.equals(schemaElementMatch.getElement().getType()))
.sorted(Comparator.comparingDouble(SchemaElementMatch::getSimilarity))
.filter(schemaElementMatch -> schemaElementMatch.getSimilarity() == 1.0)
.forEach(schemaElementMatch -> {
Object queryFilterValue = filterValueMap.get(schemaElementMatch.getElement().getId());
if (queryFilterValue != null) {
if (String.valueOf(queryFilterValue).equals(String.valueOf(schemaElementMatch.getWord()))) {
elementValueMap.put(String.valueOf(schemaElementMatch.getElement().getId()),
schemaElementMatch.getWord());
elementValueMap.put(String.valueOf(schemaElementMatch.getElement().getId()), schemaElementMatch.getWord());
}
} else {
elementValueMap.put(String.valueOf(schemaElementMatch.getElement().getId()),
schemaElementMatch.getWord());
elementValueMap.computeIfAbsent(String.valueOf(schemaElementMatch.getElement().getId()), k -> schemaElementMatch.getWord());
}
});
}

View File

@@ -208,8 +208,7 @@ public abstract class RuleSemanticQuery implements SemanticQuery, Serializable {
}
QueryResult queryResult = new QueryResult();
QueryResultWithSchemaResp queryResp = semanticLayer.queryByStruct(
convertQueryStruct(), user);
QueryResultWithSchemaResp queryResp = semanticLayer.queryByStruct(convertQueryStruct(), user);
if (queryResp != null) {
queryResult.setQueryAuthorization(queryResp.getQueryAuthorization());

View File

@@ -1,9 +1,9 @@
package com.tencent.supersonic.chat.query.rule.entity;
import static com.tencent.supersonic.chat.api.pojo.SchemaElementType.ID;
import static com.tencent.supersonic.chat.api.pojo.SchemaElementType.VALUE;
import static com.tencent.supersonic.chat.api.pojo.SchemaElementType.*;
import static com.tencent.supersonic.chat.query.rule.QueryMatchOption.OptionType.OPTIONAL;
import static com.tencent.supersonic.chat.query.rule.QueryMatchOption.RequireNumberType.AT_LEAST;
import static com.tencent.supersonic.chat.query.rule.QueryMatchOption.OptionType.REQUIRED;
import static com.tencent.supersonic.chat.query.rule.QueryMatchOption.RequireNumberType.*;
import lombok.extern.slf4j.Slf4j;
import org.springframework.stereotype.Component;
@@ -17,8 +17,7 @@ public class EntityFilterQuery extends EntityListQuery {
public EntityFilterQuery() {
super();
queryMatcher.addOption(VALUE, OPTIONAL, AT_LEAST, 0);
queryMatcher.addOption(ID, OPTIONAL, AT_LEAST, 0);
queryMatcher.addOption(VALUE, REQUIRED, AT_LEAST, 1);
}
@Override

View File

@@ -0,0 +1,23 @@
package com.tencent.supersonic.chat.query.rule.entity;
import org.springframework.stereotype.Component;
import static com.tencent.supersonic.chat.api.pojo.SchemaElementType.ID;
import static com.tencent.supersonic.chat.query.rule.QueryMatchOption.OptionType.REQUIRED;
import static com.tencent.supersonic.chat.query.rule.QueryMatchOption.RequireNumberType.AT_LEAST;
@Component
public class EntityIdQuery extends EntityListQuery {
public static final String QUERY_MODE = "ENTITY_ID";
public EntityIdQuery() {
super();
queryMatcher.addOption(ID, REQUIRED, AT_LEAST, 1);
}
@Override
public String getQueryMode() {
return QUERY_MODE;
}
}

View File

@@ -0,0 +1,51 @@
package com.tencent.supersonic.chat.rest;
import com.tencent.supersonic.auth.api.authentication.pojo.User;
import com.tencent.supersonic.auth.api.authentication.utils.UserHolder;
import com.tencent.supersonic.chat.agent.Agent;
import com.tencent.supersonic.chat.service.AgentService;
import org.springframework.web.bind.annotation.*;
import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletResponse;
import java.util.List;
@RestController
@RequestMapping("/api/chat/agent")
public class AgentController {
private AgentService agentService;
public AgentController(AgentService agentService) {
this.agentService = agentService;
}
@PostMapping
public boolean createAgent(@RequestBody Agent agent,
HttpServletRequest httpServletRequest,
HttpServletResponse httpServletResponse) {
User user = UserHolder.findUser(httpServletRequest, httpServletResponse);
agentService.createAgent(agent, user);
return true;
}
@PutMapping
public boolean updateAgent(@RequestBody Agent agent,
HttpServletRequest httpServletRequest,
HttpServletResponse httpServletResponse) {
User user = UserHolder.findUser(httpServletRequest, httpServletResponse);
agentService.updateAgent(agent, user);
return true;
}
@DeleteMapping("/{id}")
public boolean deleteAgent(@PathVariable("id") Integer id) {
agentService.deleteAgent(id);
return true;
}
@RequestMapping("/getAgentList")
public List<Agent> getAgentList() {
return agentService.getAgents();
}
}

View File

@@ -32,7 +32,7 @@ import org.springframework.web.bind.annotation.RestController;
@RestController
@RequestMapping("/api/chat/conf")
@RequestMapping({"/api/chat/conf", "/openapi/chat/conf"})
public class ChatConfigController {
@Autowired

View File

@@ -18,7 +18,7 @@ import org.springframework.web.bind.annotation.RequestParam;
import org.springframework.web.bind.annotation.RestController;
@RestController
@RequestMapping("/api/chat/manage")
@RequestMapping({"/api/chat/manage", "/openapi/chat/manage"})
public class ChatController {
private final ChatService chatService;

View File

@@ -20,7 +20,7 @@ import org.springframework.web.bind.annotation.RestController;
* query controller
*/
@RestController
@RequestMapping("/api/chat/query")
@RequestMapping({"/api/chat/query", "/openapi/chat/query"})
public class ChatQueryController {
@Autowired

View File

@@ -19,7 +19,7 @@ import org.springframework.web.bind.annotation.RestController;
* recommend controller
*/
@RestController
@RequestMapping("/api/chat/")
@RequestMapping({"/api/chat/", "/openapi/chat/"})
public class RecommendController {
@Autowired

View File

@@ -0,0 +1,19 @@
package com.tencent.supersonic.chat.service;
import com.tencent.supersonic.auth.api.authentication.pojo.User;
import com.tencent.supersonic.chat.agent.Agent;
import java.util.List;
public interface AgentService {
List<Agent> getAgents();
void createAgent(Agent agent, User user);
void updateAgent(Agent agent, User user);
Agent getAgent(Integer id);
void deleteAgent(Integer id);
}

View File

@@ -0,0 +1,82 @@
package com.tencent.supersonic.chat.service.impl;
import com.alibaba.fastjson.JSONObject;
import com.tencent.supersonic.auth.api.authentication.pojo.User;
import com.tencent.supersonic.chat.agent.Agent;
import com.tencent.supersonic.chat.persistence.dataobject.AgentDO;
import com.tencent.supersonic.chat.persistence.repository.AgentRepository;
import com.tencent.supersonic.chat.service.AgentService;
import org.springframework.beans.BeanUtils;
import org.springframework.stereotype.Service;
import java.util.Date;
import java.util.List;
import java.util.stream.Collectors;
@Service
public class AgentServiceImpl implements AgentService {
private AgentRepository agentRepository;
public AgentServiceImpl(AgentRepository agentRepository) {
this.agentRepository = agentRepository;
}
@Override
public List<Agent> getAgents() {
return getAgentDOList().stream()
.map(this::convert).collect(Collectors.toList());
}
@Override
public void createAgent(Agent agent, User user) {
agentRepository.createAgent(convert(agent, user));
}
@Override
public void updateAgent(Agent agent, User user) {
agentRepository.updateAgent(convert(agent, user));
}
@Override
public Agent getAgent(Integer id) {
if (id == null) {
return null;
}
return convert(agentRepository.getAgent(id));
}
@Override
public void deleteAgent(Integer id) {
agentRepository.deleteAgent(id);
}
private List<AgentDO> getAgentDOList() {
return agentRepository.getAgents();
}
private Agent convert(AgentDO agentDO){
if (agentDO == null ) {
return null;
}
Agent agent = new Agent();
BeanUtils.copyProperties(agentDO,agent);
agent.setAgentConfig(agentDO.getConfig());
agent.setExamples(JSONObject.parseArray(agentDO.getExamples(), String.class));
return agent;
}
private AgentDO convert(Agent agent, User user){
AgentDO agentDO = new AgentDO();
BeanUtils.copyProperties(agent, agentDO);
agentDO.setConfig(agent.getAgentConfig());
agentDO.setExamples(JSONObject.toJSONString(agent.getExamples()));
agentDO.setCreatedAt(new Date());
agentDO.setCreatedBy(user.getName());
agentDO.setUpdatedAt(new Date());
agentDO.setUpdatedBy(user.getName());
if (agentDO.getStatus() == null) {
agentDO.setStatus(1);
}
return agentDO;
}
}

View File

@@ -2,26 +2,25 @@ package com.tencent.supersonic.chat.service.impl;
import com.tencent.supersonic.auth.api.authentication.pojo.User;
import com.tencent.supersonic.chat.api.component.SchemaMapper;
import com.tencent.supersonic.chat.api.component.SemanticParser;
import com.tencent.supersonic.chat.api.component.SemanticQuery;
import com.tencent.supersonic.chat.api.component.*;
import com.tencent.supersonic.chat.api.pojo.ChatContext;
import com.tencent.supersonic.chat.api.pojo.QueryContext;
import com.tencent.supersonic.chat.api.pojo.SemanticParseInfo;
import com.tencent.supersonic.chat.api.pojo.request.ExecuteQueryReq;
import com.tencent.supersonic.chat.api.pojo.request.QueryDataReq;
import com.tencent.supersonic.chat.api.pojo.request.QueryReq;
import com.tencent.supersonic.chat.api.pojo.response.ParseResp;
import com.tencent.supersonic.chat.api.pojo.response.QueryResult;
import com.tencent.supersonic.chat.api.pojo.response.QueryState;
import com.tencent.supersonic.chat.query.QueryManager;
import com.tencent.supersonic.chat.query.QuerySelector;
import com.tencent.supersonic.chat.api.pojo.request.QueryDataReq;
import com.tencent.supersonic.chat.query.QueryManager;
import com.tencent.supersonic.chat.service.ChatService;
import com.tencent.supersonic.chat.service.QueryService;
import com.tencent.supersonic.chat.utils.ComponentFactory;
import com.tencent.supersonic.common.util.JsonUtil;
import java.util.Comparator;
import java.util.List;
import java.util.stream.Collectors;
import com.tencent.supersonic.common.util.JsonUtil;
import lombok.extern.slf4j.Slf4j;
import org.apache.calcite.sql.parser.SqlParseException;
import org.springframework.beans.BeanUtils;
@@ -63,12 +62,14 @@ public class QueryServiceImpl implements QueryService {
if (queryCtx.getCandidateQueries().size() > 0) {
log.debug("pick before [{}]", queryCtx.getCandidateQueries().stream().collect(
Collectors.toList()));
List<SemanticQuery> selectedQueries = querySelector.select(queryCtx.getCandidateQueries());
List<SemanticQuery> selectedQueries = querySelector.select(queryCtx.getCandidateQueries(), queryReq);
log.debug("pick after [{}]", selectedQueries.stream().collect(
Collectors.toList()));
List<SemanticParseInfo> selectedParses = selectedQueries.stream()
.map(SemanticQuery::getParseInfo).collect(Collectors.toList());
.map(SemanticQuery::getParseInfo)
.sorted(Comparator.comparingDouble(SemanticParseInfo::getScore).reversed())
.collect(Collectors.toList());
List<SemanticParseInfo> candidateParses = queryCtx.getCandidateQueries().stream()
.map(SemanticQuery::getParseInfo).collect(Collectors.toList());
@@ -138,7 +139,7 @@ public class QueryServiceImpl implements QueryService {
if (queryCtx.getCandidateQueries().size() > 0) {
log.info("pick before [{}]", queryCtx.getCandidateQueries().stream().collect(
Collectors.toList()));
List<SemanticQuery> selectedQueries = querySelector.select(queryCtx.getCandidateQueries());
List<SemanticQuery> selectedQueries = querySelector.select(queryCtx.getCandidateQueries(), queryReq);
log.info("pick after [{}]", selectedQueries.stream().collect(
Collectors.toList()));

View File

@@ -57,6 +57,7 @@ public class RecommendServiceImpl implements RecommendService {
item.setName(dimSchemaDesc.getName());
item.setBizName(dimSchemaDesc.getBizName());
item.setId(dimSchemaDesc.getId());
item.setAlias(dimSchemaDesc.getAlias());
return item;
}).collect(Collectors.toList());
@@ -70,6 +71,7 @@ public class RecommendServiceImpl implements RecommendService {
item.setName(metricSchemaDesc.getName());
item.setBizName(metricSchemaDesc.getBizName());
item.setId(metricSchemaDesc.getId());
item.setAlias(metricSchemaDesc.getAlias());
return item;
}).collect(Collectors.toList());

View File

@@ -2,6 +2,7 @@ package com.tencent.supersonic.chat.service.impl;
import com.google.common.collect.Lists;
import com.hankcs.hanlp.seg.common.Term;
import com.tencent.supersonic.chat.agent.Agent;
import com.tencent.supersonic.chat.api.pojo.SchemaElement;
import com.tencent.supersonic.chat.api.pojo.SchemaElementType;
import com.tencent.supersonic.chat.api.pojo.SemanticSchema;
@@ -13,6 +14,7 @@ import com.tencent.supersonic.chat.mapper.MatchText;
import com.tencent.supersonic.chat.mapper.ModelInfoStat;
import com.tencent.supersonic.chat.mapper.ModelWithSemanticType;
import com.tencent.supersonic.chat.mapper.SearchMatchStrategy;
import com.tencent.supersonic.chat.service.AgentService;
import com.tencent.supersonic.chat.service.ChatService;
import com.tencent.supersonic.chat.service.SearchService;
import com.tencent.supersonic.chat.utils.NatureHelper;
@@ -53,22 +55,34 @@ public class SearchServiceImpl implements SearchService {
private ChatService chatService;
@Autowired
private SearchMatchStrategy searchMatchStrategy;
@Autowired
private AgentService agentService;
@Override
public List<SearchResult> search(QueryReq queryCtx) {
// 1. check search enable
Integer agentId = queryCtx.getAgentId();
if (agentId != null) {
Agent agent = agentService.getAgent(agentId);
if (!agent.enableSearch()) {
return Lists.newArrayList();
}
}
String queryText = queryCtx.getQueryText();
// 1.get meta info
// 2.get meta info
SemanticSchema semanticSchemaDb = schemaService.getSemanticSchema();
List<SchemaElement> metricsDb = semanticSchemaDb.getMetrics();
final Map<Long, String> modelToName = semanticSchemaDb.getModelIdToName();
// 2.detect by segment
// 3.detect by segment
List<Term> originals = HanlpHelper.getTerms(queryText);
Map<MatchText, List<MapResult>> regTextMap = searchMatchStrategy.match(queryText, originals,
queryCtx.getModelId());
regTextMap.entrySet().stream().forEach(m -> HanlpHelper.transLetterOriginal(m.getValue()));
// 3.get the most matching data
// 4.get the most matching data
Optional<Entry<MatchText, List<MapResult>>> mostSimilarSearchResult = regTextMap.entrySet()
.stream()
.filter(entry -> CollectionUtils.isNotEmpty(entry.getValue()))
@@ -77,7 +91,7 @@ public class SearchServiceImpl implements SearchService {
? entry1 : entry2);
log.debug("mostSimilarSearchResult:{}", mostSimilarSearchResult);
// 4.optimize the results after the query
// 5.optimize the results after the query
if (!mostSimilarSearchResult.isPresent()) {
return Lists.newArrayList();
}
@@ -89,11 +103,11 @@ public class SearchServiceImpl implements SearchService {
List<Long> possibleModels = getPossibleModels(queryCtx, originals, modelStat, queryCtx.getModelId());
// 4.1 priority dimension metric
// 5.1 priority dimension metric
boolean existMetricAndDimension = searchMetricAndDimension(new HashSet<>(possibleModels), modelToName,
searchTextEntry, searchResults);
// 4.2 process based on dimension values
// 5.2 process based on dimension values
MatchText matchText = searchTextEntry.getKey();
Map<String, String> natureToNameMap = getNatureToNameMap(searchTextEntry, new HashSet<>(possibleModels));
log.debug("possibleModels:{},natureToNameMap:{}", possibleModels, natureToNameMap);

View File

@@ -1,77 +0,0 @@
package com.tencent.supersonic.chat.utils;
import com.plexpt.chatgpt.ChatGPT;
import com.plexpt.chatgpt.entity.chat.ChatCompletion;
import com.plexpt.chatgpt.entity.chat.ChatCompletionResponse;
import com.plexpt.chatgpt.entity.chat.Message;
import com.plexpt.chatgpt.util.Proxys;
import java.net.Proxy;
import java.text.SimpleDateFormat;
import java.util.Arrays;
import java.util.Date;
import org.springframework.beans.factory.annotation.Value;
import org.springframework.stereotype.Component;
@Component
public class ChatGptHelper {
@Value("${llm.chatgpt.apikey:xx-xxxx}")
private String apiKey;
@Value("${llm.chatgpt.apiHost:https://api.openai.com/}")
private String apiHost;
@Value("${llm.chatgpt.proxyIp:default}")
private String proxyIp;
@Value("${llm.chatgpt.proxyPort:8080}")
private Integer proxyPort;
public ChatGPT getChatGPT() {
Proxy proxy = null;
if (!"default".equals(proxyIp)) {
proxy = Proxys.http(proxyIp, proxyPort);
}
return ChatGPT.builder()
.apiKey(apiKey)
.proxy(proxy)
.timeout(900)
.apiHost(apiHost) //反向代理地址
.build()
.init();
}
public String inferredTime(String queryText) {
long nowTime = System.currentTimeMillis();
Date date = new Date(nowTime);
SimpleDateFormat sdf = new SimpleDateFormat("yyyy-MM-dd");
String formattedDate = sdf.format(date);
Message system = Message.ofSystem("现在时间 " + formattedDate + ",你是一个专业的数据分析师,你的任务是基于数据,专业的解答用户的问题。"
+ "你需要遵守以下规则:\n"
+ "1.返回规范的数据格式json 输入:近 10 天的日活跃数,输出:{\"start\":\"2023-07-21\",\"end\":\"2023-07-31\"}"
+ "2.你对时间数据要求规范,能从近 10 天,国庆节,端午节,获取到相应的时间,填写到 json 中。\n"
+ "3.你的数据时间,只有当前及之前时间即可,超过则回复去年\n"
+ "4.只需要解析出时间,时间可以是时间月和年或日、日历采用公历\n"
+ "5.时间给出要是绝对正确,不能瞎编\n"
);
Message message = Message.of("输入:" + queryText + ",输出:");
ChatCompletion chatCompletion = ChatCompletion.builder()
.model(ChatCompletion.Model.GPT_3_5_TURBO_16K.getName())
.messages(Arrays.asList(system, message))
.maxTokens(10000)
.temperature(0.9)
.build();
ChatCompletionResponse response = getChatGPT().chatCompletion(chatCompletion);
Message res = response.getChoices().get(0).getMessage();
return res.getContent();
}
public static void main(String[] args) {
}
}

View File

@@ -3,11 +3,14 @@ package com.tencent.supersonic.chat.utils;
import com.tencent.supersonic.chat.api.component.SchemaMapper;
import com.tencent.supersonic.chat.api.component.SemanticLayer;
import com.tencent.supersonic.chat.api.component.SemanticParser;
import com.tencent.supersonic.chat.parser.function.ModelResolver;
import com.tencent.supersonic.chat.query.QuerySelector;
import com.tencent.supersonic.chat.api.component.DSLOptimizer;
import java.util.ArrayList;
import java.util.List;
import java.util.Objects;
import com.tencent.supersonic.chat.parser.function.ModelResolver;
import com.tencent.supersonic.chat.query.QuerySelector;
import org.apache.commons.collections.CollectionUtils;
import org.springframework.core.io.support.SpringFactoriesLoader;
@@ -15,10 +18,11 @@ public class ComponentFactory {
private static List<SchemaMapper> schemaMappers = new ArrayList<>();
private static List<SemanticParser> semanticParsers = new ArrayList<>();
private static List<DSLOptimizer> dslCorrections = new ArrayList<>();
private static SemanticLayer semanticLayer;
private static QuerySelector querySelector;
private static ModelResolver modelResolver;
public static List<SchemaMapper> getSchemaMappers() {
return CollectionUtils.isEmpty(schemaMappers) ? init(SchemaMapper.class, schemaMappers) : schemaMappers;
}
@@ -27,6 +31,11 @@ public class ComponentFactory {
return CollectionUtils.isEmpty(semanticParsers) ? init(SemanticParser.class, semanticParsers) : semanticParsers;
}
public static List<DSLOptimizer> getSqlCorrections() {
return CollectionUtils.isEmpty(dslCorrections) ? init(DSLOptimizer.class, dslCorrections) : dslCorrections;
}
public static SemanticLayer getSemanticLayer() {
if (Objects.isNull(semanticLayer)) {
semanticLayer = init(SemanticLayer.class);

View File

@@ -74,7 +74,7 @@ public class NatureHelper {
return null;
}
public static boolean isDimensionValueClassId(String nature) {
public static boolean isDimensionValueModelId(String nature) {
if (StringUtils.isEmpty(nature)) {
return false;
}
@@ -104,7 +104,7 @@ public class NatureHelper {
}
private static long getDimensionValueCount(List<Term> terms) {
return terms.stream().filter(term -> isDimensionValueClassId(term.nature.toString())).count();
return terms.stream().filter(term -> isDimensionValueModelId(term.nature.toString())).count();
}
private static long getDimensionCount(List<Term> terms) {

View File

@@ -1,4 +1,3 @@
#!/usr/bin/env bash
# python path
export python_path="/usr/local/bin/python3.9"
# pip path

View File

@@ -0,0 +1,303 @@
<?xml version="1.0" encoding="UTF-8"?>
<!DOCTYPE mapper PUBLIC "-//mybatis.org//DTD Mapper 3.0//EN" "http://mybatis.org/dtd/mybatis-3-mapper.dtd">
<mapper namespace="com.tencent.supersonic.chat.persistence.mapper.AgentDOMapper">
<resultMap id="BaseResultMap" type="com.tencent.supersonic.chat.persistence.dataobject.AgentDO">
<id column="id" jdbcType="INTEGER" property="id" />
<result column="name" jdbcType="VARCHAR" property="name" />
<result column="description" jdbcType="VARCHAR" property="description" />
<result column="status" jdbcType="INTEGER" property="status" />
<result column="examples" jdbcType="VARCHAR" property="examples" />
<result column="config" jdbcType="VARCHAR" property="config" />
<result column="created_by" jdbcType="VARCHAR" property="createdBy" />
<result column="created_at" jdbcType="TIMESTAMP" property="createdAt" />
<result column="updated_by" jdbcType="VARCHAR" property="updatedBy" />
<result column="updated_at" jdbcType="TIMESTAMP" property="updatedAt" />
<result column="enable_search" jdbcType="INTEGER" property="enableSearch" />
</resultMap>
<sql id="Example_Where_Clause">
<where>
<foreach collection="oredCriteria" item="criteria" separator="or">
<if test="criteria.valid">
<trim prefix="(" prefixOverrides="and" suffix=")">
<foreach collection="criteria.criteria" item="criterion">
<choose>
<when test="criterion.noValue">
and ${criterion.condition}
</when>
<when test="criterion.singleValue">
and ${criterion.condition} #{criterion.value}
</when>
<when test="criterion.betweenValue">
and ${criterion.condition} #{criterion.value} and #{criterion.secondValue}
</when>
<when test="criterion.listValue">
and ${criterion.condition}
<foreach close=")" collection="criterion.value" item="listItem" open="(" separator=",">
#{listItem}
</foreach>
</when>
</choose>
</foreach>
</trim>
</if>
</foreach>
</where>
</sql>
<sql id="Update_By_Example_Where_Clause">
<where>
<foreach collection="example.oredCriteria" item="criteria" separator="or">
<if test="criteria.valid">
<trim prefix="(" prefixOverrides="and" suffix=")">
<foreach collection="criteria.criteria" item="criterion">
<choose>
<when test="criterion.noValue">
and ${criterion.condition}
</when>
<when test="criterion.singleValue">
and ${criterion.condition} #{criterion.value}
</when>
<when test="criterion.betweenValue">
and ${criterion.condition} #{criterion.value} and #{criterion.secondValue}
</when>
<when test="criterion.listValue">
and ${criterion.condition}
<foreach close=")" collection="criterion.value" item="listItem" open="(" separator=",">
#{listItem}
</foreach>
</when>
</choose>
</foreach>
</trim>
</if>
</foreach>
</where>
</sql>
<sql id="Base_Column_List">
id, name, description, status, examples, config, created_by, created_at, updated_by,
updated_at, enable_search
</sql>
<select id="selectByExample" parameterType="com.tencent.supersonic.chat.persistence.dataobject.AgentDOExample" resultMap="BaseResultMap">
select
<if test="distinct">
distinct
</if>
<include refid="Base_Column_List" />
from s2_agent
<if test="_parameter != null">
<include refid="Example_Where_Clause" />
</if>
<if test="orderByClause != null">
order by ${orderByClause}
</if>
<if test="limitStart != null and limitStart>=0">
limit #{limitStart} , #{limitEnd}
</if>
</select>
<select id="selectByPrimaryKey" parameterType="java.lang.Integer" resultMap="BaseResultMap">
select
<include refid="Base_Column_List" />
from s2_agent
where id = #{id,jdbcType=INTEGER}
</select>
<delete id="deleteByPrimaryKey" parameterType="java.lang.Integer">
delete from s2_agent
where id = #{id,jdbcType=INTEGER}
</delete>
<insert id="insert" parameterType="com.tencent.supersonic.chat.persistence.dataobject.AgentDO">
insert into s2_agent (id, name, description,
status, examples, config,
created_by, created_at, updated_by,
updated_at, enable_search)
values (#{id,jdbcType=INTEGER}, #{name,jdbcType=VARCHAR}, #{description,jdbcType=VARCHAR},
#{status,jdbcType=INTEGER}, #{examples,jdbcType=VARCHAR}, #{config,jdbcType=VARCHAR},
#{createdBy,jdbcType=VARCHAR}, #{createdAt,jdbcType=TIMESTAMP}, #{updatedBy,jdbcType=VARCHAR},
#{updatedAt,jdbcType=TIMESTAMP}, #{enableSearch,jdbcType=INTEGER})
</insert>
<insert id="insertSelective" parameterType="com.tencent.supersonic.chat.persistence.dataobject.AgentDO">
insert into s2_agent
<trim prefix="(" suffix=")" suffixOverrides=",">
<if test="id != null">
id,
</if>
<if test="name != null">
name,
</if>
<if test="description != null">
description,
</if>
<if test="status != null">
status,
</if>
<if test="examples != null">
examples,
</if>
<if test="config != null">
config,
</if>
<if test="createdBy != null">
created_by,
</if>
<if test="createdAt != null">
created_at,
</if>
<if test="updatedBy != null">
updated_by,
</if>
<if test="updatedAt != null">
updated_at,
</if>
<if test="enableSearch != null">
enable_search,
</if>
</trim>
<trim prefix="values (" suffix=")" suffixOverrides=",">
<if test="id != null">
#{id,jdbcType=INTEGER},
</if>
<if test="name != null">
#{name,jdbcType=VARCHAR},
</if>
<if test="description != null">
#{description,jdbcType=VARCHAR},
</if>
<if test="status != null">
#{status,jdbcType=INTEGER},
</if>
<if test="examples != null">
#{examples,jdbcType=VARCHAR},
</if>
<if test="config != null">
#{config,jdbcType=VARCHAR},
</if>
<if test="createdBy != null">
#{createdBy,jdbcType=VARCHAR},
</if>
<if test="createdAt != null">
#{createdAt,jdbcType=TIMESTAMP},
</if>
<if test="updatedBy != null">
#{updatedBy,jdbcType=VARCHAR},
</if>
<if test="updatedAt != null">
#{updatedAt,jdbcType=TIMESTAMP},
</if>
<if test="enableSearch != null">
#{enableSearch,jdbcType=INTEGER},
</if>
</trim>
</insert>
<select id="countByExample" parameterType="com.tencent.supersonic.chat.persistence.dataobject.AgentDOExample" resultType="java.lang.Long">
select count(*) from s2_agent
<if test="_parameter != null">
<include refid="Example_Where_Clause" />
</if>
</select>
<update id="updateByExampleSelective" parameterType="map">
update s2_agent
<set>
<if test="record.id != null">
id = #{record.id,jdbcType=INTEGER},
</if>
<if test="record.name != null">
name = #{record.name,jdbcType=VARCHAR},
</if>
<if test="record.description != null">
description = #{record.description,jdbcType=VARCHAR},
</if>
<if test="record.status != null">
status = #{record.status,jdbcType=INTEGER},
</if>
<if test="record.examples != null">
examples = #{record.examples,jdbcType=VARCHAR},
</if>
<if test="record.config != null">
config = #{record.config,jdbcType=VARCHAR},
</if>
<if test="record.createdBy != null">
created_by = #{record.createdBy,jdbcType=VARCHAR},
</if>
<if test="record.createdAt != null">
created_at = #{record.createdAt,jdbcType=TIMESTAMP},
</if>
<if test="record.updatedBy != null">
updated_by = #{record.updatedBy,jdbcType=VARCHAR},
</if>
<if test="record.updatedAt != null">
updated_at = #{record.updatedAt,jdbcType=TIMESTAMP},
</if>
<if test="record.enableSearch != null">
enable_search = #{record.enableSearch,jdbcType=INTEGER},
</if>
</set>
<if test="_parameter != null">
<include refid="Update_By_Example_Where_Clause" />
</if>
</update>
<update id="updateByExample" parameterType="map">
update s2_agent
set id = #{record.id,jdbcType=INTEGER},
name = #{record.name,jdbcType=VARCHAR},
description = #{record.description,jdbcType=VARCHAR},
status = #{record.status,jdbcType=INTEGER},
examples = #{record.examples,jdbcType=VARCHAR},
config = #{record.config,jdbcType=VARCHAR},
created_by = #{record.createdBy,jdbcType=VARCHAR},
created_at = #{record.createdAt,jdbcType=TIMESTAMP},
updated_by = #{record.updatedBy,jdbcType=VARCHAR},
updated_at = #{record.updatedAt,jdbcType=TIMESTAMP},
enable_search = #{record.enableSearch,jdbcType=INTEGER}
<if test="_parameter != null">
<include refid="Update_By_Example_Where_Clause" />
</if>
</update>
<update id="updateByPrimaryKeySelective" parameterType="com.tencent.supersonic.chat.persistence.dataobject.AgentDO">
update s2_agent
<set>
<if test="name != null">
name = #{name,jdbcType=VARCHAR},
</if>
<if test="description != null">
description = #{description,jdbcType=VARCHAR},
</if>
<if test="status != null">
status = #{status,jdbcType=INTEGER},
</if>
<if test="examples != null">
examples = #{examples,jdbcType=VARCHAR},
</if>
<if test="config != null">
config = #{config,jdbcType=VARCHAR},
</if>
<if test="createdBy != null">
created_by = #{createdBy,jdbcType=VARCHAR},
</if>
<if test="createdAt != null">
created_at = #{createdAt,jdbcType=TIMESTAMP},
</if>
<if test="updatedBy != null">
updated_by = #{updatedBy,jdbcType=VARCHAR},
</if>
<if test="updatedAt != null">
updated_at = #{updatedAt,jdbcType=TIMESTAMP},
</if>
<if test="enableSearch != null">
enable_search = #{enableSearch,jdbcType=INTEGER},
</if>
</set>
where id = #{id,jdbcType=INTEGER}
</update>
<update id="updateByPrimaryKey" parameterType="com.tencent.supersonic.chat.persistence.dataobject.AgentDO">
update s2_agent
set name = #{name,jdbcType=VARCHAR},
description = #{description,jdbcType=VARCHAR},
status = #{status,jdbcType=INTEGER},
examples = #{examples,jdbcType=VARCHAR},
config = #{config,jdbcType=VARCHAR},
created_by = #{createdBy,jdbcType=VARCHAR},
created_at = #{createdAt,jdbcType=TIMESTAMP},
updated_by = #{updatedBy,jdbcType=VARCHAR},
updated_at = #{updatedAt,jdbcType=TIMESTAMP},
enable_search = #{enableSearch,jdbcType=INTEGER}
where id = #{id,jdbcType=INTEGER}
</update>
</mapper>

View File

@@ -60,3 +60,17 @@ CREATE TABLE `chat_query`
KEY `common1` (`user_name`),
KEY `common2` (`chat_id`)
) ENGINE=InnoDB AUTO_INCREMENT=1 DEFAULT CHARSET=utf8;
CREATE TABLE `chat`
(
`chat_id` bigint(8) NOT NULL AUTO_INCREMENT,
`chat_name` varchar(100) DEFAULT NULL,
`create_time` datetime DEFAULT NULL,
`last_time` datetime DEFAULT NULL,
`creator` varchar(30) DEFAULT NULL,
`last_question` varchar(200) DEFAULT NULL,
`is_delete` int(2) DEFAULT '0' COMMENT 'is deleted',
`is_top` int(2) DEFAULT '0' COMMENT 'is top',
PRIMARY KEY (`chat_id`)
) ENGINE=InnoDB AUTO_INCREMENT=1 DEFAULT CHARSET=utf8;

View File

@@ -0,0 +1,37 @@
package com.tencent.supersonic.chat.query.dsl.optimizer;
import com.tencent.supersonic.chat.api.pojo.CorrectionInfo;
import com.tencent.supersonic.chat.api.pojo.SchemaElement;
import com.tencent.supersonic.chat.api.pojo.SemanticParseInfo;
import org.junit.Assert;
import org.junit.jupiter.api.Test;
class DateFieldCorrectorTest {
@Test
void rewriter() {
DateFieldCorrector dateFieldCorrector = new DateFieldCorrector();
SemanticParseInfo parseInfo = new SemanticParseInfo();
SchemaElement model = new SchemaElement();
model.setId(2L);
parseInfo.setModel(model);
CorrectionInfo correctionInfo = CorrectionInfo.builder()
.sql("select count(歌曲名) from 歌曲库 ")
.parseInfo(parseInfo)
.build();
CorrectionInfo rewriter = dateFieldCorrector.rewriter(correctionInfo);
Assert.assertEquals("SELECT count(歌曲名) FROM 歌曲库 WHERE 数据日期 = '2023-08-14'", rewriter.getSql());
correctionInfo = CorrectionInfo.builder()
.sql("select count(歌曲名) from 歌曲库 where 数据日期 = '2023-08-14'")
.parseInfo(parseInfo)
.build();
rewriter = dateFieldCorrector.rewriter(correctionInfo);
Assert.assertEquals("select count(歌曲名) from 歌曲库 where 数据日期 = '2023-08-14'", rewriter.getSql());
}
}

View File

@@ -0,0 +1,25 @@
package com.tencent.supersonic.chat.query.dsl.optimizer;
import static org.junit.jupiter.api.Assertions.*;
import com.tencent.supersonic.chat.api.pojo.CorrectionInfo;
import org.junit.Assert;
import org.junit.jupiter.api.Test;
class SelectFieldAppendCorrectorTest {
@Test
void rewriter() {
SelectFieldAppendCorrector corrector = new SelectFieldAppendCorrector();
CorrectionInfo correctionInfo = CorrectionInfo.builder()
.sql("select 歌曲名 from 歌曲库 where datediff('day', 发布日期, '2023-08-09') <= 1 and 歌手名 = '邓紫棋' and sys_imp_date = '2023-08-09' and 歌曲发布时 = '2023-08-01' order by 播放量 desc limit 11")
.build();
CorrectionInfo rewriter = corrector.rewriter(correctionInfo);
Assert.assertEquals(
"SELECT 歌曲名, 歌手名, 歌曲发布时, 发布日期 FROM 歌曲库 WHERE datediff('day', 发布日期, '2023-08-09') <= 1 AND 歌手名 = '邓紫棋' AND sys_imp_date = '2023-08-09' AND 歌曲发布时 = '2023-08-01' ORDER BY 播放量 DESC LIMIT 11",
rewriter.getSql());
}
}

View File

@@ -1,13 +1,17 @@
package com.tencent.supersonic.knowledge.dictionary.builder;
import com.google.common.collect.Lists;
import java.util.ArrayList;
import java.util.List;
import com.tencent.supersonic.chat.api.pojo.SchemaElement;
import com.tencent.supersonic.knowledge.dictionary.DictWord;
import com.tencent.supersonic.knowledge.dictionary.DictWordType;
import java.util.List;
import org.apache.commons.lang3.StringUtils;
import org.springframework.beans.factory.annotation.Value;
import org.springframework.stereotype.Service;
import org.springframework.util.CollectionUtils;
/**
* dimension word nature
@@ -23,6 +27,7 @@ public class DimensionWordBuilder extends BaseWordBuilder {
public List<DictWord> doGet(String word, SchemaElement schemaElement) {
List<DictWord> result = Lists.newArrayList();
result.add(getOnwWordNature(word, schemaElement, false));
result.addAll(getOnwWordNatureAlias(schemaElement, false));
if (nlpDimensionUseSuffix) {
String reverseWord = StringUtils.reverse(word);
if (StringUtils.isNotEmpty(word) && !word.equalsIgnoreCase(reverseWord)) {
@@ -46,4 +51,16 @@ public class DimensionWordBuilder extends BaseWordBuilder {
return dictWord;
}
private List<DictWord> getOnwWordNatureAlias(SchemaElement schemaElement, boolean isSuffix) {
List<DictWord> dictWords = new ArrayList<>();
if (CollectionUtils.isEmpty(schemaElement.getAlias())) {
return dictWords;
}
for (String alias : schemaElement.getAlias()) {
dictWords.add(getOnwWordNature(alias, schemaElement, false));
}
return dictWords;
}
}

View File

@@ -1,13 +1,17 @@
package com.tencent.supersonic.knowledge.dictionary.builder;
import com.google.common.collect.Lists;
import java.util.ArrayList;
import java.util.List;
import com.tencent.supersonic.chat.api.pojo.SchemaElement;
import com.tencent.supersonic.knowledge.dictionary.DictWord;
import com.tencent.supersonic.knowledge.dictionary.DictWordType;
import java.util.List;
import org.apache.commons.lang3.StringUtils;
import org.springframework.beans.factory.annotation.Value;
import org.springframework.stereotype.Service;
import org.springframework.util.CollectionUtils;
/**
* Metric DictWord
@@ -22,6 +26,7 @@ public class MetricWordBuilder extends BaseWordBuilder {
public List<DictWord> doGet(String word, SchemaElement schemaElement) {
List<DictWord> result = Lists.newArrayList();
result.add(getOnwWordNature(word, schemaElement, false));
result.addAll(getOnwWordNatureAlias(schemaElement, false));
if (nlpMetricUseSuffix) {
String reverseWord = StringUtils.reverse(word);
if (!word.equalsIgnoreCase(reverseWord)) {
@@ -45,4 +50,16 @@ public class MetricWordBuilder extends BaseWordBuilder {
return dictWord;
}
private List<DictWord> getOnwWordNatureAlias(SchemaElement schemaElement, boolean isSuffix) {
List<DictWord> dictWords = new ArrayList<>();
if (CollectionUtils.isEmpty(schemaElement.getAlias())) {
return dictWords;
}
for (String alias : schemaElement.getAlias()) {
dictWords.add(getOnwWordNature(alias, schemaElement, false));
}
return dictWords;
}
}

View File

@@ -2,51 +2,38 @@ package com.tencent.supersonic.knowledge.semantic;
import com.github.pagehelper.PageInfo;
import com.tencent.supersonic.auth.api.authentication.pojo.User;
import com.tencent.supersonic.common.pojo.enums.AuthType;
import com.tencent.supersonic.common.util.ContextUtils;
import com.tencent.supersonic.common.util.JsonUtil;
import com.tencent.supersonic.common.util.S2ThreadContext;
import com.tencent.supersonic.common.pojo.enums.AuthType;
import com.tencent.supersonic.semantic.api.model.request.ModelSchemaFilterReq;
import com.tencent.supersonic.semantic.api.model.request.PageDimensionReq;
import com.tencent.supersonic.semantic.api.model.request.PageMetricReq;
import com.tencent.supersonic.semantic.api.model.response.DimensionResp;
import com.tencent.supersonic.semantic.api.model.response.DomainResp;
import com.tencent.supersonic.semantic.api.model.response.MetricResp;
import com.tencent.supersonic.semantic.api.model.response.ModelResp;
import com.tencent.supersonic.semantic.api.model.response.ModelSchemaResp;
import com.tencent.supersonic.semantic.api.model.response.QueryResultWithSchemaResp;
import com.tencent.supersonic.semantic.api.model.response.*;
import com.tencent.supersonic.semantic.api.query.request.QueryDslReq;
import com.tencent.supersonic.semantic.api.query.request.QueryMultiStructReq;
import com.tencent.supersonic.semantic.api.query.request.QueryStructReq;
import com.tencent.supersonic.semantic.model.domain.DimensionService;
import com.tencent.supersonic.semantic.model.domain.DomainService;
import com.tencent.supersonic.semantic.model.domain.MetricService;
import com.tencent.supersonic.semantic.model.domain.ModelService;
import com.tencent.supersonic.semantic.query.service.QueryService;
import com.tencent.supersonic.semantic.query.service.SchemaService;
import java.util.List;
import lombok.SneakyThrows;
import lombok.extern.slf4j.Slf4j;
@Slf4j
public class LocalSemanticLayer extends BaseSemanticLayer {
private SchemaService schemaService;
private S2ThreadContext s2ThreadContext;
private DomainService domainService;
private ModelService modelService;
private DimensionService dimensionService;
private MetricService metricService;
@SneakyThrows
@Override
public QueryResultWithSchemaResp queryByStruct(QueryStructReq queryStructReq, User user) {
try {
public QueryResultWithSchemaResp queryByStruct(QueryStructReq queryStructReq, User user){
QueryService queryService = ContextUtils.getBean(QueryService.class);
QueryResultWithSchemaResp queryResultWithSchemaResp = queryService.queryByStruct(queryStructReq, user);
return queryResultWithSchemaResp;
} catch (Exception e) {
log.info("queryByStruct has an exception:{}", e.toString());
}
return null;
return queryService.queryByStructWithAuth(queryStructReq, user);
}
@Override

View File

@@ -8,21 +8,17 @@ import com.tencent.supersonic.semantic.api.model.pojo.Entity;
import com.tencent.supersonic.semantic.api.model.response.DimSchemaResp;
import com.tencent.supersonic.semantic.api.model.response.MetricSchemaResp;
import com.tencent.supersonic.semantic.api.model.response.ModelSchemaResp;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.Set;
import java.util.stream.Collectors;
import org.apache.commons.lang3.StringUtils;
import org.apache.logging.log4j.util.Strings;
import org.springframework.beans.BeanUtils;
import org.springframework.util.CollectionUtils;
import java.util.*;
import java.util.stream.Collectors;
public class ModelSchemaBuilder {
private static String aliasSplit = ",";
public static ModelSchema build(ModelSchemaResp resp) {
ModelSchema domainSchema = new ModelSchema();
@@ -37,6 +33,13 @@ public class ModelSchemaBuilder {
Set<SchemaElement> metrics = new HashSet<>();
for (MetricSchemaResp metric : resp.getMetrics()) {
List<String> alias = new ArrayList<>();
String aliasStr = metric.getAlias();
if (Strings.isNotEmpty(aliasStr)) {
alias = Arrays.asList(aliasStr.split(aliasSplit));
}
SchemaElement metricToAdd = SchemaElement.builder()
.model(resp.getId())
.id(metric.getId())
@@ -44,16 +47,10 @@ public class ModelSchemaBuilder {
.bizName(metric.getBizName())
.type(SchemaElementType.METRIC)
.useCnt(metric.getUseCnt())
.alias(alias)
.build();
metrics.add(metricToAdd);
String alias = metric.getAlias();
if (StringUtils.isNotEmpty(alias)) {
SchemaElement alisMetricToAdd = new SchemaElement();
BeanUtils.copyProperties(metricToAdd, alisMetricToAdd);
alisMetricToAdd.setName(alias);
metrics.add(alisMetricToAdd);
}
}
domainSchema.getMetrics().addAll(metrics);
@@ -74,6 +71,11 @@ public class ModelSchemaBuilder {
}
}
List<String> alias = new ArrayList<>();
String aliasStr = dim.getAlias();
if (Strings.isNotEmpty(aliasStr)) {
alias = Arrays.asList(aliasStr.split(aliasSplit));
}
SchemaElement dimToAdd = SchemaElement.builder()
.model(resp.getId())
.id(dim.getId())
@@ -81,17 +83,10 @@ public class ModelSchemaBuilder {
.bizName(dim.getBizName())
.type(SchemaElementType.DIMENSION)
.useCnt(dim.getUseCnt())
.alias(alias)
.build();
dimensions.add(dimToAdd);
String alias = dim.getAlias();
if (StringUtils.isNotEmpty(alias)) {
SchemaElement alisDimToAdd = new SchemaElement();
BeanUtils.copyProperties(dimToAdd, alisDimToAdd);
alisDimToAdd.setName(alias);
dimensions.add(alisDimToAdd);
}
SchemaElement dimValueToAdd = SchemaElement.builder()
.model(resp.getId())
.id(dim.getId())
@@ -115,7 +110,7 @@ public class ModelSchemaBuilder {
.collect(
Collectors.toMap(SchemaElement::getId, schemaElement -> schemaElement, (k1, k2) -> k2));
if (idAndDimPair.containsKey(entity.getEntityId())) {
entityElement = idAndDimPair.get(entity.getEntityId());
BeanUtils.copyProperties(idAndDimPair.get(entity.getEntityId()), entityElement);
entityElement.setType(SchemaElementType.ENTITY);
}
entityElement.setAlias(entity.getNames());

View File

@@ -118,6 +118,12 @@
</dependency>
<dependency>
<groupId>com.github.plexpt</groupId>
<artifactId>chatgpt</artifactId>
<version>4.1.2</version>
</dependency>
<dependency>
<groupId>com.github.pagehelper</groupId>
<artifactId>pagehelper</artifactId>

View File

@@ -0,0 +1,130 @@
package com.tencent.supersonic.common.util;
import com.plexpt.chatgpt.ChatGPT;
import com.plexpt.chatgpt.entity.chat.ChatCompletion;
import com.plexpt.chatgpt.entity.chat.ChatCompletionResponse;
import com.plexpt.chatgpt.entity.chat.Message;
import com.plexpt.chatgpt.util.Proxys;
import lombok.extern.slf4j.Slf4j;
import org.springframework.beans.factory.annotation.Value;
import org.springframework.stereotype.Component;
import java.net.Proxy;
import java.text.SimpleDateFormat;
import java.util.Arrays;
import java.util.Date;
@Component
@Slf4j
public class ChatGptHelper {
@Value("${llm.chatgpt.apikey:}")
private String apiKey;
@Value("${llm.chatgpt.apiHost:}")
private String apiHost;
@Value("${llm.chatgpt.proxyIp:}")
private String proxyIp;
@Value("${llm.chatgpt.proxyPort:}")
private Integer proxyPort;
public ChatGPT getChatGPT(){
Proxy proxy = null;
if (!"default".equals(proxyIp)){
proxy = Proxys.http(proxyIp, proxyPort);
}
return ChatGPT.builder()
.apiKey(apiKey)
.proxy(proxy)
.timeout(900)
.apiHost(apiHost) //反向代理地址
.build()
.init();
}
public Message getChatCompletion(Message system,Message message){
ChatCompletion chatCompletion = ChatCompletion.builder()
.model(ChatCompletion.Model.GPT_3_5_TURBO_16K.getName())
.messages(Arrays.asList(system, message))
.maxTokens(10000)
.temperature(0.9)
.build();
ChatCompletionResponse response = getChatGPT().chatCompletion(chatCompletion);
return response.getChoices().get(0).getMessage();
}
public String inferredTime(String queryText){
long nowTime = System.currentTimeMillis();
Date date = new Date(nowTime);
SimpleDateFormat sdf = new SimpleDateFormat("yyyy-MM-dd");
String formattedDate = sdf.format(date);
Message system = Message.ofSystem("现在时间 "+formattedDate+",你是一个专业的数据分析师,你的任务是基于数据,专业的解答用户的问题。" +
"你需要遵守以下规则:\n" +
"1.返回规范的数据格式json 输入:近 10 天的日活跃数,输出:{\"start\":\"2023-07-21\",\"end\":\"2023-07-31\"}" +
"2.你对时间数据要求规范,能从近 10 天,国庆节,端午节,获取到相应的时间,填写到 json 中。\n"+
"3.你的数据时间,只有当前及之前时间即可,超过则回复去年\n" +
"4.只需要解析出时间,时间可以是时间月和年或日、日历采用公历\n"+
"5.时间给出要是绝对正确,不能瞎编\n"
);
Message message = Message.of("输入:"+queryText+",输出:");
Message res = getChatCompletion(system, message);
return res.getContent();
}
public String mockAlias(String mockType,String name,String bizName,String table,String desc,Boolean isPercentage){
String msg = "Assuming you are a professional data analyst specializing in indicators, you have a vast amount of data analysis indicator content. You are familiar with the basic format of the content,Now, Construct your answer Based on the following json-schema.\n" +
"{\n" +
"\"$schema\": \"http://json-schema.org/draft-07/schema#\",\n" +
"\"type\": \"array\",\n" +
"\"minItems\": 2,\n" +
"\"maxItems\": 4,\n" +
"\"items\": {\n" +
"\"type\": \"string\",\n" +
"\"description\": \"Assuming you are a data analyst and give a defined "+mockType+" name: " +name+","+
"this "+mockType+" is from database and table: "+table+ ",This "+mockType+" calculates the field source: "+bizName+", The description of this indicator is: "+desc+", provide some aliases for thisplease take chinese or english,but more chinese and Not repeating\"\n" +
"},\n" +
"\"additionalProperties\":false}\n" +
"Please double-check whether the answer conforms to the format described in the JSON-schema.\n" +
"ANSWER JSON:";
log.info("msg:{}",msg);
Message system = Message.ofSystem("");
Message message = Message.of(msg);
Message res = getChatCompletion(system, message);
return res.getContent();
}
public String mockDimensionValueAlias(String json){
String msg = "Assuming you are a professional data analyst specializing in indicators,for you a json list" +
"the required content to follow is as follows: " +
"1. The format of JSON," +
"2. Only return in JSON format," +
"3. the array item > 1 and < 5,more alias," +
"for exampleinput:[\"qq_music\",\"kugou_music\"],out:{\"tran\":[\"qq音乐\",\"酷狗音乐\"],\"alias\":{\"qq_music\":[\"q音\",\"qq音乐\"],\"kugou_music\":[\"kugou\",\"酷狗\"]}}," +
"now input: " + json + ","+
"answer json:";
log.info("msg:{}",msg);
Message system = Message.ofSystem("");
Message message = Message.of(msg);
Message res = getChatCompletion(system, message);
return res.getContent();
}
public static void main(String[] args) {
ChatGptHelper chatGptHelper = new ChatGptHelper();
System.out.println(chatGptHelper.mockAlias("","","","","",false));
}
}

View File

@@ -1,11 +1,13 @@
package com.tencent.supersonic.common.util;
import java.text.ParseException;
import java.text.SimpleDateFormat;
import java.time.LocalDate;
import java.time.LocalDateTime;
import java.time.format.DateTimeFormatter;
import java.time.temporal.TemporalAdjusters;
import java.util.Calendar;
import java.util.Date;
import java.util.Objects;
import lombok.extern.slf4j.Slf4j;
@Slf4j
@@ -56,30 +58,37 @@ public class DateUtils {
return dateFormat.format(calendar.getTime());
}
public static String getBeforeDate(String date, int intervalDay, DatePeriodEnum datePeriodEnum) {
Calendar calendar = Calendar.getInstance();
SimpleDateFormat dateFormat = new SimpleDateFormat(DATE_FORMAT_DOT);
try {
calendar.setTime(dateFormat.parse(date));
} catch (ParseException e) {
log.error("parse error");
}
DateTimeFormatter dateTimeFormatter = DateTimeFormatter.ofPattern(DATE_FORMAT_DOT);
LocalDate currentDate = LocalDate.parse(date, dateTimeFormatter);
LocalDate result = null;
switch (datePeriodEnum) {
case DAY:
calendar.set(Calendar.DATE, calendar.get(Calendar.DATE) - intervalDay);
result = currentDate.minusDays(intervalDay);
break;
case WEEK:
calendar.set(Calendar.DATE, calendar.get(Calendar.DATE) - intervalDay * 7);
result = currentDate.minusWeeks(intervalDay);
if (intervalDay == 0) {
result = result.with(TemporalAdjusters.previousOrSame(java.time.DayOfWeek.MONDAY));
}
break;
case MONTH:
calendar.set(Calendar.MONTH, calendar.get(Calendar.MONTH) - intervalDay);
result = currentDate.minusMonths(intervalDay);
if (intervalDay == 0) {
result = result.with(TemporalAdjusters.firstDayOfMonth());
}
break;
case YEAR:
calendar.set(Calendar.YEAR, calendar.get(Calendar.YEAR) - intervalDay);
result = currentDate.minusYears(intervalDay);
if (intervalDay == 0) {
result = result.with(TemporalAdjusters.firstDayOfYear());
}
break;
default:
}
return dateFormat.format(calendar.getTime());
if (Objects.nonNull(result)) {
return result.format(DateTimeFormatter.ofPattern(DATE_FORMAT_DOT));
}
return null;
}
}

View File

@@ -9,7 +9,6 @@ public class StringUtil {
public static final String COMMA_WRAPPER = "'%s'";
public static final String SPACE_WRAPPER = " %s ";
public static String getCommaWrap(String value) {
return String.format(COMMA_WRAPPER, value);
}

View File

@@ -41,7 +41,9 @@ public class CCJSqlParserUtils {
}
Set<String> result = new HashSet<>();
Expression where = plainSelect.getWhere();
if (Objects.nonNull(where)) {
where.accept(new FieldAcquireVisitor(result));
}
return new ArrayList<>(result);
}
@@ -166,7 +168,24 @@ public class CCJSqlParserUtils {
if (Objects.nonNull(groupByElement)) {
groupByElement.accept(new GroupByReplaceVisitor(fieldToBizName));
}
//5. add Waiting Expression
return selectStatement.toString();
}
public static String replaceFunction(String sql) {
Select selectStatement = getSelect(sql);
SelectBody selectBody = selectStatement.getSelectBody();
if (!(selectBody instanceof PlainSelect)) {
return sql;
}
PlainSelect plainSelect = (PlainSelect) selectBody;
//1. replace where dataDiff function
Expression where = plainSelect.getWhere();
FunctionReplaceVisitor visitor = new FunctionReplaceVisitor();
if (Objects.nonNull(where)) {
where.accept(visitor);
}
//2. add Waiting Expression
List<Expression> waitingForAdds = visitor.getWaitingForAdds();
addWaitingExpression(plainSelect, where, waitingForAdds);
return selectStatement.toString();
@@ -181,9 +200,10 @@ public class CCJSqlParserUtils {
if (where == null) {
plainSelect.setWhere(expression);
} else {
plainSelect.setWhere(new AndExpression(where, expression));
where = new AndExpression(where, expression);
}
}
plainSelect.setWhere(where);
}

View File

@@ -1,25 +1,15 @@
package com.tencent.supersonic.common.util.jsqlparser;
import java.util.ArrayList;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import lombok.extern.slf4j.Slf4j;
import net.sf.jsqlparser.expression.Expression;
import net.sf.jsqlparser.expression.ExpressionVisitorAdapter;
import net.sf.jsqlparser.expression.operators.relational.GreaterThan;
import net.sf.jsqlparser.expression.operators.relational.GreaterThanEquals;
import net.sf.jsqlparser.expression.operators.relational.MinorThan;
import net.sf.jsqlparser.expression.operators.relational.MinorThanEquals;
import net.sf.jsqlparser.schema.Column;
@Slf4j
public class FieldReplaceVisitor extends ExpressionVisitorAdapter {
ParseVisitorHelper parseVisitorHelper = new ParseVisitorHelper();
private Map<String, String> fieldToBizName;
private List<Expression> waitingForAdds = new ArrayList<>();
public FieldReplaceVisitor(Map<String, String> fieldToBizName) {
this.fieldToBizName = fieldToBizName;
@@ -29,42 +19,4 @@ public class FieldReplaceVisitor extends ExpressionVisitorAdapter {
public void visit(Column column) {
parseVisitorHelper.replaceColumn(column, fieldToBizName);
}
@Override
public void visit(MinorThan expr) {
Expression expression = parseVisitorHelper.reparseDate(expr, fieldToBizName, ">");
if (Objects.nonNull(expression)) {
waitingForAdds.add(expression);
}
}
@Override
public void visit(MinorThanEquals expr) {
Expression expression = parseVisitorHelper.reparseDate(expr, fieldToBizName, ">=");
if (Objects.nonNull(expression)) {
waitingForAdds.add(expression);
}
}
@Override
public void visit(GreaterThan expr) {
Expression expression = parseVisitorHelper.reparseDate(expr, fieldToBizName, "<");
if (Objects.nonNull(expression)) {
waitingForAdds.add(expression);
}
}
@Override
public void visit(GreaterThanEquals expr) {
Expression expression = parseVisitorHelper.reparseDate(expr, fieldToBizName, "<=");
if (Objects.nonNull(expression)) {
waitingForAdds.add(expression);
}
}
public List<Expression> getWaitingForAdds() {
return waitingForAdds;
}
}

View File

@@ -0,0 +1,171 @@
package com.tencent.supersonic.common.util.jsqlparser;
import com.tencent.supersonic.common.util.DatePeriodEnum;
import com.tencent.supersonic.common.util.DateUtils;
import com.tencent.supersonic.common.util.StringUtil;
import java.util.ArrayList;
import java.util.List;
import java.util.Objects;
import lombok.extern.slf4j.Slf4j;
import net.sf.jsqlparser.JSQLParserException;
import net.sf.jsqlparser.expression.DoubleValue;
import net.sf.jsqlparser.expression.Expression;
import net.sf.jsqlparser.expression.ExpressionVisitorAdapter;
import net.sf.jsqlparser.expression.Function;
import net.sf.jsqlparser.expression.LongValue;
import net.sf.jsqlparser.expression.StringValue;
import net.sf.jsqlparser.expression.operators.relational.ComparisonOperator;
import net.sf.jsqlparser.expression.operators.relational.EqualsTo;
import net.sf.jsqlparser.expression.operators.relational.GreaterThan;
import net.sf.jsqlparser.expression.operators.relational.GreaterThanEquals;
import net.sf.jsqlparser.expression.operators.relational.MinorThan;
import net.sf.jsqlparser.expression.operators.relational.MinorThanEquals;
import net.sf.jsqlparser.parser.CCJSqlParserUtil;
import net.sf.jsqlparser.schema.Column;
import org.apache.commons.collections.CollectionUtils;
@Slf4j
public class FunctionReplaceVisitor extends ExpressionVisitorAdapter {
public static final String DATE_FUNCTION = "datediff";
public static final double HALF_YEAR = 0.5d;
public static final int SIX_MONTH = 6;
public static final String EQUAL = "=";
private List<Expression> waitingForAdds = new ArrayList<>();
@Override
public void visit(MinorThan expr) {
List<Expression> expressions = reparseDate(expr, ">");
if (Objects.nonNull(expressions)) {
waitingForAdds.addAll(expressions);
}
}
@Override
public void visit(EqualsTo expr) {
List<Expression> expressions = reparseDate(expr, ">=");
if (Objects.nonNull(expressions)) {
waitingForAdds.addAll(expressions);
}
}
@Override
public void visit(MinorThanEquals expr) {
List<Expression> expressions = reparseDate(expr, ">=");
if (Objects.nonNull(expressions)) {
waitingForAdds.addAll(expressions);
}
}
@Override
public void visit(GreaterThan expr) {
List<Expression> expressions = reparseDate(expr, "<");
if (Objects.nonNull(expressions)) {
waitingForAdds.addAll(expressions);
}
}
@Override
public void visit(GreaterThanEquals expr) {
List<Expression> expressions = reparseDate(expr, "<=");
if (Objects.nonNull(expressions)) {
waitingForAdds.addAll(expressions);
}
}
public List<Expression> getWaitingForAdds() {
return waitingForAdds;
}
public List<Expression> reparseDate(ComparisonOperator comparisonOperator, String startDateOperator) {
List<Expression> result = new ArrayList<>();
Expression leftExpression = comparisonOperator.getLeftExpression();
if (!(leftExpression instanceof Function)) {
return result;
}
Function leftExpressionFunction = (Function) leftExpression;
if (!leftExpressionFunction.toString().contains(DATE_FUNCTION)) {
return result;
}
List<Expression> leftExpressions = leftExpressionFunction.getParameters().getExpressions();
if (CollectionUtils.isEmpty(leftExpressions) || leftExpressions.size() < 3) {
return result;
}
Column field = (Column) leftExpressions.get(1);
String columnName = field.getColumnName();
try {
String startDateValue = getStartDateStr(comparisonOperator, leftExpressions);
String endDateValue = getEndDateValue(leftExpressions);
String endDateOperator = comparisonOperator.getStringExpression();
String condExpr =
columnName + StringUtil.getSpaceWrap(getEndDateOperator(comparisonOperator))
+ StringUtil.getCommaWrap(endDateValue);
ComparisonOperator expression = (ComparisonOperator) CCJSqlParserUtil.parseCondExpression(condExpr);
String startDataCondExpr =
columnName + StringUtil.getSpaceWrap(startDateOperator) + StringUtil.getCommaWrap(startDateValue);
if (EQUAL.equalsIgnoreCase(endDateOperator)) {
result.add(CCJSqlParserUtil.parseCondExpression(condExpr));
expression = (ComparisonOperator) CCJSqlParserUtil.parseCondExpression(" 1 = 1 ");
}
comparisonOperator.setLeftExpression(null);
comparisonOperator.setRightExpression(null);
comparisonOperator.setASTNode(null);
comparisonOperator.setLeftExpression(expression.getLeftExpression());
comparisonOperator.setRightExpression(expression.getRightExpression());
comparisonOperator.setASTNode(expression.getASTNode());
result.add(CCJSqlParserUtil.parseCondExpression(startDataCondExpr));
return result;
} catch (JSQLParserException e) {
log.error("JSQLParserException", e);
}
return null;
}
private String getStartDateStr(ComparisonOperator minorThanEquals, List<Expression> expressions) {
String unitValue = getUnit(expressions);
String dateValue = getEndDateValue(expressions);
String dateStr = "";
Expression rightExpression = minorThanEquals.getRightExpression();
DatePeriodEnum datePeriodEnum = DatePeriodEnum.get(unitValue);
if (rightExpression instanceof DoubleValue) {
DoubleValue value = (DoubleValue) rightExpression;
double doubleValue = value.getValue();
if (DatePeriodEnum.YEAR.equals(datePeriodEnum) && doubleValue == HALF_YEAR) {
datePeriodEnum = DatePeriodEnum.MONTH;
dateStr = DateUtils.getBeforeDate(dateValue, SIX_MONTH, datePeriodEnum);
}
} else if (rightExpression instanceof LongValue) {
LongValue value = (LongValue) rightExpression;
long doubleValue = value.getValue();
dateStr = DateUtils.getBeforeDate(dateValue, (int) doubleValue, datePeriodEnum);
}
return dateStr;
}
private String getEndDateOperator(ComparisonOperator comparisonOperator) {
String operator = comparisonOperator.getStringExpression();
if (EQUAL.equalsIgnoreCase(operator)) {
operator = "<=";
}
return operator;
}
private String getEndDateValue(List<Expression> leftExpressions) {
StringValue date = (StringValue) leftExpressions.get(2);
return date.getValue();
}
private String getUnit(List<Expression> expressions) {
StringValue unit = (StringValue) expressions.get(0);
return unit.getValue();
}
}

View File

@@ -2,13 +2,18 @@ package com.tencent.supersonic.common.util.jsqlparser;
import java.util.List;
import java.util.Map;
import lombok.extern.slf4j.Slf4j;
import net.sf.jsqlparser.JSQLParserException;
import net.sf.jsqlparser.expression.Expression;
import net.sf.jsqlparser.expression.Function;
import net.sf.jsqlparser.expression.operators.relational.ExpressionList;
import net.sf.jsqlparser.parser.CCJSqlParserUtil;
import net.sf.jsqlparser.schema.Column;
import net.sf.jsqlparser.statement.select.GroupByElement;
import net.sf.jsqlparser.statement.select.GroupByVisitor;
import org.apache.commons.lang3.StringUtils;
@Slf4j
public class GroupByReplaceVisitor implements GroupByVisitor {
ParseVisitorHelper parseVisitorHelper = new ParseVisitorHelper();
@@ -29,8 +34,18 @@ public class GroupByReplaceVisitor implements GroupByVisitor {
String replaceColumn = parseVisitorHelper.getReplaceColumn(expression.toString(), fieldToBizName);
if (StringUtils.isNotEmpty(replaceColumn)) {
if (expression instanceof Column) {
groupByExpressions.set(i, new Column(replaceColumn));
}
if (expression instanceof Function) {
try {
Expression element = CCJSqlParserUtil.parseExpression(replaceColumn);
((Function) expression).getParameters().getExpressions().set(0, element);
} catch (JSQLParserException e) {
log.error("e", e);
}
}
}
}
}
}

View File

@@ -1,33 +1,16 @@
package com.tencent.supersonic.common.util.jsqlparser;
import com.tencent.supersonic.common.util.DatePeriodEnum;
import com.tencent.supersonic.common.util.DateUtils;
import com.tencent.supersonic.common.util.StringUtil;
import java.util.List;
import java.util.Map;
import java.util.Map.Entry;
import java.util.Optional;
import java.util.stream.Collectors;
import lombok.extern.slf4j.Slf4j;
import net.sf.jsqlparser.JSQLParserException;
import net.sf.jsqlparser.expression.DoubleValue;
import net.sf.jsqlparser.expression.Expression;
import net.sf.jsqlparser.expression.Function;
import net.sf.jsqlparser.expression.LongValue;
import net.sf.jsqlparser.expression.StringValue;
import net.sf.jsqlparser.expression.operators.relational.ComparisonOperator;
import net.sf.jsqlparser.parser.CCJSqlParserUtil;
import net.sf.jsqlparser.schema.Column;
import org.apache.commons.collections.CollectionUtils;
import org.apache.commons.lang3.StringUtils;
@Slf4j
public class ParseVisitorHelper {
public static final double HALF_YEAR = 0.5d;
public static final int SIX_MONTH = 6;
public static final String DATE_FUNCTION = "datediff";
public void replaceColumn(Column column, Map<String, String> fieldToBizName) {
String columnName = column.getColumnName();
column.setColumnName(getReplaceColumn(columnName, fieldToBizName));
@@ -53,88 +36,6 @@ public class ParseVisitorHelper {
return columnName;
}
public Expression reparseDate(ComparisonOperator comparisonOperator, Map<String, String> fieldToBizName,
String startDateOperator) {
Expression leftExpression = comparisonOperator.getLeftExpression();
if (leftExpression instanceof Column) {
Column leftExpressionColumn = (Column) leftExpression;
replaceColumn(leftExpressionColumn, fieldToBizName);
return null;
}
if (!(leftExpression instanceof Function)) {
return null;
}
Function leftExpressionFunction = (Function) leftExpression;
if (!leftExpressionFunction.toString().contains(DATE_FUNCTION)) {
return null;
}
List<Expression> leftExpressions = leftExpressionFunction.getParameters().getExpressions();
if (CollectionUtils.isEmpty(leftExpressions) || leftExpressions.size() < 3) {
return null;
}
Column field = (Column) leftExpressions.get(1);
String columnName = field.getColumnName();
String startDateValue = getStartDateStr(comparisonOperator, leftExpressions);
String fieldBizName = fieldToBizName.get(columnName);
try {
String endDateValue = getEndDateValue(leftExpressions);
String stringExpression = comparisonOperator.getStringExpression();
String condExpr =
fieldBizName + StringUtil.getSpaceWrap(stringExpression) + StringUtil.getCommaWrap(endDateValue);
ComparisonOperator expression = (ComparisonOperator) CCJSqlParserUtil.parseCondExpression(condExpr);
comparisonOperator.setLeftExpression(null);
comparisonOperator.setRightExpression(null);
comparisonOperator.setASTNode(null);
comparisonOperator.setLeftExpression(expression.getLeftExpression());
comparisonOperator.setRightExpression(expression.getRightExpression());
comparisonOperator.setASTNode(expression.getASTNode());
String startDataCondExpr =
fieldBizName + StringUtil.getSpaceWrap(startDateOperator) + StringUtil.getCommaWrap(startDateValue);
return CCJSqlParserUtil.parseCondExpression(startDataCondExpr);
} catch (JSQLParserException e) {
log.error("JSQLParserException", e);
}
return null;
}
private String getEndDateValue(List<Expression> leftExpressions) {
StringValue date = (StringValue) leftExpressions.get(2);
return date.getValue();
}
private String getStartDateStr(ComparisonOperator minorThanEquals, List<Expression> expressions) {
String unitValue = getUnit(expressions);
String dateValue = getEndDateValue(expressions);
String dateStr = "";
Expression rightExpression = minorThanEquals.getRightExpression();
DatePeriodEnum datePeriodEnum = DatePeriodEnum.get(unitValue);
if (rightExpression instanceof DoubleValue) {
DoubleValue value = (DoubleValue) rightExpression;
double doubleValue = value.getValue();
if (DatePeriodEnum.YEAR.equals(datePeriodEnum) && doubleValue == HALF_YEAR) {
datePeriodEnum = DatePeriodEnum.MONTH;
dateStr = DateUtils.getBeforeDate(dateValue, SIX_MONTH, datePeriodEnum);
}
} else if (rightExpression instanceof LongValue) {
LongValue value = (LongValue) rightExpression;
long doubleValue = value.getValue();
dateStr = DateUtils.getBeforeDate(dateValue, (int) doubleValue, datePeriodEnum);
}
return dateStr;
}
private String getUnit(List<Expression> expressions) {
StringValue unit = (StringValue) expressions.get(0);
return unit.getValue();
}
public static int editDistance(String word1, String word2) {
final int m = word1.length();
final int n = word2.length();

View File

@@ -14,13 +14,25 @@ class DateUtilsTest {
dateStr = DateUtils.getBeforeDate("2023-08-10", 8, DatePeriodEnum.DAY);
Assert.assertEquals(dateStr, "2023-08-02");
dateStr = DateUtils.getBeforeDate("2023-08-10", 0, DatePeriodEnum.DAY);
Assert.assertEquals(dateStr, "2023-08-10");
dateStr = DateUtils.getBeforeDate("2023-08-10", 1, DatePeriodEnum.WEEK);
Assert.assertEquals(dateStr, "2023-08-03");
dateStr = DateUtils.getBeforeDate("2023-08-10", 0, DatePeriodEnum.WEEK);
Assert.assertEquals(dateStr, "2023-08-07");
dateStr = DateUtils.getBeforeDate("2023-08-01", 1, DatePeriodEnum.MONTH);
Assert.assertEquals(dateStr, "2023-07-01");
dateStr = DateUtils.getBeforeDate("2023-08-10", 0, DatePeriodEnum.MONTH);
Assert.assertEquals(dateStr, "2023-08-01");
dateStr = DateUtils.getBeforeDate("2023-08-01", 1, DatePeriodEnum.YEAR);
Assert.assertEquals(dateStr, "2022-08-01");
dateStr = DateUtils.getBeforeDate("2023-08-10", 0, DatePeriodEnum.YEAR);
Assert.assertEquals(dateStr, "2023-01-01");
}
}

View File

@@ -22,37 +22,104 @@ class CCJSqlParserUtilsTest {
String replaceSql = "select 歌曲名 from 歌曲库 where datediff('day', 发布日期, '2023-08-09') <= 1 and 歌手名 = '邓紫棋' and 数据日期 = '2023-08-09' and 歌曲发布时 = '2023-08-01' order by 播放量 desc limit 11";
replaceSql = CCJSqlParserUtils.replaceFields(replaceSql, fieldToBizName);
replaceSql = CCJSqlParserUtils.replaceFunction(replaceSql);
Assert.assertEquals(
"SELECT song_name FROM 歌曲库 WHERE publish_date <= '2023-08-09' AND singer_name = '邓紫棋' AND sys_imp_date = '2023-08-09' AND song_publis_date = '2023-08-01' AND publish_date >= '2023-08-08' ORDER BY play_count DESC LIMIT 11"
, replaceSql);
replaceSql = "select YEAR(发行日期), count(歌曲名) from 歌曲库 where YEAR(发行日期) in (2022, 2023) and 数据日期 = '2023-08-14' group by YEAR(发行日期)";
replaceSql = CCJSqlParserUtils.replaceFields(replaceSql, fieldToBizName);
replaceSql = CCJSqlParserUtils.replaceFunction(replaceSql);
Assert.assertEquals(
"SELECT YEAR(publish_date), count(song_name) FROM 歌曲库 WHERE YEAR(publish_date) IN (2022, 2023) AND sys_imp_date = '2023-08-14' GROUP BY YEAR(publish_date)",
replaceSql);
replaceSql = "select YEAR(发行日期), count(歌曲名) from 歌曲库 where YEAR(发行日期) in (2022, 2023) and 数据日期 = '2023-08-14' group by 发行日期";
replaceSql = CCJSqlParserUtils.replaceFields(replaceSql, fieldToBizName);
replaceSql = CCJSqlParserUtils.replaceFunction(replaceSql);
Assert.assertEquals(
"SELECT YEAR(publish_date), count(song_name) FROM 歌曲库 WHERE YEAR(publish_date) IN (2022, 2023) AND sys_imp_date = '2023-08-14' GROUP BY publish_date",
replaceSql);
replaceSql = CCJSqlParserUtils.replaceFields(
"select 歌曲名 from 歌曲库 where datediff('year', 发布日期, '2023-08-11') <= 1 and 结算播放量 > 1000000 and datediff('day', 数据日期, '2023-08-11') <= 30",
fieldToBizName);
replaceSql = CCJSqlParserUtils.replaceFunction(replaceSql);
Assert.assertEquals(
"SELECT song_name FROM 歌曲库 WHERE publish_date <= '2023-08-11' AND play_count > 1000000 AND sys_imp_date <= '2023-08-11' AND publish_date >= '2022-08-11' AND sys_imp_date >= '2023-07-12'"
, replaceSql);
replaceSql = CCJSqlParserUtils.replaceFields(
"select 歌曲名 from 歌曲库 where datediff('day', 发布日期, '2023-08-09') <= 1 and 歌手名 = '邓紫棋' and 数据日期 = '2023-08-09' order by 播放量 desc limit 11",
fieldToBizName);
replaceSql = CCJSqlParserUtils.replaceFunction(replaceSql);
Assert.assertEquals(
"SELECT song_name FROM 歌曲库 WHERE publish_date <= '2023-08-09' AND singer_name = '邓紫棋' AND sys_imp_date = '2023-08-09' AND publish_date >= '2023-08-08' ORDER BY play_count DESC LIMIT 11"
, replaceSql);
replaceSql = CCJSqlParserUtils.replaceFields(
"select 歌曲名 from 歌曲库 where datediff('year', 发布日期, '2023-08-09') = 0 and 歌手名 = '邓紫棋' and 数据日期 = '2023-08-09' order by 播放量 desc limit 11",
fieldToBizName);
replaceSql = CCJSqlParserUtils.replaceFunction(replaceSql);
Assert.assertEquals(
"SELECT song_name FROM 歌曲库 WHERE 1 = 1 AND singer_name = '邓紫棋' AND sys_imp_date = '2023-08-09' AND publish_date <= '2023-08-09' AND publish_date >= '2023-01-01' ORDER BY play_count DESC LIMIT 11"
, replaceSql);
replaceSql = CCJSqlParserUtils.replaceFields(
"select 歌曲名 from 歌曲库 where datediff('year', 发布日期, '2023-08-09') <= 0.5 and 歌手名 = '邓紫棋' and 数据日期 = '2023-08-09' order by 播放量 desc limit 11",
fieldToBizName);
replaceSql = CCJSqlParserUtils.replaceFunction(replaceSql);
Assert.assertEquals(
"SELECT song_name FROM 歌曲库 WHERE publish_date <= '2023-08-09' AND singer_name = '邓紫棋' AND sys_imp_date = '2023-08-09' AND publish_date >= '2023-02-09' ORDER BY play_count DESC LIMIT 11"
, replaceSql);
replaceSql = CCJSqlParserUtils.replaceFields(
"select 歌曲名 from 歌曲库 where datediff('year', 发布日期, '2023-08-09') >= 0.5 and 歌手名 = '邓紫棋' and 数据日期 = '2023-08-09' order by 播放量 desc limit 11",
fieldToBizName);
replaceSql = CCJSqlParserUtils.replaceFunction(replaceSql);
Assert.assertEquals(
"SELECT song_name FROM 歌曲库 WHERE publish_date >= '2023-08-09' AND singer_name = '邓紫棋' AND sys_imp_date = '2023-08-09' AND publish_date <= '2023-02-09' ORDER BY play_count DESC LIMIT 11"
, replaceSql);
replaceSql = CCJSqlParserUtils.replaceFields(
"select 部门,用户 from 超音数 where 数据日期 = '2023-08-08' and 用户 =alice and 发布日期 ='11' order by 访问次数 desc limit 1",
"select 部门,用户 from 超音数 where 数据日期 = '2023-08-08' and 用户 ='alice' and 发布日期 ='11' order by 访问次数 desc limit 1",
fieldToBizName);
replaceSql = CCJSqlParserUtils.replaceFunction(replaceSql);
Assert.assertEquals(
"SELECT department, user_id FROM 超音数 WHERE sys_imp_date = '2023-08-08' AND user_id = 'alice' AND publish_date = '11' ORDER BY pv DESC LIMIT 1"
, replaceSql);
replaceSql = CCJSqlParserUtils.replaceTable(replaceSql, "s2");
replaceSql = CCJSqlParserUtils.addFieldsToSelect(replaceSql, Collections.singletonList("field_a"));
replaceSql = CCJSqlParserUtils.replaceFields(
"select 部门,sum (访问次数) from 超音数 where 数据日期 = '2023-08-08' and 用户 =alice and 发布日期 ='11' group by 部门 limit 1",
"select 部门,sum (访问次数) from 超音数 where 数据日期 = '2023-08-08' and 用户 ='alice' and 发布日期 ='11' group by 部门 limit 1",
fieldToBizName);
Assert.assertEquals(replaceSql,
"SELECT department, sum(pv) FROM 超音数 WHERE sys_imp_date = '2023-08-08' AND user_id = user_id AND publish_date = '11' GROUP BY department LIMIT 1");
replaceSql = CCJSqlParserUtils.replaceFunction(replaceSql);
Assert.assertEquals(
"SELECT department, sum(pv) FROM 超音数 WHERE sys_imp_date = '2023-08-08' AND user_id = 'alice' AND publish_date = '11' GROUP BY department LIMIT 1",
replaceSql);
replaceSql = "select sum(访问次数) from 超音数 where 数据日期 >= '2023-08-06' and 数据日期 <= '2023-08-06' and 部门 = 'hr'";
replaceSql = CCJSqlParserUtils.replaceFields(replaceSql, fieldToBizName);
replaceSql = CCJSqlParserUtils.replaceFunction(replaceSql);
System.out.println(replaceSql);
Assert.assertEquals(
"SELECT sum(pv) FROM 超音数 WHERE sys_imp_date >= '2023-08-06' AND sys_imp_date <= '2023-08-06' AND department = 'hr'",
replaceSql);
}

View File

@@ -6,9 +6,10 @@ com.tencent.supersonic.chat.api.component.SchemaMapper=\
com.tencent.supersonic.chat.api.component.SemanticParser=\
com.tencent.supersonic.chat.parser.rule.QueryModeParser, \
com.tencent.supersonic.chat.parser.rule.ContextInheritParser, \
com.tencent.supersonic.chat.parser.rule.AgentCheckParser, \
com.tencent.supersonic.chat.parser.rule.TimeRangeParser, \
com.tencent.supersonic.chat.parser.rule.AggregateTypeParser, \
com.tencent.supersonic.chat.parser.llm.LLMDSLParser, \
com.tencent.supersonic.chat.parser.llm.dsl.LLMDSLParser, \
com.tencent.supersonic.chat.parser.function.FunctionBasedParser
com.tencent.supersonic.chat.api.component.SemanticLayer=\
com.tencent.supersonic.knowledge.semantic.RemoteSemanticLayer
@@ -20,3 +21,12 @@ com.tencent.supersonic.auth.authentication.interceptor.AuthenticationInterceptor
com.tencent.supersonic.auth.authentication.interceptor.DefaultAuthenticationInterceptor
com.tencent.supersonic.auth.api.authentication.adaptor.UserAdaptor=\
com.tencent.supersonic.auth.authentication.adaptor.DefaultUserAdaptor
com.tencent.supersonic.chat.api.component.DSLOptimizer=\
com.tencent.supersonic.chat.query.dsl.optimizer.DateFieldCorrector, \
com.tencent.supersonic.chat.query.dsl.optimizer.FieldCorrector, \
com.tencent.supersonic.chat.query.dsl.optimizer.FunctionCorrector, \
com.tencent.supersonic.chat.query.dsl.optimizer.TableNameCorrector, \
com.tencent.supersonic.chat.query.dsl.optimizer.QueryFilterAppend, \
com.tencent.supersonic.chat.query.dsl.optimizer.SelectFieldAppendCorrector

View File

@@ -1,6 +1,12 @@
package com.tencent.supersonic;
import com.alibaba.fastjson.JSONObject;
import com.google.common.collect.Lists;
import com.tencent.supersonic.auth.api.authentication.pojo.User;
import com.tencent.supersonic.chat.agent.Agent;
import com.tencent.supersonic.chat.agent.AgentConfig;
import com.tencent.supersonic.chat.agent.tool.AgentToolType;
import com.tencent.supersonic.chat.agent.tool.RuleQueryTool;
import com.tencent.supersonic.chat.api.pojo.request.ChatAggConfigReq;
import com.tencent.supersonic.chat.api.pojo.request.ChatConfigBaseReq;
import com.tencent.supersonic.chat.api.pojo.request.ChatDefaultConfigReq;
@@ -14,16 +20,14 @@ import com.tencent.supersonic.chat.plugin.Plugin;
import com.tencent.supersonic.chat.plugin.PluginParseConfig;
import com.tencent.supersonic.chat.query.plugin.ParamOption;
import com.tencent.supersonic.chat.query.plugin.WebBase;
import com.tencent.supersonic.chat.service.ChatService;
import com.tencent.supersonic.chat.service.ConfigService;
import com.tencent.supersonic.chat.service.PluginService;
import com.tencent.supersonic.chat.service.QueryService;
import com.tencent.supersonic.chat.service.*;
import com.tencent.supersonic.common.util.JsonUtil;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import lombok.extern.slf4j.Slf4j;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.beans.factory.annotation.Qualifier;
import org.springframework.boot.context.event.ApplicationReadyEvent;
import org.springframework.context.ApplicationListener;
import org.springframework.stereotype.Component;
@@ -32,6 +36,7 @@ import org.springframework.stereotype.Component;
@Slf4j
public class ConfigureDemo implements ApplicationListener<ApplicationReadyEvent> {
@Qualifier("chatQueryService")
@Autowired
private QueryService queryService;
@Autowired
@@ -40,6 +45,8 @@ public class ConfigureDemo implements ApplicationListener<ApplicationReadyEvent>
protected ConfigService configService;
@Autowired
private PluginService pluginService;
@Autowired
private AgentService agentService;
private User user = User.getFakeUser();
@@ -175,43 +182,25 @@ public class ConfigureDemo implements ApplicationListener<ApplicationReadyEvent>
pluginService.createPlugin(plugin_1, user);
}
private void addPlugin_2() {
Plugin plugin_2 = new Plugin();
plugin_2.setType("DSL");
plugin_2.setModelList(Arrays.asList(1L, 2L));
plugin_2.setPattern("");
plugin_2.setParseModeConfig(null);
plugin_2.setName("大模型语义解析");
List<String> examples = new ArrayList<>();
examples.add("超音数访问次数最高的部门是哪个");
examples.add("超音数访问人数最高的部门是哪个");
PluginParseConfig parseConfig = PluginParseConfig.builder()
.name("DSL")
.description("这个工具能够将用户的自然语言查询转化为SQL语句从而从数据库中的查询具体的数据。用于处理数据查询的问题提供基于事实的数据")
.examples(examples)
.build();
plugin_2.setParseModeConfig(JsonUtil.toString(parseConfig));
pluginService.createPlugin(plugin_2, user);
}
private void addPlugin_3() {
Plugin plugin_2 = new Plugin();
plugin_2.setType("CONTENT_INTERPRET");
plugin_2.setModelList(Arrays.asList(1L));
plugin_2.setPattern("超音数最近访问情况怎么样");
plugin_2.setParseModeConfig(null);
plugin_2.setName("内容解读");
List<String> examples = new ArrayList<>();
examples.add("超音数最近访问情况怎么样");
examples.add("超音数最近访问情况如何");
PluginParseConfig parseConfig = PluginParseConfig.builder()
.name("supersonic_content_interpret")
.description("这个工具能够先查询到相关的数据并交给大模型进行解读, 最后返回解读结果")
.examples(examples)
.build();
plugin_2.setParseModeConfig(JsonUtil.toString(parseConfig));
pluginService.createPlugin(plugin_2, user);
private void addAgent() {
Agent agent = new Agent();
agent.setId(1);
agent.setName("查信息");
agent.setDescription("查信息");
agent.setStatus(1);
agent.setEnableSearch(1);
agent.setExamples(Lists.newArrayList("超音数访问次数", "超音数访问人数", "alice 停留时长"));
AgentConfig agentConfig = new AgentConfig();
RuleQueryTool ruleQueryTool = new RuleQueryTool();
ruleQueryTool.setType(AgentToolType.RULE);
ruleQueryTool.setQueryModes(Lists.newArrayList(
"ENTITY_DETAIL", "ENTITY_LIST_FILTER", "ENTITY_ID", "METRIC_ENTITY",
"METRIC_FILTER", "METRIC_GROUPBY", "METRIC_MODEL", "METRIC_ORDERBY"
));
agentConfig.getTools().add(ruleQueryTool);
agent.setAgentConfig(JSONObject.toJSONString(agentConfig));
agentService.createAgent(agent, User.getFakeUser());
}
@Override
@@ -220,8 +209,7 @@ public class ConfigureDemo implements ApplicationListener<ApplicationReadyEvent>
addDemoChatConfig_1();
addDemoChatConfig_2();
addPlugin_1();
addPlugin_2();
addPlugin_3();
addAgent();
addSampleChats();
addSampleChats2();
} catch (Exception e) {

View File

@@ -6,9 +6,10 @@ com.tencent.supersonic.chat.api.component.SchemaMapper=\
com.tencent.supersonic.chat.api.component.SemanticParser=\
com.tencent.supersonic.chat.parser.rule.QueryModeParser, \
com.tencent.supersonic.chat.parser.rule.ContextInheritParser, \
com.tencent.supersonic.chat.parser.rule.AgentCheckParser, \
com.tencent.supersonic.chat.parser.rule.TimeRangeParser, \
com.tencent.supersonic.chat.parser.rule.AggregateTypeParser, \
com.tencent.supersonic.chat.parser.llm.LLMDSLParser, \
com.tencent.supersonic.chat.parser.llm.dsl.LLMDSLParser, \
com.tencent.supersonic.chat.parser.embedding.EmbeddingBasedParser, \
com.tencent.supersonic.chat.parser.function.FunctionBasedParser
com.tencent.supersonic.chat.api.component.SemanticLayer=\
@@ -21,3 +22,11 @@ com.tencent.supersonic.auth.authentication.interceptor.AuthenticationInterceptor
com.tencent.supersonic.auth.authentication.interceptor.DefaultAuthenticationInterceptor
com.tencent.supersonic.auth.api.authentication.adaptor.UserAdaptor=\
com.tencent.supersonic.auth.authentication.adaptor.DefaultUserAdaptor
com.tencent.supersonic.chat.api.component.DSLOptimizer=\
com.tencent.supersonic.chat.query.dsl.optimizer.DateFieldCorrector, \
com.tencent.supersonic.chat.query.dsl.optimizer.FieldCorrector, \
com.tencent.supersonic.chat.query.dsl.optimizer.FunctionCorrector, \
com.tencent.supersonic.chat.query.dsl.optimizer.TableNameCorrector, \
com.tencent.supersonic.chat.query.dsl.optimizer.QueryFilterAppend, \
com.tencent.supersonic.chat.query.dsl.optimizer.SelectFieldAppendCorrector

View File

@@ -648,6 +648,22 @@ CREATE TABLE IF NOT EXISTS `s2_plugin`
COMMENT
ON TABLE s2_plugin IS 'plugin information table';
CREATE TABLE IF NOT EXISTS s2_agent
(
id int AUTO_INCREMENT,
name varchar(100) null,
description varchar(500) null,
status int null,
examples varchar(500) null,
config varchar(2000) null,
created_by varchar(100) null,
created_at TIMESTAMP null,
updated_by varchar(100) null,
updated_at TIMESTAMP null,
enable_search int null,
PRIMARY KEY (`id`)
); COMMENT ON TABLE s2_agent IS 'assistant information table';
-------demo for semantic and chat
CREATE TABLE IF NOT EXISTS `s2_user_department`

View File

@@ -11,6 +11,7 @@ import com.tencent.supersonic.chat.api.pojo.request.QueryReq;
import com.tencent.supersonic.chat.api.pojo.response.ParseResp;
import com.tencent.supersonic.chat.api.pojo.response.QueryResult;
import com.tencent.supersonic.chat.api.pojo.response.QueryState;
import com.tencent.supersonic.chat.service.AgentService;
import com.tencent.supersonic.chat.service.ChatService;
import com.tencent.supersonic.chat.service.ConfigService;
import com.tencent.supersonic.chat.service.QueryService;
@@ -22,6 +23,7 @@ import org.junit.runner.RunWith;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.beans.factory.annotation.Qualifier;
import org.springframework.boot.test.context.SpringBootTest;
import org.springframework.boot.test.mock.mockito.MockBean;
import org.springframework.test.context.ActiveProfiles;
import org.springframework.test.context.junit4.SpringRunner;
@@ -42,6 +44,8 @@ public class BaseQueryTest {
protected ChatService chatService;
@Autowired
protected ConfigService configService;
@MockBean
protected AgentService agentService;
protected QueryResult submitMultiTurnChat(String queryText) throws Exception {
ParseResp parseResp = submitParse(queryText);
@@ -78,6 +82,11 @@ public class BaseQueryTest {
return queryService.performParsing(queryContextReq);
}
protected ParseResp submitParseWithAgent(String queryText, Integer agentId) {
QueryReq queryContextReq = DataUtils.getQueryReqWithAgent(10, queryText, agentId);
return queryService.performParsing(queryContextReq);
}
protected void assertSchemaElements(Set<SchemaElement> expected, Set<SchemaElement> actual) {
Set<String> expectedNames = expected.stream().map(s -> s.getName())
.filter(s -> s != null).collect(Collectors.toSet());

View File

@@ -0,0 +1,56 @@
package com.tencent.supersonic.integration;
import com.alibaba.fastjson.JSONObject;
import com.tencent.supersonic.StandaloneLauncher;
import com.tencent.supersonic.chat.api.pojo.request.QueryReq;
import com.tencent.supersonic.chat.api.pojo.response.QueryResult;
import com.tencent.supersonic.chat.parser.embedding.EmbeddingConfig;
import com.tencent.supersonic.chat.plugin.PluginManager;
import com.tencent.supersonic.chat.query.metricInterpret.LLmAnswerResp;
import com.tencent.supersonic.chat.service.AgentService;
import com.tencent.supersonic.chat.service.QueryService;
import com.tencent.supersonic.util.DataUtils;
import org.junit.Assert;
import org.junit.Test;
import org.junit.runner.RunWith;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.beans.factory.annotation.Qualifier;
import org.springframework.boot.test.context.SpringBootTest;
import org.springframework.boot.test.mock.mockito.MockBean;
import org.springframework.http.ResponseEntity;
import org.springframework.test.context.ActiveProfiles;
import org.springframework.test.context.junit4.SpringRunner;
@RunWith(SpringRunner.class)
@SpringBootTest(classes = StandaloneLauncher.class)
@ActiveProfiles("local")
public class MetricInterpretTest {
@MockBean
private AgentService agentService;
@MockBean
private PluginManager pluginManager;
@MockBean
private EmbeddingConfig embeddingConfig;
@Autowired
@Qualifier("chatQueryService")
private QueryService queryService;
@Test
public void testMetricInterpret() throws Exception {
MockConfiguration.mockAgent(agentService);
MockConfiguration.mockEmbeddingUrl(embeddingConfig);
LLmAnswerResp lLmAnswerResp = new LLmAnswerResp();
lLmAnswerResp.setAssistant_message("alice最近在超音数的访问情况有增多");
MockConfiguration.mockPluginManagerDoRequest(pluginManager, "answer_with_plugin_call",
ResponseEntity.ok(JSONObject.toJSONString(lLmAnswerResp)));
QueryReq queryReq = DataUtils.getQueryReqWithAgent(1000, "能不能帮我解读分析下最近alice在超音数的访问情况",
DataUtils.getAgent().getId());
QueryResult queryResult = queryService.executeQuery(queryReq);
Assert.assertEquals(queryResult.getQueryResults().get(0).get("answer"), lLmAnswerResp.getAssistant_message());
}
}

View File

@@ -1,30 +1,31 @@
package com.tencent.supersonic.integration;
import static com.tencent.supersonic.common.pojo.enums.AggregateTypeEnum.NONE;
import static com.tencent.supersonic.common.pojo.enums.AggregateTypeEnum.SUM;
import com.tencent.supersonic.auth.api.authentication.pojo.User;
import com.tencent.supersonic.chat.api.pojo.SemanticParseInfo;
import com.tencent.supersonic.chat.api.pojo.request.ChatConfigEditReqReq;
import com.tencent.supersonic.chat.api.pojo.request.ItemVisibility;
import com.tencent.supersonic.chat.api.pojo.request.QueryFilter;
import com.tencent.supersonic.chat.api.pojo.response.ChatConfigResp;
import com.tencent.supersonic.chat.api.pojo.response.ParseResp;
import com.tencent.supersonic.chat.api.pojo.response.QueryResult;
import com.tencent.supersonic.chat.api.pojo.request.ChatConfigEditReqReq;
import com.tencent.supersonic.chat.api.pojo.response.ChatConfigResp;
import com.tencent.supersonic.chat.api.pojo.request.ItemVisibility;
import com.tencent.supersonic.chat.query.rule.metric.MetricModelQuery;
import com.tencent.supersonic.chat.query.rule.metric.MetricFilterQuery;
import com.tencent.supersonic.chat.query.rule.metric.MetricGroupByQuery;
import com.tencent.supersonic.chat.query.rule.metric.MetricModelQuery;
import com.tencent.supersonic.chat.query.rule.metric.MetricTopNQuery;
import com.tencent.supersonic.common.pojo.DateConf;
import com.tencent.supersonic.semantic.api.query.enums.FilterOperatorEnum;
import com.tencent.supersonic.util.DataUtils;
import java.text.DateFormat;
import java.text.SimpleDateFormat;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import org.junit.Assert;
import org.junit.Test;
import org.springframework.beans.BeanUtils;
import java.text.DateFormat;
import java.text.SimpleDateFormat;
import java.util.*;
import java.util.stream.Collectors;
import static com.tencent.supersonic.common.pojo.enums.AggregateTypeEnum.*;
public class MetricQueryTest extends BaseQueryTest {
@@ -50,6 +51,17 @@ public class MetricQueryTest extends BaseQueryTest {
assertQueryResult(expectedResult, actualResult);
}
@Test
public void queryTest_METRIC_FILTER_with_agent() {
//agent only support METRIC_ENTITY, METRIC_FILTER
MockConfiguration.mockAgent(agentService);
ParseResp parseResp = submitParseWithAgent("alice的访问次数", DataUtils.getAgent().getId());
Assert.assertNotNull(parseResp.getSelectedParses());
List<String> queryModes = parseResp.getSelectedParses().stream()
.map(SemanticParseInfo::getQueryMode).collect(Collectors.toList());
Assert.assertTrue(queryModes.contains("METRIC_FILTER"));
}
@Test
public void queryTest_METRIC_DOMAIN() throws Exception {
QueryResult actualResult = submitNewChat("超音数的访问次数");
@@ -69,6 +81,16 @@ public class MetricQueryTest extends BaseQueryTest {
assertQueryResult(expectedResult, actualResult);
}
@Test
public void queryTest_METRIC_MODEL_with_agent() {
//agent only support METRIC_ENTITY, METRIC_FILTER
MockConfiguration.mockAgent(agentService);
ParseResp parseResp = submitParseWithAgent("超音数的访问次数", DataUtils.getAgent().getId());
List<String> queryModes = parseResp.getSelectedParses().stream()
.map(SemanticParseInfo::getQueryMode).collect(Collectors.toList());
Assert.assertTrue(queryModes.contains("METRIC_MODEL"));
}
@Test
public void queryTest_METRIC_GROUPBY() throws Exception {
QueryResult actualResult = submitNewChat("超音数各部门的访问次数");

View File

@@ -1,22 +1,23 @@
package com.tencent.supersonic.integration.plugin;
package com.tencent.supersonic.integration;
import static org.mockito.ArgumentMatchers.eq;
import static org.mockito.ArgumentMatchers.notNull;
import static org.mockito.Mockito.when;
import com.google.common.collect.Lists;
import com.tencent.supersonic.chat.parser.embedding.EmbeddingConfig;
import com.tencent.supersonic.chat.parser.embedding.EmbeddingResp;
import com.tencent.supersonic.chat.parser.embedding.RecallRetrieval;
import com.tencent.supersonic.chat.plugin.PluginManager;
import com.tencent.supersonic.chat.service.AgentService;
import com.tencent.supersonic.util.DataUtils;
import lombok.extern.slf4j.Slf4j;
import org.springframework.context.annotation.Configuration;
import org.springframework.http.ResponseEntity;
import static org.mockito.ArgumentMatchers.eq;
import static org.mockito.ArgumentMatchers.notNull;
import static org.mockito.Mockito.when;
@Configuration
@Slf4j
public class PluginMockConfiguration {
public class MockConfiguration {
public static void mockEmbeddingRecognize(PluginManager pluginManager, String text, String id) {
EmbeddingResp embeddingResp = new EmbeddingResp();
@@ -33,9 +34,12 @@ public class PluginMockConfiguration {
when(embeddingConfig.getUrl()).thenReturn("test");
}
public static void mockPluginManagerDoRequest(PluginManager pluginManager, String path,
ResponseEntity<String> responseEntity) {
public static void mockPluginManagerDoRequest(PluginManager pluginManager, String path, ResponseEntity<String> responseEntity) {
when(pluginManager.doRequest(eq(path), notNull(String.class))).thenReturn(responseEntity);
}
public static void mockAgent(AgentService agentService) {
when(agentService.getAgent(1)).thenReturn(DataUtils.getAgent());
}
}

Some files were not shown because too many files have changed in this diff Show More