diff --git a/chat/server/src/main/java/com/tencent/supersonic/chat/server/parser/NL2SQLParser.java b/chat/server/src/main/java/com/tencent/supersonic/chat/server/parser/NL2SQLParser.java index 726b9039b..c61d88e9d 100644 --- a/chat/server/src/main/java/com/tencent/supersonic/chat/server/parser/NL2SQLParser.java +++ b/chat/server/src/main/java/com/tencent/supersonic/chat/server/parser/NL2SQLParser.java @@ -3,9 +3,11 @@ package com.tencent.supersonic.chat.server.parser; import com.tencent.supersonic.chat.server.pojo.ChatParseContext; import com.tencent.supersonic.chat.server.util.QueryReqConverter; import com.tencent.supersonic.common.util.ContextUtils; +import com.tencent.supersonic.headless.api.pojo.SemanticParseInfo; import com.tencent.supersonic.headless.api.pojo.request.QueryReq; import com.tencent.supersonic.headless.api.pojo.response.ParseResp; import com.tencent.supersonic.headless.server.service.ChatQueryService; +import java.util.List; public class NL2SQLParser implements ChatParser { @@ -14,6 +16,9 @@ public class NL2SQLParser implements ChatParser { if (!chatParseContext.enableNL2SQL()) { return; } + if (checkSkip(parseResp)) { + return; + } QueryReq queryReq = QueryReqConverter.buildText2SqlQueryReq(chatParseContext); ChatQueryService chatQueryService = ContextUtils.getBean(ChatQueryService.class); ParseResp text2SqlParseResp = chatQueryService.performParsing(queryReq); @@ -22,4 +27,14 @@ public class NL2SQLParser implements ChatParser { } } + private boolean checkSkip(ParseResp parseResp) { + List selectedParses = parseResp.getSelectedParses(); + for (SemanticParseInfo semanticParseInfo : selectedParses) { + if (semanticParseInfo.getScore() >= parseResp.getQueryText().length()) { + return true; + } + } + return false; + } + } diff --git a/chat/server/src/main/java/com/tencent/supersonic/chat/server/plugin/recognize/PluginRecognizer.java b/chat/server/src/main/java/com/tencent/supersonic/chat/server/plugin/recognize/PluginRecognizer.java index 95f5446e8..e0ffeba1d 100644 --- a/chat/server/src/main/java/com/tencent/supersonic/chat/server/plugin/recognize/PluginRecognizer.java +++ b/chat/server/src/main/java/com/tencent/supersonic/chat/server/plugin/recognize/PluginRecognizer.java @@ -9,6 +9,7 @@ import com.tencent.supersonic.chat.server.plugin.PluginRecallResult; import com.tencent.supersonic.chat.server.pojo.ChatParseContext; import com.tencent.supersonic.common.pojo.Constants; import com.tencent.supersonic.common.pojo.enums.FilterOperatorEnum; +import com.tencent.supersonic.headless.api.pojo.SchemaElement; import com.tencent.supersonic.headless.api.pojo.SchemaElementMatch; import com.tencent.supersonic.headless.api.pojo.SchemaElementType; import com.tencent.supersonic.headless.api.pojo.SemanticParseInfo; @@ -71,6 +72,9 @@ public abstract class PluginRecognizer { } SemanticParseInfo semanticParseInfo = new SemanticParseInfo(); semanticParseInfo.setElementMatches(schemaElementMatches); + SchemaElement schemaElement = new SchemaElement(); + schemaElement.setDataSet(dataSetId); + semanticParseInfo.setDataSet(schemaElement); Map properties = new HashMap<>(); PluginParseResult pluginParseResult = new PluginParseResult(); pluginParseResult.setPlugin(plugin); diff --git a/headless/core/src/main/java/com/tencent/supersonic/headless/core/chat/mapper/QueryFilterMapper.java b/headless/core/src/main/java/com/tencent/supersonic/headless/core/chat/mapper/QueryFilterMapper.java index f048edd23..913877070 100644 --- a/headless/core/src/main/java/com/tencent/supersonic/headless/core/chat/mapper/QueryFilterMapper.java +++ b/headless/core/src/main/java/com/tencent/supersonic/headless/core/chat/mapper/QueryFilterMapper.java @@ -49,11 +49,11 @@ public class QueryFilterMapper extends BaseMapper { } } - private List addValueSchemaElementMatch(Long viewId, QueryContext queryContext, + private void addValueSchemaElementMatch(Long dataSetId, QueryContext queryContext, List candidateElementMatches) { QueryFilters queryFilters = queryContext.getQueryFilters(); if (queryFilters == null || CollectionUtils.isEmpty(queryFilters.getFilters())) { - return candidateElementMatches; + return; } for (QueryFilter filter : queryFilters.getFilters()) { if (checkExistSameValueSchemaElementMatch(filter, candidateElementMatches)) { @@ -64,7 +64,7 @@ public class QueryFilterMapper extends BaseMapper { .name(String.valueOf(filter.getValue())) .type(SchemaElementType.VALUE) .bizName(filter.getBizName()) - .dataSet(viewId) + .dataSet(dataSetId) .build(); SchemaElementMatch schemaElementMatch = SchemaElementMatch.builder() .element(element) @@ -75,7 +75,7 @@ public class QueryFilterMapper extends BaseMapper { .build(); candidateElementMatches.add(schemaElementMatch); } - return candidateElementMatches; + queryContext.getMapInfo().setMatchedElements(dataSetId, candidateElementMatches); } private boolean checkExistSameValueSchemaElementMatch(QueryFilter queryFilter, diff --git a/headless/server/src/main/java/com/tencent/supersonic/headless/server/service/impl/DataSetServiceImpl.java b/headless/server/src/main/java/com/tencent/supersonic/headless/server/service/impl/DataSetServiceImpl.java index 57e33205c..bae880d46 100644 --- a/headless/server/src/main/java/com/tencent/supersonic/headless/server/service/impl/DataSetServiceImpl.java +++ b/headless/server/src/main/java/com/tencent/supersonic/headless/server/service/impl/DataSetServiceImpl.java @@ -1,6 +1,6 @@ package com.tencent.supersonic.headless.server.service.impl; -import com.alibaba.fastjson.JSONObject; +import com.alibaba.fastjson2.JSONObject; import com.baomidou.mybatisplus.core.conditions.query.QueryWrapper; import com.baomidou.mybatisplus.extension.service.impl.ServiceImpl; import com.google.common.cache.Cache; @@ -35,6 +35,13 @@ import com.tencent.supersonic.headless.server.service.DimensionService; import com.tencent.supersonic.headless.server.service.DomainService; import com.tencent.supersonic.headless.server.service.MetricService; import com.tencent.supersonic.headless.server.service.TagMetaService; +import org.apache.commons.lang3.StringUtils; +import org.apache.commons.lang3.tuple.Pair; +import org.springframework.beans.BeanUtils; +import org.springframework.beans.factory.annotation.Autowired; +import org.springframework.context.annotation.Lazy; +import org.springframework.stereotype.Service; +import org.springframework.util.CollectionUtils; import java.util.Arrays; import java.util.Comparator; import java.util.Date; @@ -45,13 +52,6 @@ import java.util.Set; import java.util.concurrent.TimeUnit; import java.util.function.Function; import java.util.stream.Collectors; -import org.apache.commons.lang3.StringUtils; -import org.apache.commons.lang3.tuple.Pair; -import org.springframework.beans.BeanUtils; -import org.springframework.beans.factory.annotation.Autowired; -import org.springframework.context.annotation.Lazy; -import org.springframework.stereotype.Service; -import org.springframework.util.CollectionUtils; @Service public class DataSetServiceImpl diff --git a/launchers/standalone/src/main/resources/application-local.yaml b/launchers/standalone/src/main/resources/application-local.yaml index 5b5337e1b..cbe243794 100644 --- a/launchers/standalone/src/main/resources/application-local.yaml +++ b/launchers/standalone/src/main/resources/application-local.yaml @@ -59,8 +59,8 @@ s2: langchain4j: #1.chat-model chat-model: - provider: local_ai - local_ai: + provider: open_ai + openai: api-key: api_key model-name: gpt-3.5-turbo-16k temperature: 0.0