(improvement)(Chat) fix QueryFilter putting element (#848)

Co-authored-by: jolunoluo
This commit is contained in:
LXW
2024-03-21 14:11:00 +08:00
committed by GitHub
parent 031b2bff5f
commit dfba275811
5 changed files with 33 additions and 14 deletions

View File

@@ -3,9 +3,11 @@ package com.tencent.supersonic.chat.server.parser;
import com.tencent.supersonic.chat.server.pojo.ChatParseContext; import com.tencent.supersonic.chat.server.pojo.ChatParseContext;
import com.tencent.supersonic.chat.server.util.QueryReqConverter; import com.tencent.supersonic.chat.server.util.QueryReqConverter;
import com.tencent.supersonic.common.util.ContextUtils; import com.tencent.supersonic.common.util.ContextUtils;
import com.tencent.supersonic.headless.api.pojo.SemanticParseInfo;
import com.tencent.supersonic.headless.api.pojo.request.QueryReq; import com.tencent.supersonic.headless.api.pojo.request.QueryReq;
import com.tencent.supersonic.headless.api.pojo.response.ParseResp; import com.tencent.supersonic.headless.api.pojo.response.ParseResp;
import com.tencent.supersonic.headless.server.service.ChatQueryService; import com.tencent.supersonic.headless.server.service.ChatQueryService;
import java.util.List;
public class NL2SQLParser implements ChatParser { public class NL2SQLParser implements ChatParser {
@@ -14,6 +16,9 @@ public class NL2SQLParser implements ChatParser {
if (!chatParseContext.enableNL2SQL()) { if (!chatParseContext.enableNL2SQL()) {
return; return;
} }
if (checkSkip(parseResp)) {
return;
}
QueryReq queryReq = QueryReqConverter.buildText2SqlQueryReq(chatParseContext); QueryReq queryReq = QueryReqConverter.buildText2SqlQueryReq(chatParseContext);
ChatQueryService chatQueryService = ContextUtils.getBean(ChatQueryService.class); ChatQueryService chatQueryService = ContextUtils.getBean(ChatQueryService.class);
ParseResp text2SqlParseResp = chatQueryService.performParsing(queryReq); ParseResp text2SqlParseResp = chatQueryService.performParsing(queryReq);
@@ -22,4 +27,14 @@ public class NL2SQLParser implements ChatParser {
} }
} }
private boolean checkSkip(ParseResp parseResp) {
List<SemanticParseInfo> selectedParses = parseResp.getSelectedParses();
for (SemanticParseInfo semanticParseInfo : selectedParses) {
if (semanticParseInfo.getScore() >= parseResp.getQueryText().length()) {
return true;
}
}
return false;
}
} }

View File

@@ -9,6 +9,7 @@ import com.tencent.supersonic.chat.server.plugin.PluginRecallResult;
import com.tencent.supersonic.chat.server.pojo.ChatParseContext; import com.tencent.supersonic.chat.server.pojo.ChatParseContext;
import com.tencent.supersonic.common.pojo.Constants; import com.tencent.supersonic.common.pojo.Constants;
import com.tencent.supersonic.common.pojo.enums.FilterOperatorEnum; import com.tencent.supersonic.common.pojo.enums.FilterOperatorEnum;
import com.tencent.supersonic.headless.api.pojo.SchemaElement;
import com.tencent.supersonic.headless.api.pojo.SchemaElementMatch; import com.tencent.supersonic.headless.api.pojo.SchemaElementMatch;
import com.tencent.supersonic.headless.api.pojo.SchemaElementType; import com.tencent.supersonic.headless.api.pojo.SchemaElementType;
import com.tencent.supersonic.headless.api.pojo.SemanticParseInfo; import com.tencent.supersonic.headless.api.pojo.SemanticParseInfo;
@@ -71,6 +72,9 @@ public abstract class PluginRecognizer {
} }
SemanticParseInfo semanticParseInfo = new SemanticParseInfo(); SemanticParseInfo semanticParseInfo = new SemanticParseInfo();
semanticParseInfo.setElementMatches(schemaElementMatches); semanticParseInfo.setElementMatches(schemaElementMatches);
SchemaElement schemaElement = new SchemaElement();
schemaElement.setDataSet(dataSetId);
semanticParseInfo.setDataSet(schemaElement);
Map<String, Object> properties = new HashMap<>(); Map<String, Object> properties = new HashMap<>();
PluginParseResult pluginParseResult = new PluginParseResult(); PluginParseResult pluginParseResult = new PluginParseResult();
pluginParseResult.setPlugin(plugin); pluginParseResult.setPlugin(plugin);

View File

@@ -49,11 +49,11 @@ public class QueryFilterMapper extends BaseMapper {
} }
} }
private List<SchemaElementMatch> addValueSchemaElementMatch(Long viewId, QueryContext queryContext, private void addValueSchemaElementMatch(Long dataSetId, QueryContext queryContext,
List<SchemaElementMatch> candidateElementMatches) { List<SchemaElementMatch> candidateElementMatches) {
QueryFilters queryFilters = queryContext.getQueryFilters(); QueryFilters queryFilters = queryContext.getQueryFilters();
if (queryFilters == null || CollectionUtils.isEmpty(queryFilters.getFilters())) { if (queryFilters == null || CollectionUtils.isEmpty(queryFilters.getFilters())) {
return candidateElementMatches; return;
} }
for (QueryFilter filter : queryFilters.getFilters()) { for (QueryFilter filter : queryFilters.getFilters()) {
if (checkExistSameValueSchemaElementMatch(filter, candidateElementMatches)) { if (checkExistSameValueSchemaElementMatch(filter, candidateElementMatches)) {
@@ -64,7 +64,7 @@ public class QueryFilterMapper extends BaseMapper {
.name(String.valueOf(filter.getValue())) .name(String.valueOf(filter.getValue()))
.type(SchemaElementType.VALUE) .type(SchemaElementType.VALUE)
.bizName(filter.getBizName()) .bizName(filter.getBizName())
.dataSet(viewId) .dataSet(dataSetId)
.build(); .build();
SchemaElementMatch schemaElementMatch = SchemaElementMatch.builder() SchemaElementMatch schemaElementMatch = SchemaElementMatch.builder()
.element(element) .element(element)
@@ -75,7 +75,7 @@ public class QueryFilterMapper extends BaseMapper {
.build(); .build();
candidateElementMatches.add(schemaElementMatch); candidateElementMatches.add(schemaElementMatch);
} }
return candidateElementMatches; queryContext.getMapInfo().setMatchedElements(dataSetId, candidateElementMatches);
} }
private boolean checkExistSameValueSchemaElementMatch(QueryFilter queryFilter, private boolean checkExistSameValueSchemaElementMatch(QueryFilter queryFilter,

View File

@@ -1,6 +1,6 @@
package com.tencent.supersonic.headless.server.service.impl; package com.tencent.supersonic.headless.server.service.impl;
import com.alibaba.fastjson.JSONObject; import com.alibaba.fastjson2.JSONObject;
import com.baomidou.mybatisplus.core.conditions.query.QueryWrapper; import com.baomidou.mybatisplus.core.conditions.query.QueryWrapper;
import com.baomidou.mybatisplus.extension.service.impl.ServiceImpl; import com.baomidou.mybatisplus.extension.service.impl.ServiceImpl;
import com.google.common.cache.Cache; import com.google.common.cache.Cache;
@@ -35,6 +35,13 @@ import com.tencent.supersonic.headless.server.service.DimensionService;
import com.tencent.supersonic.headless.server.service.DomainService; import com.tencent.supersonic.headless.server.service.DomainService;
import com.tencent.supersonic.headless.server.service.MetricService; import com.tencent.supersonic.headless.server.service.MetricService;
import com.tencent.supersonic.headless.server.service.TagMetaService; import com.tencent.supersonic.headless.server.service.TagMetaService;
import org.apache.commons.lang3.StringUtils;
import org.apache.commons.lang3.tuple.Pair;
import org.springframework.beans.BeanUtils;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.context.annotation.Lazy;
import org.springframework.stereotype.Service;
import org.springframework.util.CollectionUtils;
import java.util.Arrays; import java.util.Arrays;
import java.util.Comparator; import java.util.Comparator;
import java.util.Date; import java.util.Date;
@@ -45,13 +52,6 @@ import java.util.Set;
import java.util.concurrent.TimeUnit; import java.util.concurrent.TimeUnit;
import java.util.function.Function; import java.util.function.Function;
import java.util.stream.Collectors; import java.util.stream.Collectors;
import org.apache.commons.lang3.StringUtils;
import org.apache.commons.lang3.tuple.Pair;
import org.springframework.beans.BeanUtils;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.context.annotation.Lazy;
import org.springframework.stereotype.Service;
import org.springframework.util.CollectionUtils;
@Service @Service
public class DataSetServiceImpl public class DataSetServiceImpl

View File

@@ -59,8 +59,8 @@ s2:
langchain4j: langchain4j:
#1.chat-model #1.chat-model
chat-model: chat-model:
provider: local_ai provider: open_ai
local_ai: openai:
api-key: api_key api-key: api_key
model-name: gpt-3.5-turbo-16k model-name: gpt-3.5-turbo-16k
temperature: 0.0 temperature: 0.0