[improvement][chat]Support user feedback to multiple candidate semantic parses.#1847

This commit is contained in:
jerryjzhang
2024-10-28 17:31:33 +08:00
parent 5dda539798
commit b3d4440781
9 changed files with 51 additions and 34 deletions

View File

@@ -1,6 +1,7 @@
package com.tencent.supersonic.chat.api.pojo.request;
import com.tencent.supersonic.common.pojo.User;
import com.tencent.supersonic.headless.api.pojo.SemanticParseInfo;
import com.tencent.supersonic.headless.api.pojo.request.QueryFilters;
import lombok.AllArgsConstructor;
import lombok.Builder;
@@ -20,5 +21,5 @@ public class ChatParseReq {
private boolean saveAnswer = true;
private boolean disableLLM = false;
private Long queryId;
private Integer parseId;
private SemanticParseInfo selectedParse;
}

View File

@@ -76,7 +76,7 @@ public class NL2SQLParser implements ChatQueryParser {
if (!parseContext.enableNL2SQL() || Objects.isNull(parseContext.getAgent())) {
return;
}
if (parseContext.enableFeedback()) {
if (parseContext.needFeedback()) {
processFeedback(parseContext);
return;
}

View File

@@ -4,7 +4,6 @@ import com.tencent.supersonic.chat.api.pojo.request.ChatParseReq;
import com.tencent.supersonic.chat.api.pojo.response.ChatParseResp;
import com.tencent.supersonic.chat.server.agent.Agent;
import com.tencent.supersonic.headless.api.pojo.SemanticParseInfo;
import com.tencent.supersonic.headless.api.pojo.response.ParseResp;
import lombok.Data;
import java.util.Objects;
@@ -14,21 +13,21 @@ public class ParseContext {
private ChatParseReq request;
private ChatParseResp response;
private Agent agent;
private SemanticParseInfo selectedParseInfo;
public ParseContext(ChatParseReq request) {
public ParseContext(ChatParseReq request, ChatParseResp response) {
this.request = request;
this.response = response;
}
public boolean enableNL2SQL() {
return agent.containsDatasetTool();
}
public boolean enableFeedback() {
return agent.enableFeedback() && Objects.isNull(request.getParseId());
public boolean needFeedback() {
return agent.enableFeedback() && Objects.isNull(request.getSelectedParse());
}
public boolean enableLLM() {
return !(enableFeedback() || request.isDisableLLM());
return !(needFeedback() || request.isDisableLLM());
}
}

View File

@@ -1,6 +1,11 @@
package com.tencent.supersonic.chat.server.processor.parse;
import com.google.common.collect.Lists;
import com.google.common.collect.Sets;
import com.tencent.supersonic.chat.server.pojo.ParseContext;
import com.tencent.supersonic.headless.api.pojo.SemanticParseInfo;
import java.util.*;
/**
* ParseInfoSortProcessor sorts candidate parse info based on certain algorithm. \
@@ -9,6 +14,17 @@ public class ParseInfoSortProcessor implements ParseResultProcessor {
@Override
public void process(ParseContext parseContext) {
Set<String> parseInfoText = Sets.newHashSet();
List<SemanticParseInfo> sortedParseInfo = Lists.newArrayList();
parseContext.getResponse().getSelectedParses().forEach(p -> {
if (!parseInfoText.contains(p.getTextInfo())) {
sortedParseInfo.add(p);
parseInfoText.add(p.getTextInfo());
}
});
Collections.sort(sortedParseInfo, (o1, o2) -> o1.getScore() - o2.getScore() >= 0 ? 1 : 0);
parseContext.getResponse().setSelectedParses(sortedParseInfo);
}
}

View File

@@ -37,7 +37,6 @@ import com.tencent.supersonic.headless.api.pojo.request.DimensionValueReq;
import com.tencent.supersonic.headless.api.pojo.request.QueryFilter;
import com.tencent.supersonic.headless.api.pojo.request.QueryNLReq;
import com.tencent.supersonic.headless.api.pojo.request.SemanticQueryReq;
import com.tencent.supersonic.headless.api.pojo.response.ParseResp;
import com.tencent.supersonic.headless.api.pojo.response.QueryState;
import com.tencent.supersonic.headless.api.pojo.response.SearchResult;
import com.tencent.supersonic.headless.api.pojo.response.SemanticQueryResp;
@@ -95,7 +94,7 @@ public class ChatQueryServiceImpl implements ChatQueryService {
@Override
public List<SearchResult> search(ChatParseReq chatParseReq) {
ParseContext parseContext = buildParseContext(chatParseReq);
ParseContext parseContext = buildParseContext(chatParseReq, null);
Agent agent = parseContext.getAgent();
if (!agent.enableSearch()) {
return Lists.newArrayList();
@@ -106,20 +105,20 @@ public class ChatQueryServiceImpl implements ChatQueryService {
@Override
public ChatParseResp parse(ChatParseReq chatParseReq) {
ParseContext parseContext = buildParseContext(chatParseReq);
Long queryId = chatManageService.createChatQuery(chatParseReq);
parseContext.setResponse(new ChatParseResp(queryId));
for (ChatQueryParser parser : chatQueryParsers) {
parser.parse(parseContext);
}
for (ParseResultProcessor processor : parseResultProcessors) {
processor.process(parseContext);
Long queryId = chatParseReq.getQueryId();
if (Objects.isNull(queryId)) {
queryId = chatManageService.createChatQuery(chatParseReq);
}
ParseContext parseContext = buildParseContext(chatParseReq, new ChatParseResp(queryId));
chatQueryParsers.forEach(p -> p.parse(parseContext));
parseResultProcessors.forEach(p -> p.process(parseContext));
if (!parseContext.needFeedback()) {
chatManageService.batchAddParse(chatParseReq, parseContext.getResponse());
chatManageService.updateParseCostTime(parseContext.getResponse());
}
chatParseReq.setQueryText(parseContext.getRequest().getQueryText());
chatManageService.batchAddParse(chatParseReq, parseContext.getResponse());
chatManageService.updateParseCostTime(parseContext.getResponse());
return parseContext.getResponse();
}
@@ -164,16 +163,10 @@ public class ChatQueryServiceImpl implements ChatQueryService {
return execute(executeReq);
}
private ParseContext buildParseContext(ChatParseReq chatParseReq) {
ParseContext parseContext = new ParseContext(chatParseReq);
private ParseContext buildParseContext(ChatParseReq chatParseReq, ChatParseResp chatParseResp) {
ParseContext parseContext = new ParseContext(chatParseReq, chatParseResp);
Agent agent = agentService.getAgent(chatParseReq.getAgentId());
parseContext.setAgent(agent);
if (Objects.nonNull(chatParseReq.getQueryId())
&& Objects.nonNull(chatParseReq.getParseId())) {
SemanticParseInfo parseInfo = chatManageService.getParseInfo(chatParseReq.getQueryId(),
chatParseReq.getParseId());
parseContext.setSelectedParseInfo(parseInfo);
}
return parseContext;
}

View File

@@ -14,7 +14,7 @@ public class QueryReqConverter {
parseContext.enableLLM() ? Text2SQLType.RULE_AND_LLM : Text2SQLType.ONLY_RULE);
queryNLReq.setDataSetIds(parseContext.getAgent().getDataSetIds());
queryNLReq.setChatAppConfig(parseContext.getAgent().getChatAppConfig());
queryNLReq.setSelectedParseInfo(parseContext.getSelectedParseInfo());
queryNLReq.setSelectedParseInfo(parseContext.getRequest().getSelectedParse());
return queryNLReq;
}

View File

@@ -3,6 +3,7 @@ package com.tencent.supersonic.headless.core.cache;
import com.tencent.supersonic.common.util.ContextUtils;
import com.tencent.supersonic.headless.api.pojo.request.SemanticQueryReq;
import lombok.extern.slf4j.Slf4j;
import org.apache.commons.lang3.StringUtils;
import org.springframework.stereotype.Component;
import java.util.List;
@@ -18,7 +19,10 @@ public class DefaultQueryCache implements QueryCache {
CacheManager cacheManager = ContextUtils.getBean(CacheManager.class);
if (isCache(semanticQueryReq)) {
Object result = cacheManager.get(cacheKey);
log.info("query from cache, key:{},result:{}", cacheKey, result);
if (Objects.nonNull(result)) {
log.info("query from cache, key:{},result:{}", cacheKey,
StringUtils.normalizeSpace(result.toString()));
}
return result;
}
return null;

View File

@@ -131,7 +131,10 @@ public class S2SemanticLayerService implements SemanticLayerService {
String cacheKey = queryCache.getCacheKey(queryReq);
Object query = queryCache.query(queryReq, cacheKey);
log.info("cacheKey:{},query:{}", cacheKey, query);
if (Objects.nonNull(query)) {
log.info("cacheKey:{},query:{}", cacheKey,
StringUtils.normalizeSpace(query.toString()));
}
if (Objects.nonNull(query)) {
SemanticQueryResp queryResp = (SemanticQueryResp) query;
queryResp.setUseCache(true);

View File

@@ -69,7 +69,8 @@ com.tencent.supersonic.chat.server.processor.parse.ParseResultProcessor=\
com.tencent.supersonic.chat.server.processor.parse.QueryRecommendProcessor,\
com.tencent.supersonic.chat.server.processor.parse.TimeCostCalcProcessor,\
com.tencent.supersonic.chat.server.processor.parse.ErrorMsgRewriteProcessor,\
com.tencent.supersonic.chat.server.processor.parse.ParseInfoFormatProcessor
com.tencent.supersonic.chat.server.processor.parse.ParseInfoFormatProcessor,\
com.tencent.supersonic.chat.server.processor.parse.ParseInfoSortProcessor
com.tencent.supersonic.chat.server.processor.execute.ExecuteResultProcessor=\
com.tencent.supersonic.chat.server.processor.execute.MetricRecommendProcessor,\