[improvement][chat]Support user feedback to multiple candidate semantic parses.#1847

This commit is contained in:
jerryjzhang
2024-10-28 17:31:33 +08:00
parent 5dda539798
commit b3d4440781
9 changed files with 51 additions and 34 deletions

View File

@@ -1,6 +1,7 @@
package com.tencent.supersonic.chat.api.pojo.request; package com.tencent.supersonic.chat.api.pojo.request;
import com.tencent.supersonic.common.pojo.User; import com.tencent.supersonic.common.pojo.User;
import com.tencent.supersonic.headless.api.pojo.SemanticParseInfo;
import com.tencent.supersonic.headless.api.pojo.request.QueryFilters; import com.tencent.supersonic.headless.api.pojo.request.QueryFilters;
import lombok.AllArgsConstructor; import lombok.AllArgsConstructor;
import lombok.Builder; import lombok.Builder;
@@ -20,5 +21,5 @@ public class ChatParseReq {
private boolean saveAnswer = true; private boolean saveAnswer = true;
private boolean disableLLM = false; private boolean disableLLM = false;
private Long queryId; private Long queryId;
private Integer parseId; private SemanticParseInfo selectedParse;
} }

View File

@@ -76,7 +76,7 @@ public class NL2SQLParser implements ChatQueryParser {
if (!parseContext.enableNL2SQL() || Objects.isNull(parseContext.getAgent())) { if (!parseContext.enableNL2SQL() || Objects.isNull(parseContext.getAgent())) {
return; return;
} }
if (parseContext.enableFeedback()) { if (parseContext.needFeedback()) {
processFeedback(parseContext); processFeedback(parseContext);
return; return;
} }

View File

@@ -4,7 +4,6 @@ import com.tencent.supersonic.chat.api.pojo.request.ChatParseReq;
import com.tencent.supersonic.chat.api.pojo.response.ChatParseResp; import com.tencent.supersonic.chat.api.pojo.response.ChatParseResp;
import com.tencent.supersonic.chat.server.agent.Agent; import com.tencent.supersonic.chat.server.agent.Agent;
import com.tencent.supersonic.headless.api.pojo.SemanticParseInfo; import com.tencent.supersonic.headless.api.pojo.SemanticParseInfo;
import com.tencent.supersonic.headless.api.pojo.response.ParseResp;
import lombok.Data; import lombok.Data;
import java.util.Objects; import java.util.Objects;
@@ -14,21 +13,21 @@ public class ParseContext {
private ChatParseReq request; private ChatParseReq request;
private ChatParseResp response; private ChatParseResp response;
private Agent agent; private Agent agent;
private SemanticParseInfo selectedParseInfo;
public ParseContext(ChatParseReq request) { public ParseContext(ChatParseReq request, ChatParseResp response) {
this.request = request; this.request = request;
this.response = response;
} }
public boolean enableNL2SQL() { public boolean enableNL2SQL() {
return agent.containsDatasetTool(); return agent.containsDatasetTool();
} }
public boolean enableFeedback() { public boolean needFeedback() {
return agent.enableFeedback() && Objects.isNull(request.getParseId()); return agent.enableFeedback() && Objects.isNull(request.getSelectedParse());
} }
public boolean enableLLM() { public boolean enableLLM() {
return !(enableFeedback() || request.isDisableLLM()); return !(needFeedback() || request.isDisableLLM());
} }
} }

View File

@@ -1,6 +1,11 @@
package com.tencent.supersonic.chat.server.processor.parse; package com.tencent.supersonic.chat.server.processor.parse;
import com.google.common.collect.Lists;
import com.google.common.collect.Sets;
import com.tencent.supersonic.chat.server.pojo.ParseContext; import com.tencent.supersonic.chat.server.pojo.ParseContext;
import com.tencent.supersonic.headless.api.pojo.SemanticParseInfo;
import java.util.*;
/** /**
* ParseInfoSortProcessor sorts candidate parse info based on certain algorithm. \ * ParseInfoSortProcessor sorts candidate parse info based on certain algorithm. \
@@ -9,6 +14,17 @@ public class ParseInfoSortProcessor implements ParseResultProcessor {
@Override @Override
public void process(ParseContext parseContext) { public void process(ParseContext parseContext) {
Set<String> parseInfoText = Sets.newHashSet();
List<SemanticParseInfo> sortedParseInfo = Lists.newArrayList();
parseContext.getResponse().getSelectedParses().forEach(p -> {
if (!parseInfoText.contains(p.getTextInfo())) {
sortedParseInfo.add(p);
parseInfoText.add(p.getTextInfo());
}
});
Collections.sort(sortedParseInfo, (o1, o2) -> o1.getScore() - o2.getScore() >= 0 ? 1 : 0);
parseContext.getResponse().setSelectedParses(sortedParseInfo);
} }
} }

View File

@@ -37,7 +37,6 @@ import com.tencent.supersonic.headless.api.pojo.request.DimensionValueReq;
import com.tencent.supersonic.headless.api.pojo.request.QueryFilter; import com.tencent.supersonic.headless.api.pojo.request.QueryFilter;
import com.tencent.supersonic.headless.api.pojo.request.QueryNLReq; import com.tencent.supersonic.headless.api.pojo.request.QueryNLReq;
import com.tencent.supersonic.headless.api.pojo.request.SemanticQueryReq; import com.tencent.supersonic.headless.api.pojo.request.SemanticQueryReq;
import com.tencent.supersonic.headless.api.pojo.response.ParseResp;
import com.tencent.supersonic.headless.api.pojo.response.QueryState; import com.tencent.supersonic.headless.api.pojo.response.QueryState;
import com.tencent.supersonic.headless.api.pojo.response.SearchResult; import com.tencent.supersonic.headless.api.pojo.response.SearchResult;
import com.tencent.supersonic.headless.api.pojo.response.SemanticQueryResp; import com.tencent.supersonic.headless.api.pojo.response.SemanticQueryResp;
@@ -95,7 +94,7 @@ public class ChatQueryServiceImpl implements ChatQueryService {
@Override @Override
public List<SearchResult> search(ChatParseReq chatParseReq) { public List<SearchResult> search(ChatParseReq chatParseReq) {
ParseContext parseContext = buildParseContext(chatParseReq); ParseContext parseContext = buildParseContext(chatParseReq, null);
Agent agent = parseContext.getAgent(); Agent agent = parseContext.getAgent();
if (!agent.enableSearch()) { if (!agent.enableSearch()) {
return Lists.newArrayList(); return Lists.newArrayList();
@@ -106,20 +105,20 @@ public class ChatQueryServiceImpl implements ChatQueryService {
@Override @Override
public ChatParseResp parse(ChatParseReq chatParseReq) { public ChatParseResp parse(ChatParseReq chatParseReq) {
ParseContext parseContext = buildParseContext(chatParseReq); Long queryId = chatParseReq.getQueryId();
Long queryId = chatManageService.createChatQuery(chatParseReq); if (Objects.isNull(queryId)) {
parseContext.setResponse(new ChatParseResp(queryId)); queryId = chatManageService.createChatQuery(chatParseReq);
}
for (ChatQueryParser parser : chatQueryParsers) {
parser.parse(parseContext); ParseContext parseContext = buildParseContext(chatParseReq, new ChatParseResp(queryId));
} chatQueryParsers.forEach(p -> p.parse(parseContext));
for (ParseResultProcessor processor : parseResultProcessors) { parseResultProcessors.forEach(p -> p.process(parseContext));
processor.process(parseContext);
if (!parseContext.needFeedback()) {
chatManageService.batchAddParse(chatParseReq, parseContext.getResponse());
chatManageService.updateParseCostTime(parseContext.getResponse());
} }
chatParseReq.setQueryText(parseContext.getRequest().getQueryText());
chatManageService.batchAddParse(chatParseReq, parseContext.getResponse());
chatManageService.updateParseCostTime(parseContext.getResponse());
return parseContext.getResponse(); return parseContext.getResponse();
} }
@@ -164,16 +163,10 @@ public class ChatQueryServiceImpl implements ChatQueryService {
return execute(executeReq); return execute(executeReq);
} }
private ParseContext buildParseContext(ChatParseReq chatParseReq) { private ParseContext buildParseContext(ChatParseReq chatParseReq, ChatParseResp chatParseResp) {
ParseContext parseContext = new ParseContext(chatParseReq); ParseContext parseContext = new ParseContext(chatParseReq, chatParseResp);
Agent agent = agentService.getAgent(chatParseReq.getAgentId()); Agent agent = agentService.getAgent(chatParseReq.getAgentId());
parseContext.setAgent(agent); parseContext.setAgent(agent);
if (Objects.nonNull(chatParseReq.getQueryId())
&& Objects.nonNull(chatParseReq.getParseId())) {
SemanticParseInfo parseInfo = chatManageService.getParseInfo(chatParseReq.getQueryId(),
chatParseReq.getParseId());
parseContext.setSelectedParseInfo(parseInfo);
}
return parseContext; return parseContext;
} }

View File

@@ -14,7 +14,7 @@ public class QueryReqConverter {
parseContext.enableLLM() ? Text2SQLType.RULE_AND_LLM : Text2SQLType.ONLY_RULE); parseContext.enableLLM() ? Text2SQLType.RULE_AND_LLM : Text2SQLType.ONLY_RULE);
queryNLReq.setDataSetIds(parseContext.getAgent().getDataSetIds()); queryNLReq.setDataSetIds(parseContext.getAgent().getDataSetIds());
queryNLReq.setChatAppConfig(parseContext.getAgent().getChatAppConfig()); queryNLReq.setChatAppConfig(parseContext.getAgent().getChatAppConfig());
queryNLReq.setSelectedParseInfo(parseContext.getSelectedParseInfo()); queryNLReq.setSelectedParseInfo(parseContext.getRequest().getSelectedParse());
return queryNLReq; return queryNLReq;
} }

View File

@@ -3,6 +3,7 @@ package com.tencent.supersonic.headless.core.cache;
import com.tencent.supersonic.common.util.ContextUtils; import com.tencent.supersonic.common.util.ContextUtils;
import com.tencent.supersonic.headless.api.pojo.request.SemanticQueryReq; import com.tencent.supersonic.headless.api.pojo.request.SemanticQueryReq;
import lombok.extern.slf4j.Slf4j; import lombok.extern.slf4j.Slf4j;
import org.apache.commons.lang3.StringUtils;
import org.springframework.stereotype.Component; import org.springframework.stereotype.Component;
import java.util.List; import java.util.List;
@@ -18,7 +19,10 @@ public class DefaultQueryCache implements QueryCache {
CacheManager cacheManager = ContextUtils.getBean(CacheManager.class); CacheManager cacheManager = ContextUtils.getBean(CacheManager.class);
if (isCache(semanticQueryReq)) { if (isCache(semanticQueryReq)) {
Object result = cacheManager.get(cacheKey); Object result = cacheManager.get(cacheKey);
log.info("query from cache, key:{},result:{}", cacheKey, result); if (Objects.nonNull(result)) {
log.info("query from cache, key:{},result:{}", cacheKey,
StringUtils.normalizeSpace(result.toString()));
}
return result; return result;
} }
return null; return null;

View File

@@ -131,7 +131,10 @@ public class S2SemanticLayerService implements SemanticLayerService {
String cacheKey = queryCache.getCacheKey(queryReq); String cacheKey = queryCache.getCacheKey(queryReq);
Object query = queryCache.query(queryReq, cacheKey); Object query = queryCache.query(queryReq, cacheKey);
log.info("cacheKey:{},query:{}", cacheKey, query); if (Objects.nonNull(query)) {
log.info("cacheKey:{},query:{}", cacheKey,
StringUtils.normalizeSpace(query.toString()));
}
if (Objects.nonNull(query)) { if (Objects.nonNull(query)) {
SemanticQueryResp queryResp = (SemanticQueryResp) query; SemanticQueryResp queryResp = (SemanticQueryResp) query;
queryResp.setUseCache(true); queryResp.setUseCache(true);

View File

@@ -69,7 +69,8 @@ com.tencent.supersonic.chat.server.processor.parse.ParseResultProcessor=\
com.tencent.supersonic.chat.server.processor.parse.QueryRecommendProcessor,\ com.tencent.supersonic.chat.server.processor.parse.QueryRecommendProcessor,\
com.tencent.supersonic.chat.server.processor.parse.TimeCostCalcProcessor,\ com.tencent.supersonic.chat.server.processor.parse.TimeCostCalcProcessor,\
com.tencent.supersonic.chat.server.processor.parse.ErrorMsgRewriteProcessor,\ com.tencent.supersonic.chat.server.processor.parse.ErrorMsgRewriteProcessor,\
com.tencent.supersonic.chat.server.processor.parse.ParseInfoFormatProcessor com.tencent.supersonic.chat.server.processor.parse.ParseInfoFormatProcessor,\
com.tencent.supersonic.chat.server.processor.parse.ParseInfoSortProcessor
com.tencent.supersonic.chat.server.processor.execute.ExecuteResultProcessor=\ com.tencent.supersonic.chat.server.processor.execute.ExecuteResultProcessor=\
com.tencent.supersonic.chat.server.processor.execute.MetricRecommendProcessor,\ com.tencent.supersonic.chat.server.processor.execute.MetricRecommendProcessor,\