(improvement)(Chat) fix QueryFilter putting element (#848)

Co-authored-by: jolunoluo
This commit is contained in:
LXW
2024-03-21 14:11:00 +08:00
committed by GitHub
parent 031b2bff5f
commit dfba275811
5 changed files with 33 additions and 14 deletions

View File

@@ -3,9 +3,11 @@ package com.tencent.supersonic.chat.server.parser;
import com.tencent.supersonic.chat.server.pojo.ChatParseContext;
import com.tencent.supersonic.chat.server.util.QueryReqConverter;
import com.tencent.supersonic.common.util.ContextUtils;
import com.tencent.supersonic.headless.api.pojo.SemanticParseInfo;
import com.tencent.supersonic.headless.api.pojo.request.QueryReq;
import com.tencent.supersonic.headless.api.pojo.response.ParseResp;
import com.tencent.supersonic.headless.server.service.ChatQueryService;
import java.util.List;
public class NL2SQLParser implements ChatParser {
@@ -14,6 +16,9 @@ public class NL2SQLParser implements ChatParser {
if (!chatParseContext.enableNL2SQL()) {
return;
}
if (checkSkip(parseResp)) {
return;
}
QueryReq queryReq = QueryReqConverter.buildText2SqlQueryReq(chatParseContext);
ChatQueryService chatQueryService = ContextUtils.getBean(ChatQueryService.class);
ParseResp text2SqlParseResp = chatQueryService.performParsing(queryReq);
@@ -22,4 +27,14 @@ public class NL2SQLParser implements ChatParser {
}
}
private boolean checkSkip(ParseResp parseResp) {
List<SemanticParseInfo> selectedParses = parseResp.getSelectedParses();
for (SemanticParseInfo semanticParseInfo : selectedParses) {
if (semanticParseInfo.getScore() >= parseResp.getQueryText().length()) {
return true;
}
}
return false;
}
}

View File

@@ -9,6 +9,7 @@ import com.tencent.supersonic.chat.server.plugin.PluginRecallResult;
import com.tencent.supersonic.chat.server.pojo.ChatParseContext;
import com.tencent.supersonic.common.pojo.Constants;
import com.tencent.supersonic.common.pojo.enums.FilterOperatorEnum;
import com.tencent.supersonic.headless.api.pojo.SchemaElement;
import com.tencent.supersonic.headless.api.pojo.SchemaElementMatch;
import com.tencent.supersonic.headless.api.pojo.SchemaElementType;
import com.tencent.supersonic.headless.api.pojo.SemanticParseInfo;
@@ -71,6 +72,9 @@ public abstract class PluginRecognizer {
}
SemanticParseInfo semanticParseInfo = new SemanticParseInfo();
semanticParseInfo.setElementMatches(schemaElementMatches);
SchemaElement schemaElement = new SchemaElement();
schemaElement.setDataSet(dataSetId);
semanticParseInfo.setDataSet(schemaElement);
Map<String, Object> properties = new HashMap<>();
PluginParseResult pluginParseResult = new PluginParseResult();
pluginParseResult.setPlugin(plugin);

View File

@@ -49,11 +49,11 @@ public class QueryFilterMapper extends BaseMapper {
}
}
private List<SchemaElementMatch> addValueSchemaElementMatch(Long viewId, QueryContext queryContext,
private void addValueSchemaElementMatch(Long dataSetId, QueryContext queryContext,
List<SchemaElementMatch> candidateElementMatches) {
QueryFilters queryFilters = queryContext.getQueryFilters();
if (queryFilters == null || CollectionUtils.isEmpty(queryFilters.getFilters())) {
return candidateElementMatches;
return;
}
for (QueryFilter filter : queryFilters.getFilters()) {
if (checkExistSameValueSchemaElementMatch(filter, candidateElementMatches)) {
@@ -64,7 +64,7 @@ public class QueryFilterMapper extends BaseMapper {
.name(String.valueOf(filter.getValue()))
.type(SchemaElementType.VALUE)
.bizName(filter.getBizName())
.dataSet(viewId)
.dataSet(dataSetId)
.build();
SchemaElementMatch schemaElementMatch = SchemaElementMatch.builder()
.element(element)
@@ -75,7 +75,7 @@ public class QueryFilterMapper extends BaseMapper {
.build();
candidateElementMatches.add(schemaElementMatch);
}
return candidateElementMatches;
queryContext.getMapInfo().setMatchedElements(dataSetId, candidateElementMatches);
}
private boolean checkExistSameValueSchemaElementMatch(QueryFilter queryFilter,

View File

@@ -1,6 +1,6 @@
package com.tencent.supersonic.headless.server.service.impl;
import com.alibaba.fastjson.JSONObject;
import com.alibaba.fastjson2.JSONObject;
import com.baomidou.mybatisplus.core.conditions.query.QueryWrapper;
import com.baomidou.mybatisplus.extension.service.impl.ServiceImpl;
import com.google.common.cache.Cache;
@@ -35,6 +35,13 @@ import com.tencent.supersonic.headless.server.service.DimensionService;
import com.tencent.supersonic.headless.server.service.DomainService;
import com.tencent.supersonic.headless.server.service.MetricService;
import com.tencent.supersonic.headless.server.service.TagMetaService;
import org.apache.commons.lang3.StringUtils;
import org.apache.commons.lang3.tuple.Pair;
import org.springframework.beans.BeanUtils;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.context.annotation.Lazy;
import org.springframework.stereotype.Service;
import org.springframework.util.CollectionUtils;
import java.util.Arrays;
import java.util.Comparator;
import java.util.Date;
@@ -45,13 +52,6 @@ import java.util.Set;
import java.util.concurrent.TimeUnit;
import java.util.function.Function;
import java.util.stream.Collectors;
import org.apache.commons.lang3.StringUtils;
import org.apache.commons.lang3.tuple.Pair;
import org.springframework.beans.BeanUtils;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.context.annotation.Lazy;
import org.springframework.stereotype.Service;
import org.springframework.util.CollectionUtils;
@Service
public class DataSetServiceImpl

View File

@@ -59,8 +59,8 @@ s2:
langchain4j:
#1.chat-model
chat-model:
provider: local_ai
local_ai:
provider: open_ai
openai:
api-key: api_key
model-name: gpt-3.5-turbo-16k
temperature: 0.0