mirror of
https://github.com/tencentmusic/supersonic.git
synced 2025-12-12 20:51:48 +00:00
(improvement)(Chat) fix QueryFilter putting element (#848)
Co-authored-by: jolunoluo
This commit is contained in:
@@ -3,9 +3,11 @@ package com.tencent.supersonic.chat.server.parser;
|
|||||||
import com.tencent.supersonic.chat.server.pojo.ChatParseContext;
|
import com.tencent.supersonic.chat.server.pojo.ChatParseContext;
|
||||||
import com.tencent.supersonic.chat.server.util.QueryReqConverter;
|
import com.tencent.supersonic.chat.server.util.QueryReqConverter;
|
||||||
import com.tencent.supersonic.common.util.ContextUtils;
|
import com.tencent.supersonic.common.util.ContextUtils;
|
||||||
|
import com.tencent.supersonic.headless.api.pojo.SemanticParseInfo;
|
||||||
import com.tencent.supersonic.headless.api.pojo.request.QueryReq;
|
import com.tencent.supersonic.headless.api.pojo.request.QueryReq;
|
||||||
import com.tencent.supersonic.headless.api.pojo.response.ParseResp;
|
import com.tencent.supersonic.headless.api.pojo.response.ParseResp;
|
||||||
import com.tencent.supersonic.headless.server.service.ChatQueryService;
|
import com.tencent.supersonic.headless.server.service.ChatQueryService;
|
||||||
|
import java.util.List;
|
||||||
|
|
||||||
public class NL2SQLParser implements ChatParser {
|
public class NL2SQLParser implements ChatParser {
|
||||||
|
|
||||||
@@ -14,6 +16,9 @@ public class NL2SQLParser implements ChatParser {
|
|||||||
if (!chatParseContext.enableNL2SQL()) {
|
if (!chatParseContext.enableNL2SQL()) {
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
if (checkSkip(parseResp)) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
QueryReq queryReq = QueryReqConverter.buildText2SqlQueryReq(chatParseContext);
|
QueryReq queryReq = QueryReqConverter.buildText2SqlQueryReq(chatParseContext);
|
||||||
ChatQueryService chatQueryService = ContextUtils.getBean(ChatQueryService.class);
|
ChatQueryService chatQueryService = ContextUtils.getBean(ChatQueryService.class);
|
||||||
ParseResp text2SqlParseResp = chatQueryService.performParsing(queryReq);
|
ParseResp text2SqlParseResp = chatQueryService.performParsing(queryReq);
|
||||||
@@ -22,4 +27,14 @@ public class NL2SQLParser implements ChatParser {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
private boolean checkSkip(ParseResp parseResp) {
|
||||||
|
List<SemanticParseInfo> selectedParses = parseResp.getSelectedParses();
|
||||||
|
for (SemanticParseInfo semanticParseInfo : selectedParses) {
|
||||||
|
if (semanticParseInfo.getScore() >= parseResp.getQueryText().length()) {
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -9,6 +9,7 @@ import com.tencent.supersonic.chat.server.plugin.PluginRecallResult;
|
|||||||
import com.tencent.supersonic.chat.server.pojo.ChatParseContext;
|
import com.tencent.supersonic.chat.server.pojo.ChatParseContext;
|
||||||
import com.tencent.supersonic.common.pojo.Constants;
|
import com.tencent.supersonic.common.pojo.Constants;
|
||||||
import com.tencent.supersonic.common.pojo.enums.FilterOperatorEnum;
|
import com.tencent.supersonic.common.pojo.enums.FilterOperatorEnum;
|
||||||
|
import com.tencent.supersonic.headless.api.pojo.SchemaElement;
|
||||||
import com.tencent.supersonic.headless.api.pojo.SchemaElementMatch;
|
import com.tencent.supersonic.headless.api.pojo.SchemaElementMatch;
|
||||||
import com.tencent.supersonic.headless.api.pojo.SchemaElementType;
|
import com.tencent.supersonic.headless.api.pojo.SchemaElementType;
|
||||||
import com.tencent.supersonic.headless.api.pojo.SemanticParseInfo;
|
import com.tencent.supersonic.headless.api.pojo.SemanticParseInfo;
|
||||||
@@ -71,6 +72,9 @@ public abstract class PluginRecognizer {
|
|||||||
}
|
}
|
||||||
SemanticParseInfo semanticParseInfo = new SemanticParseInfo();
|
SemanticParseInfo semanticParseInfo = new SemanticParseInfo();
|
||||||
semanticParseInfo.setElementMatches(schemaElementMatches);
|
semanticParseInfo.setElementMatches(schemaElementMatches);
|
||||||
|
SchemaElement schemaElement = new SchemaElement();
|
||||||
|
schemaElement.setDataSet(dataSetId);
|
||||||
|
semanticParseInfo.setDataSet(schemaElement);
|
||||||
Map<String, Object> properties = new HashMap<>();
|
Map<String, Object> properties = new HashMap<>();
|
||||||
PluginParseResult pluginParseResult = new PluginParseResult();
|
PluginParseResult pluginParseResult = new PluginParseResult();
|
||||||
pluginParseResult.setPlugin(plugin);
|
pluginParseResult.setPlugin(plugin);
|
||||||
|
|||||||
@@ -49,11 +49,11 @@ public class QueryFilterMapper extends BaseMapper {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
private List<SchemaElementMatch> addValueSchemaElementMatch(Long viewId, QueryContext queryContext,
|
private void addValueSchemaElementMatch(Long dataSetId, QueryContext queryContext,
|
||||||
List<SchemaElementMatch> candidateElementMatches) {
|
List<SchemaElementMatch> candidateElementMatches) {
|
||||||
QueryFilters queryFilters = queryContext.getQueryFilters();
|
QueryFilters queryFilters = queryContext.getQueryFilters();
|
||||||
if (queryFilters == null || CollectionUtils.isEmpty(queryFilters.getFilters())) {
|
if (queryFilters == null || CollectionUtils.isEmpty(queryFilters.getFilters())) {
|
||||||
return candidateElementMatches;
|
return;
|
||||||
}
|
}
|
||||||
for (QueryFilter filter : queryFilters.getFilters()) {
|
for (QueryFilter filter : queryFilters.getFilters()) {
|
||||||
if (checkExistSameValueSchemaElementMatch(filter, candidateElementMatches)) {
|
if (checkExistSameValueSchemaElementMatch(filter, candidateElementMatches)) {
|
||||||
@@ -64,7 +64,7 @@ public class QueryFilterMapper extends BaseMapper {
|
|||||||
.name(String.valueOf(filter.getValue()))
|
.name(String.valueOf(filter.getValue()))
|
||||||
.type(SchemaElementType.VALUE)
|
.type(SchemaElementType.VALUE)
|
||||||
.bizName(filter.getBizName())
|
.bizName(filter.getBizName())
|
||||||
.dataSet(viewId)
|
.dataSet(dataSetId)
|
||||||
.build();
|
.build();
|
||||||
SchemaElementMatch schemaElementMatch = SchemaElementMatch.builder()
|
SchemaElementMatch schemaElementMatch = SchemaElementMatch.builder()
|
||||||
.element(element)
|
.element(element)
|
||||||
@@ -75,7 +75,7 @@ public class QueryFilterMapper extends BaseMapper {
|
|||||||
.build();
|
.build();
|
||||||
candidateElementMatches.add(schemaElementMatch);
|
candidateElementMatches.add(schemaElementMatch);
|
||||||
}
|
}
|
||||||
return candidateElementMatches;
|
queryContext.getMapInfo().setMatchedElements(dataSetId, candidateElementMatches);
|
||||||
}
|
}
|
||||||
|
|
||||||
private boolean checkExistSameValueSchemaElementMatch(QueryFilter queryFilter,
|
private boolean checkExistSameValueSchemaElementMatch(QueryFilter queryFilter,
|
||||||
|
|||||||
@@ -1,6 +1,6 @@
|
|||||||
package com.tencent.supersonic.headless.server.service.impl;
|
package com.tencent.supersonic.headless.server.service.impl;
|
||||||
|
|
||||||
import com.alibaba.fastjson.JSONObject;
|
import com.alibaba.fastjson2.JSONObject;
|
||||||
import com.baomidou.mybatisplus.core.conditions.query.QueryWrapper;
|
import com.baomidou.mybatisplus.core.conditions.query.QueryWrapper;
|
||||||
import com.baomidou.mybatisplus.extension.service.impl.ServiceImpl;
|
import com.baomidou.mybatisplus.extension.service.impl.ServiceImpl;
|
||||||
import com.google.common.cache.Cache;
|
import com.google.common.cache.Cache;
|
||||||
@@ -35,6 +35,13 @@ import com.tencent.supersonic.headless.server.service.DimensionService;
|
|||||||
import com.tencent.supersonic.headless.server.service.DomainService;
|
import com.tencent.supersonic.headless.server.service.DomainService;
|
||||||
import com.tencent.supersonic.headless.server.service.MetricService;
|
import com.tencent.supersonic.headless.server.service.MetricService;
|
||||||
import com.tencent.supersonic.headless.server.service.TagMetaService;
|
import com.tencent.supersonic.headless.server.service.TagMetaService;
|
||||||
|
import org.apache.commons.lang3.StringUtils;
|
||||||
|
import org.apache.commons.lang3.tuple.Pair;
|
||||||
|
import org.springframework.beans.BeanUtils;
|
||||||
|
import org.springframework.beans.factory.annotation.Autowired;
|
||||||
|
import org.springframework.context.annotation.Lazy;
|
||||||
|
import org.springframework.stereotype.Service;
|
||||||
|
import org.springframework.util.CollectionUtils;
|
||||||
import java.util.Arrays;
|
import java.util.Arrays;
|
||||||
import java.util.Comparator;
|
import java.util.Comparator;
|
||||||
import java.util.Date;
|
import java.util.Date;
|
||||||
@@ -45,13 +52,6 @@ import java.util.Set;
|
|||||||
import java.util.concurrent.TimeUnit;
|
import java.util.concurrent.TimeUnit;
|
||||||
import java.util.function.Function;
|
import java.util.function.Function;
|
||||||
import java.util.stream.Collectors;
|
import java.util.stream.Collectors;
|
||||||
import org.apache.commons.lang3.StringUtils;
|
|
||||||
import org.apache.commons.lang3.tuple.Pair;
|
|
||||||
import org.springframework.beans.BeanUtils;
|
|
||||||
import org.springframework.beans.factory.annotation.Autowired;
|
|
||||||
import org.springframework.context.annotation.Lazy;
|
|
||||||
import org.springframework.stereotype.Service;
|
|
||||||
import org.springframework.util.CollectionUtils;
|
|
||||||
|
|
||||||
@Service
|
@Service
|
||||||
public class DataSetServiceImpl
|
public class DataSetServiceImpl
|
||||||
|
|||||||
@@ -59,8 +59,8 @@ s2:
|
|||||||
langchain4j:
|
langchain4j:
|
||||||
#1.chat-model
|
#1.chat-model
|
||||||
chat-model:
|
chat-model:
|
||||||
provider: local_ai
|
provider: open_ai
|
||||||
local_ai:
|
openai:
|
||||||
api-key: api_key
|
api-key: api_key
|
||||||
model-name: gpt-3.5-turbo-16k
|
model-name: gpt-3.5-turbo-16k
|
||||||
temperature: 0.0
|
temperature: 0.0
|
||||||
|
|||||||
Reference in New Issue
Block a user