mirror of
https://github.com/tencentmusic/supersonic.git
synced 2025-12-11 03:58:14 +00:00
[improvement](supersonic) based on version 0.7.2 (#34)
Co-authored-by: zuopengge <hwzuopengge@tencent.com>
This commit is contained in:
@@ -0,0 +1,9 @@
|
||||
package com.tencent.supersonic.chat.api.component;
|
||||
|
||||
import com.tencent.supersonic.chat.api.pojo.CorrectionInfo;
|
||||
import net.sf.jsqlparser.JSQLParserException;
|
||||
|
||||
public interface DSLOptimizer {
|
||||
CorrectionInfo rewriter(CorrectionInfo correctionInfo) throws JSQLParserException;
|
||||
|
||||
}
|
||||
@@ -0,0 +1,21 @@
|
||||
package com.tencent.supersonic.chat.api.pojo;
|
||||
|
||||
import com.tencent.supersonic.chat.api.pojo.request.QueryFilters;
|
||||
import lombok.AllArgsConstructor;
|
||||
import lombok.Builder;
|
||||
import lombok.Data;
|
||||
import lombok.NoArgsConstructor;
|
||||
|
||||
@Data
|
||||
@AllArgsConstructor
|
||||
@NoArgsConstructor
|
||||
@Builder
|
||||
public class CorrectionInfo {
|
||||
|
||||
private QueryFilters queryFilters;
|
||||
|
||||
private SemanticParseInfo parseInfo;
|
||||
|
||||
private String sql;
|
||||
|
||||
}
|
||||
@@ -12,4 +12,5 @@ public class QueryReq {
|
||||
private User user;
|
||||
private QueryFilters queryFilters;
|
||||
private boolean saveAnswer = true;
|
||||
private Integer agentId;
|
||||
}
|
||||
|
||||
@@ -40,12 +40,6 @@
|
||||
<scope>compile</scope>
|
||||
</dependency>
|
||||
|
||||
<dependency>
|
||||
<groupId>com.github.plexpt</groupId>
|
||||
<artifactId>chatgpt</artifactId>
|
||||
<version>4.1.2</version>
|
||||
</dependency>
|
||||
|
||||
<dependency>
|
||||
<groupId>org.junit.jupiter</groupId>
|
||||
<artifactId>junit-jupiter</artifactId>
|
||||
|
||||
@@ -0,0 +1,43 @@
|
||||
package com.tencent.supersonic.chat.agent;
|
||||
|
||||
|
||||
import com.alibaba.fastjson.JSONObject;
|
||||
import com.google.common.collect.Lists;
|
||||
import com.tencent.supersonic.chat.agent.tool.AgentToolType;
|
||||
import com.tencent.supersonic.common.pojo.RecordInfo;
|
||||
import lombok.Data;
|
||||
import org.springframework.util.CollectionUtils;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
import java.util.stream.Collectors;
|
||||
|
||||
@Data
|
||||
public class Agent extends RecordInfo {
|
||||
|
||||
private Integer id;
|
||||
private Integer enableSearch;
|
||||
private String name;
|
||||
private String description;
|
||||
|
||||
//0 offline, 1 online
|
||||
private Integer status;
|
||||
private List<String> examples;
|
||||
private String agentConfig;
|
||||
|
||||
public List<String> getTools(AgentToolType type) {
|
||||
Map map = JSONObject.parseObject(agentConfig, Map.class);
|
||||
if (CollectionUtils.isEmpty(map) || map.get("tools") == null) {
|
||||
return Lists.newArrayList();
|
||||
}
|
||||
List<Map> toolList = (List) map.get("tools");
|
||||
return toolList.stream()
|
||||
.filter(tool -> type.name().equals(tool.get("type")))
|
||||
.map(JSONObject::toJSONString)
|
||||
.collect(Collectors.toList());
|
||||
}
|
||||
|
||||
public boolean enableSearch() {
|
||||
return enableSearch != null && enableSearch == 1;
|
||||
}
|
||||
|
||||
}
|
||||
@@ -0,0 +1,18 @@
|
||||
package com.tencent.supersonic.chat.agent;
|
||||
|
||||
import com.google.common.collect.Lists;
|
||||
import com.tencent.supersonic.chat.agent.tool.AgentTool;
|
||||
import lombok.AllArgsConstructor;
|
||||
import lombok.Data;
|
||||
import lombok.NoArgsConstructor;
|
||||
import java.util.List;
|
||||
|
||||
|
||||
@Data
|
||||
@NoArgsConstructor
|
||||
@AllArgsConstructor
|
||||
public class AgentConfig {
|
||||
|
||||
List<AgentTool> tools = Lists.newArrayList();
|
||||
|
||||
}
|
||||
@@ -0,0 +1,19 @@
|
||||
package com.tencent.supersonic.chat.agent.tool;
|
||||
|
||||
|
||||
import lombok.AllArgsConstructor;
|
||||
import lombok.Data;
|
||||
import lombok.NoArgsConstructor;
|
||||
|
||||
import java.util.List;
|
||||
|
||||
@Data
|
||||
@NoArgsConstructor
|
||||
@AllArgsConstructor
|
||||
public class AgentTool {
|
||||
|
||||
private String name;
|
||||
|
||||
private AgentToolType type;
|
||||
|
||||
}
|
||||
@@ -0,0 +1,8 @@
|
||||
package com.tencent.supersonic.chat.agent.tool;
|
||||
|
||||
public enum AgentToolType {
|
||||
RULE,
|
||||
DSL,
|
||||
PLUGIN,
|
||||
INTERPRET
|
||||
}
|
||||
@@ -0,0 +1,14 @@
|
||||
package com.tencent.supersonic.chat.agent.tool;
|
||||
|
||||
import lombok.Data;
|
||||
|
||||
import java.util.List;
|
||||
|
||||
@Data
|
||||
public class DslTool extends AgentTool {
|
||||
|
||||
private List<Long> modelIds;
|
||||
|
||||
private List<String> exampleQuestions;
|
||||
|
||||
}
|
||||
@@ -0,0 +1,16 @@
|
||||
package com.tencent.supersonic.chat.agent.tool;
|
||||
|
||||
import com.tencent.supersonic.chat.parser.llm.interpret.MetricOption;
|
||||
import lombok.Data;
|
||||
|
||||
import java.util.List;
|
||||
|
||||
|
||||
@Data
|
||||
public class MetricInterpretTool extends AgentTool {
|
||||
|
||||
private Long modelId;
|
||||
|
||||
private List<MetricOption> metricOptions;
|
||||
|
||||
}
|
||||
@@ -0,0 +1,13 @@
|
||||
package com.tencent.supersonic.chat.agent.tool;
|
||||
|
||||
|
||||
import lombok.Data;
|
||||
|
||||
import java.util.List;
|
||||
|
||||
@Data
|
||||
public class PluginTool extends AgentTool {
|
||||
|
||||
private List<Long> plugins;
|
||||
|
||||
}
|
||||
@@ -0,0 +1,13 @@
|
||||
package com.tencent.supersonic.chat.agent.tool;
|
||||
|
||||
|
||||
import lombok.Data;
|
||||
|
||||
import java.util.List;
|
||||
|
||||
@Data
|
||||
public class RuleQueryTool extends AgentTool {
|
||||
|
||||
private List<String> queryModes;
|
||||
|
||||
}
|
||||
@@ -19,7 +19,7 @@ import org.springframework.stereotype.Service;
|
||||
@Slf4j
|
||||
public class MapperHelper {
|
||||
|
||||
@Value("${one.detection.size:6}")
|
||||
@Value("${one.detection.size:8}")
|
||||
private Integer oneDetectionSize;
|
||||
@Value("${one.detection.max.size:20}")
|
||||
private Integer oneDetectionMaxSize;
|
||||
@@ -64,7 +64,7 @@ public class MapperHelper {
|
||||
*/
|
||||
public boolean existDimensionValues(List<String> natures) {
|
||||
for (String nature : natures) {
|
||||
if (NatureHelper.isDimensionValueClassId(nature)) {
|
||||
if (NatureHelper.isDimensionValueModelId(nature)) {
|
||||
return true;
|
||||
}
|
||||
}
|
||||
|
||||
@@ -32,7 +32,7 @@ public class QueryMatchStrategy implements MatchStrategy {
|
||||
private MapperHelper mapperHelper;
|
||||
|
||||
@Override
|
||||
public Map<MatchText, List<MapResult>> match(String text, List<Term> terms, Long detectmodelId) {
|
||||
public Map<MatchText, List<MapResult>> match(String text, List<Term> terms, Long detectModelId) {
|
||||
if (Objects.isNull(terms) || StringUtils.isEmpty(text)) {
|
||||
return null;
|
||||
}
|
||||
@@ -43,22 +43,18 @@ public class QueryMatchStrategy implements MatchStrategy {
|
||||
List<Integer> offsetList = terms.stream().sorted(Comparator.comparing(Term::getOffset))
|
||||
.map(term -> term.getOffset()).collect(Collectors.toList());
|
||||
|
||||
log.debug("retryCount:{},terms:{},regOffsetToLength:{},offsetList:{},detectmodelId:{}", terms,
|
||||
regOffsetToLength, offsetList, detectmodelId);
|
||||
log.debug("retryCount:{},terms:{},regOffsetToLength:{},offsetList:{},detectModelId:{}", terms,
|
||||
regOffsetToLength, offsetList, detectModelId);
|
||||
|
||||
List<MapResult> detects = detect(text, regOffsetToLength, offsetList, detectmodelId);
|
||||
List<MapResult> detects = detect(text, regOffsetToLength, offsetList, detectModelId);
|
||||
Map<MatchText, List<MapResult>> result = new HashMap<>();
|
||||
|
||||
MatchText matchText = MatchText.builder()
|
||||
.regText(text)
|
||||
.detectSegment(text)
|
||||
.build();
|
||||
result.put(matchText, detects);
|
||||
result.put(MatchText.builder().regText(text).detectSegment(text).build(), detects);
|
||||
return result;
|
||||
}
|
||||
|
||||
private List<MapResult> detect(String text, Map<Integer, Integer> regOffsetToLength, List<Integer> offsetList,
|
||||
Long detectmodelId) {
|
||||
Long detectModelId) {
|
||||
List<MapResult> results = Lists.newArrayList();
|
||||
|
||||
for (Integer index = 0; index <= text.length() - 1; ) {
|
||||
@@ -69,7 +65,7 @@ public class QueryMatchStrategy implements MatchStrategy {
|
||||
int offset = mapperHelper.getStepOffset(offsetList, index);
|
||||
i = mapperHelper.getStepIndex(regOffsetToLength, i);
|
||||
if (i <= text.length()) {
|
||||
List<MapResult> mapResults = detectByStep(text, detectmodelId, index, i, offset);
|
||||
List<MapResult> mapResults = detectByStep(text, detectModelId, index, i, offset);
|
||||
selectMapResultInOneRound(mapResultRowSet, mapResults);
|
||||
}
|
||||
}
|
||||
@@ -106,15 +102,15 @@ public class QueryMatchStrategy implements MatchStrategy {
|
||||
return a.getName() + Constants.UNDERLINE + String.join(Constants.UNDERLINE, a.getNatures());
|
||||
}
|
||||
|
||||
private List<MapResult> detectByStep(String text, Long detectmodelId, Integer index, Integer i, int offset) {
|
||||
private List<MapResult> detectByStep(String text, Long detectModelId, Integer index, Integer i, int offset) {
|
||||
String detectSegment = text.substring(index, i);
|
||||
Integer oneDetectionSize = mapperHelper.getOneDetectionSize();
|
||||
|
||||
// step1. pre search
|
||||
LinkedHashSet<MapResult> mapResults = SearchService.prefixSearch(detectSegment,
|
||||
mapperHelper.getOneDetectionMaxSize())
|
||||
Integer oneDetectionMaxSize = mapperHelper.getOneDetectionMaxSize();
|
||||
LinkedHashSet<MapResult> mapResults = SearchService.prefixSearch(detectSegment, oneDetectionMaxSize)
|
||||
.stream().collect(Collectors.toCollection(LinkedHashSet::new));
|
||||
// step2. suffix search
|
||||
LinkedHashSet<MapResult> suffixMapResults = SearchService.suffixSearch(detectSegment, oneDetectionSize)
|
||||
LinkedHashSet<MapResult> suffixMapResults = SearchService.suffixSearch(detectSegment, oneDetectionMaxSize)
|
||||
.stream().collect(Collectors.toCollection(LinkedHashSet::new));
|
||||
|
||||
mapResults.addAll(suffixMapResults);
|
||||
@@ -126,11 +122,11 @@ public class QueryMatchStrategy implements MatchStrategy {
|
||||
mapResults = mapResults.stream().sorted((a, b) -> -(b.getName().length() - a.getName().length()))
|
||||
.collect(Collectors.toCollection(LinkedHashSet::new));
|
||||
// step4. filter by classId
|
||||
if (Objects.nonNull(detectmodelId) && detectmodelId > 0) {
|
||||
log.debug("detectmodelId:{}, before parseResults:{}", mapResults);
|
||||
if (Objects.nonNull(detectModelId) && detectModelId > 0) {
|
||||
log.debug("detectModelId:{}, before parseResults:{}", mapResults);
|
||||
mapResults = mapResults.stream().map(entry -> {
|
||||
List<String> natures = entry.getNatures().stream().filter(
|
||||
nature -> nature.startsWith(DictWordType.NATURE_SPILT + detectmodelId) || (nature.startsWith(
|
||||
nature -> nature.startsWith(DictWordType.NATURE_SPILT + detectModelId) || (nature.startsWith(
|
||||
DictWordType.NATURE_SPILT))
|
||||
).collect(Collectors.toList());
|
||||
entry.setNatures(natures);
|
||||
@@ -145,8 +141,7 @@ public class QueryMatchStrategy implements MatchStrategy {
|
||||
.filter(term -> CollectionUtils.isNotEmpty(term.getNatures()))
|
||||
.collect(Collectors.toCollection(LinkedHashSet::new));
|
||||
|
||||
log.debug("metricDimensionThreshold:{},dimensionValueThreshold:{},after isSimilarity parseResults:{}",
|
||||
mapResults);
|
||||
log.debug("after isSimilarity parseResults:{}", mapResults);
|
||||
|
||||
mapResults = mapResults.stream().map(parseResult -> {
|
||||
parseResult.setOffset(offset);
|
||||
@@ -165,7 +160,7 @@ public class QueryMatchStrategy implements MatchStrategy {
|
||||
if (CollectionUtils.isNotEmpty(dimensionMetrics)) {
|
||||
return dimensionMetrics;
|
||||
} else {
|
||||
return mapResults.stream().limit(oneDetectionSize).collect(Collectors.toList());
|
||||
return mapResults.stream().limit(mapperHelper.getOneDetectionSize()).collect(Collectors.toList());
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -2,8 +2,8 @@ package com.tencent.supersonic.chat.parser;
|
||||
|
||||
|
||||
import com.tencent.supersonic.chat.api.component.SemanticQuery;
|
||||
import com.tencent.supersonic.chat.api.pojo.QueryContext;
|
||||
import com.tencent.supersonic.chat.api.pojo.SemanticParseInfo;
|
||||
import com.tencent.supersonic.chat.api.pojo.*;
|
||||
import com.tencent.supersonic.chat.query.dsl.DSLQuery;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
|
||||
/**
|
||||
@@ -21,6 +21,9 @@ public class SatisfactionChecker {
|
||||
// check all the parse info in candidate
|
||||
public static boolean check(QueryContext queryContext) {
|
||||
for (SemanticQuery query : queryContext.getCandidateQueries()) {
|
||||
if (query.getQueryMode().equals(DSLQuery.QUERY_MODE)) {
|
||||
continue;
|
||||
}
|
||||
if (checkThreshold(queryContext.getRequest().getQueryText(), query.getParseInfo())) {
|
||||
return true;
|
||||
}
|
||||
|
||||
@@ -2,14 +2,7 @@ package com.tencent.supersonic.chat.parser.embedding;
|
||||
|
||||
import com.google.common.collect.Lists;
|
||||
import com.tencent.supersonic.chat.api.component.SemanticParser;
|
||||
import com.tencent.supersonic.chat.api.pojo.ChatContext;
|
||||
import com.tencent.supersonic.chat.api.pojo.ModelSchema;
|
||||
import com.tencent.supersonic.chat.api.pojo.QueryContext;
|
||||
import com.tencent.supersonic.chat.api.pojo.SchemaElement;
|
||||
import com.tencent.supersonic.chat.api.pojo.SchemaElementMatch;
|
||||
import com.tencent.supersonic.chat.api.pojo.SchemaElementType;
|
||||
import com.tencent.supersonic.chat.api.pojo.SchemaMapInfo;
|
||||
import com.tencent.supersonic.chat.api.pojo.SemanticParseInfo;
|
||||
import com.tencent.supersonic.chat.api.pojo.*;
|
||||
import com.tencent.supersonic.chat.api.pojo.request.QueryFilter;
|
||||
import com.tencent.supersonic.chat.api.pojo.request.QueryReq;
|
||||
import com.tencent.supersonic.chat.parser.ParseMode;
|
||||
@@ -17,17 +10,14 @@ import com.tencent.supersonic.chat.plugin.Plugin;
|
||||
import com.tencent.supersonic.chat.plugin.PluginManager;
|
||||
import com.tencent.supersonic.chat.plugin.PluginParseResult;
|
||||
import com.tencent.supersonic.chat.query.QueryManager;
|
||||
import com.tencent.supersonic.chat.query.dsl.DSLQuery;
|
||||
import com.tencent.supersonic.chat.query.plugin.PluginSemanticQuery;
|
||||
import com.tencent.supersonic.chat.service.PluginService;
|
||||
import com.tencent.supersonic.chat.service.SemanticService;
|
||||
import com.tencent.supersonic.common.pojo.Constants;
|
||||
import com.tencent.supersonic.common.util.ContextUtils;
|
||||
import com.tencent.supersonic.semantic.api.query.enums.FilterOperatorEnum;
|
||||
import java.util.Comparator;
|
||||
import java.util.HashMap;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
import java.util.*;
|
||||
import java.util.stream.Collectors;
|
||||
import com.tencent.supersonic.semantic.api.query.enums.FilterOperatorEnum;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.apache.commons.lang3.StringUtils;
|
||||
import org.apache.commons.lang3.tuple.Pair;
|
||||
@@ -53,48 +43,32 @@ public class EmbeddingBasedParser implements SemanticParser {
|
||||
if (CollectionUtils.isEmpty(embeddingRetrievals)) {
|
||||
return;
|
||||
}
|
||||
PluginService pluginService = ContextUtils.getBean(PluginService.class);
|
||||
List<Plugin> plugins = pluginService.getPluginList();
|
||||
List<Plugin> plugins = getPluginList(queryContext);
|
||||
Map<Long, Plugin> pluginMap = plugins.stream().collect(Collectors.toMap(Plugin::getId, p -> p));
|
||||
for (RecallRetrieval embeddingRetrieval : embeddingRetrievals) {
|
||||
Plugin plugin = pluginMap.get(Long.parseLong(embeddingRetrieval.getId()));
|
||||
if (plugin == null) {
|
||||
if (plugin == null || DSLQuery.QUERY_MODE.equalsIgnoreCase(plugin.getType())) {
|
||||
continue;
|
||||
}
|
||||
Pair<Boolean, List<Long>> pair = PluginManager.resolve(plugin, queryContext);
|
||||
Pair<Boolean, Set<Long>> pair = PluginManager.resolve(plugin, queryContext);
|
||||
log.info("embedding plugin resolve: {}", pair);
|
||||
if (pair.getLeft()) {
|
||||
List<Long> modelList = pair.getRight();
|
||||
Set<Long> modelList = pair.getRight();
|
||||
if (CollectionUtils.isEmpty(modelList)) {
|
||||
return;
|
||||
}
|
||||
modelList = distinctModelList(plugin, queryContext.getMapInfo(), modelList);
|
||||
for (Long modelId : modelList) {
|
||||
buildQuery(plugin, Double.parseDouble(embeddingRetrieval.getDistance()), modelId, queryContext,
|
||||
queryContext.getMapInfo().getMatchedElements(modelId));
|
||||
if (plugin.isContainsAllModel()) {
|
||||
break;
|
||||
}
|
||||
}
|
||||
return;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
public List<Long> distinctModelList(Plugin plugin, SchemaMapInfo schemaMapInfo, List<Long> modelList) {
|
||||
if (!plugin.isContainsAllModel()) {
|
||||
return modelList;
|
||||
}
|
||||
boolean noElementMatch = true;
|
||||
for (Long model : modelList) {
|
||||
List<SchemaElementMatch> schemaElementMatches = schemaMapInfo.getMatchedElements(model);
|
||||
if (!CollectionUtils.isEmpty(schemaElementMatches)) {
|
||||
noElementMatch = false;
|
||||
}
|
||||
}
|
||||
if (noElementMatch) {
|
||||
return modelList.subList(0, 1);
|
||||
}
|
||||
return modelList;
|
||||
}
|
||||
|
||||
private void buildQuery(Plugin plugin, double distance, Long modelId,
|
||||
QueryContext queryContext, List<SchemaElementMatch> schemaElementMatches) {
|
||||
log.info("EmbeddingBasedParser Model: {} choose plugin: [{} {}]", modelId, plugin.getId(), plugin.getName());
|
||||
@@ -126,6 +100,8 @@ public class EmbeddingBasedParser implements SemanticParser {
|
||||
pluginParseResult.setRequest(queryReq);
|
||||
pluginParseResult.setDistance(distance);
|
||||
properties.put(Constants.CONTEXT, pluginParseResult);
|
||||
properties.put("type", "plugin");
|
||||
properties.put("name", plugin.getName());
|
||||
semanticParseInfo.setProperties(properties);
|
||||
semanticParseInfo.setScore(distance);
|
||||
fillSemanticParseInfo(semanticParseInfo);
|
||||
@@ -176,4 +152,8 @@ public class EmbeddingBasedParser implements SemanticParser {
|
||||
}
|
||||
}
|
||||
|
||||
protected List<Plugin> getPluginList(QueryContext queryContext) {
|
||||
return PluginManager.getPluginAgentCanSupport(queryContext.getRequest().getAgentId());
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
@@ -14,23 +14,16 @@ import com.tencent.supersonic.chat.plugin.PluginManager;
|
||||
import com.tencent.supersonic.chat.plugin.PluginParseConfig;
|
||||
import com.tencent.supersonic.chat.plugin.PluginParseResult;
|
||||
import com.tencent.supersonic.chat.query.QueryManager;
|
||||
import com.tencent.supersonic.chat.query.dsl.DSLQuery;
|
||||
import com.tencent.supersonic.chat.query.plugin.PluginSemanticQuery;
|
||||
import com.tencent.supersonic.chat.query.dsl.DSLQuery;
|
||||
import com.tencent.supersonic.chat.service.PluginService;
|
||||
import com.tencent.supersonic.chat.utils.ComponentFactory;
|
||||
import com.tencent.supersonic.common.pojo.Constants;
|
||||
import com.tencent.supersonic.common.util.ContextUtils;
|
||||
import com.tencent.supersonic.common.util.JsonUtil;
|
||||
import java.net.URI;
|
||||
import java.util.ArrayList;
|
||||
import java.util.HashMap;
|
||||
import java.util.HashSet;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
import java.util.Objects;
|
||||
import java.util.Optional;
|
||||
import java.util.Set;
|
||||
import java.util.*;
|
||||
import java.util.stream.Collectors;
|
||||
import com.tencent.supersonic.common.util.JsonUtil;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.apache.commons.lang3.StringUtils;
|
||||
import org.apache.commons.lang3.tuple.Pair;
|
||||
@@ -62,6 +55,10 @@ public class FunctionBasedParser implements SemanticParser {
|
||||
return;
|
||||
}
|
||||
List<PluginParseConfig> functionDOList = getFunctionDO(queryCtx.getRequest().getModelId(), queryCtx);
|
||||
if (CollectionUtils.isEmpty(functionDOList)) {
|
||||
log.info("function call parser, plugin is empty, skip");
|
||||
return;
|
||||
}
|
||||
FunctionReq functionReq = FunctionReq.builder()
|
||||
.queryText(queryCtx.getRequest().getQueryText())
|
||||
.pluginConfigs(functionDOList).build();
|
||||
@@ -85,7 +82,7 @@ public class FunctionBasedParser implements SemanticParser {
|
||||
PluginSemanticQuery semanticQuery = QueryManager.createPluginQuery(toolSelection);
|
||||
ModelResolver ModelResolver = ComponentFactory.getModelResolver();
|
||||
log.info("plugin ModelList:{}", plugin.getModelList());
|
||||
Pair<Boolean, List<Long>> pluginResolveResult = PluginManager.resolve(plugin, queryCtx);
|
||||
Pair<Boolean, Set<Long>> pluginResolveResult = PluginManager.resolve(plugin, queryCtx);
|
||||
Long modelId = ModelResolver.resolve(queryCtx, chatCtx, pluginResolveResult.getRight());
|
||||
log.info("FunctionBasedParser modelId:{}", modelId);
|
||||
if ((Objects.isNull(modelId) || modelId <= 0) && !plugin.isContainsAllModel()) {
|
||||
@@ -102,6 +99,8 @@ public class FunctionBasedParser implements SemanticParser {
|
||||
functionCallParseResult.setRequest(queryCtx.getRequest());
|
||||
Map<String, Object> properties = new HashMap<>();
|
||||
properties.put(Constants.CONTEXT, functionCallParseResult);
|
||||
properties.put("type", "plugin");
|
||||
properties.put("name", plugin.getName());
|
||||
parseInfo.setProperties(properties);
|
||||
parseInfo.setScore(FUNCTION_BONUS_THRESHOLD);
|
||||
parseInfo.setQueryMode(semanticQuery.getQueryMode());
|
||||
@@ -112,17 +111,6 @@ public class FunctionBasedParser implements SemanticParser {
|
||||
queryCtx.getCandidateQueries().add(semanticQuery);
|
||||
}
|
||||
|
||||
|
||||
private Set<Long> getMatchModels(QueryContext queryCtx) {
|
||||
Set<Long> result = new HashSet<>();
|
||||
Long modelId = queryCtx.getRequest().getModelId();
|
||||
if (Objects.nonNull(modelId) && modelId > 0) {
|
||||
result.add(modelId);
|
||||
return result;
|
||||
}
|
||||
return queryCtx.getMapInfo().getMatchedModels();
|
||||
}
|
||||
|
||||
private boolean skipFunction(QueryContext queryCtx, FunctionResp functionResp) {
|
||||
if (Objects.isNull(functionResp) || StringUtils.isBlank(functionResp.getToolSelection())) {
|
||||
return true;
|
||||
@@ -140,7 +128,7 @@ public class FunctionBasedParser implements SemanticParser {
|
||||
|
||||
private List<PluginParseConfig> getFunctionDO(Long modelId, QueryContext queryContext) {
|
||||
log.info("user decide Model:{}", modelId);
|
||||
List<Plugin> plugins = PluginManager.getPlugins();
|
||||
List<Plugin> plugins = getPluginList(queryContext);
|
||||
List<PluginParseConfig> functionDOList = plugins.stream().filter(plugin -> {
|
||||
if (DSLQuery.QUERY_MODE.equalsIgnoreCase(plugin.getType())) {
|
||||
return false;
|
||||
@@ -153,12 +141,12 @@ public class FunctionBasedParser implements SemanticParser {
|
||||
if (StringUtils.isBlank(pluginParseConfig.getName())) {
|
||||
return false;
|
||||
}
|
||||
Pair<Boolean, List<Long>> pluginResolverResult = PluginManager.resolve(plugin, queryContext);
|
||||
log.info("embedding plugin [{}-{}] resolve: {}", plugin.getId(), plugin.getName(), pluginResolverResult);
|
||||
Pair<Boolean, Set<Long>> pluginResolverResult = PluginManager.resolve(plugin, queryContext);
|
||||
log.info("plugin [{}-{}] resolve: {}", plugin.getId(), plugin.getName(), pluginResolverResult);
|
||||
if (!pluginResolverResult.getLeft()) {
|
||||
return false;
|
||||
} else {
|
||||
List<Long> resolveModel = pluginResolverResult.getRight();
|
||||
Set<Long> resolveModel = pluginResolverResult.getRight();
|
||||
if (modelId != null && modelId > 0) {
|
||||
if (plugin.isContainsAllModel()) {
|
||||
return true;
|
||||
@@ -172,20 +160,6 @@ public class FunctionBasedParser implements SemanticParser {
|
||||
return functionDOList;
|
||||
}
|
||||
|
||||
private List<String> getFunctionNames(Set<Long> matchedModels) {
|
||||
List<Plugin> plugins = PluginManager.getPlugins();
|
||||
Set<String> functionNames = plugins.stream()
|
||||
.filter(entry -> {
|
||||
if (!CollectionUtils.isEmpty(entry.getModelList()) && !CollectionUtils.isEmpty(matchedModels)) {
|
||||
return entry.getModelList().stream().anyMatch(matchedModels::contains);
|
||||
}
|
||||
return true;
|
||||
}
|
||||
).map(Plugin::getName).collect(Collectors.toSet());
|
||||
functionNames.add(DSLQuery.QUERY_MODE);
|
||||
return new ArrayList<>(functionNames);
|
||||
}
|
||||
|
||||
public FunctionResp requestFunction(String url, FunctionReq functionReq) {
|
||||
HttpHeaders headers = new HttpHeaders();
|
||||
long startTime = System.currentTimeMillis();
|
||||
@@ -205,4 +179,8 @@ public class FunctionBasedParser implements SemanticParser {
|
||||
}
|
||||
return null;
|
||||
}
|
||||
|
||||
protected List<Plugin> getPluginList(QueryContext queryContext) {
|
||||
return PluginManager.getPluginAgentCanSupport(queryContext.getRequest().getAgentId());
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,21 +1,12 @@
|
||||
package com.tencent.supersonic.chat.parser.function;
|
||||
|
||||
import com.tencent.supersonic.chat.api.component.SemanticQuery;
|
||||
import com.tencent.supersonic.chat.api.pojo.ChatContext;
|
||||
import com.tencent.supersonic.chat.api.pojo.QueryContext;
|
||||
import com.tencent.supersonic.chat.api.pojo.SchemaElementMatch;
|
||||
import com.tencent.supersonic.chat.api.pojo.SchemaElementType;
|
||||
import com.tencent.supersonic.chat.api.pojo.SchemaMapInfo;
|
||||
import com.tencent.supersonic.chat.api.pojo.SemanticParseInfo;
|
||||
import com.tencent.supersonic.chat.api.pojo.*;
|
||||
import com.tencent.supersonic.chat.api.pojo.request.QueryReq;
|
||||
import java.util.HashMap;
|
||||
import java.util.HashSet;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
import java.util.Objects;
|
||||
import java.util.Set;
|
||||
import java.util.stream.Collectors;
|
||||
import com.tencent.supersonic.chat.api.component.SemanticQuery;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
|
||||
import java.util.*;
|
||||
import java.util.stream.Collectors;
|
||||
import org.apache.commons.collections.CollectionUtils;
|
||||
|
||||
@Slf4j
|
||||
@@ -55,7 +46,7 @@ public class HeuristicModelResolver implements ModelResolver {
|
||||
* @return false will use context Model, true will use other Model , maybe include context Model
|
||||
*/
|
||||
protected static boolean isAllowSwitch(Map<Long, SemanticQuery> ModelQueryModes, SchemaMapInfo schemaMap,
|
||||
ChatContext chatCtx, QueryReq searchCtx, Long modelId, List<Long> restrictiveModels) {
|
||||
ChatContext chatCtx, QueryReq searchCtx, Long modelId, Set<Long> restrictiveModels) {
|
||||
if (!Objects.nonNull(modelId) || modelId <= 0) {
|
||||
return true;
|
||||
}
|
||||
@@ -80,8 +71,7 @@ public class HeuristicModelResolver implements ModelResolver {
|
||||
}
|
||||
if (searchCtx.getQueryText() != null && semanticParseInfo.getDateInfo() != null) {
|
||||
if (semanticParseInfo.getDateInfo().getDetectWord() != null) {
|
||||
if (semanticParseInfo.getDateInfo().getDetectWord()
|
||||
.equalsIgnoreCase(searchCtx.getQueryText())) {
|
||||
if (semanticParseInfo.getDateInfo().getDetectWord().equalsIgnoreCase(searchCtx.getQueryText())) {
|
||||
log.info("timeParseResults is not null , can not switch context , timeParseResults:{},",
|
||||
semanticParseInfo.getDateInfo());
|
||||
return false;
|
||||
@@ -131,7 +121,7 @@ public class HeuristicModelResolver implements ModelResolver {
|
||||
}
|
||||
|
||||
|
||||
public Long resolve(QueryContext queryContext, ChatContext chatCtx, List<Long> restrictiveModels) {
|
||||
public Long resolve(QueryContext queryContext, ChatContext chatCtx, Set<Long> restrictiveModels) {
|
||||
Long modelId = queryContext.getRequest().getModelId();
|
||||
if (Objects.nonNull(modelId) && modelId > 0) {
|
||||
if (CollectionUtils.isNotEmpty(restrictiveModels) && restrictiveModels.contains(modelId)) {
|
||||
@@ -151,17 +141,16 @@ public class HeuristicModelResolver implements ModelResolver {
|
||||
for (Long matchedModel : matchedModels) {
|
||||
ModelQueryModes.put(matchedModel, null);
|
||||
}
|
||||
if (ModelQueryModes.size() == 1) {
|
||||
if(ModelQueryModes.size()==1){
|
||||
return ModelQueryModes.keySet().stream().findFirst().get();
|
||||
}
|
||||
return resolve(ModelQueryModes, queryContext, chatCtx,
|
||||
queryContext.getMapInfo(), restrictiveModels);
|
||||
queryContext.getMapInfo(),restrictiveModels);
|
||||
}
|
||||
|
||||
public Long resolve(Map<Long, SemanticQuery> ModelQueryModes, QueryContext queryContext,
|
||||
ChatContext chatCtx, SchemaMapInfo schemaMap, List<Long> restrictiveModels) {
|
||||
Long selectModel = selectModel(ModelQueryModes, queryContext.getRequest(), chatCtx, schemaMap,
|
||||
restrictiveModels);
|
||||
ChatContext chatCtx, SchemaMapInfo schemaMap, Set<Long> restrictiveModels) {
|
||||
Long selectModel = selectModel(ModelQueryModes, queryContext.getRequest(), chatCtx, schemaMap,restrictiveModels);
|
||||
if (selectModel > 0) {
|
||||
log.info("selectModel {} ", selectModel);
|
||||
return selectModel;
|
||||
@@ -172,7 +161,7 @@ public class HeuristicModelResolver implements ModelResolver {
|
||||
|
||||
public Long selectModel(Map<Long, SemanticQuery> ModelQueryModes, QueryReq queryContext,
|
||||
ChatContext chatCtx,
|
||||
SchemaMapInfo schemaMap, List<Long> restrictiveModels) {
|
||||
SchemaMapInfo schemaMap, Set<Long> restrictiveModels) {
|
||||
// if QueryContext has modelId and in ModelQueryModes
|
||||
if (ModelQueryModes.containsKey(queryContext.getModelId())) {
|
||||
log.info("selectModel from QueryContext [{}]", queryContext.getModelId());
|
||||
@@ -181,7 +170,7 @@ public class HeuristicModelResolver implements ModelResolver {
|
||||
// if ChatContext has modelId and in ModelQueryModes
|
||||
if (chatCtx.getParseInfo().getModelId() > 0) {
|
||||
Long modelId = chatCtx.getParseInfo().getModelId();
|
||||
if (!isAllowSwitch(ModelQueryModes, schemaMap, chatCtx, queryContext, modelId, restrictiveModels)) {
|
||||
if (!isAllowSwitch(ModelQueryModes, schemaMap, chatCtx, queryContext, modelId,restrictiveModels)) {
|
||||
log.info("selectModel from ChatContext [{}]", modelId);
|
||||
return modelId;
|
||||
}
|
||||
|
||||
@@ -4,9 +4,10 @@ package com.tencent.supersonic.chat.parser.function;
|
||||
import com.tencent.supersonic.chat.api.pojo.ChatContext;
|
||||
import com.tencent.supersonic.chat.api.pojo.QueryContext;
|
||||
import java.util.List;
|
||||
import java.util.Set;
|
||||
|
||||
public interface ModelResolver {
|
||||
|
||||
Long resolve(QueryContext queryContext, ChatContext chatCtx, List<Long> restrictiveModels);
|
||||
Long resolve(QueryContext queryContext, ChatContext chatCtx, Set<Long> restrictiveModels);
|
||||
|
||||
}
|
||||
@@ -1,11 +0,0 @@
|
||||
package com.tencent.supersonic.chat.parser.llm;
|
||||
|
||||
import com.tencent.supersonic.chat.plugin.PluginParseResult;
|
||||
import com.tencent.supersonic.chat.query.dsl.LLMResp;
|
||||
import lombok.Data;
|
||||
|
||||
@Data
|
||||
public class DSLParseResult extends PluginParseResult {
|
||||
|
||||
private LLMResp llmResp;
|
||||
}
|
||||
@@ -0,0 +1,20 @@
|
||||
package com.tencent.supersonic.chat.parser.llm.dsl;
|
||||
|
||||
import com.tencent.supersonic.common.util.DateUtils;
|
||||
|
||||
public class DSLDateHelper {
|
||||
|
||||
public static String getCurrentDate(Long modelId) {
|
||||
return DateUtils.getBeforeDate(4);
|
||||
// ChatConfigFilter filter = new ChatConfigFilter();
|
||||
// filter.setModelId(modelId);
|
||||
//
|
||||
// List<ChatConfigResp> configResps = ContextUtils.getBean(ConfigService.class).search(filter, null);
|
||||
// if (CollectionUtils.isEmpty(configResps)) {
|
||||
// return
|
||||
// }
|
||||
// ChatConfigResp chatConfigResp = configResps.get(0);
|
||||
// chatConfigResp.getChatDetailConfig().getChatDefaultConfig().get
|
||||
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,17 @@
|
||||
package com.tencent.supersonic.chat.parser.llm.dsl;
|
||||
|
||||
import com.tencent.supersonic.chat.agent.tool.DslTool;
|
||||
import com.tencent.supersonic.chat.api.pojo.request.QueryReq;
|
||||
import com.tencent.supersonic.chat.plugin.PluginParseResult;
|
||||
import com.tencent.supersonic.chat.query.dsl.LLMResp;
|
||||
import lombok.Data;
|
||||
|
||||
@Data
|
||||
public class DSLParseResult {
|
||||
|
||||
private LLMResp llmResp;
|
||||
|
||||
private QueryReq request;
|
||||
|
||||
private DslTool dslTool;
|
||||
}
|
||||
@@ -1,5 +1,10 @@
|
||||
package com.tencent.supersonic.chat.parser.llm;
|
||||
package com.tencent.supersonic.chat.parser.llm.dsl;
|
||||
|
||||
import com.alibaba.fastjson.JSONObject;
|
||||
import com.google.common.collect.Lists;
|
||||
import com.tencent.supersonic.chat.agent.Agent;
|
||||
import com.tencent.supersonic.chat.agent.tool.AgentToolType;
|
||||
import com.tencent.supersonic.chat.agent.tool.DslTool;
|
||||
import com.tencent.supersonic.chat.api.component.SemanticParser;
|
||||
import com.tencent.supersonic.chat.api.pojo.ChatContext;
|
||||
import com.tencent.supersonic.chat.api.pojo.QueryContext;
|
||||
@@ -11,26 +16,26 @@ import com.tencent.supersonic.chat.api.pojo.SemanticSchema;
|
||||
import com.tencent.supersonic.chat.config.LLMConfig;
|
||||
import com.tencent.supersonic.chat.parser.SatisfactionChecker;
|
||||
import com.tencent.supersonic.chat.parser.function.ModelResolver;
|
||||
import com.tencent.supersonic.chat.plugin.Plugin;
|
||||
import com.tencent.supersonic.chat.plugin.PluginManager;
|
||||
import com.tencent.supersonic.chat.query.QueryManager;
|
||||
import com.tencent.supersonic.chat.query.dsl.DSLBuilder;
|
||||
import com.tencent.supersonic.chat.query.dsl.DSLQuery;
|
||||
import com.tencent.supersonic.chat.query.dsl.LLMReq;
|
||||
import com.tencent.supersonic.chat.query.dsl.LLMReq.ElementValue;
|
||||
import com.tencent.supersonic.chat.query.dsl.LLMResp;
|
||||
import com.tencent.supersonic.chat.query.dsl.optimizer.BaseDSLOptimizer;
|
||||
import com.tencent.supersonic.chat.query.plugin.PluginSemanticQuery;
|
||||
import com.tencent.supersonic.chat.service.AgentService;
|
||||
import com.tencent.supersonic.chat.utils.ComponentFactory;
|
||||
import com.tencent.supersonic.common.pojo.Constants;
|
||||
import com.tencent.supersonic.common.util.ContextUtils;
|
||||
import com.tencent.supersonic.common.util.DateUtils;
|
||||
import com.tencent.supersonic.common.util.JsonUtil;
|
||||
import com.tencent.supersonic.knowledge.service.SchemaService;
|
||||
import java.util.ArrayList;
|
||||
import java.util.Collection;
|
||||
import java.util.HashMap;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
import java.util.Objects;
|
||||
import java.util.Optional;
|
||||
import java.util.Set;
|
||||
import java.util.stream.Collectors;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
@@ -56,30 +61,30 @@ public class LLMDSLParser implements SemanticParser {
|
||||
queryCtx.getRequest().getQueryText());
|
||||
return;
|
||||
}
|
||||
List<Plugin> dslPlugins = PluginManager.getPlugins().stream()
|
||||
.filter(plugin -> DSLQuery.QUERY_MODE.equalsIgnoreCase(plugin.getType()))
|
||||
.collect(Collectors.toList());
|
||||
|
||||
if (CollectionUtils.isEmpty(dslPlugins)) {
|
||||
return;
|
||||
}
|
||||
Plugin plugin = dslPlugins.get(0);
|
||||
List<Long> dslModels = plugin.getModelList();
|
||||
|
||||
List<DslTool> dslTools = getDslTools(queryCtx.getRequest().getAgentId());
|
||||
Set<Long> distinctModelIds = dslTools.stream().map(DslTool::getModelIds)
|
||||
.flatMap(Collection::stream)
|
||||
.collect(Collectors.toSet());
|
||||
try {
|
||||
ModelResolver modelResolver = ComponentFactory.getModelResolver();
|
||||
Long modelId = modelResolver.resolve(queryCtx, chatCtx, dslModels);
|
||||
log.info("resolve modelId:{},dslModels:{}", modelId, dslModels);
|
||||
Long modelId = modelResolver.resolve(queryCtx, chatCtx, distinctModelIds);
|
||||
log.info("resolve modelId:{},dslModels:{}", modelId, distinctModelIds);
|
||||
|
||||
if (Objects.isNull(modelId)) {
|
||||
if (Objects.isNull(modelId) || modelId <= 0) {
|
||||
return;
|
||||
}
|
||||
|
||||
Optional<DslTool> dslToolOptional = dslTools.stream().filter(tool ->
|
||||
tool.getModelIds().contains(modelId)).findFirst();
|
||||
if (!dslToolOptional.isPresent()) {
|
||||
log.info("no dsl tool in this agent, skip dsl parser");
|
||||
return;
|
||||
}
|
||||
DslTool dslTool = dslToolOptional.get();
|
||||
LLMResp llmResp = requestLLM(queryCtx, modelId);
|
||||
if (Objects.isNull(llmResp)) {
|
||||
return;
|
||||
}
|
||||
|
||||
PluginSemanticQuery semanticQuery = QueryManager.createPluginQuery(DSLQuery.QUERY_MODE);
|
||||
|
||||
SemanticParseInfo parseInfo = semanticQuery.getParseInfo();
|
||||
@@ -89,10 +94,12 @@ public class LLMDSLParser implements SemanticParser {
|
||||
DSLParseResult dslParseResult = new DSLParseResult();
|
||||
dslParseResult.setRequest(queryCtx.getRequest());
|
||||
dslParseResult.setLlmResp(llmResp);
|
||||
dslParseResult.setPlugin(plugin);
|
||||
dslParseResult.setDslTool(dslToolOptional.get());
|
||||
|
||||
Map<String, Object> properties = new HashMap<>();
|
||||
properties.put(Constants.CONTEXT, dslParseResult);
|
||||
properties.put("type", "internal");
|
||||
properties.put("name", dslTool.getName());
|
||||
parseInfo.setProperties(properties);
|
||||
parseInfo.setScore(FUNCTION_BONUS_THRESHOLD);
|
||||
parseInfo.setQueryMode(semanticQuery.getQueryMode());
|
||||
@@ -126,13 +133,13 @@ public class LLMDSLParser implements SemanticParser {
|
||||
llmSchema.setModelName(modelIdToName.get(modelId));
|
||||
llmSchema.setDomainName(modelIdToName.get(modelId));
|
||||
List<String> fieldNameList = getFieldNameList(queryCtx, modelId, semanticSchema);
|
||||
fieldNameList.add(DSLBuilder.DATA_Field);
|
||||
fieldNameList.add(BaseDSLOptimizer.DATE_FIELD);
|
||||
llmSchema.setFieldNameList(fieldNameList);
|
||||
llmReq.setSchema(llmSchema);
|
||||
List<ElementValue> linking = new ArrayList<>();
|
||||
linking.addAll(getValueList(queryCtx, modelId, semanticSchema));
|
||||
llmReq.setLinking(linking);
|
||||
String currentDate = getCurrentDate(modelId);
|
||||
String currentDate = DSLDateHelper.getCurrentDate(modelId);
|
||||
llmReq.setCurrentDate(currentDate);
|
||||
|
||||
log.info("requestLLM request, modelId:{},llmReq:{}", modelId, llmReq);
|
||||
@@ -156,21 +163,6 @@ public class LLMDSLParser implements SemanticParser {
|
||||
return null;
|
||||
}
|
||||
|
||||
|
||||
private String getCurrentDate(Long modelId) {
|
||||
return DateUtils.getBeforeDate(4);
|
||||
// ChatConfigFilter filter = new ChatConfigFilter();
|
||||
// filter.setModelId(modelId);
|
||||
//
|
||||
// List<ChatConfigResp> configResps = ContextUtils.getBean(ConfigService.class).search(filter, null);
|
||||
// if (CollectionUtils.isEmpty(configResps)) {
|
||||
// return
|
||||
// }
|
||||
// ChatConfigResp chatConfigResp = configResps.get(0);
|
||||
// chatConfigResp.getChatDetailConfig().getChatDefaultConfig().get
|
||||
|
||||
}
|
||||
|
||||
private List<ElementValue> getValueList(QueryContext queryCtx, Long modelId, SemanticSchema semanticSchema) {
|
||||
Map<Long, String> itemIdToName = getItemIdToName(modelId, semanticSchema);
|
||||
|
||||
@@ -228,4 +220,18 @@ public class LLMDSLParser implements SemanticParser {
|
||||
.collect(Collectors.toMap(SchemaElement::getId, SchemaElement::getName, (value1, value2) -> value2));
|
||||
}
|
||||
|
||||
private List<DslTool> getDslTools(Integer agentId) {
|
||||
AgentService agentService = ContextUtils.getBean(AgentService.class);
|
||||
Agent agent = agentService.getAgent(agentId);
|
||||
if (agent == null) {
|
||||
return Lists.newArrayList();
|
||||
}
|
||||
List<String> tools = agent.getTools(AgentToolType.DSL);
|
||||
if (CollectionUtils.isEmpty(tools)) {
|
||||
return Lists.newArrayList();
|
||||
}
|
||||
return tools.stream().map(tool -> JSONObject.parseObject(tool, DslTool.class))
|
||||
.collect(Collectors.toList());
|
||||
}
|
||||
|
||||
}
|
||||
@@ -0,0 +1,144 @@
|
||||
package com.tencent.supersonic.chat.parser.llm.interpret;
|
||||
|
||||
import com.alibaba.fastjson.JSONObject;
|
||||
import com.google.common.collect.Sets;
|
||||
import com.tencent.supersonic.chat.agent.Agent;
|
||||
import com.tencent.supersonic.chat.agent.tool.AgentToolType;
|
||||
import com.tencent.supersonic.chat.agent.tool.MetricInterpretTool;
|
||||
import com.tencent.supersonic.chat.api.component.SemanticLayer;
|
||||
import com.tencent.supersonic.chat.api.component.SemanticParser;
|
||||
import com.tencent.supersonic.chat.api.pojo.*;
|
||||
import com.tencent.supersonic.chat.api.pojo.request.QueryFilter;
|
||||
import com.tencent.supersonic.chat.api.pojo.request.QueryReq;
|
||||
import com.tencent.supersonic.chat.parser.SatisfactionChecker;
|
||||
import com.tencent.supersonic.chat.query.metricInterpret.MetricInterpretQuery;
|
||||
import com.tencent.supersonic.chat.query.QueryManager;
|
||||
import com.tencent.supersonic.chat.query.plugin.PluginSemanticQuery;
|
||||
import com.tencent.supersonic.chat.service.AgentService;
|
||||
import com.tencent.supersonic.chat.utils.ComponentFactory;
|
||||
import com.tencent.supersonic.common.pojo.DateConf;
|
||||
import com.tencent.supersonic.common.util.ContextUtils;
|
||||
import com.tencent.supersonic.semantic.api.model.enums.TimeDimensionEnum;
|
||||
import com.tencent.supersonic.semantic.api.query.enums.FilterOperatorEnum;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.springframework.util.CollectionUtils;
|
||||
import java.util.*;
|
||||
import java.util.stream.Collectors;
|
||||
|
||||
@Slf4j
|
||||
public class MetricInterpretParser implements SemanticParser {
|
||||
|
||||
@Override
|
||||
public void parse(QueryContext queryContext, ChatContext chatContext) {
|
||||
if (SatisfactionChecker.check(queryContext)) {
|
||||
log.info("skip MetricInterpretParser");
|
||||
return;
|
||||
}
|
||||
Map<Long, MetricInterpretTool> metricInterpretToolMap = getMetricInterpretTools(queryContext.getRequest().getAgentId());
|
||||
log.info("metric interpret tool : {}", metricInterpretToolMap);
|
||||
if (CollectionUtils.isEmpty(metricInterpretToolMap)) {
|
||||
return;
|
||||
}
|
||||
Map<Long, List<SchemaElementMatch>> elementMatches = queryContext.getMapInfo().getModelElementMatches();
|
||||
for (Long modelId : elementMatches.keySet()) {
|
||||
MetricInterpretTool metricInterpretTool = metricInterpretToolMap.get(modelId);
|
||||
if (metricInterpretTool == null) {
|
||||
continue;
|
||||
}
|
||||
if (CollectionUtils.isEmpty(elementMatches.get(modelId))) {
|
||||
continue;
|
||||
}
|
||||
List<MetricOption> metricOptions = metricInterpretTool.getMetricOptions();
|
||||
if (!CollectionUtils.isEmpty(metricOptions)) {
|
||||
List<Long> metricIds = metricOptions.stream().map(MetricOption::getMetricId).collect(Collectors.toList());
|
||||
buildQuery(modelId, queryContext, metricIds, elementMatches.get(modelId), metricInterpretTool.getName());
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
private void buildQuery(Long modelId, QueryContext queryContext,
|
||||
List<Long> metricIds, List<SchemaElementMatch> schemaElementMatches, String toolName) {
|
||||
PluginSemanticQuery metricInterpretQuery = QueryManager.createPluginQuery(MetricInterpretQuery.QUERY_MODE);
|
||||
Set<SchemaElement> metrics = getMetrics(metricIds, modelId);
|
||||
SemanticParseInfo semanticParseInfo = buildSemanticParseInfo(modelId, queryContext.getRequest(),
|
||||
metrics, schemaElementMatches, toolName);
|
||||
semanticParseInfo.setQueryMode(metricInterpretQuery.getQueryMode());
|
||||
semanticParseInfo.getProperties().put("queryText", queryContext.getRequest().getQueryText());
|
||||
metricInterpretQuery.setParseInfo(semanticParseInfo);
|
||||
queryContext.getCandidateQueries().add(metricInterpretQuery);
|
||||
}
|
||||
|
||||
public Set<SchemaElement> getMetrics(List<Long> metricIds, Long modelId) {
|
||||
SemanticLayer semanticLayer = ComponentFactory.getSemanticLayer();
|
||||
ModelSchema modelSchema = semanticLayer.getModelSchema(modelId, true);
|
||||
Set<SchemaElement> metrics = modelSchema.getMetrics();
|
||||
return metrics.stream().filter(schemaElement -> metricIds.contains(schemaElement.getId()))
|
||||
.collect(Collectors.toSet());
|
||||
}
|
||||
|
||||
private Map<Long, MetricInterpretTool> getMetricInterpretTools(Integer agentId) {
|
||||
AgentService agentService = ContextUtils.getBean(AgentService.class);
|
||||
Agent agent = agentService.getAgent(agentId);
|
||||
if (agent == null) {
|
||||
return new HashMap<>();
|
||||
}
|
||||
List<String> tools= agent.getTools(AgentToolType.INTERPRET);
|
||||
if (CollectionUtils.isEmpty(tools)) {
|
||||
return new HashMap<>();
|
||||
}
|
||||
List<MetricInterpretTool> metricInterpretTools = tools.stream().map(tool ->
|
||||
JSONObject.parseObject(tool, MetricInterpretTool.class))
|
||||
.filter(tool -> !CollectionUtils.isEmpty(tool.getMetricOptions()))
|
||||
.collect(Collectors.toList());
|
||||
Map<Long, MetricInterpretTool> metricInterpretToolMap = new HashMap<>();
|
||||
for (MetricInterpretTool metricInterpretTool : metricInterpretTools) {
|
||||
metricInterpretToolMap.putIfAbsent(metricInterpretTool.getModelId(),
|
||||
metricInterpretTool);
|
||||
}
|
||||
return metricInterpretToolMap;
|
||||
}
|
||||
|
||||
private SemanticParseInfo buildSemanticParseInfo(Long modelId, QueryReq queryReq, Set<SchemaElement> metrics,
|
||||
List<SchemaElementMatch> schemaElementMatches, String toolName) {
|
||||
SchemaElement Model = new SchemaElement();
|
||||
Model.setModel(modelId);
|
||||
Model.setId(modelId);
|
||||
SemanticParseInfo semanticParseInfo = new SemanticParseInfo();
|
||||
semanticParseInfo.setMetrics(metrics);
|
||||
SchemaElement dimension = new SchemaElement();
|
||||
dimension.setBizName(TimeDimensionEnum.DAY.getName());
|
||||
semanticParseInfo.setDimensions(Sets.newHashSet(dimension));
|
||||
semanticParseInfo.setElementMatches(schemaElementMatches);
|
||||
semanticParseInfo.setModel(Model);
|
||||
semanticParseInfo.setScore(queryReq.getQueryText().length());
|
||||
DateConf dateConf = new DateConf();
|
||||
dateConf.setDateMode(DateConf.DateMode.RECENT);
|
||||
dateConf.setUnit(15);
|
||||
semanticParseInfo.setDateInfo(dateConf);
|
||||
Map<String, Object> properties = new HashMap<>();
|
||||
properties.put("type", "internal");
|
||||
properties.put("name", toolName);
|
||||
semanticParseInfo.setProperties(properties);
|
||||
fillSemanticParseInfo(semanticParseInfo);
|
||||
return semanticParseInfo;
|
||||
}
|
||||
|
||||
private void fillSemanticParseInfo(SemanticParseInfo semanticParseInfo) {
|
||||
List<SchemaElementMatch> schemaElementMatches = semanticParseInfo.getElementMatches();
|
||||
if (!CollectionUtils.isEmpty(schemaElementMatches)) {
|
||||
schemaElementMatches.stream().filter(schemaElementMatch ->
|
||||
SchemaElementType.VALUE.equals(schemaElementMatch.getElement().getType())
|
||||
|| SchemaElementType.ID.equals(schemaElementMatch.getElement().getType()))
|
||||
.forEach(schemaElementMatch -> {
|
||||
QueryFilter queryFilter = new QueryFilter();
|
||||
queryFilter.setValue(schemaElementMatch.getWord());
|
||||
queryFilter.setElementID(schemaElementMatch.getElement().getId());
|
||||
queryFilter.setName(schemaElementMatch.getElement().getName());
|
||||
queryFilter.setOperator(FilterOperatorEnum.EQUALS);
|
||||
queryFilter.setBizName(schemaElementMatch.getElement().getBizName());
|
||||
semanticParseInfo.getDimensionFilters().add(queryFilter);
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
@@ -0,0 +1,14 @@
|
||||
package com.tencent.supersonic.chat.parser.llm.interpret;
|
||||
|
||||
|
||||
import lombok.AllArgsConstructor;
|
||||
import lombok.Data;
|
||||
import lombok.NoArgsConstructor;
|
||||
|
||||
@Data
|
||||
@AllArgsConstructor
|
||||
@NoArgsConstructor
|
||||
public class MetricOption {
|
||||
|
||||
private Long metricId;
|
||||
}
|
||||
@@ -1,4 +1,4 @@
|
||||
package com.tencent.supersonic.chat.parser.llm;
|
||||
package com.tencent.supersonic.chat.parser.llm.time;
|
||||
|
||||
import com.alibaba.fastjson.JSON;
|
||||
import com.alibaba.fastjson.JSONObject;
|
||||
@@ -6,8 +6,8 @@ import com.tencent.supersonic.chat.api.component.SemanticParser;
|
||||
import com.tencent.supersonic.chat.api.component.SemanticQuery;
|
||||
import com.tencent.supersonic.chat.api.pojo.ChatContext;
|
||||
import com.tencent.supersonic.chat.api.pojo.QueryContext;
|
||||
import com.tencent.supersonic.chat.utils.ChatGptHelper;
|
||||
import com.tencent.supersonic.common.pojo.DateConf;
|
||||
import com.tencent.supersonic.common.util.ChatGptHelper;
|
||||
import com.tencent.supersonic.common.util.ContextUtils;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
|
||||
@@ -17,20 +17,20 @@ public class LLMTimeEnhancementParse implements SemanticParser {
|
||||
|
||||
@Override
|
||||
public void parse(QueryContext queryContext, ChatContext chatContext) {
|
||||
log.info("before queryContext:{},chatContext:{}", queryContext, chatContext);
|
||||
log.info("before queryContext:{},chatContext:{}",queryContext,chatContext);
|
||||
ChatGptHelper chatGptHelper = ContextUtils.getBean(ChatGptHelper.class);
|
||||
String inferredTime = chatGptHelper.inferredTime(queryContext.getRequest().getQueryText());
|
||||
try {
|
||||
String inferredTime = chatGptHelper.inferredTime(queryContext.getRequest().getQueryText());
|
||||
if (!queryContext.getCandidateQueries().isEmpty()) {
|
||||
for (SemanticQuery query : queryContext.getCandidateQueries()) {
|
||||
DateConf dateInfo = query.getParseInfo().getDateInfo();
|
||||
JSONObject jsonObject = JSON.parseObject(inferredTime);
|
||||
if (jsonObject.containsKey("date")) {
|
||||
if (jsonObject.containsKey("date")){
|
||||
dateInfo.setDateMode(DateConf.DateMode.BETWEEN);
|
||||
dateInfo.setStartDate(jsonObject.getString("date"));
|
||||
dateInfo.setEndDate(jsonObject.getString("date"));
|
||||
query.getParseInfo().setDateInfo(dateInfo);
|
||||
} else if (jsonObject.containsKey("start")) {
|
||||
}else if (jsonObject.containsKey("start")){
|
||||
dateInfo.setDateMode(DateConf.DateMode.BETWEEN);
|
||||
dateInfo.setStartDate(jsonObject.getString("start"));
|
||||
dateInfo.setEndDate(jsonObject.getString("end"));
|
||||
@@ -38,12 +38,11 @@ public class LLMTimeEnhancementParse implements SemanticParser {
|
||||
}
|
||||
}
|
||||
}
|
||||
} catch (Exception exception) {
|
||||
log.error("{} parse error,this reason is:{}", LLMTimeEnhancementParse.class.getSimpleName(),
|
||||
(Object) exception.getStackTrace());
|
||||
}catch (Exception exception){
|
||||
log.error("{} parse error,this reason is:{}",LLMTimeEnhancementParse.class.getSimpleName(), (Object) exception.getStackTrace());
|
||||
}
|
||||
|
||||
log.info("after queryContext:{},chatContext:{}", queryContext, chatContext);
|
||||
log.info("{} after queryContext:{},chatContext:{}",LLMTimeEnhancementParse.class.getSimpleName(),queryContext,chatContext);
|
||||
}
|
||||
|
||||
|
||||
@@ -0,0 +1,63 @@
|
||||
package com.tencent.supersonic.chat.parser.rule;
|
||||
|
||||
import com.alibaba.fastjson.JSONObject;
|
||||
import com.google.common.collect.Lists;
|
||||
import com.tencent.supersonic.chat.agent.Agent;
|
||||
import com.tencent.supersonic.chat.agent.tool.AgentToolType;
|
||||
import com.tencent.supersonic.chat.agent.tool.RuleQueryTool;
|
||||
import com.tencent.supersonic.chat.api.component.SemanticParser;
|
||||
import com.tencent.supersonic.chat.api.component.SemanticQuery;
|
||||
import com.tencent.supersonic.chat.api.pojo.ChatContext;
|
||||
import com.tencent.supersonic.chat.api.pojo.QueryContext;
|
||||
import com.tencent.supersonic.chat.service.AgentService;
|
||||
import com.tencent.supersonic.common.util.ContextUtils;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.apache.commons.collections.CollectionUtils;
|
||||
import java.util.Collection;
|
||||
import java.util.List;
|
||||
import java.util.stream.Collectors;
|
||||
|
||||
@Slf4j
|
||||
public class AgentCheckParser implements SemanticParser {
|
||||
|
||||
@Override
|
||||
public void parse(QueryContext queryContext, ChatContext chatContext) {
|
||||
List<SemanticQuery> queries = queryContext.getCandidateQueries();
|
||||
agentCanSupport(queryContext.getRequest().getAgentId(), queries);
|
||||
}
|
||||
|
||||
private void agentCanSupport(Integer agentId, List<SemanticQuery> queries) {
|
||||
AgentService agentService = ContextUtils.getBean(AgentService.class);
|
||||
Agent agent = agentService.getAgent(agentId);
|
||||
if (agent == null) {
|
||||
return;
|
||||
}
|
||||
List<String> queryModes = getRuleTools(agentId).stream().map(RuleQueryTool::getQueryModes)
|
||||
.flatMap(Collection::stream).collect(Collectors.toList());
|
||||
if (CollectionUtils.isEmpty(queries)) {
|
||||
queries.clear();
|
||||
return;
|
||||
}
|
||||
log.info("queries resolved:{} {}", agent.getName(),
|
||||
queries.stream().map(SemanticQuery::getQueryMode).collect(Collectors.toList()));
|
||||
queries.removeIf(query ->
|
||||
!queryModes.contains(query.getQueryMode()));
|
||||
log.info("rule queries witch can be supported by agent :{} {}", agent.getName(),
|
||||
queries.stream().map(SemanticQuery::getQueryMode).collect(Collectors.toList()));
|
||||
}
|
||||
|
||||
private static List<RuleQueryTool> getRuleTools(Integer agentId) {
|
||||
AgentService agentService = ContextUtils.getBean(AgentService.class);
|
||||
Agent agent = agentService.getAgent(agentId);
|
||||
if (agent == null) {
|
||||
return Lists.newArrayList();
|
||||
}
|
||||
List<String> tools = agent.getTools(AgentToolType.RULE);
|
||||
if (CollectionUtils.isEmpty(tools)) {
|
||||
return Lists.newArrayList();
|
||||
}
|
||||
return tools.stream().map(tool -> JSONObject.parseObject(tool, RuleQueryTool.class))
|
||||
.collect(Collectors.toList());
|
||||
}
|
||||
|
||||
}
|
||||
@@ -1,12 +1,9 @@
|
||||
package com.tencent.supersonic.chat.parser.rule;
|
||||
|
||||
import com.tencent.supersonic.chat.api.component.SemanticParser;
|
||||
import com.tencent.supersonic.chat.api.pojo.ChatContext;
|
||||
import com.tencent.supersonic.chat.api.pojo.QueryContext;
|
||||
import com.tencent.supersonic.chat.api.pojo.SchemaElementMatch;
|
||||
import com.tencent.supersonic.chat.api.pojo.SchemaMapInfo;
|
||||
import com.tencent.supersonic.chat.api.pojo.*;
|
||||
import com.tencent.supersonic.chat.query.rule.RuleSemanticQuery;
|
||||
import java.util.List;
|
||||
import java.util.*;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
|
||||
@Slf4j
|
||||
@@ -15,12 +12,10 @@ public class QueryModeParser implements SemanticParser {
|
||||
@Override
|
||||
public void parse(QueryContext queryContext, ChatContext chatContext) {
|
||||
SchemaMapInfo mapInfo = queryContext.getMapInfo();
|
||||
|
||||
// iterate all schemaElementMatches to resolve semantic query
|
||||
for (Long modelId : mapInfo.getMatchedModels()) {
|
||||
List<SchemaElementMatch> elementMatches = mapInfo.getMatchedElements(modelId);
|
||||
List<RuleSemanticQuery> queries = RuleSemanticQuery.resolve(elementMatches, queryContext);
|
||||
|
||||
for (RuleSemanticQuery query : queries) {
|
||||
query.fillParseInfo(modelId, queryContext, chatContext);
|
||||
queryContext.getCandidateQueries().add(query);
|
||||
|
||||
@@ -0,0 +1,236 @@
|
||||
package com.tencent.supersonic.chat.persistence.dataobject;
|
||||
|
||||
import java.util.Date;
|
||||
|
||||
public class AgentDO {
|
||||
/**
|
||||
*
|
||||
*/
|
||||
private Integer id;
|
||||
|
||||
/**
|
||||
*
|
||||
*/
|
||||
private String name;
|
||||
|
||||
/**
|
||||
*
|
||||
*/
|
||||
private String description;
|
||||
|
||||
/**
|
||||
* 0 offline, 1 online
|
||||
*/
|
||||
private Integer status;
|
||||
|
||||
/**
|
||||
*
|
||||
*/
|
||||
private String examples;
|
||||
|
||||
/**
|
||||
*
|
||||
*/
|
||||
private String config;
|
||||
|
||||
/**
|
||||
*
|
||||
*/
|
||||
private String createdBy;
|
||||
|
||||
/**
|
||||
*
|
||||
*/
|
||||
private Date createdAt;
|
||||
|
||||
/**
|
||||
*
|
||||
*/
|
||||
private String updatedBy;
|
||||
|
||||
/**
|
||||
*
|
||||
*/
|
||||
private Date updatedAt;
|
||||
|
||||
/**
|
||||
*
|
||||
*/
|
||||
private Integer enableSearch;
|
||||
|
||||
/**
|
||||
*
|
||||
* @return id
|
||||
*/
|
||||
public Integer getId() {
|
||||
return id;
|
||||
}
|
||||
|
||||
/**
|
||||
*
|
||||
* @param id
|
||||
*/
|
||||
public void setId(Integer id) {
|
||||
this.id = id;
|
||||
}
|
||||
|
||||
/**
|
||||
*
|
||||
* @return name
|
||||
*/
|
||||
public String getName() {
|
||||
return name;
|
||||
}
|
||||
|
||||
/**
|
||||
*
|
||||
* @param name
|
||||
*/
|
||||
public void setName(String name) {
|
||||
this.name = name == null ? null : name.trim();
|
||||
}
|
||||
|
||||
/**
|
||||
*
|
||||
* @return description
|
||||
*/
|
||||
public String getDescription() {
|
||||
return description;
|
||||
}
|
||||
|
||||
/**
|
||||
*
|
||||
* @param description
|
||||
*/
|
||||
public void setDescription(String description) {
|
||||
this.description = description == null ? null : description.trim();
|
||||
}
|
||||
|
||||
/**
|
||||
* 0 offline, 1 online
|
||||
* @return status 0 offline, 1 online
|
||||
*/
|
||||
public Integer getStatus() {
|
||||
return status;
|
||||
}
|
||||
|
||||
/**
|
||||
* 0 offline, 1 online
|
||||
* @param status 0 offline, 1 online
|
||||
*/
|
||||
public void setStatus(Integer status) {
|
||||
this.status = status;
|
||||
}
|
||||
|
||||
/**
|
||||
*
|
||||
* @return examples
|
||||
*/
|
||||
public String getExamples() {
|
||||
return examples;
|
||||
}
|
||||
|
||||
/**
|
||||
*
|
||||
* @param examples
|
||||
*/
|
||||
public void setExamples(String examples) {
|
||||
this.examples = examples == null ? null : examples.trim();
|
||||
}
|
||||
|
||||
/**
|
||||
*
|
||||
* @return config
|
||||
*/
|
||||
public String getConfig() {
|
||||
return config;
|
||||
}
|
||||
|
||||
/**
|
||||
*
|
||||
* @param config
|
||||
*/
|
||||
public void setConfig(String config) {
|
||||
this.config = config == null ? null : config.trim();
|
||||
}
|
||||
|
||||
/**
|
||||
*
|
||||
* @return created_by
|
||||
*/
|
||||
public String getCreatedBy() {
|
||||
return createdBy;
|
||||
}
|
||||
|
||||
/**
|
||||
*
|
||||
* @param createdBy
|
||||
*/
|
||||
public void setCreatedBy(String createdBy) {
|
||||
this.createdBy = createdBy == null ? null : createdBy.trim();
|
||||
}
|
||||
|
||||
/**
|
||||
*
|
||||
* @return created_at
|
||||
*/
|
||||
public Date getCreatedAt() {
|
||||
return createdAt;
|
||||
}
|
||||
|
||||
/**
|
||||
*
|
||||
* @param createdAt
|
||||
*/
|
||||
public void setCreatedAt(Date createdAt) {
|
||||
this.createdAt = createdAt;
|
||||
}
|
||||
|
||||
/**
|
||||
*
|
||||
* @return updated_by
|
||||
*/
|
||||
public String getUpdatedBy() {
|
||||
return updatedBy;
|
||||
}
|
||||
|
||||
/**
|
||||
*
|
||||
* @param updatedBy
|
||||
*/
|
||||
public void setUpdatedBy(String updatedBy) {
|
||||
this.updatedBy = updatedBy == null ? null : updatedBy.trim();
|
||||
}
|
||||
|
||||
/**
|
||||
*
|
||||
* @return updated_at
|
||||
*/
|
||||
public Date getUpdatedAt() {
|
||||
return updatedAt;
|
||||
}
|
||||
|
||||
/**
|
||||
*
|
||||
* @param updatedAt
|
||||
*/
|
||||
public void setUpdatedAt(Date updatedAt) {
|
||||
this.updatedAt = updatedAt;
|
||||
}
|
||||
|
||||
/**
|
||||
*
|
||||
* @return enable_search
|
||||
*/
|
||||
public Integer getEnableSearch() {
|
||||
return enableSearch;
|
||||
}
|
||||
|
||||
/**
|
||||
*
|
||||
* @param enableSearch
|
||||
*/
|
||||
public void setEnableSearch(Integer enableSearch) {
|
||||
this.enableSearch = enableSearch;
|
||||
}
|
||||
}
|
||||
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,71 @@
|
||||
package com.tencent.supersonic.chat.persistence.mapper;
|
||||
|
||||
import com.tencent.supersonic.chat.persistence.dataobject.AgentDO;
|
||||
import com.tencent.supersonic.chat.persistence.dataobject.AgentDOExample;
|
||||
import org.apache.ibatis.annotations.Mapper;
|
||||
import org.apache.ibatis.annotations.Param;
|
||||
|
||||
import java.util.List;
|
||||
|
||||
@Mapper
|
||||
public interface AgentDOMapper {
|
||||
/**
|
||||
*
|
||||
* @mbg.generated
|
||||
*/
|
||||
long countByExample(AgentDOExample example);
|
||||
|
||||
/**
|
||||
*
|
||||
* @mbg.generated
|
||||
*/
|
||||
int deleteByPrimaryKey(Integer id);
|
||||
|
||||
/**
|
||||
*
|
||||
* @mbg.generated
|
||||
*/
|
||||
int insert(AgentDO record);
|
||||
|
||||
/**
|
||||
*
|
||||
* @mbg.generated
|
||||
*/
|
||||
int insertSelective(AgentDO record);
|
||||
|
||||
/**
|
||||
*
|
||||
* @mbg.generated
|
||||
*/
|
||||
List<AgentDO> selectByExample(AgentDOExample example);
|
||||
|
||||
/**
|
||||
*
|
||||
* @mbg.generated
|
||||
*/
|
||||
AgentDO selectByPrimaryKey(Integer id);
|
||||
|
||||
/**
|
||||
*
|
||||
* @mbg.generated
|
||||
*/
|
||||
int updateByExampleSelective(@Param("record") AgentDO record, @Param("example") AgentDOExample example);
|
||||
|
||||
/**
|
||||
*
|
||||
* @mbg.generated
|
||||
*/
|
||||
int updateByExample(@Param("record") AgentDO record, @Param("example") AgentDOExample example);
|
||||
|
||||
/**
|
||||
*
|
||||
* @mbg.generated
|
||||
*/
|
||||
int updateByPrimaryKeySelective(AgentDO record);
|
||||
|
||||
/**
|
||||
*
|
||||
* @mbg.generated
|
||||
*/
|
||||
int updateByPrimaryKey(AgentDO record);
|
||||
}
|
||||
@@ -0,0 +1,18 @@
|
||||
package com.tencent.supersonic.chat.persistence.repository;
|
||||
|
||||
import com.tencent.supersonic.chat.persistence.dataobject.AgentDO;
|
||||
|
||||
import java.util.List;
|
||||
|
||||
public interface AgentRepository {
|
||||
|
||||
List<AgentDO> getAgents();
|
||||
|
||||
void createAgent(AgentDO agentDO);
|
||||
|
||||
void updateAgent(AgentDO agentDO);
|
||||
|
||||
AgentDO getAgent(Integer id);
|
||||
|
||||
void deleteAgent(Integer id);
|
||||
}
|
||||
@@ -0,0 +1,43 @@
|
||||
package com.tencent.supersonic.chat.persistence.repository.impl;
|
||||
|
||||
import com.tencent.supersonic.chat.persistence.dataobject.AgentDO;
|
||||
import com.tencent.supersonic.chat.persistence.dataobject.AgentDOExample;
|
||||
import com.tencent.supersonic.chat.persistence.mapper.AgentDOMapper;
|
||||
import com.tencent.supersonic.chat.persistence.repository.AgentRepository;
|
||||
import org.springframework.stereotype.Repository;
|
||||
import java.util.List;
|
||||
|
||||
@Repository
|
||||
public class AgentRepositoryImpl implements AgentRepository {
|
||||
|
||||
private AgentDOMapper agentDOMapper;
|
||||
|
||||
public AgentRepositoryImpl(AgentDOMapper agentDOMapper) {
|
||||
this.agentDOMapper = agentDOMapper;
|
||||
}
|
||||
|
||||
@Override
|
||||
public List<AgentDO> getAgents() {
|
||||
return agentDOMapper.selectByExample(new AgentDOExample());
|
||||
}
|
||||
|
||||
@Override
|
||||
public void createAgent(AgentDO agentDO) {
|
||||
agentDOMapper.insert(agentDO);
|
||||
}
|
||||
|
||||
@Override
|
||||
public void updateAgent(AgentDO agentDO) {
|
||||
agentDOMapper.updateByPrimaryKey(agentDO);
|
||||
}
|
||||
|
||||
@Override
|
||||
public AgentDO getAgent(Integer id) {
|
||||
return agentDOMapper.selectByPrimaryKey(id);
|
||||
}
|
||||
|
||||
@Override
|
||||
public void deleteAgent(Integer id) {
|
||||
agentDOMapper.deleteByPrimaryKey(id);
|
||||
}
|
||||
}
|
||||
@@ -7,12 +7,14 @@ import com.tencent.supersonic.chat.persistence.dataobject.ChatContextDO;
|
||||
import com.tencent.supersonic.chat.persistence.mapper.ChatContextMapper;
|
||||
import com.tencent.supersonic.chat.persistence.repository.ChatContextRepository;
|
||||
import com.tencent.supersonic.common.util.JsonUtil;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.springframework.beans.factory.annotation.Autowired;
|
||||
import org.springframework.context.annotation.Primary;
|
||||
import org.springframework.stereotype.Repository;
|
||||
|
||||
@Repository
|
||||
@Primary
|
||||
@Slf4j
|
||||
public class ChatContextRepositoryImpl implements ChatContextRepository {
|
||||
|
||||
@Autowired(required = false)
|
||||
@@ -50,8 +52,8 @@ public class ChatContextRepositoryImpl implements ChatContextRepository {
|
||||
chatContext.setUser(contextDO.getUser());
|
||||
chatContext.setQueryText(contextDO.getQueryText());
|
||||
if (contextDO.getSemanticParse() != null && !contextDO.getSemanticParse().isEmpty()) {
|
||||
SemanticParseInfo semanticParseInfo = JsonUtil.toObject(contextDO.getSemanticParse(),
|
||||
SemanticParseInfo.class);
|
||||
log.info("--->: {}",contextDO.getSemanticParse());
|
||||
SemanticParseInfo semanticParseInfo = JsonUtil.toObject(contextDO.getSemanticParse(), SemanticParseInfo.class);
|
||||
chatContext.setParseInfo(semanticParseInfo);
|
||||
}
|
||||
return chatContext;
|
||||
|
||||
@@ -3,11 +3,11 @@ package com.tencent.supersonic.chat.plugin;
|
||||
import com.alibaba.fastjson.JSONObject;
|
||||
import com.google.common.collect.Lists;
|
||||
import com.google.common.collect.Sets;
|
||||
import com.tencent.supersonic.chat.api.pojo.QueryContext;
|
||||
import com.tencent.supersonic.chat.api.pojo.SchemaElement;
|
||||
import com.tencent.supersonic.chat.api.pojo.SchemaElementMatch;
|
||||
import com.tencent.supersonic.chat.api.pojo.SchemaElementType;
|
||||
import com.tencent.supersonic.chat.api.pojo.SchemaMapInfo;
|
||||
import com.tencent.supersonic.chat.agent.tool.DslTool;
|
||||
import com.tencent.supersonic.chat.api.pojo.*;
|
||||
import com.tencent.supersonic.chat.agent.Agent;
|
||||
import com.tencent.supersonic.chat.agent.tool.AgentToolType;
|
||||
import com.tencent.supersonic.chat.agent.tool.PluginTool;
|
||||
import com.tencent.supersonic.chat.parser.ParseMode;
|
||||
import com.tencent.supersonic.chat.parser.embedding.EmbeddingConfig;
|
||||
import com.tencent.supersonic.chat.parser.embedding.EmbeddingResp;
|
||||
@@ -16,30 +16,22 @@ import com.tencent.supersonic.chat.plugin.event.PluginAddEvent;
|
||||
import com.tencent.supersonic.chat.plugin.event.PluginUpdateEvent;
|
||||
import com.tencent.supersonic.chat.query.plugin.ParamOption;
|
||||
import com.tencent.supersonic.chat.query.plugin.WebBase;
|
||||
import com.tencent.supersonic.chat.service.AgentService;
|
||||
import com.tencent.supersonic.chat.service.PluginService;
|
||||
import com.tencent.supersonic.common.util.ContextUtils;
|
||||
import java.net.URI;
|
||||
import java.util.ArrayList;
|
||||
import java.util.HashMap;
|
||||
import java.util.HashSet;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
import java.util.Objects;
|
||||
import java.util.Optional;
|
||||
import java.util.Set;
|
||||
import java.util.*;
|
||||
import java.util.concurrent.ConcurrentHashMap;
|
||||
import java.util.stream.Collectors;
|
||||
import java.util.stream.Stream;
|
||||
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.apache.commons.collections.CollectionUtils;
|
||||
import org.apache.commons.lang3.tuple.Pair;
|
||||
import org.apache.logging.log4j.util.Strings;
|
||||
import org.springframework.context.event.EventListener;
|
||||
import org.springframework.core.ParameterizedTypeReference;
|
||||
import org.springframework.http.HttpEntity;
|
||||
import org.springframework.http.HttpHeaders;
|
||||
import org.springframework.http.HttpMethod;
|
||||
import org.springframework.http.MediaType;
|
||||
import org.springframework.http.ResponseEntity;
|
||||
import org.springframework.http.*;
|
||||
import org.springframework.stereotype.Component;
|
||||
import org.springframework.web.client.RestTemplate;
|
||||
import org.springframework.web.util.UriComponentsBuilder;
|
||||
@@ -59,12 +51,40 @@ public class PluginManager {
|
||||
this.restTemplate = restTemplate;
|
||||
}
|
||||
|
||||
public static List<Plugin> getPlugins() {
|
||||
public static List<Plugin> getPluginAgentCanSupport(Integer agentId) {
|
||||
PluginService pluginService = ContextUtils.getBean(PluginService.class);
|
||||
List<Plugin> pluginList = pluginService.getPluginList().stream().filter(plugin ->
|
||||
CollectionUtils.isNotEmpty(plugin.getModelList())).collect(Collectors.toList());
|
||||
pluginList.addAll(internalPluginMap.values());
|
||||
return new ArrayList<>(pluginList);
|
||||
List<Plugin> plugins = pluginService.getPluginList();
|
||||
if (agentId == null) {
|
||||
return plugins;
|
||||
}
|
||||
Agent agent = ContextUtils.getBean(AgentService.class).getAgent(agentId);
|
||||
if (agent == null) {
|
||||
return plugins;
|
||||
}
|
||||
List<Long> pluginIds = getPluginTools(agentId).stream().map(PluginTool::getPlugins)
|
||||
.flatMap(Collection::stream).collect(Collectors.toList());
|
||||
if (CollectionUtils.isEmpty(pluginIds)) {
|
||||
return Lists.newArrayList();
|
||||
}
|
||||
plugins = plugins.stream().filter(plugin -> pluginIds.contains(plugin.getId()))
|
||||
.collect(Collectors.toList());
|
||||
log.info("plugins witch can be supported by cur agent :{} {}", agent.getName(),
|
||||
plugins.stream().map(Plugin::getName).collect(Collectors.toList()));
|
||||
return plugins;
|
||||
}
|
||||
|
||||
private static List<PluginTool> getPluginTools(Integer agentId) {
|
||||
AgentService agentService = ContextUtils.getBean(AgentService.class);
|
||||
Agent agent = agentService.getAgent(agentId);
|
||||
if (agent == null) {
|
||||
return Lists.newArrayList();
|
||||
}
|
||||
List<String> tools = agent.getTools(AgentToolType.PLUGIN);
|
||||
if (CollectionUtils.isEmpty(tools)) {
|
||||
return Lists.newArrayList();
|
||||
}
|
||||
return tools.stream().map(tool -> JSONObject.parseObject(tool, PluginTool.class))
|
||||
.collect(Collectors.toList());
|
||||
}
|
||||
|
||||
@EventListener
|
||||
@@ -201,17 +221,17 @@ public class PluginManager {
|
||||
return String.valueOf(Integer.parseInt(id) / 1000);
|
||||
}
|
||||
|
||||
public static Pair<Boolean, List<Long>> resolve(Plugin plugin, QueryContext queryContext) {
|
||||
public static Pair<Boolean, Set<Long>> resolve(Plugin plugin, QueryContext queryContext) {
|
||||
SchemaMapInfo schemaMapInfo = queryContext.getMapInfo();
|
||||
Set<Long> pluginMatchedModel = getPluginMatchedModel(plugin, queryContext);
|
||||
if (CollectionUtils.isEmpty(pluginMatchedModel) && !plugin.isContainsAllModel()) {
|
||||
return Pair.of(false, Lists.newArrayList());
|
||||
return Pair.of(false, Sets.newHashSet());
|
||||
}
|
||||
List<ParamOption> paramOptions = getSemanticOption(plugin);
|
||||
if (CollectionUtils.isEmpty(paramOptions)) {
|
||||
return Pair.of(true, new ArrayList<>(pluginMatchedModel));
|
||||
return Pair.of(true, Sets.newHashSet());
|
||||
}
|
||||
List<Long> matchedModel = Lists.newArrayList();
|
||||
Set<Long> matchedModel = Sets.newHashSet();
|
||||
Map<Long, List<ParamOption>> paramOptionMap = paramOptions.stream().
|
||||
collect(Collectors.groupingBy(ParamOption::getModelId));
|
||||
for (Long modelId : paramOptionMap.keySet()) {
|
||||
@@ -237,7 +257,7 @@ public class PluginManager {
|
||||
}
|
||||
}
|
||||
if (CollectionUtils.isEmpty(matchedModel)) {
|
||||
return Pair.of(false, Lists.newArrayList());
|
||||
return Pair.of(false, Sets.newHashSet());
|
||||
}
|
||||
return Pair.of(true, matchedModel);
|
||||
}
|
||||
|
||||
@@ -1,149 +0,0 @@
|
||||
package com.tencent.supersonic.chat.query.ContentInterpret;
|
||||
|
||||
import com.alibaba.fastjson.JSONObject;
|
||||
import com.google.common.collect.Lists;
|
||||
import com.tencent.supersonic.auth.api.authentication.pojo.User;
|
||||
import com.tencent.supersonic.chat.api.component.SemanticLayer;
|
||||
import com.tencent.supersonic.chat.api.pojo.ModelSchema;
|
||||
import com.tencent.supersonic.chat.api.pojo.SchemaElement;
|
||||
import com.tencent.supersonic.chat.api.pojo.request.QueryFilter;
|
||||
import com.tencent.supersonic.chat.api.pojo.response.QueryResult;
|
||||
import com.tencent.supersonic.chat.api.pojo.response.QueryState;
|
||||
import com.tencent.supersonic.chat.plugin.PluginManager;
|
||||
import com.tencent.supersonic.chat.plugin.PluginParseResult;
|
||||
import com.tencent.supersonic.chat.query.QueryManager;
|
||||
import com.tencent.supersonic.chat.query.plugin.PluginSemanticQuery;
|
||||
import com.tencent.supersonic.chat.utils.ComponentFactory;
|
||||
import com.tencent.supersonic.common.pojo.Aggregator;
|
||||
import com.tencent.supersonic.common.pojo.Constants;
|
||||
import com.tencent.supersonic.common.pojo.DateConf;
|
||||
import com.tencent.supersonic.common.pojo.QueryColumn;
|
||||
import com.tencent.supersonic.common.pojo.enums.AggOperatorEnum;
|
||||
import com.tencent.supersonic.common.util.ContextUtils;
|
||||
import com.tencent.supersonic.common.util.JsonUtil;
|
||||
import com.tencent.supersonic.semantic.api.model.enums.TimeDimensionEnum;
|
||||
import com.tencent.supersonic.semantic.api.model.response.QueryResultWithSchemaResp;
|
||||
import com.tencent.supersonic.semantic.api.query.pojo.Filter;
|
||||
import com.tencent.supersonic.semantic.api.query.request.QueryStructReq;
|
||||
import java.util.HashMap;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
import java.util.Set;
|
||||
import java.util.stream.Collectors;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.apache.calcite.sql.parser.SqlParseException;
|
||||
import org.springframework.beans.BeanUtils;
|
||||
import org.springframework.http.ResponseEntity;
|
||||
import org.springframework.stereotype.Component;
|
||||
import org.springframework.util.CollectionUtils;
|
||||
|
||||
@Slf4j
|
||||
@Component
|
||||
public class ContentInterpretQuery extends PluginSemanticQuery {
|
||||
|
||||
@Override
|
||||
public String getQueryMode() {
|
||||
return "CONTENT_INTERPRET";
|
||||
}
|
||||
|
||||
public ContentInterpretQuery() {
|
||||
QueryManager.register(this);
|
||||
}
|
||||
|
||||
@Override
|
||||
public QueryResult execute(User user) throws SqlParseException {
|
||||
QueryResultWithSchemaResp queryResultWithSchemaResp = queryMetric(user);
|
||||
String text = generateDataText(queryResultWithSchemaResp);
|
||||
Map<String, Object> properties = parseInfo.getProperties();
|
||||
PluginParseResult pluginParseResult = JsonUtil.toObject(JsonUtil.toString(properties.get(Constants.CONTEXT))
|
||||
, PluginParseResult.class);
|
||||
String answer = fetchInterpret(pluginParseResult.getRequest().getQueryText(), text);
|
||||
QueryResult queryResult = new QueryResult();
|
||||
List<QueryColumn> queryColumns = Lists.newArrayList(new QueryColumn("结果", "string", "answer"));
|
||||
Map<String, Object> result = new HashMap<>();
|
||||
result.put("answer", answer);
|
||||
List<Map<String, Object>> resultList = Lists.newArrayList();
|
||||
resultList.add(result);
|
||||
queryResultWithSchemaResp.setResultList(resultList);
|
||||
queryResultWithSchemaResp.setColumns(queryColumns);
|
||||
queryResult.setResponse(queryResultWithSchemaResp);
|
||||
queryResult.setQueryMode(getQueryMode());
|
||||
queryResult.setQueryState(QueryState.SUCCESS);
|
||||
return queryResult;
|
||||
}
|
||||
|
||||
|
||||
private QueryResultWithSchemaResp queryMetric(User user) {
|
||||
SemanticLayer semanticLayer = ComponentFactory.getSemanticLayer();
|
||||
QueryStructReq queryStructReq = new QueryStructReq();
|
||||
queryStructReq.setModelId(parseInfo.getModelId());
|
||||
queryStructReq.setGroups(Lists.newArrayList(TimeDimensionEnum.DAY.getName()));
|
||||
ModelSchema modelSchema = semanticLayer.getModelSchema(parseInfo.getModelId(), true);
|
||||
queryStructReq.setAggregators(buildAggregator(modelSchema));
|
||||
List<Filter> filterList = Lists.newArrayList();
|
||||
for (QueryFilter queryFilter : parseInfo.getDimensionFilters()) {
|
||||
Filter filter = new Filter();
|
||||
BeanUtils.copyProperties(queryFilter, filter);
|
||||
filterList.add(filter);
|
||||
}
|
||||
queryStructReq.setDimensionFilters(filterList);
|
||||
DateConf dateConf = new DateConf();
|
||||
dateConf.setDateMode(DateConf.DateMode.RECENT);
|
||||
dateConf.setUnit(7);
|
||||
queryStructReq.setDateInfo(dateConf);
|
||||
return semanticLayer.queryByStruct(queryStructReq, user);
|
||||
}
|
||||
|
||||
private List<Aggregator> buildAggregator(ModelSchema modelSchema) {
|
||||
List<Aggregator> aggregators = Lists.newArrayList();
|
||||
Set<SchemaElement> metrics = modelSchema.getMetrics();
|
||||
if (CollectionUtils.isEmpty(metrics)) {
|
||||
return aggregators;
|
||||
}
|
||||
for (SchemaElement schemaElement : metrics) {
|
||||
Aggregator aggregator = new Aggregator();
|
||||
aggregator.setColumn(schemaElement.getBizName());
|
||||
aggregator.setFunc(AggOperatorEnum.SUM);
|
||||
aggregator.setNameCh(schemaElement.getName());
|
||||
aggregators.add(aggregator);
|
||||
}
|
||||
return aggregators;
|
||||
}
|
||||
|
||||
|
||||
public String generateDataText(QueryResultWithSchemaResp queryResultWithSchemaResp) {
|
||||
Map<String, String> map = queryResultWithSchemaResp.getColumns().stream()
|
||||
.collect(Collectors.toMap(QueryColumn::getNameEn, QueryColumn::getName));
|
||||
StringBuilder stringBuilder = new StringBuilder();
|
||||
for (Map<String, Object> valueMap : queryResultWithSchemaResp.getResultList()) {
|
||||
for (String key : valueMap.keySet()) {
|
||||
String name = "";
|
||||
if (TimeDimensionEnum.getNameList().contains(key)) {
|
||||
name = "日期";
|
||||
} else {
|
||||
name = map.get(key);
|
||||
}
|
||||
String value = String.valueOf(valueMap.get(key));
|
||||
stringBuilder.append(name).append(":").append(value).append(" ");
|
||||
}
|
||||
}
|
||||
return stringBuilder.toString();
|
||||
}
|
||||
|
||||
|
||||
public String fetchInterpret(String queryText, String dataText) {
|
||||
PluginManager pluginManager = ContextUtils.getBean(PluginManager.class);
|
||||
LLmAnswerReq lLmAnswerReq = new LLmAnswerReq();
|
||||
lLmAnswerReq.setQueryText(queryText);
|
||||
lLmAnswerReq.setPluginOutput(dataText);
|
||||
ResponseEntity<String> responseEntity = pluginManager.doRequest("answer_with_plugin_call",
|
||||
JSONObject.toJSONString(lLmAnswerReq));
|
||||
LLmAnswerResp lLmAnswerResp = JSONObject.parseObject(responseEntity.getBody(), LLmAnswerResp.class);
|
||||
if (lLmAnswerResp != null) {
|
||||
return lLmAnswerResp.getAssistant_message();
|
||||
}
|
||||
return null;
|
||||
}
|
||||
|
||||
|
||||
}
|
||||
@@ -3,6 +3,7 @@ package com.tencent.supersonic.chat.query;
|
||||
import com.tencent.supersonic.chat.api.component.SemanticQuery;
|
||||
import com.tencent.supersonic.chat.api.pojo.SchemaElementMatch;
|
||||
import com.tencent.supersonic.chat.api.pojo.SemanticParseInfo;
|
||||
import com.tencent.supersonic.chat.api.pojo.request.QueryReq;
|
||||
import com.tencent.supersonic.chat.query.rule.RuleSemanticQuery;
|
||||
import com.tencent.supersonic.chat.query.rule.metric.MetricEntityQuery;
|
||||
import com.tencent.supersonic.chat.query.rule.metric.MetricModelQuery;
|
||||
@@ -18,7 +19,7 @@ public class HeuristicQuerySelector implements QuerySelector {
|
||||
private static final double CANDIDATE_THRESHOLD = 0.2;
|
||||
|
||||
@Override
|
||||
public List<SemanticQuery> select(List<SemanticQuery> candidateQueries) {
|
||||
public List<SemanticQuery> select(List<SemanticQuery> candidateQueries, QueryReq queryReq) {
|
||||
List<SemanticQuery> selectedQueries = new ArrayList<>();
|
||||
|
||||
if (CollectionUtils.isNotEmpty(candidateQueries) && candidateQueries.size() == 1) {
|
||||
|
||||
@@ -1,6 +1,8 @@
|
||||
package com.tencent.supersonic.chat.query;
|
||||
|
||||
import com.tencent.supersonic.chat.api.component.SemanticQuery;
|
||||
import com.tencent.supersonic.chat.api.pojo.request.QueryReq;
|
||||
|
||||
import java.util.List;
|
||||
|
||||
/**
|
||||
@@ -8,5 +10,5 @@ import java.util.List;
|
||||
**/
|
||||
public interface QuerySelector {
|
||||
|
||||
List<SemanticQuery> select(List<SemanticQuery> candidateQueries);
|
||||
List<SemanticQuery> select(List<SemanticQuery> candidateQueries, QueryReq queryReq);
|
||||
}
|
||||
|
||||
@@ -1,78 +0,0 @@
|
||||
package com.tencent.supersonic.chat.query.dsl;
|
||||
|
||||
import com.tencent.supersonic.chat.api.pojo.SchemaElement;
|
||||
import com.tencent.supersonic.chat.api.pojo.SemanticParseInfo;
|
||||
import com.tencent.supersonic.chat.api.pojo.SemanticSchema;
|
||||
import com.tencent.supersonic.chat.api.pojo.request.QueryFilter;
|
||||
import com.tencent.supersonic.chat.api.pojo.request.QueryFilters;
|
||||
import com.tencent.supersonic.common.pojo.Constants;
|
||||
import com.tencent.supersonic.common.util.ContextUtils;
|
||||
import com.tencent.supersonic.common.util.StringUtil;
|
||||
import com.tencent.supersonic.common.util.jsqlparser.CCJSqlParserUtils;
|
||||
import com.tencent.supersonic.knowledge.service.SchemaService;
|
||||
import com.tencent.supersonic.semantic.api.model.enums.TimeDimensionEnum;
|
||||
import java.util.ArrayList;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
import java.util.Objects;
|
||||
import java.util.stream.Collectors;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import net.sf.jsqlparser.expression.Expression;
|
||||
import net.sf.jsqlparser.parser.CCJSqlParserUtil;
|
||||
import org.apache.commons.collections.CollectionUtils;
|
||||
import org.apache.commons.lang3.StringUtils;
|
||||
|
||||
@Slf4j
|
||||
public class DSLBuilder {
|
||||
|
||||
public static final String DATA_Field = "数据日期";
|
||||
public static final String TABLE_PREFIX = "t_";
|
||||
|
||||
public String build(SemanticParseInfo parseInfo, QueryFilters queryFilters, LLMResp llmResp, Long modelId)
|
||||
throws Exception {
|
||||
|
||||
String sqlOutput = llmResp.getSqlOutput();
|
||||
SemanticSchema semanticSchema = ContextUtils.getBean(SchemaService.class).getSemanticSchema();
|
||||
List<SchemaElement> dbAllFields = new ArrayList<>();
|
||||
dbAllFields.addAll(semanticSchema.getMetrics());
|
||||
dbAllFields.addAll(semanticSchema.getDimensions());
|
||||
|
||||
Map<String, String> fieldToBizName = getMapInfo(modelId, dbAllFields);
|
||||
fieldToBizName.put(DATA_Field, TimeDimensionEnum.DAY.getName());
|
||||
|
||||
sqlOutput = CCJSqlParserUtils.replaceFields(sqlOutput, fieldToBizName);
|
||||
|
||||
sqlOutput = CCJSqlParserUtils.replaceTable(sqlOutput, TABLE_PREFIX + modelId);
|
||||
|
||||
String queryFilter = getQueryFilter(queryFilters);
|
||||
if (StringUtils.isNotEmpty(queryFilter)) {
|
||||
log.info("add queryFilter to sql :{}", queryFilter);
|
||||
Expression expression = CCJSqlParserUtil.parseCondExpression(queryFilter);
|
||||
CCJSqlParserUtils.addWhere(sqlOutput, expression);
|
||||
}
|
||||
|
||||
log.info("build sqlOutput:{}", sqlOutput);
|
||||
return sqlOutput;
|
||||
}
|
||||
|
||||
protected Map<String, String> getMapInfo(Long modelId, List<SchemaElement> metrics) {
|
||||
return metrics.stream().filter(entry -> entry.getModel().equals(modelId))
|
||||
.collect(Collectors.toMap(SchemaElement::getName, a -> a.getBizName(), (k1, k2) -> k1));
|
||||
}
|
||||
|
||||
private String getQueryFilter(QueryFilters queryFilters) {
|
||||
if (Objects.isNull(queryFilters) || CollectionUtils.isEmpty(queryFilters.getFilters())) {
|
||||
return "";
|
||||
}
|
||||
List<QueryFilter> filters = queryFilters.getFilters();
|
||||
|
||||
return filters.stream()
|
||||
.map(filter -> {
|
||||
String bizNameWrap = StringUtil.getSpaceWrap(filter.getBizName());
|
||||
String operatorWrap = StringUtil.getSpaceWrap(filter.getOperator().getValue());
|
||||
String valueWrap = StringUtil.getCommaWrap(filter.getValue().toString());
|
||||
return bizNameWrap + operatorWrap + valueWrap;
|
||||
})
|
||||
.collect(Collectors.joining(Constants.AND_UPPER));
|
||||
}
|
||||
}
|
||||
@@ -2,14 +2,14 @@ package com.tencent.supersonic.chat.query.dsl;
|
||||
|
||||
import com.tencent.supersonic.auth.api.authentication.pojo.User;
|
||||
import com.tencent.supersonic.chat.api.component.SemanticLayer;
|
||||
import com.tencent.supersonic.chat.api.pojo.SemanticParseInfo;
|
||||
import com.tencent.supersonic.chat.api.pojo.request.QueryFilters;
|
||||
import com.tencent.supersonic.chat.api.pojo.request.QueryReq;
|
||||
import com.tencent.supersonic.chat.api.pojo.response.EntityInfo;
|
||||
import com.tencent.supersonic.chat.api.pojo.response.QueryResult;
|
||||
import com.tencent.supersonic.chat.api.pojo.response.QueryState;
|
||||
import com.tencent.supersonic.chat.parser.llm.DSLParseResult;
|
||||
import com.tencent.supersonic.chat.parser.llm.dsl.DSLParseResult;
|
||||
import com.tencent.supersonic.chat.query.QueryManager;
|
||||
import com.tencent.supersonic.chat.api.pojo.CorrectionInfo;
|
||||
import com.tencent.supersonic.chat.api.component.DSLOptimizer;
|
||||
import com.tencent.supersonic.chat.query.plugin.PluginSemanticQuery;
|
||||
import com.tencent.supersonic.chat.service.SemanticService;
|
||||
import com.tencent.supersonic.chat.utils.ComponentFactory;
|
||||
@@ -32,7 +32,6 @@ import org.springframework.stereotype.Component;
|
||||
public class DSLQuery extends PluginSemanticQuery {
|
||||
|
||||
public static final String QUERY_MODE = "DSL";
|
||||
private DSLBuilder dslBuilder = new DSLBuilder();
|
||||
protected SemanticLayer semanticLayer = ComponentFactory.getSemanticLayer();
|
||||
|
||||
public DSLQuery() {
|
||||
@@ -51,12 +50,26 @@ public class DSLQuery extends PluginSemanticQuery {
|
||||
LLMResp llmResp = dslParseResult.getLlmResp();
|
||||
QueryReq queryReq = dslParseResult.getRequest();
|
||||
|
||||
Long modelId = parseInfo.getModelId();
|
||||
String querySql = convertToSql(queryReq.getQueryFilters(), llmResp, parseInfo, modelId);
|
||||
CorrectionInfo correctionInfo = CorrectionInfo.builder()
|
||||
.queryFilters(queryReq.getQueryFilters())
|
||||
.sql(llmResp.getSqlOutput())
|
||||
.parseInfo(parseInfo)
|
||||
.build();
|
||||
|
||||
List<DSLOptimizer> DSLCorrections = ComponentFactory.getSqlCorrections();
|
||||
|
||||
DSLCorrections.forEach(DSLCorrection -> {
|
||||
try {
|
||||
DSLCorrection.rewriter(correctionInfo);
|
||||
log.info("sqlCorrection:{} sql:{}", DSLCorrection.getClass().getSimpleName(), correctionInfo.getSql());
|
||||
} catch (Exception e) {
|
||||
log.error("sqlCorrection:{} execute error,correctionInfo:{}", DSLCorrection, correctionInfo, e);
|
||||
}
|
||||
});
|
||||
String querySql = correctionInfo.getSql();
|
||||
|
||||
long startTime = System.currentTimeMillis();
|
||||
|
||||
QueryDslReq queryDslReq = QueryReqBuilder.buildDslReq(querySql, modelId);
|
||||
QueryDslReq queryDslReq = QueryReqBuilder.buildDslReq(querySql, parseInfo.getModelId());
|
||||
QueryResultWithSchemaResp queryResp = semanticLayer.queryByDsl(queryDslReq, user);
|
||||
|
||||
log.info("queryByDsl cost:{},querySql:{}", System.currentTimeMillis() - startTime, querySql);
|
||||
@@ -80,17 +93,4 @@ public class DSLQuery extends PluginSemanticQuery {
|
||||
parseInfo.setProperties(null);
|
||||
return queryResult;
|
||||
}
|
||||
|
||||
|
||||
protected String convertToSql(QueryFilters queryFilters, LLMResp llmResp, SemanticParseInfo parseInfo,
|
||||
Long modelId) {
|
||||
try {
|
||||
return dslBuilder.build(parseInfo, queryFilters, llmResp, modelId);
|
||||
} catch (Exception e) {
|
||||
log.error("convertToSql error", e);
|
||||
}
|
||||
return null;
|
||||
}
|
||||
|
||||
|
||||
}
|
||||
|
||||
@@ -0,0 +1,33 @@
|
||||
package com.tencent.supersonic.chat.query.dsl.optimizer;
|
||||
|
||||
import com.tencent.supersonic.chat.api.component.DSLOptimizer;
|
||||
import com.tencent.supersonic.chat.api.pojo.SchemaElement;
|
||||
import com.tencent.supersonic.chat.api.pojo.SemanticSchema;
|
||||
import com.tencent.supersonic.common.util.ContextUtils;
|
||||
import com.tencent.supersonic.knowledge.service.SchemaService;
|
||||
import com.tencent.supersonic.semantic.api.model.enums.TimeDimensionEnum;
|
||||
import java.util.ArrayList;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
import java.util.stream.Collectors;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
|
||||
@Slf4j
|
||||
public abstract class BaseDSLOptimizer implements DSLOptimizer {
|
||||
public static final String DATE_FIELD = "数据日期";
|
||||
protected Map<String, String> getFieldToBizName(Long modelId) {
|
||||
|
||||
SemanticSchema semanticSchema = ContextUtils.getBean(SchemaService.class).getSemanticSchema();
|
||||
|
||||
List<SchemaElement> dbAllFields = new ArrayList<>();
|
||||
dbAllFields.addAll(semanticSchema.getMetrics());
|
||||
dbAllFields.addAll(semanticSchema.getDimensions());
|
||||
|
||||
Map<String, String> result = dbAllFields.stream()
|
||||
.filter(entry -> entry.getModel().equals(modelId))
|
||||
.collect(Collectors.toMap(SchemaElement::getName, a -> a.getBizName(), (k1, k2) -> k1));
|
||||
result.put(DATE_FIELD, TimeDimensionEnum.DAY.getName());
|
||||
return result;
|
||||
}
|
||||
|
||||
}
|
||||
@@ -0,0 +1,26 @@
|
||||
package com.tencent.supersonic.chat.query.dsl.optimizer;
|
||||
|
||||
import com.tencent.supersonic.chat.api.pojo.CorrectionInfo;
|
||||
import com.tencent.supersonic.chat.parser.llm.dsl.DSLDateHelper;
|
||||
import com.tencent.supersonic.common.util.jsqlparser.CCJSqlParserUtils;
|
||||
import java.util.List;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.springframework.util.CollectionUtils;
|
||||
|
||||
@Slf4j
|
||||
public class DateFieldCorrector extends BaseDSLOptimizer {
|
||||
|
||||
@Override
|
||||
public CorrectionInfo rewriter(CorrectionInfo correctionInfo) {
|
||||
|
||||
String sql = correctionInfo.getSql();
|
||||
List<String> whereFields = CCJSqlParserUtils.getWhereFields(sql);
|
||||
if (CollectionUtils.isEmpty(whereFields) || !whereFields.contains(BaseDSLOptimizer.DATE_FIELD)) {
|
||||
String currentDate = DSLDateHelper.getCurrentDate(correctionInfo.getParseInfo().getModelId());
|
||||
sql = CCJSqlParserUtils.addWhere(sql, BaseDSLOptimizer.DATE_FIELD, currentDate);
|
||||
}
|
||||
correctionInfo.setSql(sql);
|
||||
return correctionInfo;
|
||||
}
|
||||
|
||||
}
|
||||
@@ -0,0 +1,17 @@
|
||||
package com.tencent.supersonic.chat.query.dsl.optimizer;
|
||||
|
||||
import com.tencent.supersonic.chat.api.pojo.CorrectionInfo;
|
||||
import com.tencent.supersonic.common.util.jsqlparser.CCJSqlParserUtils;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
|
||||
@Slf4j
|
||||
public class FieldCorrector extends BaseDSLOptimizer {
|
||||
|
||||
@Override
|
||||
public CorrectionInfo rewriter(CorrectionInfo correctionInfo) {
|
||||
String replaceFields = CCJSqlParserUtils.replaceFields(correctionInfo.getSql(),
|
||||
getFieldToBizName(correctionInfo.getParseInfo().getModelId()));
|
||||
correctionInfo.setSql(replaceFields);
|
||||
return correctionInfo;
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,16 @@
|
||||
package com.tencent.supersonic.chat.query.dsl.optimizer;
|
||||
|
||||
import com.tencent.supersonic.chat.api.pojo.CorrectionInfo;
|
||||
import com.tencent.supersonic.common.util.jsqlparser.CCJSqlParserUtils;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
|
||||
@Slf4j
|
||||
public class FunctionCorrector extends BaseDSLOptimizer {
|
||||
|
||||
@Override
|
||||
public CorrectionInfo rewriter(CorrectionInfo correctionInfo) {
|
||||
String replaceFunction = CCJSqlParserUtils.replaceFunction(correctionInfo.getSql());
|
||||
correctionInfo.setSql(replaceFunction);
|
||||
return correctionInfo;
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,48 @@
|
||||
package com.tencent.supersonic.chat.query.dsl.optimizer;
|
||||
|
||||
import com.tencent.supersonic.chat.api.pojo.CorrectionInfo;
|
||||
import com.tencent.supersonic.chat.api.pojo.request.QueryFilters;
|
||||
import com.tencent.supersonic.common.pojo.Constants;
|
||||
import com.tencent.supersonic.common.util.StringUtil;
|
||||
import com.tencent.supersonic.common.util.jsqlparser.CCJSqlParserUtils;
|
||||
import java.util.Objects;
|
||||
import java.util.stream.Collectors;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import net.sf.jsqlparser.JSQLParserException;
|
||||
import net.sf.jsqlparser.expression.Expression;
|
||||
import net.sf.jsqlparser.parser.CCJSqlParserUtil;
|
||||
import org.apache.commons.collections.CollectionUtils;
|
||||
import org.apache.commons.lang3.StringUtils;
|
||||
|
||||
@Slf4j
|
||||
public class QueryFilterAppend extends BaseDSLOptimizer {
|
||||
|
||||
@Override
|
||||
public CorrectionInfo rewriter(CorrectionInfo correctionInfo) throws JSQLParserException {
|
||||
String queryFilter = getQueryFilter(correctionInfo.getQueryFilters());
|
||||
String sql = correctionInfo.getSql();
|
||||
|
||||
if (StringUtils.isNotEmpty(queryFilter)) {
|
||||
log.info("add queryFilter to sql :{}", queryFilter);
|
||||
Expression expression = CCJSqlParserUtil.parseCondExpression(queryFilter);
|
||||
sql = CCJSqlParserUtils.addWhere(sql, expression);
|
||||
}
|
||||
correctionInfo.setSql(sql);
|
||||
return correctionInfo;
|
||||
}
|
||||
|
||||
private String getQueryFilter(QueryFilters queryFilters) {
|
||||
if (Objects.isNull(queryFilters) || CollectionUtils.isEmpty(queryFilters.getFilters())) {
|
||||
return null;
|
||||
}
|
||||
return queryFilters.getFilters().stream()
|
||||
.map(filter -> {
|
||||
String bizNameWrap = StringUtil.getSpaceWrap(filter.getBizName());
|
||||
String operatorWrap = StringUtil.getSpaceWrap(filter.getOperator().getValue());
|
||||
String valueWrap = StringUtil.getCommaWrap(filter.getValue().toString());
|
||||
return bizNameWrap + operatorWrap + valueWrap;
|
||||
})
|
||||
.collect(Collectors.joining(Constants.AND_UPPER));
|
||||
}
|
||||
|
||||
}
|
||||
@@ -0,0 +1,35 @@
|
||||
package com.tencent.supersonic.chat.query.dsl.optimizer;
|
||||
|
||||
import com.tencent.supersonic.chat.api.pojo.CorrectionInfo;
|
||||
import com.tencent.supersonic.common.util.jsqlparser.CCJSqlParserUtils;
|
||||
import com.tencent.supersonic.semantic.api.model.enums.TimeDimensionEnum;
|
||||
import java.util.ArrayList;
|
||||
import java.util.HashSet;
|
||||
import java.util.Set;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.springframework.util.CollectionUtils;
|
||||
|
||||
@Slf4j
|
||||
public class SelectFieldAppendCorrector extends BaseDSLOptimizer {
|
||||
|
||||
@Override
|
||||
public CorrectionInfo rewriter(CorrectionInfo correctionInfo) {
|
||||
String sql = correctionInfo.getSql();
|
||||
if (CCJSqlParserUtils.hasAggregateFunction(sql)) {
|
||||
return correctionInfo;
|
||||
}
|
||||
Set<String> selectFields = new HashSet<>(CCJSqlParserUtils.getSelectFields(sql));
|
||||
Set<String> whereFields = new HashSet<>(CCJSqlParserUtils.getWhereFields(sql));
|
||||
if (CollectionUtils.isEmpty(selectFields) || CollectionUtils.isEmpty(whereFields)) {
|
||||
return correctionInfo;
|
||||
}
|
||||
|
||||
whereFields.removeAll(selectFields);
|
||||
whereFields.remove(TimeDimensionEnum.DAY.getName());
|
||||
whereFields.remove(TimeDimensionEnum.WEEK.getName());
|
||||
whereFields.remove(TimeDimensionEnum.MONTH.getName());
|
||||
String replaceFields = CCJSqlParserUtils.addFieldsToSelect(sql, new ArrayList<>(whereFields));
|
||||
correctionInfo.setSql(replaceFields);
|
||||
return correctionInfo;
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,21 @@
|
||||
package com.tencent.supersonic.chat.query.dsl.optimizer;
|
||||
|
||||
import com.tencent.supersonic.chat.api.pojo.CorrectionInfo;
|
||||
import com.tencent.supersonic.common.util.jsqlparser.CCJSqlParserUtils;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
|
||||
@Slf4j
|
||||
public class TableNameCorrector extends BaseDSLOptimizer {
|
||||
|
||||
public static final String TABLE_PREFIX = "t_";
|
||||
|
||||
@Override
|
||||
public CorrectionInfo rewriter(CorrectionInfo correctionInfo) {
|
||||
Long modelId = correctionInfo.getParseInfo().getModelId();
|
||||
String sqlOutput = correctionInfo.getSql();
|
||||
String replaceTable = CCJSqlParserUtils.replaceTable(sqlOutput, TABLE_PREFIX + modelId);
|
||||
correctionInfo.setSql(replaceTable);
|
||||
return correctionInfo;
|
||||
}
|
||||
|
||||
}
|
||||
@@ -1,4 +1,4 @@
|
||||
package com.tencent.supersonic.chat.query.ContentInterpret;
|
||||
package com.tencent.supersonic.chat.query.metricInterpret;
|
||||
|
||||
|
||||
import lombok.Data;
|
||||
@@ -1,4 +1,4 @@
|
||||
package com.tencent.supersonic.chat.query.ContentInterpret;
|
||||
package com.tencent.supersonic.chat.query.metricInterpret;
|
||||
|
||||
|
||||
import lombok.Data;
|
||||
@@ -0,0 +1,143 @@
|
||||
package com.tencent.supersonic.chat.query.metricInterpret;
|
||||
|
||||
import com.alibaba.fastjson.JSONObject;
|
||||
import com.google.common.collect.Lists;
|
||||
import com.tencent.supersonic.auth.api.authentication.pojo.User;
|
||||
import com.tencent.supersonic.chat.api.component.SemanticLayer;
|
||||
import com.tencent.supersonic.chat.api.pojo.SchemaElement;
|
||||
import com.tencent.supersonic.chat.api.pojo.SchemaElementMatch;
|
||||
import com.tencent.supersonic.chat.api.pojo.SchemaElementType;
|
||||
import com.tencent.supersonic.chat.api.pojo.response.QueryResult;
|
||||
import com.tencent.supersonic.chat.api.pojo.response.QueryState;
|
||||
import com.tencent.supersonic.chat.plugin.PluginManager;
|
||||
import com.tencent.supersonic.chat.query.QueryManager;
|
||||
import com.tencent.supersonic.chat.query.plugin.PluginSemanticQuery;
|
||||
import com.tencent.supersonic.chat.utils.ComponentFactory;
|
||||
import com.tencent.supersonic.chat.utils.QueryReqBuilder;
|
||||
import com.tencent.supersonic.common.pojo.Aggregator;
|
||||
import com.tencent.supersonic.common.pojo.QueryColumn;
|
||||
import com.tencent.supersonic.common.pojo.enums.AggOperatorEnum;
|
||||
import com.tencent.supersonic.common.util.ContextUtils;
|
||||
import com.tencent.supersonic.semantic.api.model.response.QueryResultWithSchemaResp;
|
||||
import com.tencent.supersonic.semantic.api.query.request.QueryStructReq;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.apache.calcite.sql.parser.SqlParseException;
|
||||
import org.apache.commons.lang3.StringUtils;
|
||||
import org.springframework.http.ResponseEntity;
|
||||
import org.springframework.stereotype.Component;
|
||||
import org.springframework.util.CollectionUtils;
|
||||
import java.util.*;
|
||||
import java.util.stream.Collectors;
|
||||
|
||||
@Slf4j
|
||||
@Component
|
||||
public class MetricInterpretQuery extends PluginSemanticQuery {
|
||||
|
||||
|
||||
public final static String QUERY_MODE = "METRIC_INTERPRET";
|
||||
|
||||
@Override
|
||||
public String getQueryMode() {
|
||||
return QUERY_MODE;
|
||||
}
|
||||
|
||||
public MetricInterpretQuery() {
|
||||
QueryManager.register(this);
|
||||
}
|
||||
|
||||
@Override
|
||||
public QueryResult execute(User user) throws SqlParseException {
|
||||
QueryStructReq queryStructReq = QueryReqBuilder.buildStructReq(parseInfo);
|
||||
fillAggregator(queryStructReq, parseInfo.getMetrics());
|
||||
queryStructReq.setNativeQuery(true);
|
||||
SemanticLayer semanticLayer = ComponentFactory.getSemanticLayer();
|
||||
QueryResultWithSchemaResp queryResultWithSchemaResp = semanticLayer.queryByStruct(queryStructReq, user);
|
||||
String text = generateTableText(queryResultWithSchemaResp);
|
||||
Map<String, Object> properties = parseInfo.getProperties();
|
||||
Map<String, String> replacedMap = new HashMap<>();
|
||||
String textReplaced = replaceText((String) properties.get("queryText"), parseInfo.getElementMatches(), replacedMap);
|
||||
String answer = replaceAnswer(fetchInterpret(textReplaced, text), replacedMap);
|
||||
QueryResult queryResult = new QueryResult();
|
||||
List<QueryColumn> queryColumns = Lists.newArrayList(new QueryColumn("结果","string","answer"));
|
||||
Map<String, Object> result = new HashMap<>();
|
||||
result.put("answer", answer);
|
||||
List<Map<String, Object>> resultList = Lists.newArrayList();
|
||||
resultList.add(result);
|
||||
queryResult.setQueryResults(resultList);
|
||||
queryResult.setQueryColumns(queryColumns);
|
||||
queryResult.setQueryMode(getQueryMode());
|
||||
queryResult.setQueryState(QueryState.SUCCESS);
|
||||
return queryResult;
|
||||
}
|
||||
|
||||
private String replaceText(String text, List<SchemaElementMatch> schemaElementMatches, Map<String, String> replacedMap) {
|
||||
if (CollectionUtils.isEmpty(schemaElementMatches)) {
|
||||
return text;
|
||||
}
|
||||
List<SchemaElementMatch> valueSchemaElementMatches = schemaElementMatches.stream()
|
||||
.filter(schemaElementMatch ->
|
||||
SchemaElementType.VALUE.equals(schemaElementMatch.getElement().getType())
|
||||
|| SchemaElementType.ID.equals(schemaElementMatch.getElement().getType()))
|
||||
.collect(Collectors.toList());
|
||||
for (SchemaElementMatch schemaElementMatch : valueSchemaElementMatches) {
|
||||
String detectWord = schemaElementMatch.getDetectWord();
|
||||
if (StringUtils.isBlank(detectWord)) {
|
||||
continue;
|
||||
}
|
||||
text = text.replace(detectWord, "xxx");
|
||||
replacedMap.put("xxx", detectWord);
|
||||
}
|
||||
return text;
|
||||
}
|
||||
|
||||
private void fillAggregator(QueryStructReq queryStructReq, Set<SchemaElement> schemaElements) {
|
||||
queryStructReq.getAggregators().clear();
|
||||
for (SchemaElement schemaElement : schemaElements) {
|
||||
Aggregator aggregator = new Aggregator();
|
||||
aggregator.setColumn(schemaElement.getBizName());
|
||||
aggregator.setFunc(AggOperatorEnum.SUM);
|
||||
aggregator.setNameCh(schemaElement.getName());
|
||||
queryStructReq.getAggregators().add(aggregator);
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
private String replaceAnswer(String text, Map<String, String> replacedMap) {
|
||||
for (String key : replacedMap.keySet()) {
|
||||
text = text.replaceAll(key, replacedMap.get(key));
|
||||
}
|
||||
return text;
|
||||
}
|
||||
|
||||
public static String generateTableText(QueryResultWithSchemaResp result) {
|
||||
StringBuilder tableBuilder = new StringBuilder();
|
||||
for (QueryColumn column : result.getColumns()) {
|
||||
tableBuilder.append(column.getName()).append("\t");
|
||||
}
|
||||
tableBuilder.append("\n");
|
||||
for (Map<String, Object> row : result.getResultList()) {
|
||||
for (QueryColumn column : result.getColumns()) {
|
||||
tableBuilder.append(row.get(column.getNameEn())).append("\t");
|
||||
}
|
||||
tableBuilder.append("\n");
|
||||
}
|
||||
return tableBuilder.toString();
|
||||
}
|
||||
|
||||
|
||||
public String fetchInterpret(String queryText, String dataText) {
|
||||
PluginManager pluginManager = ContextUtils.getBean(PluginManager.class);
|
||||
LLmAnswerReq lLmAnswerReq = new LLmAnswerReq();
|
||||
lLmAnswerReq.setQueryText(queryText);
|
||||
lLmAnswerReq.setPluginOutput(dataText);
|
||||
ResponseEntity<String> responseEntity = pluginManager.doRequest("answer_with_plugin_call",
|
||||
JSONObject.toJSONString(lLmAnswerReq));
|
||||
LLmAnswerResp lLmAnswerResp = JSONObject.parseObject(responseEntity.getBody(), LLmAnswerResp.class);
|
||||
if (lLmAnswerResp != null) {
|
||||
return lLmAnswerResp.getAssistant_message();
|
||||
}
|
||||
return null;
|
||||
}
|
||||
|
||||
|
||||
}
|
||||
@@ -2,15 +2,14 @@ package com.tencent.supersonic.chat.query.plugin.webpage;
|
||||
|
||||
import com.google.common.collect.Lists;
|
||||
import com.tencent.supersonic.auth.api.authentication.pojo.User;
|
||||
import com.tencent.supersonic.chat.api.pojo.ModelSchema;
|
||||
import com.tencent.supersonic.chat.api.pojo.SchemaElementMatch;
|
||||
import com.tencent.supersonic.chat.api.pojo.SchemaElementType;
|
||||
import com.tencent.supersonic.chat.api.pojo.*;
|
||||
import com.tencent.supersonic.chat.api.pojo.request.QueryFilter;
|
||||
import com.tencent.supersonic.chat.api.pojo.request.QueryFilters;
|
||||
import com.tencent.supersonic.chat.api.pojo.request.QueryReq;
|
||||
import com.tencent.supersonic.chat.api.pojo.response.EntityInfo;
|
||||
import com.tencent.supersonic.chat.api.pojo.response.QueryResult;
|
||||
import com.tencent.supersonic.chat.api.pojo.response.QueryState;
|
||||
import com.tencent.supersonic.chat.api.pojo.response.ChatConfigRichResp;
|
||||
import com.tencent.supersonic.chat.plugin.Plugin;
|
||||
import com.tencent.supersonic.chat.plugin.PluginParseResult;
|
||||
import com.tencent.supersonic.chat.query.QueryManager;
|
||||
@@ -18,18 +17,18 @@ import com.tencent.supersonic.chat.query.plugin.ParamOption;
|
||||
import com.tencent.supersonic.chat.query.plugin.PluginSemanticQuery;
|
||||
import com.tencent.supersonic.chat.query.plugin.WebBase;
|
||||
import com.tencent.supersonic.chat.query.plugin.WebBaseResult;
|
||||
import com.tencent.supersonic.chat.service.ConfigService;
|
||||
import com.tencent.supersonic.chat.service.SemanticService;
|
||||
import com.tencent.supersonic.common.pojo.Constants;
|
||||
import com.tencent.supersonic.common.util.ContextUtils;
|
||||
import com.tencent.supersonic.common.util.JsonUtil;
|
||||
import java.util.Comparator;
|
||||
import java.util.HashMap;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.springframework.stereotype.Component;
|
||||
import org.springframework.util.CollectionUtils;
|
||||
|
||||
import java.util.*;
|
||||
import java.util.stream.Collectors;
|
||||
|
||||
@Slf4j
|
||||
@Component
|
||||
public class WebPageQuery extends PluginSemanticQuery {
|
||||
@@ -107,17 +106,15 @@ public class WebPageQuery extends PluginSemanticQuery {
|
||||
.filter(schemaElementMatch ->
|
||||
SchemaElementType.VALUE.equals(schemaElementMatch.getElement().getType())
|
||||
|| SchemaElementType.ID.equals(schemaElementMatch.getElement().getType()))
|
||||
.sorted(Comparator.comparingDouble(SchemaElementMatch::getSimilarity))
|
||||
.filter(schemaElementMatch -> schemaElementMatch.getSimilarity() == 1.0)
|
||||
.forEach(schemaElementMatch -> {
|
||||
Object queryFilterValue = filterValueMap.get(schemaElementMatch.getElement().getId());
|
||||
if (queryFilterValue != null) {
|
||||
if (String.valueOf(queryFilterValue).equals(String.valueOf(schemaElementMatch.getWord()))) {
|
||||
elementValueMap.put(String.valueOf(schemaElementMatch.getElement().getId()),
|
||||
schemaElementMatch.getWord());
|
||||
elementValueMap.put(String.valueOf(schemaElementMatch.getElement().getId()), schemaElementMatch.getWord());
|
||||
}
|
||||
} else {
|
||||
elementValueMap.put(String.valueOf(schemaElementMatch.getElement().getId()),
|
||||
schemaElementMatch.getWord());
|
||||
elementValueMap.computeIfAbsent(String.valueOf(schemaElementMatch.getElement().getId()), k -> schemaElementMatch.getWord());
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
@@ -208,8 +208,7 @@ public abstract class RuleSemanticQuery implements SemanticQuery, Serializable {
|
||||
}
|
||||
|
||||
QueryResult queryResult = new QueryResult();
|
||||
QueryResultWithSchemaResp queryResp = semanticLayer.queryByStruct(
|
||||
convertQueryStruct(), user);
|
||||
QueryResultWithSchemaResp queryResp = semanticLayer.queryByStruct(convertQueryStruct(), user);
|
||||
|
||||
if (queryResp != null) {
|
||||
queryResult.setQueryAuthorization(queryResp.getQueryAuthorization());
|
||||
|
||||
@@ -1,9 +1,9 @@
|
||||
package com.tencent.supersonic.chat.query.rule.entity;
|
||||
|
||||
import static com.tencent.supersonic.chat.api.pojo.SchemaElementType.ID;
|
||||
import static com.tencent.supersonic.chat.api.pojo.SchemaElementType.VALUE;
|
||||
import static com.tencent.supersonic.chat.api.pojo.SchemaElementType.*;
|
||||
import static com.tencent.supersonic.chat.query.rule.QueryMatchOption.OptionType.OPTIONAL;
|
||||
import static com.tencent.supersonic.chat.query.rule.QueryMatchOption.RequireNumberType.AT_LEAST;
|
||||
import static com.tencent.supersonic.chat.query.rule.QueryMatchOption.OptionType.REQUIRED;
|
||||
import static com.tencent.supersonic.chat.query.rule.QueryMatchOption.RequireNumberType.*;
|
||||
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.springframework.stereotype.Component;
|
||||
@@ -17,8 +17,7 @@ public class EntityFilterQuery extends EntityListQuery {
|
||||
|
||||
public EntityFilterQuery() {
|
||||
super();
|
||||
queryMatcher.addOption(VALUE, OPTIONAL, AT_LEAST, 0);
|
||||
queryMatcher.addOption(ID, OPTIONAL, AT_LEAST, 0);
|
||||
queryMatcher.addOption(VALUE, REQUIRED, AT_LEAST, 1);
|
||||
}
|
||||
|
||||
@Override
|
||||
|
||||
@@ -0,0 +1,23 @@
|
||||
package com.tencent.supersonic.chat.query.rule.entity;
|
||||
|
||||
import org.springframework.stereotype.Component;
|
||||
import static com.tencent.supersonic.chat.api.pojo.SchemaElementType.ID;
|
||||
import static com.tencent.supersonic.chat.query.rule.QueryMatchOption.OptionType.REQUIRED;
|
||||
import static com.tencent.supersonic.chat.query.rule.QueryMatchOption.RequireNumberType.AT_LEAST;
|
||||
|
||||
@Component
|
||||
public class EntityIdQuery extends EntityListQuery {
|
||||
|
||||
public static final String QUERY_MODE = "ENTITY_ID";
|
||||
|
||||
public EntityIdQuery() {
|
||||
super();
|
||||
queryMatcher.addOption(ID, REQUIRED, AT_LEAST, 1);
|
||||
}
|
||||
|
||||
@Override
|
||||
public String getQueryMode() {
|
||||
return QUERY_MODE;
|
||||
}
|
||||
|
||||
}
|
||||
@@ -0,0 +1,51 @@
|
||||
package com.tencent.supersonic.chat.rest;
|
||||
|
||||
import com.tencent.supersonic.auth.api.authentication.pojo.User;
|
||||
import com.tencent.supersonic.auth.api.authentication.utils.UserHolder;
|
||||
import com.tencent.supersonic.chat.agent.Agent;
|
||||
import com.tencent.supersonic.chat.service.AgentService;
|
||||
import org.springframework.web.bind.annotation.*;
|
||||
import javax.servlet.http.HttpServletRequest;
|
||||
import javax.servlet.http.HttpServletResponse;
|
||||
import java.util.List;
|
||||
|
||||
@RestController
|
||||
@RequestMapping("/api/chat/agent")
|
||||
public class AgentController {
|
||||
|
||||
private AgentService agentService;
|
||||
|
||||
public AgentController(AgentService agentService) {
|
||||
this.agentService = agentService;
|
||||
}
|
||||
|
||||
@PostMapping
|
||||
public boolean createAgent(@RequestBody Agent agent,
|
||||
HttpServletRequest httpServletRequest,
|
||||
HttpServletResponse httpServletResponse) {
|
||||
User user = UserHolder.findUser(httpServletRequest, httpServletResponse);
|
||||
agentService.createAgent(agent, user);
|
||||
return true;
|
||||
}
|
||||
|
||||
@PutMapping
|
||||
public boolean updateAgent(@RequestBody Agent agent,
|
||||
HttpServletRequest httpServletRequest,
|
||||
HttpServletResponse httpServletResponse) {
|
||||
User user = UserHolder.findUser(httpServletRequest, httpServletResponse);
|
||||
agentService.updateAgent(agent, user);
|
||||
return true;
|
||||
}
|
||||
|
||||
@DeleteMapping("/{id}")
|
||||
public boolean deleteAgent(@PathVariable("id") Integer id) {
|
||||
agentService.deleteAgent(id);
|
||||
return true;
|
||||
}
|
||||
|
||||
@RequestMapping("/getAgentList")
|
||||
public List<Agent> getAgentList() {
|
||||
return agentService.getAgents();
|
||||
}
|
||||
|
||||
}
|
||||
@@ -32,7 +32,7 @@ import org.springframework.web.bind.annotation.RestController;
|
||||
|
||||
|
||||
@RestController
|
||||
@RequestMapping("/api/chat/conf")
|
||||
@RequestMapping({"/api/chat/conf", "/openapi/chat/conf"})
|
||||
public class ChatConfigController {
|
||||
|
||||
@Autowired
|
||||
|
||||
@@ -18,7 +18,7 @@ import org.springframework.web.bind.annotation.RequestParam;
|
||||
import org.springframework.web.bind.annotation.RestController;
|
||||
|
||||
@RestController
|
||||
@RequestMapping("/api/chat/manage")
|
||||
@RequestMapping({"/api/chat/manage", "/openapi/chat/manage"})
|
||||
public class ChatController {
|
||||
|
||||
private final ChatService chatService;
|
||||
|
||||
@@ -20,7 +20,7 @@ import org.springframework.web.bind.annotation.RestController;
|
||||
* query controller
|
||||
*/
|
||||
@RestController
|
||||
@RequestMapping("/api/chat/query")
|
||||
@RequestMapping({"/api/chat/query", "/openapi/chat/query"})
|
||||
public class ChatQueryController {
|
||||
|
||||
@Autowired
|
||||
|
||||
@@ -19,7 +19,7 @@ import org.springframework.web.bind.annotation.RestController;
|
||||
* recommend controller
|
||||
*/
|
||||
@RestController
|
||||
@RequestMapping("/api/chat/")
|
||||
@RequestMapping({"/api/chat/", "/openapi/chat/"})
|
||||
public class RecommendController {
|
||||
|
||||
@Autowired
|
||||
|
||||
@@ -0,0 +1,19 @@
|
||||
package com.tencent.supersonic.chat.service;
|
||||
|
||||
import com.tencent.supersonic.auth.api.authentication.pojo.User;
|
||||
import com.tencent.supersonic.chat.agent.Agent;
|
||||
import java.util.List;
|
||||
|
||||
public interface AgentService {
|
||||
|
||||
List<Agent> getAgents();
|
||||
|
||||
void createAgent(Agent agent, User user);
|
||||
|
||||
void updateAgent(Agent agent, User user);
|
||||
|
||||
Agent getAgent(Integer id);
|
||||
|
||||
void deleteAgent(Integer id);
|
||||
|
||||
}
|
||||
@@ -0,0 +1,82 @@
|
||||
package com.tencent.supersonic.chat.service.impl;
|
||||
|
||||
import com.alibaba.fastjson.JSONObject;
|
||||
import com.tencent.supersonic.auth.api.authentication.pojo.User;
|
||||
import com.tencent.supersonic.chat.agent.Agent;
|
||||
import com.tencent.supersonic.chat.persistence.dataobject.AgentDO;
|
||||
import com.tencent.supersonic.chat.persistence.repository.AgentRepository;
|
||||
import com.tencent.supersonic.chat.service.AgentService;
|
||||
import org.springframework.beans.BeanUtils;
|
||||
import org.springframework.stereotype.Service;
|
||||
import java.util.Date;
|
||||
import java.util.List;
|
||||
import java.util.stream.Collectors;
|
||||
|
||||
@Service
|
||||
public class AgentServiceImpl implements AgentService {
|
||||
|
||||
private AgentRepository agentRepository;
|
||||
|
||||
public AgentServiceImpl(AgentRepository agentRepository) {
|
||||
this.agentRepository = agentRepository;
|
||||
}
|
||||
|
||||
@Override
|
||||
public List<Agent> getAgents() {
|
||||
return getAgentDOList().stream()
|
||||
.map(this::convert).collect(Collectors.toList());
|
||||
}
|
||||
|
||||
@Override
|
||||
public void createAgent(Agent agent, User user) {
|
||||
agentRepository.createAgent(convert(agent, user));
|
||||
}
|
||||
|
||||
@Override
|
||||
public void updateAgent(Agent agent, User user) {
|
||||
agentRepository.updateAgent(convert(agent, user));
|
||||
}
|
||||
|
||||
@Override
|
||||
public Agent getAgent(Integer id) {
|
||||
if (id == null) {
|
||||
return null;
|
||||
}
|
||||
return convert(agentRepository.getAgent(id));
|
||||
}
|
||||
|
||||
@Override
|
||||
public void deleteAgent(Integer id) {
|
||||
agentRepository.deleteAgent(id);
|
||||
}
|
||||
|
||||
private List<AgentDO> getAgentDOList() {
|
||||
return agentRepository.getAgents();
|
||||
}
|
||||
|
||||
private Agent convert(AgentDO agentDO){
|
||||
if (agentDO == null ) {
|
||||
return null;
|
||||
}
|
||||
Agent agent = new Agent();
|
||||
BeanUtils.copyProperties(agentDO,agent);
|
||||
agent.setAgentConfig(agentDO.getConfig());
|
||||
agent.setExamples(JSONObject.parseArray(agentDO.getExamples(), String.class));
|
||||
return agent;
|
||||
}
|
||||
|
||||
private AgentDO convert(Agent agent, User user){
|
||||
AgentDO agentDO = new AgentDO();
|
||||
BeanUtils.copyProperties(agent, agentDO);
|
||||
agentDO.setConfig(agent.getAgentConfig());
|
||||
agentDO.setExamples(JSONObject.toJSONString(agent.getExamples()));
|
||||
agentDO.setCreatedAt(new Date());
|
||||
agentDO.setCreatedBy(user.getName());
|
||||
agentDO.setUpdatedAt(new Date());
|
||||
agentDO.setUpdatedBy(user.getName());
|
||||
if (agentDO.getStatus() == null) {
|
||||
agentDO.setStatus(1);
|
||||
}
|
||||
return agentDO;
|
||||
}
|
||||
}
|
||||
@@ -2,26 +2,25 @@ package com.tencent.supersonic.chat.service.impl;
|
||||
|
||||
|
||||
import com.tencent.supersonic.auth.api.authentication.pojo.User;
|
||||
import com.tencent.supersonic.chat.api.component.SchemaMapper;
|
||||
import com.tencent.supersonic.chat.api.component.SemanticParser;
|
||||
import com.tencent.supersonic.chat.api.component.SemanticQuery;
|
||||
import com.tencent.supersonic.chat.api.component.*;
|
||||
import com.tencent.supersonic.chat.api.pojo.ChatContext;
|
||||
import com.tencent.supersonic.chat.api.pojo.QueryContext;
|
||||
import com.tencent.supersonic.chat.api.pojo.SemanticParseInfo;
|
||||
import com.tencent.supersonic.chat.api.pojo.request.ExecuteQueryReq;
|
||||
import com.tencent.supersonic.chat.api.pojo.request.QueryDataReq;
|
||||
import com.tencent.supersonic.chat.api.pojo.request.QueryReq;
|
||||
import com.tencent.supersonic.chat.api.pojo.response.ParseResp;
|
||||
import com.tencent.supersonic.chat.api.pojo.response.QueryResult;
|
||||
import com.tencent.supersonic.chat.api.pojo.response.QueryState;
|
||||
import com.tencent.supersonic.chat.query.QueryManager;
|
||||
import com.tencent.supersonic.chat.query.QuerySelector;
|
||||
import com.tencent.supersonic.chat.api.pojo.request.QueryDataReq;
|
||||
import com.tencent.supersonic.chat.query.QueryManager;
|
||||
import com.tencent.supersonic.chat.service.ChatService;
|
||||
import com.tencent.supersonic.chat.service.QueryService;
|
||||
import com.tencent.supersonic.chat.utils.ComponentFactory;
|
||||
import com.tencent.supersonic.common.util.JsonUtil;
|
||||
import java.util.Comparator;
|
||||
import java.util.List;
|
||||
import java.util.stream.Collectors;
|
||||
import com.tencent.supersonic.common.util.JsonUtil;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.apache.calcite.sql.parser.SqlParseException;
|
||||
import org.springframework.beans.BeanUtils;
|
||||
@@ -63,12 +62,14 @@ public class QueryServiceImpl implements QueryService {
|
||||
if (queryCtx.getCandidateQueries().size() > 0) {
|
||||
log.debug("pick before [{}]", queryCtx.getCandidateQueries().stream().collect(
|
||||
Collectors.toList()));
|
||||
List<SemanticQuery> selectedQueries = querySelector.select(queryCtx.getCandidateQueries());
|
||||
List<SemanticQuery> selectedQueries = querySelector.select(queryCtx.getCandidateQueries(), queryReq);
|
||||
log.debug("pick after [{}]", selectedQueries.stream().collect(
|
||||
Collectors.toList()));
|
||||
|
||||
List<SemanticParseInfo> selectedParses = selectedQueries.stream()
|
||||
.map(SemanticQuery::getParseInfo).collect(Collectors.toList());
|
||||
.map(SemanticQuery::getParseInfo)
|
||||
.sorted(Comparator.comparingDouble(SemanticParseInfo::getScore).reversed())
|
||||
.collect(Collectors.toList());
|
||||
List<SemanticParseInfo> candidateParses = queryCtx.getCandidateQueries().stream()
|
||||
.map(SemanticQuery::getParseInfo).collect(Collectors.toList());
|
||||
|
||||
@@ -138,7 +139,7 @@ public class QueryServiceImpl implements QueryService {
|
||||
if (queryCtx.getCandidateQueries().size() > 0) {
|
||||
log.info("pick before [{}]", queryCtx.getCandidateQueries().stream().collect(
|
||||
Collectors.toList()));
|
||||
List<SemanticQuery> selectedQueries = querySelector.select(queryCtx.getCandidateQueries());
|
||||
List<SemanticQuery> selectedQueries = querySelector.select(queryCtx.getCandidateQueries(), queryReq);
|
||||
log.info("pick after [{}]", selectedQueries.stream().collect(
|
||||
Collectors.toList()));
|
||||
|
||||
|
||||
@@ -57,6 +57,7 @@ public class RecommendServiceImpl implements RecommendService {
|
||||
item.setName(dimSchemaDesc.getName());
|
||||
item.setBizName(dimSchemaDesc.getBizName());
|
||||
item.setId(dimSchemaDesc.getId());
|
||||
item.setAlias(dimSchemaDesc.getAlias());
|
||||
return item;
|
||||
}).collect(Collectors.toList());
|
||||
|
||||
@@ -70,6 +71,7 @@ public class RecommendServiceImpl implements RecommendService {
|
||||
item.setName(metricSchemaDesc.getName());
|
||||
item.setBizName(metricSchemaDesc.getBizName());
|
||||
item.setId(metricSchemaDesc.getId());
|
||||
item.setAlias(metricSchemaDesc.getAlias());
|
||||
return item;
|
||||
}).collect(Collectors.toList());
|
||||
|
||||
|
||||
@@ -2,6 +2,7 @@ package com.tencent.supersonic.chat.service.impl;
|
||||
|
||||
import com.google.common.collect.Lists;
|
||||
import com.hankcs.hanlp.seg.common.Term;
|
||||
import com.tencent.supersonic.chat.agent.Agent;
|
||||
import com.tencent.supersonic.chat.api.pojo.SchemaElement;
|
||||
import com.tencent.supersonic.chat.api.pojo.SchemaElementType;
|
||||
import com.tencent.supersonic.chat.api.pojo.SemanticSchema;
|
||||
@@ -13,6 +14,7 @@ import com.tencent.supersonic.chat.mapper.MatchText;
|
||||
import com.tencent.supersonic.chat.mapper.ModelInfoStat;
|
||||
import com.tencent.supersonic.chat.mapper.ModelWithSemanticType;
|
||||
import com.tencent.supersonic.chat.mapper.SearchMatchStrategy;
|
||||
import com.tencent.supersonic.chat.service.AgentService;
|
||||
import com.tencent.supersonic.chat.service.ChatService;
|
||||
import com.tencent.supersonic.chat.service.SearchService;
|
||||
import com.tencent.supersonic.chat.utils.NatureHelper;
|
||||
@@ -53,22 +55,34 @@ public class SearchServiceImpl implements SearchService {
|
||||
private ChatService chatService;
|
||||
@Autowired
|
||||
private SearchMatchStrategy searchMatchStrategy;
|
||||
@Autowired
|
||||
private AgentService agentService;
|
||||
|
||||
@Override
|
||||
public List<SearchResult> search(QueryReq queryCtx) {
|
||||
|
||||
// 1. check search enable
|
||||
Integer agentId = queryCtx.getAgentId();
|
||||
if (agentId != null) {
|
||||
Agent agent = agentService.getAgent(agentId);
|
||||
if (!agent.enableSearch()) {
|
||||
return Lists.newArrayList();
|
||||
}
|
||||
}
|
||||
|
||||
String queryText = queryCtx.getQueryText();
|
||||
// 1.get meta info
|
||||
// 2.get meta info
|
||||
SemanticSchema semanticSchemaDb = schemaService.getSemanticSchema();
|
||||
List<SchemaElement> metricsDb = semanticSchemaDb.getMetrics();
|
||||
final Map<Long, String> modelToName = semanticSchemaDb.getModelIdToName();
|
||||
|
||||
// 2.detect by segment
|
||||
// 3.detect by segment
|
||||
List<Term> originals = HanlpHelper.getTerms(queryText);
|
||||
Map<MatchText, List<MapResult>> regTextMap = searchMatchStrategy.match(queryText, originals,
|
||||
queryCtx.getModelId());
|
||||
regTextMap.entrySet().stream().forEach(m -> HanlpHelper.transLetterOriginal(m.getValue()));
|
||||
|
||||
// 3.get the most matching data
|
||||
// 4.get the most matching data
|
||||
Optional<Entry<MatchText, List<MapResult>>> mostSimilarSearchResult = regTextMap.entrySet()
|
||||
.stream()
|
||||
.filter(entry -> CollectionUtils.isNotEmpty(entry.getValue()))
|
||||
@@ -77,7 +91,7 @@ public class SearchServiceImpl implements SearchService {
|
||||
? entry1 : entry2);
|
||||
log.debug("mostSimilarSearchResult:{}", mostSimilarSearchResult);
|
||||
|
||||
// 4.optimize the results after the query
|
||||
// 5.optimize the results after the query
|
||||
if (!mostSimilarSearchResult.isPresent()) {
|
||||
return Lists.newArrayList();
|
||||
}
|
||||
@@ -89,11 +103,11 @@ public class SearchServiceImpl implements SearchService {
|
||||
|
||||
List<Long> possibleModels = getPossibleModels(queryCtx, originals, modelStat, queryCtx.getModelId());
|
||||
|
||||
// 4.1 priority dimension metric
|
||||
// 5.1 priority dimension metric
|
||||
boolean existMetricAndDimension = searchMetricAndDimension(new HashSet<>(possibleModels), modelToName,
|
||||
searchTextEntry, searchResults);
|
||||
|
||||
// 4.2 process based on dimension values
|
||||
// 5.2 process based on dimension values
|
||||
MatchText matchText = searchTextEntry.getKey();
|
||||
Map<String, String> natureToNameMap = getNatureToNameMap(searchTextEntry, new HashSet<>(possibleModels));
|
||||
log.debug("possibleModels:{},natureToNameMap:{}", possibleModels, natureToNameMap);
|
||||
|
||||
@@ -1,77 +0,0 @@
|
||||
package com.tencent.supersonic.chat.utils;
|
||||
|
||||
|
||||
import com.plexpt.chatgpt.ChatGPT;
|
||||
import com.plexpt.chatgpt.entity.chat.ChatCompletion;
|
||||
import com.plexpt.chatgpt.entity.chat.ChatCompletionResponse;
|
||||
import com.plexpt.chatgpt.entity.chat.Message;
|
||||
import com.plexpt.chatgpt.util.Proxys;
|
||||
import java.net.Proxy;
|
||||
import java.text.SimpleDateFormat;
|
||||
import java.util.Arrays;
|
||||
import java.util.Date;
|
||||
import org.springframework.beans.factory.annotation.Value;
|
||||
import org.springframework.stereotype.Component;
|
||||
|
||||
|
||||
@Component
|
||||
public class ChatGptHelper {
|
||||
|
||||
@Value("${llm.chatgpt.apikey:xx-xxxx}")
|
||||
private String apiKey;
|
||||
|
||||
@Value("${llm.chatgpt.apiHost:https://api.openai.com/}")
|
||||
private String apiHost;
|
||||
|
||||
@Value("${llm.chatgpt.proxyIp:default}")
|
||||
private String proxyIp;
|
||||
|
||||
@Value("${llm.chatgpt.proxyPort:8080}")
|
||||
private Integer proxyPort;
|
||||
|
||||
|
||||
public ChatGPT getChatGPT() {
|
||||
Proxy proxy = null;
|
||||
if (!"default".equals(proxyIp)) {
|
||||
proxy = Proxys.http(proxyIp, proxyPort);
|
||||
}
|
||||
return ChatGPT.builder()
|
||||
.apiKey(apiKey)
|
||||
.proxy(proxy)
|
||||
.timeout(900)
|
||||
.apiHost(apiHost) //反向代理地址
|
||||
.build()
|
||||
.init();
|
||||
}
|
||||
|
||||
public String inferredTime(String queryText) {
|
||||
long nowTime = System.currentTimeMillis();
|
||||
Date date = new Date(nowTime);
|
||||
SimpleDateFormat sdf = new SimpleDateFormat("yyyy-MM-dd");
|
||||
String formattedDate = sdf.format(date);
|
||||
Message system = Message.ofSystem("现在时间 " + formattedDate + ",你是一个专业的数据分析师,你的任务是基于数据,专业的解答用户的问题。"
|
||||
+ "你需要遵守以下规则:\n"
|
||||
+ "1.返回规范的数据格式,json,如: 输入:近 10 天的日活跃数,输出:{\"start\":\"2023-07-21\",\"end\":\"2023-07-31\"}"
|
||||
+ "2.你对时间数据要求规范,能从近 10 天,国庆节,端午节,获取到相应的时间,填写到 json 中。\n"
|
||||
+ "3.你的数据时间,只有当前及之前时间即可,超过则回复去年\n"
|
||||
+ "4.只需要解析出时间,时间可以是时间月和年或日、日历采用公历\n"
|
||||
+ "5.时间给出要是绝对正确,不能瞎编\n"
|
||||
);
|
||||
Message message = Message.of("输入:" + queryText + ",输出:");
|
||||
ChatCompletion chatCompletion = ChatCompletion.builder()
|
||||
.model(ChatCompletion.Model.GPT_3_5_TURBO_16K.getName())
|
||||
.messages(Arrays.asList(system, message))
|
||||
.maxTokens(10000)
|
||||
.temperature(0.9)
|
||||
.build();
|
||||
ChatCompletionResponse response = getChatGPT().chatCompletion(chatCompletion);
|
||||
Message res = response.getChoices().get(0).getMessage();
|
||||
return res.getContent();
|
||||
}
|
||||
|
||||
public static void main(String[] args) {
|
||||
|
||||
}
|
||||
|
||||
|
||||
}
|
||||
@@ -3,11 +3,14 @@ package com.tencent.supersonic.chat.utils;
|
||||
import com.tencent.supersonic.chat.api.component.SchemaMapper;
|
||||
import com.tencent.supersonic.chat.api.component.SemanticLayer;
|
||||
import com.tencent.supersonic.chat.api.component.SemanticParser;
|
||||
import com.tencent.supersonic.chat.parser.function.ModelResolver;
|
||||
import com.tencent.supersonic.chat.query.QuerySelector;
|
||||
|
||||
import com.tencent.supersonic.chat.api.component.DSLOptimizer;
|
||||
import java.util.ArrayList;
|
||||
import java.util.List;
|
||||
import java.util.Objects;
|
||||
|
||||
import com.tencent.supersonic.chat.parser.function.ModelResolver;
|
||||
import com.tencent.supersonic.chat.query.QuerySelector;
|
||||
import org.apache.commons.collections.CollectionUtils;
|
||||
import org.springframework.core.io.support.SpringFactoriesLoader;
|
||||
|
||||
@@ -15,10 +18,11 @@ public class ComponentFactory {
|
||||
|
||||
private static List<SchemaMapper> schemaMappers = new ArrayList<>();
|
||||
private static List<SemanticParser> semanticParsers = new ArrayList<>();
|
||||
|
||||
private static List<DSLOptimizer> dslCorrections = new ArrayList<>();
|
||||
private static SemanticLayer semanticLayer;
|
||||
private static QuerySelector querySelector;
|
||||
private static ModelResolver modelResolver;
|
||||
|
||||
public static List<SchemaMapper> getSchemaMappers() {
|
||||
return CollectionUtils.isEmpty(schemaMappers) ? init(SchemaMapper.class, schemaMappers) : schemaMappers;
|
||||
}
|
||||
@@ -27,6 +31,11 @@ public class ComponentFactory {
|
||||
return CollectionUtils.isEmpty(semanticParsers) ? init(SemanticParser.class, semanticParsers) : semanticParsers;
|
||||
}
|
||||
|
||||
public static List<DSLOptimizer> getSqlCorrections() {
|
||||
return CollectionUtils.isEmpty(dslCorrections) ? init(DSLOptimizer.class, dslCorrections) : dslCorrections;
|
||||
}
|
||||
|
||||
|
||||
public static SemanticLayer getSemanticLayer() {
|
||||
if (Objects.isNull(semanticLayer)) {
|
||||
semanticLayer = init(SemanticLayer.class);
|
||||
|
||||
@@ -74,7 +74,7 @@ public class NatureHelper {
|
||||
return null;
|
||||
}
|
||||
|
||||
public static boolean isDimensionValueClassId(String nature) {
|
||||
public static boolean isDimensionValueModelId(String nature) {
|
||||
if (StringUtils.isEmpty(nature)) {
|
||||
return false;
|
||||
}
|
||||
@@ -104,7 +104,7 @@ public class NatureHelper {
|
||||
}
|
||||
|
||||
private static long getDimensionValueCount(List<Term> terms) {
|
||||
return terms.stream().filter(term -> isDimensionValueClassId(term.nature.toString())).count();
|
||||
return terms.stream().filter(term -> isDimensionValueModelId(term.nature.toString())).count();
|
||||
}
|
||||
|
||||
private static long getDimensionCount(List<Term> terms) {
|
||||
|
||||
@@ -1,4 +1,3 @@
|
||||
#!/usr/bin/env bash
|
||||
# python path
|
||||
export python_path="/usr/local/bin/python3.9"
|
||||
# pip path
|
||||
|
||||
303
chat/core/src/main/resources/mapper/AgentDOMapper.xml
Normal file
303
chat/core/src/main/resources/mapper/AgentDOMapper.xml
Normal file
@@ -0,0 +1,303 @@
|
||||
<?xml version="1.0" encoding="UTF-8"?>
|
||||
<!DOCTYPE mapper PUBLIC "-//mybatis.org//DTD Mapper 3.0//EN" "http://mybatis.org/dtd/mybatis-3-mapper.dtd">
|
||||
<mapper namespace="com.tencent.supersonic.chat.persistence.mapper.AgentDOMapper">
|
||||
<resultMap id="BaseResultMap" type="com.tencent.supersonic.chat.persistence.dataobject.AgentDO">
|
||||
<id column="id" jdbcType="INTEGER" property="id" />
|
||||
<result column="name" jdbcType="VARCHAR" property="name" />
|
||||
<result column="description" jdbcType="VARCHAR" property="description" />
|
||||
<result column="status" jdbcType="INTEGER" property="status" />
|
||||
<result column="examples" jdbcType="VARCHAR" property="examples" />
|
||||
<result column="config" jdbcType="VARCHAR" property="config" />
|
||||
<result column="created_by" jdbcType="VARCHAR" property="createdBy" />
|
||||
<result column="created_at" jdbcType="TIMESTAMP" property="createdAt" />
|
||||
<result column="updated_by" jdbcType="VARCHAR" property="updatedBy" />
|
||||
<result column="updated_at" jdbcType="TIMESTAMP" property="updatedAt" />
|
||||
<result column="enable_search" jdbcType="INTEGER" property="enableSearch" />
|
||||
</resultMap>
|
||||
<sql id="Example_Where_Clause">
|
||||
<where>
|
||||
<foreach collection="oredCriteria" item="criteria" separator="or">
|
||||
<if test="criteria.valid">
|
||||
<trim prefix="(" prefixOverrides="and" suffix=")">
|
||||
<foreach collection="criteria.criteria" item="criterion">
|
||||
<choose>
|
||||
<when test="criterion.noValue">
|
||||
and ${criterion.condition}
|
||||
</when>
|
||||
<when test="criterion.singleValue">
|
||||
and ${criterion.condition} #{criterion.value}
|
||||
</when>
|
||||
<when test="criterion.betweenValue">
|
||||
and ${criterion.condition} #{criterion.value} and #{criterion.secondValue}
|
||||
</when>
|
||||
<when test="criterion.listValue">
|
||||
and ${criterion.condition}
|
||||
<foreach close=")" collection="criterion.value" item="listItem" open="(" separator=",">
|
||||
#{listItem}
|
||||
</foreach>
|
||||
</when>
|
||||
</choose>
|
||||
</foreach>
|
||||
</trim>
|
||||
</if>
|
||||
</foreach>
|
||||
</where>
|
||||
</sql>
|
||||
<sql id="Update_By_Example_Where_Clause">
|
||||
<where>
|
||||
<foreach collection="example.oredCriteria" item="criteria" separator="or">
|
||||
<if test="criteria.valid">
|
||||
<trim prefix="(" prefixOverrides="and" suffix=")">
|
||||
<foreach collection="criteria.criteria" item="criterion">
|
||||
<choose>
|
||||
<when test="criterion.noValue">
|
||||
and ${criterion.condition}
|
||||
</when>
|
||||
<when test="criterion.singleValue">
|
||||
and ${criterion.condition} #{criterion.value}
|
||||
</when>
|
||||
<when test="criterion.betweenValue">
|
||||
and ${criterion.condition} #{criterion.value} and #{criterion.secondValue}
|
||||
</when>
|
||||
<when test="criterion.listValue">
|
||||
and ${criterion.condition}
|
||||
<foreach close=")" collection="criterion.value" item="listItem" open="(" separator=",">
|
||||
#{listItem}
|
||||
</foreach>
|
||||
</when>
|
||||
</choose>
|
||||
</foreach>
|
||||
</trim>
|
||||
</if>
|
||||
</foreach>
|
||||
</where>
|
||||
</sql>
|
||||
<sql id="Base_Column_List">
|
||||
id, name, description, status, examples, config, created_by, created_at, updated_by,
|
||||
updated_at, enable_search
|
||||
</sql>
|
||||
<select id="selectByExample" parameterType="com.tencent.supersonic.chat.persistence.dataobject.AgentDOExample" resultMap="BaseResultMap">
|
||||
select
|
||||
<if test="distinct">
|
||||
distinct
|
||||
</if>
|
||||
<include refid="Base_Column_List" />
|
||||
from s2_agent
|
||||
<if test="_parameter != null">
|
||||
<include refid="Example_Where_Clause" />
|
||||
</if>
|
||||
<if test="orderByClause != null">
|
||||
order by ${orderByClause}
|
||||
</if>
|
||||
<if test="limitStart != null and limitStart>=0">
|
||||
limit #{limitStart} , #{limitEnd}
|
||||
</if>
|
||||
</select>
|
||||
<select id="selectByPrimaryKey" parameterType="java.lang.Integer" resultMap="BaseResultMap">
|
||||
select
|
||||
<include refid="Base_Column_List" />
|
||||
from s2_agent
|
||||
where id = #{id,jdbcType=INTEGER}
|
||||
</select>
|
||||
<delete id="deleteByPrimaryKey" parameterType="java.lang.Integer">
|
||||
delete from s2_agent
|
||||
where id = #{id,jdbcType=INTEGER}
|
||||
</delete>
|
||||
<insert id="insert" parameterType="com.tencent.supersonic.chat.persistence.dataobject.AgentDO">
|
||||
insert into s2_agent (id, name, description,
|
||||
status, examples, config,
|
||||
created_by, created_at, updated_by,
|
||||
updated_at, enable_search)
|
||||
values (#{id,jdbcType=INTEGER}, #{name,jdbcType=VARCHAR}, #{description,jdbcType=VARCHAR},
|
||||
#{status,jdbcType=INTEGER}, #{examples,jdbcType=VARCHAR}, #{config,jdbcType=VARCHAR},
|
||||
#{createdBy,jdbcType=VARCHAR}, #{createdAt,jdbcType=TIMESTAMP}, #{updatedBy,jdbcType=VARCHAR},
|
||||
#{updatedAt,jdbcType=TIMESTAMP}, #{enableSearch,jdbcType=INTEGER})
|
||||
</insert>
|
||||
<insert id="insertSelective" parameterType="com.tencent.supersonic.chat.persistence.dataobject.AgentDO">
|
||||
insert into s2_agent
|
||||
<trim prefix="(" suffix=")" suffixOverrides=",">
|
||||
<if test="id != null">
|
||||
id,
|
||||
</if>
|
||||
<if test="name != null">
|
||||
name,
|
||||
</if>
|
||||
<if test="description != null">
|
||||
description,
|
||||
</if>
|
||||
<if test="status != null">
|
||||
status,
|
||||
</if>
|
||||
<if test="examples != null">
|
||||
examples,
|
||||
</if>
|
||||
<if test="config != null">
|
||||
config,
|
||||
</if>
|
||||
<if test="createdBy != null">
|
||||
created_by,
|
||||
</if>
|
||||
<if test="createdAt != null">
|
||||
created_at,
|
||||
</if>
|
||||
<if test="updatedBy != null">
|
||||
updated_by,
|
||||
</if>
|
||||
<if test="updatedAt != null">
|
||||
updated_at,
|
||||
</if>
|
||||
<if test="enableSearch != null">
|
||||
enable_search,
|
||||
</if>
|
||||
</trim>
|
||||
<trim prefix="values (" suffix=")" suffixOverrides=",">
|
||||
<if test="id != null">
|
||||
#{id,jdbcType=INTEGER},
|
||||
</if>
|
||||
<if test="name != null">
|
||||
#{name,jdbcType=VARCHAR},
|
||||
</if>
|
||||
<if test="description != null">
|
||||
#{description,jdbcType=VARCHAR},
|
||||
</if>
|
||||
<if test="status != null">
|
||||
#{status,jdbcType=INTEGER},
|
||||
</if>
|
||||
<if test="examples != null">
|
||||
#{examples,jdbcType=VARCHAR},
|
||||
</if>
|
||||
<if test="config != null">
|
||||
#{config,jdbcType=VARCHAR},
|
||||
</if>
|
||||
<if test="createdBy != null">
|
||||
#{createdBy,jdbcType=VARCHAR},
|
||||
</if>
|
||||
<if test="createdAt != null">
|
||||
#{createdAt,jdbcType=TIMESTAMP},
|
||||
</if>
|
||||
<if test="updatedBy != null">
|
||||
#{updatedBy,jdbcType=VARCHAR},
|
||||
</if>
|
||||
<if test="updatedAt != null">
|
||||
#{updatedAt,jdbcType=TIMESTAMP},
|
||||
</if>
|
||||
<if test="enableSearch != null">
|
||||
#{enableSearch,jdbcType=INTEGER},
|
||||
</if>
|
||||
</trim>
|
||||
</insert>
|
||||
<select id="countByExample" parameterType="com.tencent.supersonic.chat.persistence.dataobject.AgentDOExample" resultType="java.lang.Long">
|
||||
select count(*) from s2_agent
|
||||
<if test="_parameter != null">
|
||||
<include refid="Example_Where_Clause" />
|
||||
</if>
|
||||
</select>
|
||||
<update id="updateByExampleSelective" parameterType="map">
|
||||
update s2_agent
|
||||
<set>
|
||||
<if test="record.id != null">
|
||||
id = #{record.id,jdbcType=INTEGER},
|
||||
</if>
|
||||
<if test="record.name != null">
|
||||
name = #{record.name,jdbcType=VARCHAR},
|
||||
</if>
|
||||
<if test="record.description != null">
|
||||
description = #{record.description,jdbcType=VARCHAR},
|
||||
</if>
|
||||
<if test="record.status != null">
|
||||
status = #{record.status,jdbcType=INTEGER},
|
||||
</if>
|
||||
<if test="record.examples != null">
|
||||
examples = #{record.examples,jdbcType=VARCHAR},
|
||||
</if>
|
||||
<if test="record.config != null">
|
||||
config = #{record.config,jdbcType=VARCHAR},
|
||||
</if>
|
||||
<if test="record.createdBy != null">
|
||||
created_by = #{record.createdBy,jdbcType=VARCHAR},
|
||||
</if>
|
||||
<if test="record.createdAt != null">
|
||||
created_at = #{record.createdAt,jdbcType=TIMESTAMP},
|
||||
</if>
|
||||
<if test="record.updatedBy != null">
|
||||
updated_by = #{record.updatedBy,jdbcType=VARCHAR},
|
||||
</if>
|
||||
<if test="record.updatedAt != null">
|
||||
updated_at = #{record.updatedAt,jdbcType=TIMESTAMP},
|
||||
</if>
|
||||
<if test="record.enableSearch != null">
|
||||
enable_search = #{record.enableSearch,jdbcType=INTEGER},
|
||||
</if>
|
||||
</set>
|
||||
<if test="_parameter != null">
|
||||
<include refid="Update_By_Example_Where_Clause" />
|
||||
</if>
|
||||
</update>
|
||||
<update id="updateByExample" parameterType="map">
|
||||
update s2_agent
|
||||
set id = #{record.id,jdbcType=INTEGER},
|
||||
name = #{record.name,jdbcType=VARCHAR},
|
||||
description = #{record.description,jdbcType=VARCHAR},
|
||||
status = #{record.status,jdbcType=INTEGER},
|
||||
examples = #{record.examples,jdbcType=VARCHAR},
|
||||
config = #{record.config,jdbcType=VARCHAR},
|
||||
created_by = #{record.createdBy,jdbcType=VARCHAR},
|
||||
created_at = #{record.createdAt,jdbcType=TIMESTAMP},
|
||||
updated_by = #{record.updatedBy,jdbcType=VARCHAR},
|
||||
updated_at = #{record.updatedAt,jdbcType=TIMESTAMP},
|
||||
enable_search = #{record.enableSearch,jdbcType=INTEGER}
|
||||
<if test="_parameter != null">
|
||||
<include refid="Update_By_Example_Where_Clause" />
|
||||
</if>
|
||||
</update>
|
||||
<update id="updateByPrimaryKeySelective" parameterType="com.tencent.supersonic.chat.persistence.dataobject.AgentDO">
|
||||
update s2_agent
|
||||
<set>
|
||||
<if test="name != null">
|
||||
name = #{name,jdbcType=VARCHAR},
|
||||
</if>
|
||||
<if test="description != null">
|
||||
description = #{description,jdbcType=VARCHAR},
|
||||
</if>
|
||||
<if test="status != null">
|
||||
status = #{status,jdbcType=INTEGER},
|
||||
</if>
|
||||
<if test="examples != null">
|
||||
examples = #{examples,jdbcType=VARCHAR},
|
||||
</if>
|
||||
<if test="config != null">
|
||||
config = #{config,jdbcType=VARCHAR},
|
||||
</if>
|
||||
<if test="createdBy != null">
|
||||
created_by = #{createdBy,jdbcType=VARCHAR},
|
||||
</if>
|
||||
<if test="createdAt != null">
|
||||
created_at = #{createdAt,jdbcType=TIMESTAMP},
|
||||
</if>
|
||||
<if test="updatedBy != null">
|
||||
updated_by = #{updatedBy,jdbcType=VARCHAR},
|
||||
</if>
|
||||
<if test="updatedAt != null">
|
||||
updated_at = #{updatedAt,jdbcType=TIMESTAMP},
|
||||
</if>
|
||||
<if test="enableSearch != null">
|
||||
enable_search = #{enableSearch,jdbcType=INTEGER},
|
||||
</if>
|
||||
</set>
|
||||
where id = #{id,jdbcType=INTEGER}
|
||||
</update>
|
||||
<update id="updateByPrimaryKey" parameterType="com.tencent.supersonic.chat.persistence.dataobject.AgentDO">
|
||||
update s2_agent
|
||||
set name = #{name,jdbcType=VARCHAR},
|
||||
description = #{description,jdbcType=VARCHAR},
|
||||
status = #{status,jdbcType=INTEGER},
|
||||
examples = #{examples,jdbcType=VARCHAR},
|
||||
config = #{config,jdbcType=VARCHAR},
|
||||
created_by = #{createdBy,jdbcType=VARCHAR},
|
||||
created_at = #{createdAt,jdbcType=TIMESTAMP},
|
||||
updated_by = #{updatedBy,jdbcType=VARCHAR},
|
||||
updated_at = #{updatedAt,jdbcType=TIMESTAMP},
|
||||
enable_search = #{enableSearch,jdbcType=INTEGER}
|
||||
where id = #{id,jdbcType=INTEGER}
|
||||
</update>
|
||||
</mapper>
|
||||
@@ -59,4 +59,18 @@ CREATE TABLE `chat_query`
|
||||
KEY `common` (`question_id`),
|
||||
KEY `common1` (`user_name`),
|
||||
KEY `common2` (`chat_id`)
|
||||
) ENGINE=InnoDB AUTO_INCREMENT=1 DEFAULT CHARSET=utf8;
|
||||
) ENGINE=InnoDB AUTO_INCREMENT=1 DEFAULT CHARSET=utf8;
|
||||
|
||||
|
||||
CREATE TABLE `chat`
|
||||
(
|
||||
`chat_id` bigint(8) NOT NULL AUTO_INCREMENT,
|
||||
`chat_name` varchar(100) DEFAULT NULL,
|
||||
`create_time` datetime DEFAULT NULL,
|
||||
`last_time` datetime DEFAULT NULL,
|
||||
`creator` varchar(30) DEFAULT NULL,
|
||||
`last_question` varchar(200) DEFAULT NULL,
|
||||
`is_delete` int(2) DEFAULT '0' COMMENT 'is deleted',
|
||||
`is_top` int(2) DEFAULT '0' COMMENT 'is top',
|
||||
PRIMARY KEY (`chat_id`)
|
||||
) ENGINE=InnoDB AUTO_INCREMENT=1 DEFAULT CHARSET=utf8;
|
||||
|
||||
@@ -0,0 +1,37 @@
|
||||
package com.tencent.supersonic.chat.query.dsl.optimizer;
|
||||
|
||||
import com.tencent.supersonic.chat.api.pojo.CorrectionInfo;
|
||||
import com.tencent.supersonic.chat.api.pojo.SchemaElement;
|
||||
import com.tencent.supersonic.chat.api.pojo.SemanticParseInfo;
|
||||
import org.junit.Assert;
|
||||
import org.junit.jupiter.api.Test;
|
||||
|
||||
class DateFieldCorrectorTest {
|
||||
|
||||
@Test
|
||||
void rewriter() {
|
||||
DateFieldCorrector dateFieldCorrector = new DateFieldCorrector();
|
||||
SemanticParseInfo parseInfo = new SemanticParseInfo();
|
||||
SchemaElement model = new SchemaElement();
|
||||
model.setId(2L);
|
||||
parseInfo.setModel(model);
|
||||
CorrectionInfo correctionInfo = CorrectionInfo.builder()
|
||||
.sql("select count(歌曲名) from 歌曲库 ")
|
||||
.parseInfo(parseInfo)
|
||||
.build();
|
||||
|
||||
CorrectionInfo rewriter = dateFieldCorrector.rewriter(correctionInfo);
|
||||
|
||||
Assert.assertEquals("SELECT count(歌曲名) FROM 歌曲库 WHERE 数据日期 = '2023-08-14'", rewriter.getSql());
|
||||
|
||||
correctionInfo = CorrectionInfo.builder()
|
||||
.sql("select count(歌曲名) from 歌曲库 where 数据日期 = '2023-08-14'")
|
||||
.parseInfo(parseInfo)
|
||||
.build();
|
||||
|
||||
rewriter = dateFieldCorrector.rewriter(correctionInfo);
|
||||
|
||||
Assert.assertEquals("select count(歌曲名) from 歌曲库 where 数据日期 = '2023-08-14'", rewriter.getSql());
|
||||
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,25 @@
|
||||
package com.tencent.supersonic.chat.query.dsl.optimizer;
|
||||
|
||||
import static org.junit.jupiter.api.Assertions.*;
|
||||
|
||||
import com.tencent.supersonic.chat.api.pojo.CorrectionInfo;
|
||||
import org.junit.Assert;
|
||||
import org.junit.jupiter.api.Test;
|
||||
|
||||
class SelectFieldAppendCorrectorTest {
|
||||
|
||||
@Test
|
||||
void rewriter() {
|
||||
SelectFieldAppendCorrector corrector = new SelectFieldAppendCorrector();
|
||||
CorrectionInfo correctionInfo = CorrectionInfo.builder()
|
||||
.sql("select 歌曲名 from 歌曲库 where datediff('day', 发布日期, '2023-08-09') <= 1 and 歌手名 = '邓紫棋' and sys_imp_date = '2023-08-09' and 歌曲发布时 = '2023-08-01' order by 播放量 desc limit 11")
|
||||
.build();
|
||||
|
||||
CorrectionInfo rewriter = corrector.rewriter(correctionInfo);
|
||||
|
||||
Assert.assertEquals(
|
||||
"SELECT 歌曲名, 歌手名, 歌曲发布时, 发布日期 FROM 歌曲库 WHERE datediff('day', 发布日期, '2023-08-09') <= 1 AND 歌手名 = '邓紫棋' AND sys_imp_date = '2023-08-09' AND 歌曲发布时 = '2023-08-01' ORDER BY 播放量 DESC LIMIT 11",
|
||||
rewriter.getSql());
|
||||
|
||||
}
|
||||
}
|
||||
@@ -1,13 +1,17 @@
|
||||
package com.tencent.supersonic.knowledge.dictionary.builder;
|
||||
|
||||
import com.google.common.collect.Lists;
|
||||
|
||||
import java.util.ArrayList;
|
||||
import java.util.List;
|
||||
|
||||
import com.tencent.supersonic.chat.api.pojo.SchemaElement;
|
||||
import com.tencent.supersonic.knowledge.dictionary.DictWord;
|
||||
import com.tencent.supersonic.knowledge.dictionary.DictWordType;
|
||||
import java.util.List;
|
||||
import org.apache.commons.lang3.StringUtils;
|
||||
import org.springframework.beans.factory.annotation.Value;
|
||||
import org.springframework.stereotype.Service;
|
||||
import org.springframework.util.CollectionUtils;
|
||||
|
||||
/**
|
||||
* dimension word nature
|
||||
@@ -23,6 +27,7 @@ public class DimensionWordBuilder extends BaseWordBuilder {
|
||||
public List<DictWord> doGet(String word, SchemaElement schemaElement) {
|
||||
List<DictWord> result = Lists.newArrayList();
|
||||
result.add(getOnwWordNature(word, schemaElement, false));
|
||||
result.addAll(getOnwWordNatureAlias(schemaElement, false));
|
||||
if (nlpDimensionUseSuffix) {
|
||||
String reverseWord = StringUtils.reverse(word);
|
||||
if (StringUtils.isNotEmpty(word) && !word.equalsIgnoreCase(reverseWord)) {
|
||||
@@ -46,4 +51,16 @@ public class DimensionWordBuilder extends BaseWordBuilder {
|
||||
return dictWord;
|
||||
}
|
||||
|
||||
private List<DictWord> getOnwWordNatureAlias(SchemaElement schemaElement, boolean isSuffix) {
|
||||
List<DictWord> dictWords = new ArrayList<>();
|
||||
if (CollectionUtils.isEmpty(schemaElement.getAlias())) {
|
||||
return dictWords;
|
||||
}
|
||||
|
||||
for (String alias : schemaElement.getAlias()) {
|
||||
dictWords.add(getOnwWordNature(alias, schemaElement, false));
|
||||
}
|
||||
return dictWords;
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
@@ -1,13 +1,17 @@
|
||||
package com.tencent.supersonic.knowledge.dictionary.builder;
|
||||
|
||||
import com.google.common.collect.Lists;
|
||||
|
||||
import java.util.ArrayList;
|
||||
import java.util.List;
|
||||
|
||||
import com.tencent.supersonic.chat.api.pojo.SchemaElement;
|
||||
import com.tencent.supersonic.knowledge.dictionary.DictWord;
|
||||
import com.tencent.supersonic.knowledge.dictionary.DictWordType;
|
||||
import java.util.List;
|
||||
import org.apache.commons.lang3.StringUtils;
|
||||
import org.springframework.beans.factory.annotation.Value;
|
||||
import org.springframework.stereotype.Service;
|
||||
import org.springframework.util.CollectionUtils;
|
||||
|
||||
/**
|
||||
* Metric DictWord
|
||||
@@ -22,6 +26,7 @@ public class MetricWordBuilder extends BaseWordBuilder {
|
||||
public List<DictWord> doGet(String word, SchemaElement schemaElement) {
|
||||
List<DictWord> result = Lists.newArrayList();
|
||||
result.add(getOnwWordNature(word, schemaElement, false));
|
||||
result.addAll(getOnwWordNatureAlias(schemaElement, false));
|
||||
if (nlpMetricUseSuffix) {
|
||||
String reverseWord = StringUtils.reverse(word);
|
||||
if (!word.equalsIgnoreCase(reverseWord)) {
|
||||
@@ -45,4 +50,16 @@ public class MetricWordBuilder extends BaseWordBuilder {
|
||||
return dictWord;
|
||||
}
|
||||
|
||||
private List<DictWord> getOnwWordNatureAlias(SchemaElement schemaElement, boolean isSuffix) {
|
||||
List<DictWord> dictWords = new ArrayList<>();
|
||||
if (CollectionUtils.isEmpty(schemaElement.getAlias())) {
|
||||
return dictWords;
|
||||
}
|
||||
|
||||
for (String alias : schemaElement.getAlias()) {
|
||||
dictWords.add(getOnwWordNature(alias, schemaElement, false));
|
||||
}
|
||||
return dictWords;
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
@@ -2,51 +2,38 @@ package com.tencent.supersonic.knowledge.semantic;
|
||||
|
||||
import com.github.pagehelper.PageInfo;
|
||||
import com.tencent.supersonic.auth.api.authentication.pojo.User;
|
||||
import com.tencent.supersonic.common.pojo.enums.AuthType;
|
||||
import com.tencent.supersonic.common.util.ContextUtils;
|
||||
import com.tencent.supersonic.common.util.JsonUtil;
|
||||
import com.tencent.supersonic.common.util.S2ThreadContext;
|
||||
import com.tencent.supersonic.common.pojo.enums.AuthType;
|
||||
import com.tencent.supersonic.semantic.api.model.request.ModelSchemaFilterReq;
|
||||
import com.tencent.supersonic.semantic.api.model.request.PageDimensionReq;
|
||||
import com.tencent.supersonic.semantic.api.model.request.PageMetricReq;
|
||||
import com.tencent.supersonic.semantic.api.model.response.DimensionResp;
|
||||
import com.tencent.supersonic.semantic.api.model.response.DomainResp;
|
||||
import com.tencent.supersonic.semantic.api.model.response.MetricResp;
|
||||
import com.tencent.supersonic.semantic.api.model.response.ModelResp;
|
||||
import com.tencent.supersonic.semantic.api.model.response.ModelSchemaResp;
|
||||
import com.tencent.supersonic.semantic.api.model.response.QueryResultWithSchemaResp;
|
||||
import com.tencent.supersonic.semantic.api.model.response.*;
|
||||
import com.tencent.supersonic.semantic.api.query.request.QueryDslReq;
|
||||
import com.tencent.supersonic.semantic.api.query.request.QueryMultiStructReq;
|
||||
import com.tencent.supersonic.semantic.api.query.request.QueryStructReq;
|
||||
import com.tencent.supersonic.semantic.model.domain.DimensionService;
|
||||
import com.tencent.supersonic.semantic.model.domain.DomainService;
|
||||
import com.tencent.supersonic.semantic.model.domain.MetricService;
|
||||
import com.tencent.supersonic.semantic.model.domain.ModelService;
|
||||
import com.tencent.supersonic.semantic.query.service.QueryService;
|
||||
import com.tencent.supersonic.semantic.query.service.SchemaService;
|
||||
import java.util.List;
|
||||
import lombok.SneakyThrows;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
|
||||
@Slf4j
|
||||
public class LocalSemanticLayer extends BaseSemanticLayer {
|
||||
|
||||
private SchemaService schemaService;
|
||||
private S2ThreadContext s2ThreadContext;
|
||||
private DomainService domainService;
|
||||
private ModelService modelService;
|
||||
private DimensionService dimensionService;
|
||||
private MetricService metricService;
|
||||
|
||||
@SneakyThrows
|
||||
@Override
|
||||
public QueryResultWithSchemaResp queryByStruct(QueryStructReq queryStructReq, User user) {
|
||||
try {
|
||||
QueryService queryService = ContextUtils.getBean(QueryService.class);
|
||||
QueryResultWithSchemaResp queryResultWithSchemaResp = queryService.queryByStruct(queryStructReq, user);
|
||||
return queryResultWithSchemaResp;
|
||||
} catch (Exception e) {
|
||||
log.info("queryByStruct has an exception:{}", e.toString());
|
||||
}
|
||||
return null;
|
||||
public QueryResultWithSchemaResp queryByStruct(QueryStructReq queryStructReq, User user){
|
||||
QueryService queryService = ContextUtils.getBean(QueryService.class);
|
||||
return queryService.queryByStructWithAuth(queryStructReq, user);
|
||||
}
|
||||
|
||||
@Override
|
||||
|
||||
@@ -8,21 +8,17 @@ import com.tencent.supersonic.semantic.api.model.pojo.Entity;
|
||||
import com.tencent.supersonic.semantic.api.model.response.DimSchemaResp;
|
||||
import com.tencent.supersonic.semantic.api.model.response.MetricSchemaResp;
|
||||
import com.tencent.supersonic.semantic.api.model.response.ModelSchemaResp;
|
||||
import java.util.ArrayList;
|
||||
import java.util.Arrays;
|
||||
import java.util.HashSet;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
import java.util.Objects;
|
||||
import java.util.Set;
|
||||
import java.util.stream.Collectors;
|
||||
import org.apache.commons.lang3.StringUtils;
|
||||
import org.apache.logging.log4j.util.Strings;
|
||||
import org.springframework.beans.BeanUtils;
|
||||
import org.springframework.util.CollectionUtils;
|
||||
|
||||
import java.util.*;
|
||||
import java.util.stream.Collectors;
|
||||
|
||||
public class ModelSchemaBuilder {
|
||||
|
||||
private static String aliasSplit = ",";
|
||||
|
||||
public static ModelSchema build(ModelSchemaResp resp) {
|
||||
ModelSchema domainSchema = new ModelSchema();
|
||||
|
||||
@@ -37,6 +33,13 @@ public class ModelSchemaBuilder {
|
||||
|
||||
Set<SchemaElement> metrics = new HashSet<>();
|
||||
for (MetricSchemaResp metric : resp.getMetrics()) {
|
||||
|
||||
List<String> alias = new ArrayList<>();
|
||||
String aliasStr = metric.getAlias();
|
||||
if (Strings.isNotEmpty(aliasStr)) {
|
||||
alias = Arrays.asList(aliasStr.split(aliasSplit));
|
||||
}
|
||||
|
||||
SchemaElement metricToAdd = SchemaElement.builder()
|
||||
.model(resp.getId())
|
||||
.id(metric.getId())
|
||||
@@ -44,16 +47,10 @@ public class ModelSchemaBuilder {
|
||||
.bizName(metric.getBizName())
|
||||
.type(SchemaElementType.METRIC)
|
||||
.useCnt(metric.getUseCnt())
|
||||
.alias(alias)
|
||||
.build();
|
||||
metrics.add(metricToAdd);
|
||||
|
||||
String alias = metric.getAlias();
|
||||
if (StringUtils.isNotEmpty(alias)) {
|
||||
SchemaElement alisMetricToAdd = new SchemaElement();
|
||||
BeanUtils.copyProperties(metricToAdd, alisMetricToAdd);
|
||||
alisMetricToAdd.setName(alias);
|
||||
metrics.add(alisMetricToAdd);
|
||||
}
|
||||
}
|
||||
domainSchema.getMetrics().addAll(metrics);
|
||||
|
||||
@@ -74,6 +71,11 @@ public class ModelSchemaBuilder {
|
||||
}
|
||||
}
|
||||
|
||||
List<String> alias = new ArrayList<>();
|
||||
String aliasStr = dim.getAlias();
|
||||
if (Strings.isNotEmpty(aliasStr)) {
|
||||
alias = Arrays.asList(aliasStr.split(aliasSplit));
|
||||
}
|
||||
SchemaElement dimToAdd = SchemaElement.builder()
|
||||
.model(resp.getId())
|
||||
.id(dim.getId())
|
||||
@@ -81,17 +83,10 @@ public class ModelSchemaBuilder {
|
||||
.bizName(dim.getBizName())
|
||||
.type(SchemaElementType.DIMENSION)
|
||||
.useCnt(dim.getUseCnt())
|
||||
.alias(alias)
|
||||
.build();
|
||||
dimensions.add(dimToAdd);
|
||||
|
||||
String alias = dim.getAlias();
|
||||
if (StringUtils.isNotEmpty(alias)) {
|
||||
SchemaElement alisDimToAdd = new SchemaElement();
|
||||
BeanUtils.copyProperties(dimToAdd, alisDimToAdd);
|
||||
alisDimToAdd.setName(alias);
|
||||
dimensions.add(alisDimToAdd);
|
||||
}
|
||||
|
||||
SchemaElement dimValueToAdd = SchemaElement.builder()
|
||||
.model(resp.getId())
|
||||
.id(dim.getId())
|
||||
@@ -115,7 +110,7 @@ public class ModelSchemaBuilder {
|
||||
.collect(
|
||||
Collectors.toMap(SchemaElement::getId, schemaElement -> schemaElement, (k1, k2) -> k2));
|
||||
if (idAndDimPair.containsKey(entity.getEntityId())) {
|
||||
entityElement = idAndDimPair.get(entity.getEntityId());
|
||||
BeanUtils.copyProperties(idAndDimPair.get(entity.getEntityId()), entityElement);
|
||||
entityElement.setType(SchemaElementType.ENTITY);
|
||||
}
|
||||
entityElement.setAlias(entity.getNames());
|
||||
|
||||
Reference in New Issue
Block a user