mirror of
https://github.com/tencentmusic/supersonic.git
synced 2025-12-11 03:58:14 +00:00
[improvement][chat]Support user feedback to multiple candidate semantic parses.#1847
This commit is contained in:
@@ -1,6 +1,7 @@
|
|||||||
package com.tencent.supersonic.chat.api.pojo.request;
|
package com.tencent.supersonic.chat.api.pojo.request;
|
||||||
|
|
||||||
import com.tencent.supersonic.common.pojo.User;
|
import com.tencent.supersonic.common.pojo.User;
|
||||||
|
import com.tencent.supersonic.headless.api.pojo.SemanticParseInfo;
|
||||||
import com.tencent.supersonic.headless.api.pojo.request.QueryFilters;
|
import com.tencent.supersonic.headless.api.pojo.request.QueryFilters;
|
||||||
import lombok.AllArgsConstructor;
|
import lombok.AllArgsConstructor;
|
||||||
import lombok.Builder;
|
import lombok.Builder;
|
||||||
@@ -20,5 +21,5 @@ public class ChatParseReq {
|
|||||||
private boolean saveAnswer = true;
|
private boolean saveAnswer = true;
|
||||||
private boolean disableLLM = false;
|
private boolean disableLLM = false;
|
||||||
private Long queryId;
|
private Long queryId;
|
||||||
private Integer parseId;
|
private SemanticParseInfo selectedParse;
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -76,7 +76,7 @@ public class NL2SQLParser implements ChatQueryParser {
|
|||||||
if (!parseContext.enableNL2SQL() || Objects.isNull(parseContext.getAgent())) {
|
if (!parseContext.enableNL2SQL() || Objects.isNull(parseContext.getAgent())) {
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
if (parseContext.enableFeedback()) {
|
if (parseContext.needFeedback()) {
|
||||||
processFeedback(parseContext);
|
processFeedback(parseContext);
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -4,7 +4,6 @@ import com.tencent.supersonic.chat.api.pojo.request.ChatParseReq;
|
|||||||
import com.tencent.supersonic.chat.api.pojo.response.ChatParseResp;
|
import com.tencent.supersonic.chat.api.pojo.response.ChatParseResp;
|
||||||
import com.tencent.supersonic.chat.server.agent.Agent;
|
import com.tencent.supersonic.chat.server.agent.Agent;
|
||||||
import com.tencent.supersonic.headless.api.pojo.SemanticParseInfo;
|
import com.tencent.supersonic.headless.api.pojo.SemanticParseInfo;
|
||||||
import com.tencent.supersonic.headless.api.pojo.response.ParseResp;
|
|
||||||
import lombok.Data;
|
import lombok.Data;
|
||||||
|
|
||||||
import java.util.Objects;
|
import java.util.Objects;
|
||||||
@@ -14,21 +13,21 @@ public class ParseContext {
|
|||||||
private ChatParseReq request;
|
private ChatParseReq request;
|
||||||
private ChatParseResp response;
|
private ChatParseResp response;
|
||||||
private Agent agent;
|
private Agent agent;
|
||||||
private SemanticParseInfo selectedParseInfo;
|
|
||||||
|
|
||||||
public ParseContext(ChatParseReq request) {
|
public ParseContext(ChatParseReq request, ChatParseResp response) {
|
||||||
this.request = request;
|
this.request = request;
|
||||||
|
this.response = response;
|
||||||
}
|
}
|
||||||
|
|
||||||
public boolean enableNL2SQL() {
|
public boolean enableNL2SQL() {
|
||||||
return agent.containsDatasetTool();
|
return agent.containsDatasetTool();
|
||||||
}
|
}
|
||||||
|
|
||||||
public boolean enableFeedback() {
|
public boolean needFeedback() {
|
||||||
return agent.enableFeedback() && Objects.isNull(request.getParseId());
|
return agent.enableFeedback() && Objects.isNull(request.getSelectedParse());
|
||||||
}
|
}
|
||||||
|
|
||||||
public boolean enableLLM() {
|
public boolean enableLLM() {
|
||||||
return !(enableFeedback() || request.isDisableLLM());
|
return !(needFeedback() || request.isDisableLLM());
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,6 +1,11 @@
|
|||||||
package com.tencent.supersonic.chat.server.processor.parse;
|
package com.tencent.supersonic.chat.server.processor.parse;
|
||||||
|
|
||||||
|
import com.google.common.collect.Lists;
|
||||||
|
import com.google.common.collect.Sets;
|
||||||
import com.tencent.supersonic.chat.server.pojo.ParseContext;
|
import com.tencent.supersonic.chat.server.pojo.ParseContext;
|
||||||
|
import com.tencent.supersonic.headless.api.pojo.SemanticParseInfo;
|
||||||
|
|
||||||
|
import java.util.*;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* ParseInfoSortProcessor sorts candidate parse info based on certain algorithm. \
|
* ParseInfoSortProcessor sorts candidate parse info based on certain algorithm. \
|
||||||
@@ -9,6 +14,17 @@ public class ParseInfoSortProcessor implements ParseResultProcessor {
|
|||||||
|
|
||||||
@Override
|
@Override
|
||||||
public void process(ParseContext parseContext) {
|
public void process(ParseContext parseContext) {
|
||||||
|
Set<String> parseInfoText = Sets.newHashSet();
|
||||||
|
List<SemanticParseInfo> sortedParseInfo = Lists.newArrayList();
|
||||||
|
|
||||||
|
parseContext.getResponse().getSelectedParses().forEach(p -> {
|
||||||
|
if (!parseInfoText.contains(p.getTextInfo())) {
|
||||||
|
sortedParseInfo.add(p);
|
||||||
|
parseInfoText.add(p.getTextInfo());
|
||||||
|
}
|
||||||
|
});
|
||||||
|
|
||||||
|
Collections.sort(sortedParseInfo, (o1, o2) -> o1.getScore() - o2.getScore() >= 0 ? 1 : 0);
|
||||||
|
parseContext.getResponse().setSelectedParses(sortedParseInfo);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -37,7 +37,6 @@ import com.tencent.supersonic.headless.api.pojo.request.DimensionValueReq;
|
|||||||
import com.tencent.supersonic.headless.api.pojo.request.QueryFilter;
|
import com.tencent.supersonic.headless.api.pojo.request.QueryFilter;
|
||||||
import com.tencent.supersonic.headless.api.pojo.request.QueryNLReq;
|
import com.tencent.supersonic.headless.api.pojo.request.QueryNLReq;
|
||||||
import com.tencent.supersonic.headless.api.pojo.request.SemanticQueryReq;
|
import com.tencent.supersonic.headless.api.pojo.request.SemanticQueryReq;
|
||||||
import com.tencent.supersonic.headless.api.pojo.response.ParseResp;
|
|
||||||
import com.tencent.supersonic.headless.api.pojo.response.QueryState;
|
import com.tencent.supersonic.headless.api.pojo.response.QueryState;
|
||||||
import com.tencent.supersonic.headless.api.pojo.response.SearchResult;
|
import com.tencent.supersonic.headless.api.pojo.response.SearchResult;
|
||||||
import com.tencent.supersonic.headless.api.pojo.response.SemanticQueryResp;
|
import com.tencent.supersonic.headless.api.pojo.response.SemanticQueryResp;
|
||||||
@@ -95,7 +94,7 @@ public class ChatQueryServiceImpl implements ChatQueryService {
|
|||||||
|
|
||||||
@Override
|
@Override
|
||||||
public List<SearchResult> search(ChatParseReq chatParseReq) {
|
public List<SearchResult> search(ChatParseReq chatParseReq) {
|
||||||
ParseContext parseContext = buildParseContext(chatParseReq);
|
ParseContext parseContext = buildParseContext(chatParseReq, null);
|
||||||
Agent agent = parseContext.getAgent();
|
Agent agent = parseContext.getAgent();
|
||||||
if (!agent.enableSearch()) {
|
if (!agent.enableSearch()) {
|
||||||
return Lists.newArrayList();
|
return Lists.newArrayList();
|
||||||
@@ -106,20 +105,20 @@ public class ChatQueryServiceImpl implements ChatQueryService {
|
|||||||
|
|
||||||
@Override
|
@Override
|
||||||
public ChatParseResp parse(ChatParseReq chatParseReq) {
|
public ChatParseResp parse(ChatParseReq chatParseReq) {
|
||||||
ParseContext parseContext = buildParseContext(chatParseReq);
|
Long queryId = chatParseReq.getQueryId();
|
||||||
Long queryId = chatManageService.createChatQuery(chatParseReq);
|
if (Objects.isNull(queryId)) {
|
||||||
parseContext.setResponse(new ChatParseResp(queryId));
|
queryId = chatManageService.createChatQuery(chatParseReq);
|
||||||
|
}
|
||||||
for (ChatQueryParser parser : chatQueryParsers) {
|
|
||||||
parser.parse(parseContext);
|
ParseContext parseContext = buildParseContext(chatParseReq, new ChatParseResp(queryId));
|
||||||
}
|
chatQueryParsers.forEach(p -> p.parse(parseContext));
|
||||||
for (ParseResultProcessor processor : parseResultProcessors) {
|
parseResultProcessors.forEach(p -> p.process(parseContext));
|
||||||
processor.process(parseContext);
|
|
||||||
|
if (!parseContext.needFeedback()) {
|
||||||
|
chatManageService.batchAddParse(chatParseReq, parseContext.getResponse());
|
||||||
|
chatManageService.updateParseCostTime(parseContext.getResponse());
|
||||||
}
|
}
|
||||||
|
|
||||||
chatParseReq.setQueryText(parseContext.getRequest().getQueryText());
|
|
||||||
chatManageService.batchAddParse(chatParseReq, parseContext.getResponse());
|
|
||||||
chatManageService.updateParseCostTime(parseContext.getResponse());
|
|
||||||
return parseContext.getResponse();
|
return parseContext.getResponse();
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -164,16 +163,10 @@ public class ChatQueryServiceImpl implements ChatQueryService {
|
|||||||
return execute(executeReq);
|
return execute(executeReq);
|
||||||
}
|
}
|
||||||
|
|
||||||
private ParseContext buildParseContext(ChatParseReq chatParseReq) {
|
private ParseContext buildParseContext(ChatParseReq chatParseReq, ChatParseResp chatParseResp) {
|
||||||
ParseContext parseContext = new ParseContext(chatParseReq);
|
ParseContext parseContext = new ParseContext(chatParseReq, chatParseResp);
|
||||||
Agent agent = agentService.getAgent(chatParseReq.getAgentId());
|
Agent agent = agentService.getAgent(chatParseReq.getAgentId());
|
||||||
parseContext.setAgent(agent);
|
parseContext.setAgent(agent);
|
||||||
if (Objects.nonNull(chatParseReq.getQueryId())
|
|
||||||
&& Objects.nonNull(chatParseReq.getParseId())) {
|
|
||||||
SemanticParseInfo parseInfo = chatManageService.getParseInfo(chatParseReq.getQueryId(),
|
|
||||||
chatParseReq.getParseId());
|
|
||||||
parseContext.setSelectedParseInfo(parseInfo);
|
|
||||||
}
|
|
||||||
return parseContext;
|
return parseContext;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -14,7 +14,7 @@ public class QueryReqConverter {
|
|||||||
parseContext.enableLLM() ? Text2SQLType.RULE_AND_LLM : Text2SQLType.ONLY_RULE);
|
parseContext.enableLLM() ? Text2SQLType.RULE_AND_LLM : Text2SQLType.ONLY_RULE);
|
||||||
queryNLReq.setDataSetIds(parseContext.getAgent().getDataSetIds());
|
queryNLReq.setDataSetIds(parseContext.getAgent().getDataSetIds());
|
||||||
queryNLReq.setChatAppConfig(parseContext.getAgent().getChatAppConfig());
|
queryNLReq.setChatAppConfig(parseContext.getAgent().getChatAppConfig());
|
||||||
queryNLReq.setSelectedParseInfo(parseContext.getSelectedParseInfo());
|
queryNLReq.setSelectedParseInfo(parseContext.getRequest().getSelectedParse());
|
||||||
|
|
||||||
return queryNLReq;
|
return queryNLReq;
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -3,6 +3,7 @@ package com.tencent.supersonic.headless.core.cache;
|
|||||||
import com.tencent.supersonic.common.util.ContextUtils;
|
import com.tencent.supersonic.common.util.ContextUtils;
|
||||||
import com.tencent.supersonic.headless.api.pojo.request.SemanticQueryReq;
|
import com.tencent.supersonic.headless.api.pojo.request.SemanticQueryReq;
|
||||||
import lombok.extern.slf4j.Slf4j;
|
import lombok.extern.slf4j.Slf4j;
|
||||||
|
import org.apache.commons.lang3.StringUtils;
|
||||||
import org.springframework.stereotype.Component;
|
import org.springframework.stereotype.Component;
|
||||||
|
|
||||||
import java.util.List;
|
import java.util.List;
|
||||||
@@ -18,7 +19,10 @@ public class DefaultQueryCache implements QueryCache {
|
|||||||
CacheManager cacheManager = ContextUtils.getBean(CacheManager.class);
|
CacheManager cacheManager = ContextUtils.getBean(CacheManager.class);
|
||||||
if (isCache(semanticQueryReq)) {
|
if (isCache(semanticQueryReq)) {
|
||||||
Object result = cacheManager.get(cacheKey);
|
Object result = cacheManager.get(cacheKey);
|
||||||
log.info("query from cache, key:{},result:{}", cacheKey, result);
|
if (Objects.nonNull(result)) {
|
||||||
|
log.info("query from cache, key:{},result:{}", cacheKey,
|
||||||
|
StringUtils.normalizeSpace(result.toString()));
|
||||||
|
}
|
||||||
return result;
|
return result;
|
||||||
}
|
}
|
||||||
return null;
|
return null;
|
||||||
|
|||||||
@@ -131,7 +131,10 @@ public class S2SemanticLayerService implements SemanticLayerService {
|
|||||||
|
|
||||||
String cacheKey = queryCache.getCacheKey(queryReq);
|
String cacheKey = queryCache.getCacheKey(queryReq);
|
||||||
Object query = queryCache.query(queryReq, cacheKey);
|
Object query = queryCache.query(queryReq, cacheKey);
|
||||||
log.info("cacheKey:{},query:{}", cacheKey, query);
|
if (Objects.nonNull(query)) {
|
||||||
|
log.info("cacheKey:{},query:{}", cacheKey,
|
||||||
|
StringUtils.normalizeSpace(query.toString()));
|
||||||
|
}
|
||||||
if (Objects.nonNull(query)) {
|
if (Objects.nonNull(query)) {
|
||||||
SemanticQueryResp queryResp = (SemanticQueryResp) query;
|
SemanticQueryResp queryResp = (SemanticQueryResp) query;
|
||||||
queryResp.setUseCache(true);
|
queryResp.setUseCache(true);
|
||||||
|
|||||||
@@ -69,7 +69,8 @@ com.tencent.supersonic.chat.server.processor.parse.ParseResultProcessor=\
|
|||||||
com.tencent.supersonic.chat.server.processor.parse.QueryRecommendProcessor,\
|
com.tencent.supersonic.chat.server.processor.parse.QueryRecommendProcessor,\
|
||||||
com.tencent.supersonic.chat.server.processor.parse.TimeCostCalcProcessor,\
|
com.tencent.supersonic.chat.server.processor.parse.TimeCostCalcProcessor,\
|
||||||
com.tencent.supersonic.chat.server.processor.parse.ErrorMsgRewriteProcessor,\
|
com.tencent.supersonic.chat.server.processor.parse.ErrorMsgRewriteProcessor,\
|
||||||
com.tencent.supersonic.chat.server.processor.parse.ParseInfoFormatProcessor
|
com.tencent.supersonic.chat.server.processor.parse.ParseInfoFormatProcessor,\
|
||||||
|
com.tencent.supersonic.chat.server.processor.parse.ParseInfoSortProcessor
|
||||||
|
|
||||||
com.tencent.supersonic.chat.server.processor.execute.ExecuteResultProcessor=\
|
com.tencent.supersonic.chat.server.processor.execute.ExecuteResultProcessor=\
|
||||||
com.tencent.supersonic.chat.server.processor.execute.MetricRecommendProcessor,\
|
com.tencent.supersonic.chat.server.processor.execute.MetricRecommendProcessor,\
|
||||||
|
|||||||
Reference in New Issue
Block a user