mirror of
https://github.com/tencentmusic/supersonic.git
synced 2025-12-10 11:07:06 +00:00
[improvement][chat]Support user feedback to multiple candidate semantic parses.#1847
This commit is contained in:
@@ -1,6 +1,7 @@
|
||||
package com.tencent.supersonic.chat.api.pojo.request;
|
||||
|
||||
import com.tencent.supersonic.common.pojo.User;
|
||||
import com.tencent.supersonic.headless.api.pojo.SemanticParseInfo;
|
||||
import com.tencent.supersonic.headless.api.pojo.request.QueryFilters;
|
||||
import lombok.AllArgsConstructor;
|
||||
import lombok.Builder;
|
||||
@@ -20,5 +21,5 @@ public class ChatParseReq {
|
||||
private boolean saveAnswer = true;
|
||||
private boolean disableLLM = false;
|
||||
private Long queryId;
|
||||
private Integer parseId;
|
||||
private SemanticParseInfo selectedParse;
|
||||
}
|
||||
|
||||
@@ -76,7 +76,7 @@ public class NL2SQLParser implements ChatQueryParser {
|
||||
if (!parseContext.enableNL2SQL() || Objects.isNull(parseContext.getAgent())) {
|
||||
return;
|
||||
}
|
||||
if (parseContext.enableFeedback()) {
|
||||
if (parseContext.needFeedback()) {
|
||||
processFeedback(parseContext);
|
||||
return;
|
||||
}
|
||||
|
||||
@@ -4,7 +4,6 @@ import com.tencent.supersonic.chat.api.pojo.request.ChatParseReq;
|
||||
import com.tencent.supersonic.chat.api.pojo.response.ChatParseResp;
|
||||
import com.tencent.supersonic.chat.server.agent.Agent;
|
||||
import com.tencent.supersonic.headless.api.pojo.SemanticParseInfo;
|
||||
import com.tencent.supersonic.headless.api.pojo.response.ParseResp;
|
||||
import lombok.Data;
|
||||
|
||||
import java.util.Objects;
|
||||
@@ -14,21 +13,21 @@ public class ParseContext {
|
||||
private ChatParseReq request;
|
||||
private ChatParseResp response;
|
||||
private Agent agent;
|
||||
private SemanticParseInfo selectedParseInfo;
|
||||
|
||||
public ParseContext(ChatParseReq request) {
|
||||
public ParseContext(ChatParseReq request, ChatParseResp response) {
|
||||
this.request = request;
|
||||
this.response = response;
|
||||
}
|
||||
|
||||
public boolean enableNL2SQL() {
|
||||
return agent.containsDatasetTool();
|
||||
}
|
||||
|
||||
public boolean enableFeedback() {
|
||||
return agent.enableFeedback() && Objects.isNull(request.getParseId());
|
||||
public boolean needFeedback() {
|
||||
return agent.enableFeedback() && Objects.isNull(request.getSelectedParse());
|
||||
}
|
||||
|
||||
public boolean enableLLM() {
|
||||
return !(enableFeedback() || request.isDisableLLM());
|
||||
return !(needFeedback() || request.isDisableLLM());
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,6 +1,11 @@
|
||||
package com.tencent.supersonic.chat.server.processor.parse;
|
||||
|
||||
import com.google.common.collect.Lists;
|
||||
import com.google.common.collect.Sets;
|
||||
import com.tencent.supersonic.chat.server.pojo.ParseContext;
|
||||
import com.tencent.supersonic.headless.api.pojo.SemanticParseInfo;
|
||||
|
||||
import java.util.*;
|
||||
|
||||
/**
|
||||
* ParseInfoSortProcessor sorts candidate parse info based on certain algorithm. \
|
||||
@@ -9,6 +14,17 @@ public class ParseInfoSortProcessor implements ParseResultProcessor {
|
||||
|
||||
@Override
|
||||
public void process(ParseContext parseContext) {
|
||||
Set<String> parseInfoText = Sets.newHashSet();
|
||||
List<SemanticParseInfo> sortedParseInfo = Lists.newArrayList();
|
||||
|
||||
parseContext.getResponse().getSelectedParses().forEach(p -> {
|
||||
if (!parseInfoText.contains(p.getTextInfo())) {
|
||||
sortedParseInfo.add(p);
|
||||
parseInfoText.add(p.getTextInfo());
|
||||
}
|
||||
});
|
||||
|
||||
Collections.sort(sortedParseInfo, (o1, o2) -> o1.getScore() - o2.getScore() >= 0 ? 1 : 0);
|
||||
parseContext.getResponse().setSelectedParses(sortedParseInfo);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -37,7 +37,6 @@ import com.tencent.supersonic.headless.api.pojo.request.DimensionValueReq;
|
||||
import com.tencent.supersonic.headless.api.pojo.request.QueryFilter;
|
||||
import com.tencent.supersonic.headless.api.pojo.request.QueryNLReq;
|
||||
import com.tencent.supersonic.headless.api.pojo.request.SemanticQueryReq;
|
||||
import com.tencent.supersonic.headless.api.pojo.response.ParseResp;
|
||||
import com.tencent.supersonic.headless.api.pojo.response.QueryState;
|
||||
import com.tencent.supersonic.headless.api.pojo.response.SearchResult;
|
||||
import com.tencent.supersonic.headless.api.pojo.response.SemanticQueryResp;
|
||||
@@ -95,7 +94,7 @@ public class ChatQueryServiceImpl implements ChatQueryService {
|
||||
|
||||
@Override
|
||||
public List<SearchResult> search(ChatParseReq chatParseReq) {
|
||||
ParseContext parseContext = buildParseContext(chatParseReq);
|
||||
ParseContext parseContext = buildParseContext(chatParseReq, null);
|
||||
Agent agent = parseContext.getAgent();
|
||||
if (!agent.enableSearch()) {
|
||||
return Lists.newArrayList();
|
||||
@@ -106,20 +105,20 @@ public class ChatQueryServiceImpl implements ChatQueryService {
|
||||
|
||||
@Override
|
||||
public ChatParseResp parse(ChatParseReq chatParseReq) {
|
||||
ParseContext parseContext = buildParseContext(chatParseReq);
|
||||
Long queryId = chatManageService.createChatQuery(chatParseReq);
|
||||
parseContext.setResponse(new ChatParseResp(queryId));
|
||||
|
||||
for (ChatQueryParser parser : chatQueryParsers) {
|
||||
parser.parse(parseContext);
|
||||
}
|
||||
for (ParseResultProcessor processor : parseResultProcessors) {
|
||||
processor.process(parseContext);
|
||||
Long queryId = chatParseReq.getQueryId();
|
||||
if (Objects.isNull(queryId)) {
|
||||
queryId = chatManageService.createChatQuery(chatParseReq);
|
||||
}
|
||||
|
||||
ParseContext parseContext = buildParseContext(chatParseReq, new ChatParseResp(queryId));
|
||||
chatQueryParsers.forEach(p -> p.parse(parseContext));
|
||||
parseResultProcessors.forEach(p -> p.process(parseContext));
|
||||
|
||||
if (!parseContext.needFeedback()) {
|
||||
chatManageService.batchAddParse(chatParseReq, parseContext.getResponse());
|
||||
chatManageService.updateParseCostTime(parseContext.getResponse());
|
||||
}
|
||||
|
||||
chatParseReq.setQueryText(parseContext.getRequest().getQueryText());
|
||||
chatManageService.batchAddParse(chatParseReq, parseContext.getResponse());
|
||||
chatManageService.updateParseCostTime(parseContext.getResponse());
|
||||
return parseContext.getResponse();
|
||||
}
|
||||
|
||||
@@ -164,16 +163,10 @@ public class ChatQueryServiceImpl implements ChatQueryService {
|
||||
return execute(executeReq);
|
||||
}
|
||||
|
||||
private ParseContext buildParseContext(ChatParseReq chatParseReq) {
|
||||
ParseContext parseContext = new ParseContext(chatParseReq);
|
||||
private ParseContext buildParseContext(ChatParseReq chatParseReq, ChatParseResp chatParseResp) {
|
||||
ParseContext parseContext = new ParseContext(chatParseReq, chatParseResp);
|
||||
Agent agent = agentService.getAgent(chatParseReq.getAgentId());
|
||||
parseContext.setAgent(agent);
|
||||
if (Objects.nonNull(chatParseReq.getQueryId())
|
||||
&& Objects.nonNull(chatParseReq.getParseId())) {
|
||||
SemanticParseInfo parseInfo = chatManageService.getParseInfo(chatParseReq.getQueryId(),
|
||||
chatParseReq.getParseId());
|
||||
parseContext.setSelectedParseInfo(parseInfo);
|
||||
}
|
||||
return parseContext;
|
||||
}
|
||||
|
||||
|
||||
@@ -14,7 +14,7 @@ public class QueryReqConverter {
|
||||
parseContext.enableLLM() ? Text2SQLType.RULE_AND_LLM : Text2SQLType.ONLY_RULE);
|
||||
queryNLReq.setDataSetIds(parseContext.getAgent().getDataSetIds());
|
||||
queryNLReq.setChatAppConfig(parseContext.getAgent().getChatAppConfig());
|
||||
queryNLReq.setSelectedParseInfo(parseContext.getSelectedParseInfo());
|
||||
queryNLReq.setSelectedParseInfo(parseContext.getRequest().getSelectedParse());
|
||||
|
||||
return queryNLReq;
|
||||
}
|
||||
|
||||
@@ -3,6 +3,7 @@ package com.tencent.supersonic.headless.core.cache;
|
||||
import com.tencent.supersonic.common.util.ContextUtils;
|
||||
import com.tencent.supersonic.headless.api.pojo.request.SemanticQueryReq;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.apache.commons.lang3.StringUtils;
|
||||
import org.springframework.stereotype.Component;
|
||||
|
||||
import java.util.List;
|
||||
@@ -18,7 +19,10 @@ public class DefaultQueryCache implements QueryCache {
|
||||
CacheManager cacheManager = ContextUtils.getBean(CacheManager.class);
|
||||
if (isCache(semanticQueryReq)) {
|
||||
Object result = cacheManager.get(cacheKey);
|
||||
log.info("query from cache, key:{},result:{}", cacheKey, result);
|
||||
if (Objects.nonNull(result)) {
|
||||
log.info("query from cache, key:{},result:{}", cacheKey,
|
||||
StringUtils.normalizeSpace(result.toString()));
|
||||
}
|
||||
return result;
|
||||
}
|
||||
return null;
|
||||
|
||||
@@ -131,7 +131,10 @@ public class S2SemanticLayerService implements SemanticLayerService {
|
||||
|
||||
String cacheKey = queryCache.getCacheKey(queryReq);
|
||||
Object query = queryCache.query(queryReq, cacheKey);
|
||||
log.info("cacheKey:{},query:{}", cacheKey, query);
|
||||
if (Objects.nonNull(query)) {
|
||||
log.info("cacheKey:{},query:{}", cacheKey,
|
||||
StringUtils.normalizeSpace(query.toString()));
|
||||
}
|
||||
if (Objects.nonNull(query)) {
|
||||
SemanticQueryResp queryResp = (SemanticQueryResp) query;
|
||||
queryResp.setUseCache(true);
|
||||
|
||||
@@ -69,7 +69,8 @@ com.tencent.supersonic.chat.server.processor.parse.ParseResultProcessor=\
|
||||
com.tencent.supersonic.chat.server.processor.parse.QueryRecommendProcessor,\
|
||||
com.tencent.supersonic.chat.server.processor.parse.TimeCostCalcProcessor,\
|
||||
com.tencent.supersonic.chat.server.processor.parse.ErrorMsgRewriteProcessor,\
|
||||
com.tencent.supersonic.chat.server.processor.parse.ParseInfoFormatProcessor
|
||||
com.tencent.supersonic.chat.server.processor.parse.ParseInfoFormatProcessor,\
|
||||
com.tencent.supersonic.chat.server.processor.parse.ParseInfoSortProcessor
|
||||
|
||||
com.tencent.supersonic.chat.server.processor.execute.ExecuteResultProcessor=\
|
||||
com.tencent.supersonic.chat.server.processor.execute.MetricRecommendProcessor,\
|
||||
|
||||
Reference in New Issue
Block a user