[improvement][chat]Use accept() pattern to improve code readability.

This commit is contained in:
jerryjzhang
2024-12-19 09:34:38 +08:00
parent 6fcd105249
commit 9faa858c22
12 changed files with 79 additions and 42 deletions

View File

@@ -1,6 +1,7 @@
package com.tencent.supersonic.chat.server.pojo;
import com.tencent.supersonic.chat.api.pojo.request.ChatExecuteReq;
import com.tencent.supersonic.chat.api.pojo.response.QueryResult;
import com.tencent.supersonic.chat.server.agent.Agent;
import com.tencent.supersonic.headless.api.pojo.SemanticParseInfo;
import lombok.Data;
@@ -8,6 +9,7 @@ import lombok.Data;
@Data
public class ExecuteContext {
private ChatExecuteReq request;
private QueryResult response;
private Agent agent;
private SemanticParseInfo parseInfo;

View File

@@ -43,12 +43,17 @@ public class DataInterpretProcessor implements ExecuteResultProcessor {
@Override
public void process(ExecuteContext executeContext, QueryResult queryResult) {
public boolean accept(ExecuteContext executeContext) {
Agent agent = executeContext.getAgent();
ChatApp chatApp = agent.getChatAppConfig().get(APP_KEY);
return Objects.nonNull(chatApp) && chatApp.isEnable();
}
@Override
public void process(ExecuteContext executeContext) {
QueryResult queryResult = executeContext.getResponse();
Agent agent = executeContext.getAgent();
ChatApp chatApp = agent.getChatAppConfig().get(APP_KEY);
if (Objects.isNull(chatApp) || !chatApp.isEnable()) {
return;
}
Map<String, Object> variable = new HashMap<>();
variable.put("question", executeContext.getRequest().getQueryText());

View File

@@ -27,17 +27,18 @@ public class DimensionRecommendProcessor implements ExecuteResultProcessor {
private static final int recommend_dimension_size = 5;
@Override
public void process(ExecuteContext executeContext, QueryResult queryResult) {
public boolean accept(ExecuteContext executeContext) {
SemanticParseInfo semanticParseInfo = executeContext.getParseInfo();
return QueryType.AGGREGATE.equals(semanticParseInfo.getQueryType())
&& !CollectionUtils.isEmpty(semanticParseInfo.getMetrics());
}
@Override
public void process(ExecuteContext executeContext) {
QueryResult queryResult = executeContext.getResponse();
SemanticParseInfo semanticParseInfo = executeContext.getParseInfo();
if (!QueryType.AGGREGATE.equals(semanticParseInfo.getQueryType())
|| CollectionUtils.isEmpty(semanticParseInfo.getMetrics())) {
return;
}
Long dataSetId = semanticParseInfo.getDataSetId();
Optional<SchemaElement> firstMetric = semanticParseInfo.getMetrics().stream().findFirst();
if (!firstMetric.isPresent()) {
return;
}
List<SchemaElement> dimensionRecommended =
getDimensions(firstMetric.get().getId(), dataSetId);
queryResult.setRecommendedDimensions(dimensionRecommended);

View File

@@ -1,11 +1,12 @@
package com.tencent.supersonic.chat.server.processor.execute;
import com.tencent.supersonic.chat.api.pojo.response.QueryResult;
import com.tencent.supersonic.chat.server.pojo.ExecuteContext;
import com.tencent.supersonic.chat.server.processor.ResultProcessor;
/** A ExecuteResultProcessor wraps things up before returning execution results to the users. */
public interface ExecuteResultProcessor extends ResultProcessor {
void process(ExecuteContext executeContext, QueryResult queryResult);
boolean accept(ExecuteContext executeContext);
void process(ExecuteContext executeContext);
}

View File

@@ -59,14 +59,18 @@ import static com.tencent.supersonic.common.pojo.Constants.TIME_FORMAT;
public class MetricRatioCalcProcessor implements ExecuteResultProcessor {
@Override
public void process(ExecuteContext executeContext, QueryResult queryResult) {
public boolean accept(ExecuteContext executeContext) {
SemanticParseInfo semanticParseInfo = executeContext.getParseInfo();
AggregatorConfig aggregatorConfig = ContextUtils.getBean(AggregatorConfig.class);
if (CollectionUtils.isEmpty(semanticParseInfo.getMetrics())
|| !aggregatorConfig.getEnableRatio()
|| !QueryType.AGGREGATE.equals(semanticParseInfo.getQueryType())) {
return;
}
return !CollectionUtils.isEmpty(semanticParseInfo.getMetrics())
&& aggregatorConfig.getEnableRatio()
&& QueryType.AGGREGATE.equals(semanticParseInfo.getQueryType());
}
@Override
public void process(ExecuteContext executeContext) {
QueryResult queryResult = executeContext.getResponse();
SemanticParseInfo semanticParseInfo = executeContext.getParseInfo();
AggregateInfo aggregateInfo = getAggregateInfo(executeContext.getRequest().getUser(),
semanticParseInfo, queryResult);
queryResult.setAggregateInfo(aggregateInfo);

View File

@@ -1,7 +1,6 @@
package com.tencent.supersonic.chat.server.processor.execute;
import com.alibaba.fastjson.JSONObject;
import com.tencent.supersonic.chat.api.pojo.response.QueryResult;
import com.tencent.supersonic.chat.server.pojo.ExecuteContext;
import com.tencent.supersonic.common.pojo.Constants;
import com.tencent.supersonic.common.pojo.enums.DictWordType;
@@ -16,14 +15,7 @@ import dev.langchain4j.store.embedding.RetrieveQuery;
import dev.langchain4j.store.embedding.RetrieveQueryResult;
import org.springframework.util.CollectionUtils;
import java.util.Collections;
import java.util.Comparator;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.Set;
import java.util.*;
import java.util.stream.Collectors;
/**
@@ -34,17 +26,20 @@ public class MetricRecommendProcessor implements ExecuteResultProcessor {
private static final int METRIC_RECOMMEND_SIZE = 5;
@Override
public void process(ExecuteContext executeContext, QueryResult queryResult) {
public boolean accept(ExecuteContext executeContext) {
SemanticParseInfo parseInfo = executeContext.getParseInfo();
return Objects.nonNull(parseInfo.getQueryType())
&& parseInfo.getQueryType().equals(QueryType.AGGREGATE)
&& !CollectionUtils.isEmpty(parseInfo.getMetrics())
&& parseInfo.getMetrics().size() <= METRIC_RECOMMEND_SIZE;
}
@Override
public void process(ExecuteContext executeContext) {
fillSimilarMetric(executeContext.getParseInfo());
}
private void fillSimilarMetric(SemanticParseInfo parseInfo) {
if (Objects.isNull(parseInfo.getQueryType())
|| !parseInfo.getQueryType().equals(QueryType.AGGREGATE)
|| parseInfo.getMetrics().size() > METRIC_RECOMMEND_SIZE
|| CollectionUtils.isEmpty(parseInfo.getMetrics())) {
return;
}
List<String> metricNames =
Collections.singletonList(parseInfo.getMetrics().iterator().next().getName());
Map<String, Object> filterCondition = new HashMap<>();

View File

@@ -43,14 +43,17 @@ public class ErrorMsgRewriteProcessor implements ParseResultProcessor {
.enable(false).build());
}
@Override
public boolean accept(ParseContext parseContext) {
ChatApp chatApp = parseContext.getAgent().getChatAppConfig().get(APP_KEY_ERROR_MESSAGE);
return StringUtils.isNotBlank(parseContext.getResponse().getErrorMsg())
&& Objects.nonNull(chatApp) || chatApp.isEnable();
}
@Override
public void process(ParseContext parseContext) {
String errMsg = parseContext.getResponse().getErrorMsg();
ChatApp chatApp = parseContext.getAgent().getChatAppConfig().get(APP_KEY_ERROR_MESSAGE);
if (StringUtils.isBlank(errMsg) || Objects.isNull(chatApp) || !chatApp.isEnable()) {
return;
}
Map<String, Object> variables = new HashMap<>();
variables.put("user_question", parseContext.getRequest().getQueryText());
variables.put("system_message", errMsg);

View File

@@ -28,6 +28,12 @@ import java.util.stream.Collectors;
**/
@Slf4j
public class ParseInfoFormatProcessor implements ParseResultProcessor {
@Override
public boolean accept(ParseContext parseContext) {
return !parseContext.getResponse().getSelectedParses().isEmpty();
}
@Override
public void process(ParseContext parseContext) {
parseContext.getResponse().getSelectedParses().forEach(p -> {

View File

@@ -6,5 +6,7 @@ import com.tencent.supersonic.chat.server.processor.ResultProcessor;
/** A ParseResultProcessor wraps things up before returning parsing results to the users. */
public interface ParseResultProcessor extends ResultProcessor {
boolean accept(ParseContext parseContext);
void process(ParseContext parseContext);
}

View File

@@ -23,6 +23,11 @@ import java.util.stream.Collectors;
@Slf4j
public class QueryRecommendProcessor implements ParseResultProcessor {
@Override
public boolean accept(ParseContext parseContext) {
return true;
}
@Override
public void process(ParseContext parseContext) {
CompletableFuture.runAsync(() -> doProcess(parseContext));

View File

@@ -10,6 +10,11 @@ import lombok.extern.slf4j.Slf4j;
@Slf4j
public class TimeCostCalcProcessor implements ParseResultProcessor {
@Override
public boolean accept(ParseContext parseContext) {
return true;
}
@Override
public void process(ParseContext parseContext) {
ChatParseResp parseResp = parseContext.getResponse();

View File

@@ -95,7 +95,12 @@ public class ChatQueryServiceImpl implements ChatQueryService {
ParseContext parseContext = buildParseContext(chatParseReq, new ChatParseResp(queryId));
chatQueryParsers.forEach(p -> p.parse(parseContext));
parseResultProcessors.forEach(p -> p.process(parseContext));
for (ParseResultProcessor processor : parseResultProcessors) {
if (processor.accept(parseContext)) {
processor.process(parseContext);
}
}
if (!parseContext.needFeedback()) {
chatManageService.batchAddParse(chatParseReq, parseContext.getResponse());
@@ -116,9 +121,12 @@ public class ChatQueryServiceImpl implements ChatQueryService {
}
}
executeContext.setResponse(queryResult);
if (queryResult != null) {
for (ExecuteResultProcessor processor : executeResultProcessors) {
processor.process(executeContext, queryResult);
if (processor.accept(executeContext)) {
processor.process(executeContext);
}
}
saveQueryResult(chatExecuteReq, queryResult);
}