(improvement)(headless) Reduce the two calls to the mapper in the parser stage. (#888)

This commit is contained in:
lexluo09
2024-04-06 11:49:44 +08:00
committed by GitHub
parent 0577090b39
commit faeb5bbeac
4 changed files with 34 additions and 21 deletions

View File

@@ -4,6 +4,9 @@ import com.tencent.supersonic.chat.server.agent.Agent;
import com.tencent.supersonic.chat.server.pojo.ChatParseContext; import com.tencent.supersonic.chat.server.pojo.ChatParseContext;
import com.tencent.supersonic.common.util.BeanMapper; import com.tencent.supersonic.common.util.BeanMapper;
import com.tencent.supersonic.headless.api.pojo.request.QueryReq; import com.tencent.supersonic.headless.api.pojo.request.QueryReq;
import org.apache.commons.collections.MapUtils;
import java.util.Objects;
public class QueryReqConverter { public class QueryReqConverter {
@@ -18,6 +21,10 @@ public class QueryReqConverter {
queryReq.setEnableLLM(true); queryReq.setEnableLLM(true);
} }
queryReq.setDataSetIds(agent.getDataSetIds()); queryReq.setDataSetIds(agent.getDataSetIds());
if (Objects.nonNull(queryReq.getMapInfo())
&& MapUtils.isNotEmpty(queryReq.getMapInfo().getDataSetElementMatches())) {
queryReq.setMapInfo(queryReq.getMapInfo());
}
return queryReq; return queryReq;
} }

View File

@@ -2,6 +2,7 @@ package com.tencent.supersonic.headless.api.pojo.request;
import com.google.common.collect.Sets; import com.google.common.collect.Sets;
import com.tencent.supersonic.auth.api.authentication.pojo.User; import com.tencent.supersonic.auth.api.authentication.pojo.User;
import com.tencent.supersonic.headless.api.pojo.SchemaMapInfo;
import lombok.Data; import lombok.Data;
import java.util.Set; import java.util.Set;
@@ -15,4 +16,5 @@ public class QueryReq {
private QueryFilters queryFilters; private QueryFilters queryFilters;
private boolean saveAnswer = true; private boolean saveAnswer = true;
private boolean enableLLM; private boolean enableLLM;
private SchemaMapInfo mapInfo = new SchemaMapInfo();
} }

View File

@@ -40,7 +40,7 @@ public class OptimizationConfig {
@Value("${embedding.mapper.word.min:4}") @Value("${embedding.mapper.word.min:4}")
private int embeddingMapperWordMin; private int embeddingMapperWordMin;
@Value("${embedding.mapper.word.max:5}") @Value("${embedding.mapper.word.max:4}")
private int embeddingMapperWordMax; private int embeddingMapperWordMax;
@Value("${embedding.mapper.batch:50}") @Value("${embedding.mapper.batch:50}")

View File

@@ -72,6 +72,7 @@ import net.sf.jsqlparser.expression.operators.relational.MinorThan;
import net.sf.jsqlparser.expression.operators.relational.MinorThanEquals; import net.sf.jsqlparser.expression.operators.relational.MinorThanEquals;
import net.sf.jsqlparser.schema.Column; import net.sf.jsqlparser.schema.Column;
import org.apache.commons.collections.CollectionUtils; import org.apache.commons.collections.CollectionUtils;
import org.apache.commons.collections.MapUtils;
import org.apache.commons.lang3.StringUtils; import org.apache.commons.lang3.StringUtils;
import org.apache.commons.lang3.tuple.Pair; import org.apache.commons.lang3.tuple.Pair;
import org.springframework.beans.BeanUtils; import org.springframework.beans.BeanUtils;
@@ -129,9 +130,12 @@ public class ChatQueryServiceImpl implements ChatQueryService {
ChatContext chatCtx = chatContextService.getOrCreateContext(queryReq.getChatId()); ChatContext chatCtx = chatContextService.getOrCreateContext(queryReq.getChatId());
// 1. mapper // 1. mapper
schemaMappers.forEach(mapper -> { if (Objects.isNull(queryReq.getMapInfo())
mapper.map(queryCtx); || MapUtils.isEmpty(queryReq.getMapInfo().getDataSetElementMatches())) {
}); schemaMappers.forEach(mapper -> {
mapper.map(queryCtx);
});
}
// 2. parser // 2. parser
semanticParsers.forEach(parser -> { semanticParsers.forEach(parser -> {
@@ -209,7 +213,7 @@ public class ChatQueryServiceImpl implements ChatQueryService {
} }
private QueryResult doExecution(SemanticQueryReq semanticQueryReq, private QueryResult doExecution(SemanticQueryReq semanticQueryReq,
SemanticParseInfo parseInfo, User user) throws Exception { SemanticParseInfo parseInfo, User user) throws Exception {
SemanticQueryResp queryResp = queryService.queryByReq(semanticQueryReq, user); SemanticQueryResp queryResp = queryService.queryByReq(semanticQueryReq, user);
QueryResult queryResult = new QueryResult(); QueryResult queryResult = new QueryResult();
if (queryResp != null) { if (queryResp != null) {
@@ -357,10 +361,10 @@ public class ChatQueryServiceImpl implements ChatQueryService {
} }
private void updateDateInfo(QueryDataReq queryData, SemanticParseInfo parseInfo, private void updateDateInfo(QueryDataReq queryData, SemanticParseInfo parseInfo,
Map<String, Map<String, String>> filedNameToValueMap, Map<String, Map<String, String>> filedNameToValueMap,
List<FieldExpression> fieldExpressionList, List<FieldExpression> fieldExpressionList,
List<Expression> addConditions, List<Expression> addConditions,
Set<String> removeFieldNames) { Set<String> removeFieldNames) {
if (Objects.isNull(queryData.getDateInfo())) { if (Objects.isNull(queryData.getDateInfo())) {
return; return;
} }
@@ -424,8 +428,8 @@ public class ChatQueryServiceImpl implements ChatQueryService {
} }
private <T extends ComparisonOperator> void addTimeFilters(String date, private <T extends ComparisonOperator> void addTimeFilters(String date,
T comparisonExpression, T comparisonExpression,
List<Expression> addConditions) { List<Expression> addConditions) {
Column column = new Column(TimeDimensionEnum.DAY.getChName()); Column column = new Column(TimeDimensionEnum.DAY.getChName());
StringValue stringValue = new StringValue(date); StringValue stringValue = new StringValue(date);
comparisonExpression.setLeftExpression(column); comparisonExpression.setLeftExpression(column);
@@ -434,10 +438,10 @@ public class ChatQueryServiceImpl implements ChatQueryService {
} }
private void updateFilters(List<FieldExpression> fieldExpressionList, private void updateFilters(List<FieldExpression> fieldExpressionList,
Set<QueryFilter> metricFilters, Set<QueryFilter> metricFilters,
Set<QueryFilter> contextMetricFilters, Set<QueryFilter> contextMetricFilters,
List<Expression> addConditions, List<Expression> addConditions,
Set<String> removeFieldNames) { Set<String> removeFieldNames) {
if (CollectionUtils.isEmpty(metricFilters)) { if (CollectionUtils.isEmpty(metricFilters)) {
return; return;
} }
@@ -473,9 +477,9 @@ public class ChatQueryServiceImpl implements ChatQueryService {
// add in condition to sql where condition // add in condition to sql where condition
private void addWhereInFilters(QueryFilter dslQueryFilter, private void addWhereInFilters(QueryFilter dslQueryFilter,
InExpression inExpression, InExpression inExpression,
Set<QueryFilter> contextMetricFilters, Set<QueryFilter> contextMetricFilters,
List<Expression> addConditions) { List<Expression> addConditions) {
Column column = new Column(dslQueryFilter.getName()); Column column = new Column(dslQueryFilter.getName());
ExpressionList expressionList = new ExpressionList(); ExpressionList expressionList = new ExpressionList();
List<Expression> expressions = new ArrayList<>(); List<Expression> expressions = new ArrayList<>();
@@ -502,9 +506,9 @@ public class ChatQueryServiceImpl implements ChatQueryService {
// add where filter // add where filter
private <T extends ComparisonOperator> void addWhereFilters(QueryFilter dslQueryFilter, private <T extends ComparisonOperator> void addWhereFilters(QueryFilter dslQueryFilter,
T comparisonExpression, T comparisonExpression,
Set<QueryFilter> contextMetricFilters, Set<QueryFilter> contextMetricFilters,
List<Expression> addConditions) { List<Expression> addConditions) {
String columnName = dslQueryFilter.getName(); String columnName = dslQueryFilter.getName();
if (StringUtils.isNotBlank(dslQueryFilter.getFunction())) { if (StringUtils.isNotBlank(dslQueryFilter.getFunction())) {
columnName = dslQueryFilter.getFunction() + "(" + dslQueryFilter.getName() + ")"; columnName = dslQueryFilter.getFunction() + "(" + dslQueryFilter.getName() + ")";