diff --git a/chat/server/src/main/java/com/tencent/supersonic/chat/server/util/QueryReqConverter.java b/chat/server/src/main/java/com/tencent/supersonic/chat/server/util/QueryReqConverter.java index ae5107a4f..60c506aba 100644 --- a/chat/server/src/main/java/com/tencent/supersonic/chat/server/util/QueryReqConverter.java +++ b/chat/server/src/main/java/com/tencent/supersonic/chat/server/util/QueryReqConverter.java @@ -4,6 +4,9 @@ import com.tencent.supersonic.chat.server.agent.Agent; import com.tencent.supersonic.chat.server.pojo.ChatParseContext; import com.tencent.supersonic.common.util.BeanMapper; import com.tencent.supersonic.headless.api.pojo.request.QueryReq; +import org.apache.commons.collections.MapUtils; + +import java.util.Objects; public class QueryReqConverter { @@ -18,6 +21,10 @@ public class QueryReqConverter { queryReq.setEnableLLM(true); } queryReq.setDataSetIds(agent.getDataSetIds()); + if (Objects.nonNull(queryReq.getMapInfo()) + && MapUtils.isNotEmpty(queryReq.getMapInfo().getDataSetElementMatches())) { + queryReq.setMapInfo(queryReq.getMapInfo()); + } return queryReq; } diff --git a/headless/api/src/main/java/com/tencent/supersonic/headless/api/pojo/request/QueryReq.java b/headless/api/src/main/java/com/tencent/supersonic/headless/api/pojo/request/QueryReq.java index fa4831305..15c39df12 100644 --- a/headless/api/src/main/java/com/tencent/supersonic/headless/api/pojo/request/QueryReq.java +++ b/headless/api/src/main/java/com/tencent/supersonic/headless/api/pojo/request/QueryReq.java @@ -2,6 +2,7 @@ package com.tencent.supersonic.headless.api.pojo.request; import com.google.common.collect.Sets; import com.tencent.supersonic.auth.api.authentication.pojo.User; +import com.tencent.supersonic.headless.api.pojo.SchemaMapInfo; import lombok.Data; import java.util.Set; @@ -15,4 +16,5 @@ public class QueryReq { private QueryFilters queryFilters; private boolean saveAnswer = true; private boolean enableLLM; + private SchemaMapInfo mapInfo = new SchemaMapInfo(); } diff --git a/headless/core/src/main/java/com/tencent/supersonic/headless/core/config/OptimizationConfig.java b/headless/core/src/main/java/com/tencent/supersonic/headless/core/config/OptimizationConfig.java index f1c8f3e83..2a78fbc4a 100644 --- a/headless/core/src/main/java/com/tencent/supersonic/headless/core/config/OptimizationConfig.java +++ b/headless/core/src/main/java/com/tencent/supersonic/headless/core/config/OptimizationConfig.java @@ -40,7 +40,7 @@ public class OptimizationConfig { @Value("${embedding.mapper.word.min:4}") private int embeddingMapperWordMin; - @Value("${embedding.mapper.word.max:5}") + @Value("${embedding.mapper.word.max:4}") private int embeddingMapperWordMax; @Value("${embedding.mapper.batch:50}") diff --git a/headless/server/src/main/java/com/tencent/supersonic/headless/server/service/impl/ChatQueryServiceImpl.java b/headless/server/src/main/java/com/tencent/supersonic/headless/server/service/impl/ChatQueryServiceImpl.java index 8d6eadbcf..a461bfcf6 100644 --- a/headless/server/src/main/java/com/tencent/supersonic/headless/server/service/impl/ChatQueryServiceImpl.java +++ b/headless/server/src/main/java/com/tencent/supersonic/headless/server/service/impl/ChatQueryServiceImpl.java @@ -72,6 +72,7 @@ import net.sf.jsqlparser.expression.operators.relational.MinorThan; import net.sf.jsqlparser.expression.operators.relational.MinorThanEquals; import net.sf.jsqlparser.schema.Column; import org.apache.commons.collections.CollectionUtils; +import org.apache.commons.collections.MapUtils; import org.apache.commons.lang3.StringUtils; import org.apache.commons.lang3.tuple.Pair; import org.springframework.beans.BeanUtils; @@ -129,9 +130,12 @@ public class ChatQueryServiceImpl implements ChatQueryService { ChatContext chatCtx = chatContextService.getOrCreateContext(queryReq.getChatId()); // 1. mapper - schemaMappers.forEach(mapper -> { - mapper.map(queryCtx); - }); + if (Objects.isNull(queryReq.getMapInfo()) + || MapUtils.isEmpty(queryReq.getMapInfo().getDataSetElementMatches())) { + schemaMappers.forEach(mapper -> { + mapper.map(queryCtx); + }); + } // 2. parser semanticParsers.forEach(parser -> { @@ -209,7 +213,7 @@ public class ChatQueryServiceImpl implements ChatQueryService { } private QueryResult doExecution(SemanticQueryReq semanticQueryReq, - SemanticParseInfo parseInfo, User user) throws Exception { + SemanticParseInfo parseInfo, User user) throws Exception { SemanticQueryResp queryResp = queryService.queryByReq(semanticQueryReq, user); QueryResult queryResult = new QueryResult(); if (queryResp != null) { @@ -357,10 +361,10 @@ public class ChatQueryServiceImpl implements ChatQueryService { } private void updateDateInfo(QueryDataReq queryData, SemanticParseInfo parseInfo, - Map> filedNameToValueMap, - List fieldExpressionList, - List addConditions, - Set removeFieldNames) { + Map> filedNameToValueMap, + List fieldExpressionList, + List addConditions, + Set removeFieldNames) { if (Objects.isNull(queryData.getDateInfo())) { return; } @@ -424,8 +428,8 @@ public class ChatQueryServiceImpl implements ChatQueryService { } private void addTimeFilters(String date, - T comparisonExpression, - List addConditions) { + T comparisonExpression, + List addConditions) { Column column = new Column(TimeDimensionEnum.DAY.getChName()); StringValue stringValue = new StringValue(date); comparisonExpression.setLeftExpression(column); @@ -434,10 +438,10 @@ public class ChatQueryServiceImpl implements ChatQueryService { } private void updateFilters(List fieldExpressionList, - Set metricFilters, - Set contextMetricFilters, - List addConditions, - Set removeFieldNames) { + Set metricFilters, + Set contextMetricFilters, + List addConditions, + Set removeFieldNames) { if (CollectionUtils.isEmpty(metricFilters)) { return; } @@ -473,9 +477,9 @@ public class ChatQueryServiceImpl implements ChatQueryService { // add in condition to sql where condition private void addWhereInFilters(QueryFilter dslQueryFilter, - InExpression inExpression, - Set contextMetricFilters, - List addConditions) { + InExpression inExpression, + Set contextMetricFilters, + List addConditions) { Column column = new Column(dslQueryFilter.getName()); ExpressionList expressionList = new ExpressionList(); List expressions = new ArrayList<>(); @@ -502,9 +506,9 @@ public class ChatQueryServiceImpl implements ChatQueryService { // add where filter private void addWhereFilters(QueryFilter dslQueryFilter, - T comparisonExpression, - Set contextMetricFilters, - List addConditions) { + T comparisonExpression, + Set contextMetricFilters, + List addConditions) { String columnName = dslQueryFilter.getName(); if (StringUtils.isNotBlank(dslQueryFilter.getFunction())) { columnName = dslQueryFilter.getFunction() + "(" + dslQueryFilter.getName() + ")";