mirror of
https://github.com/tencentmusic/supersonic.git
synced 2025-12-10 19:51:00 +00:00
[improvement][chat]Use accept() pattern to improve code readability.
This commit is contained in:
@@ -1,6 +1,7 @@
|
||||
package com.tencent.supersonic.chat.server.pojo;
|
||||
|
||||
import com.tencent.supersonic.chat.api.pojo.request.ChatExecuteReq;
|
||||
import com.tencent.supersonic.chat.api.pojo.response.QueryResult;
|
||||
import com.tencent.supersonic.chat.server.agent.Agent;
|
||||
import com.tencent.supersonic.headless.api.pojo.SemanticParseInfo;
|
||||
import lombok.Data;
|
||||
@@ -8,6 +9,7 @@ import lombok.Data;
|
||||
@Data
|
||||
public class ExecuteContext {
|
||||
private ChatExecuteReq request;
|
||||
private QueryResult response;
|
||||
private Agent agent;
|
||||
private SemanticParseInfo parseInfo;
|
||||
|
||||
|
||||
@@ -43,12 +43,17 @@ public class DataInterpretProcessor implements ExecuteResultProcessor {
|
||||
|
||||
|
||||
@Override
|
||||
public void process(ExecuteContext executeContext, QueryResult queryResult) {
|
||||
public boolean accept(ExecuteContext executeContext) {
|
||||
Agent agent = executeContext.getAgent();
|
||||
ChatApp chatApp = agent.getChatAppConfig().get(APP_KEY);
|
||||
return Objects.nonNull(chatApp) && chatApp.isEnable();
|
||||
}
|
||||
|
||||
@Override
|
||||
public void process(ExecuteContext executeContext) {
|
||||
QueryResult queryResult = executeContext.getResponse();
|
||||
Agent agent = executeContext.getAgent();
|
||||
ChatApp chatApp = agent.getChatAppConfig().get(APP_KEY);
|
||||
if (Objects.isNull(chatApp) || !chatApp.isEnable()) {
|
||||
return;
|
||||
}
|
||||
|
||||
Map<String, Object> variable = new HashMap<>();
|
||||
variable.put("question", executeContext.getRequest().getQueryText());
|
||||
|
||||
@@ -27,17 +27,18 @@ public class DimensionRecommendProcessor implements ExecuteResultProcessor {
|
||||
private static final int recommend_dimension_size = 5;
|
||||
|
||||
@Override
|
||||
public void process(ExecuteContext executeContext, QueryResult queryResult) {
|
||||
public boolean accept(ExecuteContext executeContext) {
|
||||
SemanticParseInfo semanticParseInfo = executeContext.getParseInfo();
|
||||
return QueryType.AGGREGATE.equals(semanticParseInfo.getQueryType())
|
||||
&& !CollectionUtils.isEmpty(semanticParseInfo.getMetrics());
|
||||
}
|
||||
|
||||
@Override
|
||||
public void process(ExecuteContext executeContext) {
|
||||
QueryResult queryResult = executeContext.getResponse();
|
||||
SemanticParseInfo semanticParseInfo = executeContext.getParseInfo();
|
||||
if (!QueryType.AGGREGATE.equals(semanticParseInfo.getQueryType())
|
||||
|| CollectionUtils.isEmpty(semanticParseInfo.getMetrics())) {
|
||||
return;
|
||||
}
|
||||
Long dataSetId = semanticParseInfo.getDataSetId();
|
||||
Optional<SchemaElement> firstMetric = semanticParseInfo.getMetrics().stream().findFirst();
|
||||
if (!firstMetric.isPresent()) {
|
||||
return;
|
||||
}
|
||||
List<SchemaElement> dimensionRecommended =
|
||||
getDimensions(firstMetric.get().getId(), dataSetId);
|
||||
queryResult.setRecommendedDimensions(dimensionRecommended);
|
||||
|
||||
@@ -1,11 +1,12 @@
|
||||
package com.tencent.supersonic.chat.server.processor.execute;
|
||||
|
||||
import com.tencent.supersonic.chat.api.pojo.response.QueryResult;
|
||||
import com.tencent.supersonic.chat.server.pojo.ExecuteContext;
|
||||
import com.tencent.supersonic.chat.server.processor.ResultProcessor;
|
||||
|
||||
/** A ExecuteResultProcessor wraps things up before returning execution results to the users. */
|
||||
public interface ExecuteResultProcessor extends ResultProcessor {
|
||||
|
||||
void process(ExecuteContext executeContext, QueryResult queryResult);
|
||||
boolean accept(ExecuteContext executeContext);
|
||||
|
||||
void process(ExecuteContext executeContext);
|
||||
}
|
||||
|
||||
@@ -59,14 +59,18 @@ import static com.tencent.supersonic.common.pojo.Constants.TIME_FORMAT;
|
||||
public class MetricRatioCalcProcessor implements ExecuteResultProcessor {
|
||||
|
||||
@Override
|
||||
public void process(ExecuteContext executeContext, QueryResult queryResult) {
|
||||
public boolean accept(ExecuteContext executeContext) {
|
||||
SemanticParseInfo semanticParseInfo = executeContext.getParseInfo();
|
||||
AggregatorConfig aggregatorConfig = ContextUtils.getBean(AggregatorConfig.class);
|
||||
if (CollectionUtils.isEmpty(semanticParseInfo.getMetrics())
|
||||
|| !aggregatorConfig.getEnableRatio()
|
||||
|| !QueryType.AGGREGATE.equals(semanticParseInfo.getQueryType())) {
|
||||
return;
|
||||
}
|
||||
return !CollectionUtils.isEmpty(semanticParseInfo.getMetrics())
|
||||
&& aggregatorConfig.getEnableRatio()
|
||||
&& QueryType.AGGREGATE.equals(semanticParseInfo.getQueryType());
|
||||
}
|
||||
|
||||
@Override
|
||||
public void process(ExecuteContext executeContext) {
|
||||
QueryResult queryResult = executeContext.getResponse();
|
||||
SemanticParseInfo semanticParseInfo = executeContext.getParseInfo();
|
||||
AggregateInfo aggregateInfo = getAggregateInfo(executeContext.getRequest().getUser(),
|
||||
semanticParseInfo, queryResult);
|
||||
queryResult.setAggregateInfo(aggregateInfo);
|
||||
|
||||
@@ -1,7 +1,6 @@
|
||||
package com.tencent.supersonic.chat.server.processor.execute;
|
||||
|
||||
import com.alibaba.fastjson.JSONObject;
|
||||
import com.tencent.supersonic.chat.api.pojo.response.QueryResult;
|
||||
import com.tencent.supersonic.chat.server.pojo.ExecuteContext;
|
||||
import com.tencent.supersonic.common.pojo.Constants;
|
||||
import com.tencent.supersonic.common.pojo.enums.DictWordType;
|
||||
@@ -16,14 +15,7 @@ import dev.langchain4j.store.embedding.RetrieveQuery;
|
||||
import dev.langchain4j.store.embedding.RetrieveQueryResult;
|
||||
import org.springframework.util.CollectionUtils;
|
||||
|
||||
import java.util.Collections;
|
||||
import java.util.Comparator;
|
||||
import java.util.HashMap;
|
||||
import java.util.HashSet;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
import java.util.Objects;
|
||||
import java.util.Set;
|
||||
import java.util.*;
|
||||
import java.util.stream.Collectors;
|
||||
|
||||
/**
|
||||
@@ -34,17 +26,20 @@ public class MetricRecommendProcessor implements ExecuteResultProcessor {
|
||||
private static final int METRIC_RECOMMEND_SIZE = 5;
|
||||
|
||||
@Override
|
||||
public void process(ExecuteContext executeContext, QueryResult queryResult) {
|
||||
public boolean accept(ExecuteContext executeContext) {
|
||||
SemanticParseInfo parseInfo = executeContext.getParseInfo();
|
||||
return Objects.nonNull(parseInfo.getQueryType())
|
||||
&& parseInfo.getQueryType().equals(QueryType.AGGREGATE)
|
||||
&& !CollectionUtils.isEmpty(parseInfo.getMetrics())
|
||||
&& parseInfo.getMetrics().size() <= METRIC_RECOMMEND_SIZE;
|
||||
}
|
||||
|
||||
@Override
|
||||
public void process(ExecuteContext executeContext) {
|
||||
fillSimilarMetric(executeContext.getParseInfo());
|
||||
}
|
||||
|
||||
private void fillSimilarMetric(SemanticParseInfo parseInfo) {
|
||||
if (Objects.isNull(parseInfo.getQueryType())
|
||||
|| !parseInfo.getQueryType().equals(QueryType.AGGREGATE)
|
||||
|| parseInfo.getMetrics().size() > METRIC_RECOMMEND_SIZE
|
||||
|| CollectionUtils.isEmpty(parseInfo.getMetrics())) {
|
||||
return;
|
||||
}
|
||||
List<String> metricNames =
|
||||
Collections.singletonList(parseInfo.getMetrics().iterator().next().getName());
|
||||
Map<String, Object> filterCondition = new HashMap<>();
|
||||
|
||||
@@ -43,14 +43,17 @@ public class ErrorMsgRewriteProcessor implements ParseResultProcessor {
|
||||
.enable(false).build());
|
||||
}
|
||||
|
||||
@Override
|
||||
public boolean accept(ParseContext parseContext) {
|
||||
ChatApp chatApp = parseContext.getAgent().getChatAppConfig().get(APP_KEY_ERROR_MESSAGE);
|
||||
return StringUtils.isNotBlank(parseContext.getResponse().getErrorMsg())
|
||||
&& Objects.nonNull(chatApp) || chatApp.isEnable();
|
||||
}
|
||||
|
||||
@Override
|
||||
public void process(ParseContext parseContext) {
|
||||
String errMsg = parseContext.getResponse().getErrorMsg();
|
||||
ChatApp chatApp = parseContext.getAgent().getChatAppConfig().get(APP_KEY_ERROR_MESSAGE);
|
||||
if (StringUtils.isBlank(errMsg) || Objects.isNull(chatApp) || !chatApp.isEnable()) {
|
||||
return;
|
||||
}
|
||||
|
||||
Map<String, Object> variables = new HashMap<>();
|
||||
variables.put("user_question", parseContext.getRequest().getQueryText());
|
||||
variables.put("system_message", errMsg);
|
||||
|
||||
@@ -28,6 +28,12 @@ import java.util.stream.Collectors;
|
||||
**/
|
||||
@Slf4j
|
||||
public class ParseInfoFormatProcessor implements ParseResultProcessor {
|
||||
|
||||
@Override
|
||||
public boolean accept(ParseContext parseContext) {
|
||||
return !parseContext.getResponse().getSelectedParses().isEmpty();
|
||||
}
|
||||
|
||||
@Override
|
||||
public void process(ParseContext parseContext) {
|
||||
parseContext.getResponse().getSelectedParses().forEach(p -> {
|
||||
|
||||
@@ -6,5 +6,7 @@ import com.tencent.supersonic.chat.server.processor.ResultProcessor;
|
||||
/** A ParseResultProcessor wraps things up before returning parsing results to the users. */
|
||||
public interface ParseResultProcessor extends ResultProcessor {
|
||||
|
||||
boolean accept(ParseContext parseContext);
|
||||
|
||||
void process(ParseContext parseContext);
|
||||
}
|
||||
|
||||
@@ -23,6 +23,11 @@ import java.util.stream.Collectors;
|
||||
@Slf4j
|
||||
public class QueryRecommendProcessor implements ParseResultProcessor {
|
||||
|
||||
@Override
|
||||
public boolean accept(ParseContext parseContext) {
|
||||
return true;
|
||||
}
|
||||
|
||||
@Override
|
||||
public void process(ParseContext parseContext) {
|
||||
CompletableFuture.runAsync(() -> doProcess(parseContext));
|
||||
|
||||
@@ -10,6 +10,11 @@ import lombok.extern.slf4j.Slf4j;
|
||||
@Slf4j
|
||||
public class TimeCostCalcProcessor implements ParseResultProcessor {
|
||||
|
||||
@Override
|
||||
public boolean accept(ParseContext parseContext) {
|
||||
return true;
|
||||
}
|
||||
|
||||
@Override
|
||||
public void process(ParseContext parseContext) {
|
||||
ChatParseResp parseResp = parseContext.getResponse();
|
||||
|
||||
@@ -95,7 +95,12 @@ public class ChatQueryServiceImpl implements ChatQueryService {
|
||||
|
||||
ParseContext parseContext = buildParseContext(chatParseReq, new ChatParseResp(queryId));
|
||||
chatQueryParsers.forEach(p -> p.parse(parseContext));
|
||||
parseResultProcessors.forEach(p -> p.process(parseContext));
|
||||
|
||||
for (ParseResultProcessor processor : parseResultProcessors) {
|
||||
if (processor.accept(parseContext)) {
|
||||
processor.process(parseContext);
|
||||
}
|
||||
}
|
||||
|
||||
if (!parseContext.needFeedback()) {
|
||||
chatManageService.batchAddParse(chatParseReq, parseContext.getResponse());
|
||||
@@ -116,9 +121,12 @@ public class ChatQueryServiceImpl implements ChatQueryService {
|
||||
}
|
||||
}
|
||||
|
||||
executeContext.setResponse(queryResult);
|
||||
if (queryResult != null) {
|
||||
for (ExecuteResultProcessor processor : executeResultProcessors) {
|
||||
processor.process(executeContext, queryResult);
|
||||
if (processor.accept(executeContext)) {
|
||||
processor.process(executeContext);
|
||||
}
|
||||
}
|
||||
saveQueryResult(chatExecuteReq, queryResult);
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user