[improvement][chat]Support user feedback to semantic parse info.#1729

This commit is contained in:
jerryjzhang
2024-10-28 02:07:54 +08:00
parent c2785139f2
commit eb28d832bc
11 changed files with 85 additions and 54 deletions

View File

@@ -1,6 +1,5 @@
package com.tencent.supersonic.chat.server.parser;
import com.tencent.supersonic.chat.api.pojo.request.ChatParseReq;
import com.tencent.supersonic.chat.api.pojo.response.QueryResp;
import com.tencent.supersonic.chat.server.pojo.ChatContext;
import com.tencent.supersonic.chat.server.pojo.ParseContext;
@@ -17,10 +16,12 @@ import com.tencent.supersonic.common.util.ContextUtils;
import com.tencent.supersonic.headless.api.pojo.SchemaElementMatch;
import com.tencent.supersonic.headless.api.pojo.SchemaElementType;
import com.tencent.supersonic.headless.api.pojo.SemanticParseInfo;
import com.tencent.supersonic.headless.api.pojo.enums.MapModeEnum;
import com.tencent.supersonic.headless.api.pojo.request.QueryNLReq;
import com.tencent.supersonic.headless.api.pojo.response.MapResp;
import com.tencent.supersonic.headless.api.pojo.response.ParseResp;
import com.tencent.supersonic.headless.api.pojo.response.QueryState;
import com.tencent.supersonic.headless.chat.parser.ParserConfig;
import com.tencent.supersonic.headless.server.facade.service.ChatLayerService;
import com.tencent.supersonic.headless.server.utils.ModelConfigHelper;
import dev.langchain4j.data.message.AiMessage;
@@ -74,37 +75,48 @@ public class NL2SQLParser implements ChatQueryParser {
if (!parseContext.enableNL2SQL() || Objects.isNull(parseContext.getAgent())) {
return;
}
QueryNLReq queryNLReq = QueryReqConverter.buildQueryNLReq(parseContext);
if (Objects.isNull(queryNLReq)) {
if (parseContext.enableFeedback()) {
processFeedback(parseContext);
return;
}
ParseResp parseResp = parseContext.getResponse();
ChatParseReq parseReq = parseContext.getRequest();
if (!parseContext.getRequest().isDisableLLM() && queryNLReq.getText2SQLType().enableLLM()) {
processMultiTurn(parseContext);
addDynamicExemplars(parseContext.getAgent().getId(), queryNLReq);
parseResp.setUsedExemplars(queryNLReq.getDynamicExemplars());
}
QueryNLReq queryNLReq = QueryReqConverter.buildQueryNLReq(parseContext);
ChatContextService chatContextService = ContextUtils.getBean(ChatContextService.class);
ChatContext chatCtx = chatContextService.getOrCreateContext(parseReq.getChatId());
if (chatCtx != null) {
ChatContext chatCtx =
chatContextService.getOrCreateContext(parseContext.getRequest().getChatId());
if (chatCtx != null && Objects.isNull(queryNLReq.getContextParseInfo())) {
queryNLReq.setContextParseInfo(chatCtx.getParseInfo());
}
ChatLayerService chatLayerService = ContextUtils.getBean(ChatLayerService.class);
ParseResp text2SqlParseResp = chatLayerService.parse(queryNLReq);
if (ParseResp.ParseState.COMPLETED.equals(text2SqlParseResp.getState())) {
parseResp.getSelectedParses().addAll(text2SqlParseResp.getSelectedParses());
if (parseContext.enableLLM()) {
rewriteMultiTurn(parseContext, queryNLReq);
addDynamicExemplars(parseContext, queryNLReq);
}
parseResp.setErrorMsg(text2SqlParseResp.getErrorMsg());
parseResp.setState(text2SqlParseResp.getState());
parseResp.getParseTimeCost().setSqlTime(text2SqlParseResp.getParseTimeCost().getSqlTime());
parseResp.setErrorMsg(text2SqlParseResp.getErrorMsg());
ParseResp parseResp = parseContext.getResponse();
doParse(queryNLReq, parseResp);
}
private void processMultiTurn(ParseContext parseContext) {
private void processFeedback(ParseContext parseContext) {
QueryNLReq queryNLReq = QueryReqConverter.buildQueryNLReq(parseContext);
ParseResp parseResp = parseContext.getResponse();
for (MapModeEnum mode : MapModeEnum.values()) {
queryNLReq.setMapModeEnum(mode);
doParse(queryNLReq, parseResp);
}
}
private void doParse(QueryNLReq req, ParseResp resp) {
ChatLayerService chatLayerService = ContextUtils.getBean(ChatLayerService.class);
ParseResp text2SqlParseResp = chatLayerService.parse(req);
if (text2SqlParseResp.getState().equals(ParseResp.ParseState.COMPLETED)) {
resp.getSelectedParses().addAll(text2SqlParseResp.getSelectedParses());
}
resp.setState(text2SqlParseResp.getState());
resp.setParseTimeCost(text2SqlParseResp.getParseTimeCost());
resp.setErrorMsg(text2SqlParseResp.getErrorMsg());
}
private void rewriteMultiTurn(ParseContext parseContext, QueryNLReq queryNLReq) {
ChatApp chatApp = parseContext.getAgent().getChatAppConfig().get(APP_KEY_MULTI_TURN);
if (Objects.isNull(chatApp) || !chatApp.isEnable()) {
return;
@@ -112,7 +124,6 @@ public class NL2SQLParser implements ChatQueryParser {
// derive mapping result of current question and parsing result of last question.
ChatLayerService chatLayerService = ContextUtils.getBean(ChatLayerService.class);
QueryNLReq queryNLReq = QueryReqConverter.buildQueryNLReq(parseContext);
MapResp currentMapResult = chatLayerService.map(queryNLReq);
List<QueryResp> historyQueries =
@@ -143,6 +154,7 @@ public class NL2SQLParser implements ChatQueryParser {
String rewrittenQuery = response.content().text();
keyPipelineLog.info("QueryRewrite modelReq:\n{} \nmodelResp:\n{}", prompt.text(), response);
parseContext.getRequest().setQueryText(rewrittenQuery);
queryNLReq.setQueryText(rewrittenQuery);
log.info("Last Query: {} Current Query: {}, Rewritten Query: {}", lastQuery.getQueryText(),
currentMapResult.getQueryText(), rewrittenQuery);
}
@@ -185,15 +197,17 @@ public class NL2SQLParser implements ChatQueryParser {
return contextualList;
}
private void addDynamicExemplars(Integer agentId, QueryNLReq queryNLReq) {
private void addDynamicExemplars(ParseContext parseContext, QueryNLReq queryNLReq) {
ExemplarServiceImpl exemplarManager = ContextUtils.getBean(ExemplarServiceImpl.class);
EmbeddingConfig embeddingConfig = ContextUtils.getBean(EmbeddingConfig.class);
String memoryCollectionName = embeddingConfig.getMemoryCollectionName(agentId);
String memoryCollectionName =
embeddingConfig.getMemoryCollectionName(parseContext.getAgent().getId());
ParserConfig parserConfig = ContextUtils.getBean(ParserConfig.class);
int exemplarRecallNumber =
Integer.parseInt(parserConfig.getParameterValue(PARSER_EXEMPLAR_RECALL_NUMBER));
List<Text2SQLExemplar> exemplars = exemplarManager.recallExemplars(memoryCollectionName,
queryNLReq.getQueryText(), exemplarRecallNumber);
queryNLReq.getDynamicExemplars().addAll(exemplars);
parseContext.getResponse().setUsedExemplars(exemplars);
}
}

View File

@@ -1,15 +0,0 @@
package com.tencent.supersonic.chat.server.parser;
import com.tencent.supersonic.common.config.ParameterConfig;
import com.tencent.supersonic.common.pojo.Parameter;
import lombok.extern.slf4j.Slf4j;
import org.springframework.stereotype.Service;
@Service("ChatQueryParserConfig")
@Slf4j
public class ParserConfig extends ParameterConfig {
public static final Parameter PARSER_MULTI_TURN_ENABLE =
new Parameter("s2.parser.multi-turn.enable", "false", "是否开启多轮对话", "开启多轮对话将消耗更多token",
"bool", "语义解析配置");
}

View File

@@ -2,14 +2,18 @@ package com.tencent.supersonic.chat.server.pojo;
import com.tencent.supersonic.chat.api.pojo.request.ChatParseReq;
import com.tencent.supersonic.chat.server.agent.Agent;
import com.tencent.supersonic.headless.api.pojo.SemanticParseInfo;
import com.tencent.supersonic.headless.api.pojo.response.ParseResp;
import lombok.Data;
import java.util.Objects;
@Data
public class ParseContext {
private ChatParseReq request;
private ParseResp response;
private Agent agent;
private SemanticParseInfo selectedParseInfo;
public ParseContext(ChatParseReq request) {
this.request = request;
@@ -17,9 +21,14 @@ public class ParseContext {
}
public boolean enableNL2SQL() {
if (agent == null) {
return false;
}
return agent.containsDatasetTool();
}
public boolean enableFeedback() {
return agent.enableFeedback() && Objects.isNull(request.getParseId());
}
public boolean enableLLM() {
return !(enableFeedback() || request.isDisableLLM());
}
}

View File

@@ -168,6 +168,12 @@ public class ChatQueryServiceImpl implements ChatQueryService {
ParseContext parseContext = new ParseContext(chatParseReq);
Agent agent = agentService.getAgent(chatParseReq.getAgentId());
parseContext.setAgent(agent);
if (Objects.nonNull(chatParseReq.getQueryId())
&& Objects.nonNull(chatParseReq.getParseId())) {
SemanticParseInfo parseInfo = chatManageService.getParseInfo(chatParseReq.getQueryId(),
chatParseReq.getParseId());
parseContext.setSelectedParseInfo(parseInfo);
}
return parseContext;
}

View File

@@ -8,16 +8,13 @@ import com.tencent.supersonic.headless.api.pojo.request.QueryNLReq;
public class QueryReqConverter {
public static QueryNLReq buildQueryNLReq(ParseContext parseContext) {
if (parseContext.getAgent() == null) {
return null;
}
QueryNLReq queryNLReq = new QueryNLReq();
BeanMapper.mapper(parseContext.getRequest(), queryNLReq);
queryNLReq.setText2SQLType(parseContext.getRequest().isDisableLLM() ? Text2SQLType.ONLY_RULE
: Text2SQLType.RULE_AND_LLM);
queryNLReq.setText2SQLType(
parseContext.enableLLM() ? Text2SQLType.RULE_AND_LLM : Text2SQLType.ONLY_RULE);
queryNLReq.setDataSetIds(parseContext.getAgent().getDataSetIds());
queryNLReq.setChatAppConfig(parseContext.getAgent().getChatAppConfig());
queryNLReq.setSelectedParseInfo(parseContext.getSelectedParseInfo());
return queryNLReq;
}