mirror of
https://github.com/tencentmusic/supersonic.git
synced 2025-12-12 20:51:48 +00:00
(improvement)(headless) Reduce the two calls to the mapper in the parser stage. (#888)
This commit is contained in:
@@ -4,6 +4,9 @@ import com.tencent.supersonic.chat.server.agent.Agent;
|
|||||||
import com.tencent.supersonic.chat.server.pojo.ChatParseContext;
|
import com.tencent.supersonic.chat.server.pojo.ChatParseContext;
|
||||||
import com.tencent.supersonic.common.util.BeanMapper;
|
import com.tencent.supersonic.common.util.BeanMapper;
|
||||||
import com.tencent.supersonic.headless.api.pojo.request.QueryReq;
|
import com.tencent.supersonic.headless.api.pojo.request.QueryReq;
|
||||||
|
import org.apache.commons.collections.MapUtils;
|
||||||
|
|
||||||
|
import java.util.Objects;
|
||||||
|
|
||||||
public class QueryReqConverter {
|
public class QueryReqConverter {
|
||||||
|
|
||||||
@@ -18,6 +21,10 @@ public class QueryReqConverter {
|
|||||||
queryReq.setEnableLLM(true);
|
queryReq.setEnableLLM(true);
|
||||||
}
|
}
|
||||||
queryReq.setDataSetIds(agent.getDataSetIds());
|
queryReq.setDataSetIds(agent.getDataSetIds());
|
||||||
|
if (Objects.nonNull(queryReq.getMapInfo())
|
||||||
|
&& MapUtils.isNotEmpty(queryReq.getMapInfo().getDataSetElementMatches())) {
|
||||||
|
queryReq.setMapInfo(queryReq.getMapInfo());
|
||||||
|
}
|
||||||
return queryReq;
|
return queryReq;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -2,6 +2,7 @@ package com.tencent.supersonic.headless.api.pojo.request;
|
|||||||
|
|
||||||
import com.google.common.collect.Sets;
|
import com.google.common.collect.Sets;
|
||||||
import com.tencent.supersonic.auth.api.authentication.pojo.User;
|
import com.tencent.supersonic.auth.api.authentication.pojo.User;
|
||||||
|
import com.tencent.supersonic.headless.api.pojo.SchemaMapInfo;
|
||||||
import lombok.Data;
|
import lombok.Data;
|
||||||
|
|
||||||
import java.util.Set;
|
import java.util.Set;
|
||||||
@@ -15,4 +16,5 @@ public class QueryReq {
|
|||||||
private QueryFilters queryFilters;
|
private QueryFilters queryFilters;
|
||||||
private boolean saveAnswer = true;
|
private boolean saveAnswer = true;
|
||||||
private boolean enableLLM;
|
private boolean enableLLM;
|
||||||
|
private SchemaMapInfo mapInfo = new SchemaMapInfo();
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -40,7 +40,7 @@ public class OptimizationConfig {
|
|||||||
@Value("${embedding.mapper.word.min:4}")
|
@Value("${embedding.mapper.word.min:4}")
|
||||||
private int embeddingMapperWordMin;
|
private int embeddingMapperWordMin;
|
||||||
|
|
||||||
@Value("${embedding.mapper.word.max:5}")
|
@Value("${embedding.mapper.word.max:4}")
|
||||||
private int embeddingMapperWordMax;
|
private int embeddingMapperWordMax;
|
||||||
|
|
||||||
@Value("${embedding.mapper.batch:50}")
|
@Value("${embedding.mapper.batch:50}")
|
||||||
|
|||||||
@@ -72,6 +72,7 @@ import net.sf.jsqlparser.expression.operators.relational.MinorThan;
|
|||||||
import net.sf.jsqlparser.expression.operators.relational.MinorThanEquals;
|
import net.sf.jsqlparser.expression.operators.relational.MinorThanEquals;
|
||||||
import net.sf.jsqlparser.schema.Column;
|
import net.sf.jsqlparser.schema.Column;
|
||||||
import org.apache.commons.collections.CollectionUtils;
|
import org.apache.commons.collections.CollectionUtils;
|
||||||
|
import org.apache.commons.collections.MapUtils;
|
||||||
import org.apache.commons.lang3.StringUtils;
|
import org.apache.commons.lang3.StringUtils;
|
||||||
import org.apache.commons.lang3.tuple.Pair;
|
import org.apache.commons.lang3.tuple.Pair;
|
||||||
import org.springframework.beans.BeanUtils;
|
import org.springframework.beans.BeanUtils;
|
||||||
@@ -129,9 +130,12 @@ public class ChatQueryServiceImpl implements ChatQueryService {
|
|||||||
ChatContext chatCtx = chatContextService.getOrCreateContext(queryReq.getChatId());
|
ChatContext chatCtx = chatContextService.getOrCreateContext(queryReq.getChatId());
|
||||||
|
|
||||||
// 1. mapper
|
// 1. mapper
|
||||||
schemaMappers.forEach(mapper -> {
|
if (Objects.isNull(queryReq.getMapInfo())
|
||||||
mapper.map(queryCtx);
|
|| MapUtils.isEmpty(queryReq.getMapInfo().getDataSetElementMatches())) {
|
||||||
});
|
schemaMappers.forEach(mapper -> {
|
||||||
|
mapper.map(queryCtx);
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
// 2. parser
|
// 2. parser
|
||||||
semanticParsers.forEach(parser -> {
|
semanticParsers.forEach(parser -> {
|
||||||
@@ -209,7 +213,7 @@ public class ChatQueryServiceImpl implements ChatQueryService {
|
|||||||
}
|
}
|
||||||
|
|
||||||
private QueryResult doExecution(SemanticQueryReq semanticQueryReq,
|
private QueryResult doExecution(SemanticQueryReq semanticQueryReq,
|
||||||
SemanticParseInfo parseInfo, User user) throws Exception {
|
SemanticParseInfo parseInfo, User user) throws Exception {
|
||||||
SemanticQueryResp queryResp = queryService.queryByReq(semanticQueryReq, user);
|
SemanticQueryResp queryResp = queryService.queryByReq(semanticQueryReq, user);
|
||||||
QueryResult queryResult = new QueryResult();
|
QueryResult queryResult = new QueryResult();
|
||||||
if (queryResp != null) {
|
if (queryResp != null) {
|
||||||
@@ -357,10 +361,10 @@ public class ChatQueryServiceImpl implements ChatQueryService {
|
|||||||
}
|
}
|
||||||
|
|
||||||
private void updateDateInfo(QueryDataReq queryData, SemanticParseInfo parseInfo,
|
private void updateDateInfo(QueryDataReq queryData, SemanticParseInfo parseInfo,
|
||||||
Map<String, Map<String, String>> filedNameToValueMap,
|
Map<String, Map<String, String>> filedNameToValueMap,
|
||||||
List<FieldExpression> fieldExpressionList,
|
List<FieldExpression> fieldExpressionList,
|
||||||
List<Expression> addConditions,
|
List<Expression> addConditions,
|
||||||
Set<String> removeFieldNames) {
|
Set<String> removeFieldNames) {
|
||||||
if (Objects.isNull(queryData.getDateInfo())) {
|
if (Objects.isNull(queryData.getDateInfo())) {
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
@@ -424,8 +428,8 @@ public class ChatQueryServiceImpl implements ChatQueryService {
|
|||||||
}
|
}
|
||||||
|
|
||||||
private <T extends ComparisonOperator> void addTimeFilters(String date,
|
private <T extends ComparisonOperator> void addTimeFilters(String date,
|
||||||
T comparisonExpression,
|
T comparisonExpression,
|
||||||
List<Expression> addConditions) {
|
List<Expression> addConditions) {
|
||||||
Column column = new Column(TimeDimensionEnum.DAY.getChName());
|
Column column = new Column(TimeDimensionEnum.DAY.getChName());
|
||||||
StringValue stringValue = new StringValue(date);
|
StringValue stringValue = new StringValue(date);
|
||||||
comparisonExpression.setLeftExpression(column);
|
comparisonExpression.setLeftExpression(column);
|
||||||
@@ -434,10 +438,10 @@ public class ChatQueryServiceImpl implements ChatQueryService {
|
|||||||
}
|
}
|
||||||
|
|
||||||
private void updateFilters(List<FieldExpression> fieldExpressionList,
|
private void updateFilters(List<FieldExpression> fieldExpressionList,
|
||||||
Set<QueryFilter> metricFilters,
|
Set<QueryFilter> metricFilters,
|
||||||
Set<QueryFilter> contextMetricFilters,
|
Set<QueryFilter> contextMetricFilters,
|
||||||
List<Expression> addConditions,
|
List<Expression> addConditions,
|
||||||
Set<String> removeFieldNames) {
|
Set<String> removeFieldNames) {
|
||||||
if (CollectionUtils.isEmpty(metricFilters)) {
|
if (CollectionUtils.isEmpty(metricFilters)) {
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
@@ -473,9 +477,9 @@ public class ChatQueryServiceImpl implements ChatQueryService {
|
|||||||
|
|
||||||
// add in condition to sql where condition
|
// add in condition to sql where condition
|
||||||
private void addWhereInFilters(QueryFilter dslQueryFilter,
|
private void addWhereInFilters(QueryFilter dslQueryFilter,
|
||||||
InExpression inExpression,
|
InExpression inExpression,
|
||||||
Set<QueryFilter> contextMetricFilters,
|
Set<QueryFilter> contextMetricFilters,
|
||||||
List<Expression> addConditions) {
|
List<Expression> addConditions) {
|
||||||
Column column = new Column(dslQueryFilter.getName());
|
Column column = new Column(dslQueryFilter.getName());
|
||||||
ExpressionList expressionList = new ExpressionList();
|
ExpressionList expressionList = new ExpressionList();
|
||||||
List<Expression> expressions = new ArrayList<>();
|
List<Expression> expressions = new ArrayList<>();
|
||||||
@@ -502,9 +506,9 @@ public class ChatQueryServiceImpl implements ChatQueryService {
|
|||||||
|
|
||||||
// add where filter
|
// add where filter
|
||||||
private <T extends ComparisonOperator> void addWhereFilters(QueryFilter dslQueryFilter,
|
private <T extends ComparisonOperator> void addWhereFilters(QueryFilter dslQueryFilter,
|
||||||
T comparisonExpression,
|
T comparisonExpression,
|
||||||
Set<QueryFilter> contextMetricFilters,
|
Set<QueryFilter> contextMetricFilters,
|
||||||
List<Expression> addConditions) {
|
List<Expression> addConditions) {
|
||||||
String columnName = dslQueryFilter.getName();
|
String columnName = dslQueryFilter.getName();
|
||||||
if (StringUtils.isNotBlank(dslQueryFilter.getFunction())) {
|
if (StringUtils.isNotBlank(dslQueryFilter.getFunction())) {
|
||||||
columnName = dslQueryFilter.getFunction() + "(" + dslQueryFilter.getName() + ")";
|
columnName = dslQueryFilter.getFunction() + "(" + dslQueryFilter.getName() + ")";
|
||||||
|
|||||||
Reference in New Issue
Block a user