(improvement)(chat) Put queryText to PluginParseResult (#673)

Co-authored-by: jolunoluo
This commit is contained in:
LXW
2024-01-21 16:26:49 +08:00
committed by GitHub
parent 97c767a45b
commit b28eb637c8
3 changed files with 10 additions and 6 deletions

View File

@@ -20,11 +20,11 @@ import com.tencent.supersonic.chat.core.query.plugin.PluginSemanticQuery;
import com.tencent.supersonic.common.pojo.Constants; import com.tencent.supersonic.common.pojo.Constants;
import com.tencent.supersonic.common.pojo.ModelCluster; import com.tencent.supersonic.common.pojo.ModelCluster;
import com.tencent.supersonic.common.pojo.enums.FilterOperatorEnum; import com.tencent.supersonic.common.pojo.enums.FilterOperatorEnum;
import org.springframework.util.CollectionUtils;
import java.util.HashMap; import java.util.HashMap;
import java.util.List; import java.util.List;
import java.util.Map; import java.util.Map;
import java.util.Set; import java.util.Set;
import org.springframework.util.CollectionUtils;
/** /**
@@ -63,8 +63,7 @@ public abstract class PluginParser implements SemanticParser {
for (Long modelId : modelIds) { for (Long modelId : modelIds) {
PluginSemanticQuery pluginQuery = QueryManager.createPluginQuery(plugin.getType()); PluginSemanticQuery pluginQuery = QueryManager.createPluginQuery(plugin.getType());
SemanticParseInfo semanticParseInfo = buildSemanticParseInfo(modelId, plugin, SemanticParseInfo semanticParseInfo = buildSemanticParseInfo(modelId, plugin,
queryContext.getQueryFilters(), queryContext.getModelClusterMapInfo().getMatchedElements(modelId), queryContext, pluginRecallResult.getDistance());
pluginRecallResult.getDistance());
semanticParseInfo.setQueryMode(pluginQuery.getQueryMode()); semanticParseInfo.setQueryMode(pluginQuery.getQueryMode());
semanticParseInfo.setScore(pluginRecallResult.getScore()); semanticParseInfo.setScore(pluginRecallResult.getScore());
pluginQuery.setParseInfo(semanticParseInfo); pluginQuery.setParseInfo(semanticParseInfo);
@@ -76,8 +75,11 @@ public abstract class PluginParser implements SemanticParser {
return PluginManager.getPluginAgentCanSupport(queryContext); return PluginManager.getPluginAgentCanSupport(queryContext);
} }
protected SemanticParseInfo buildSemanticParseInfo(Long modelId, Plugin plugin, QueryFilters queryFilters, protected SemanticParseInfo buildSemanticParseInfo(Long modelId, Plugin plugin,
List<SchemaElementMatch> schemaElementMatches, double distance) { QueryContext queryContext, double distance) {
List<SchemaElementMatch> schemaElementMatches =
queryContext.getModelClusterMapInfo().getMatchedElements(modelId);
QueryFilters queryFilters = queryContext.getQueryFilters();
if (modelId == null && !CollectionUtils.isEmpty(plugin.getModelList())) { if (modelId == null && !CollectionUtils.isEmpty(plugin.getModelList())) {
modelId = plugin.getModelList().get(0); modelId = plugin.getModelList().get(0);
} }
@@ -92,6 +94,7 @@ public abstract class PluginParser implements SemanticParser {
pluginParseResult.setPlugin(plugin); pluginParseResult.setPlugin(plugin);
pluginParseResult.setQueryFilters(queryFilters); pluginParseResult.setQueryFilters(queryFilters);
pluginParseResult.setDistance(distance); pluginParseResult.setDistance(distance);
pluginParseResult.setQueryText(queryContext.getQueryText());
properties.put(Constants.CONTEXT, pluginParseResult); properties.put(Constants.CONTEXT, pluginParseResult);
properties.put("type", "plugin"); properties.put("type", "plugin");
properties.put("name", plugin.getName()); properties.put("name", plugin.getName());

View File

@@ -9,4 +9,5 @@ public class PluginParseResult {
private Plugin plugin; private Plugin plugin;
private QueryFilters queryFilters; private QueryFilters queryFilters;
private double distance; private double distance;
private String queryText;
} }

View File

@@ -334,7 +334,7 @@ public class DimensionServiceImpl implements DimensionService {
.collect(Collectors.toMap(DimensionResp::getName, a -> a, (k1, k2) -> k1)); .collect(Collectors.toMap(DimensionResp::getName, a -> a, (k1, k2) -> k1));
for (DimensionReq dimensionReq : dimensionReqs) { for (DimensionReq dimensionReq : dimensionReqs) {
if (NameCheckUtils.containsSpecialCharacters(dimensionReq.getName())) { if (NameCheckUtils.containsSpecialCharacters(dimensionReq.getName())) {
throw new InvalidArgumentException("名称包含特殊字符, 请修改"); throw new InvalidArgumentException("名称包含特殊字符, 请修改: " + dimensionReq.getName());
} }
if (bizNameMap.containsKey(dimensionReq.getBizName())) { if (bizNameMap.containsKey(dimensionReq.getBizName())) {
DimensionResp dimensionResp = bizNameMap.get(dimensionReq.getBizName()); DimensionResp dimensionResp = bizNameMap.get(dimensionReq.getBizName());