mirror of
https://github.com/tencentmusic/supersonic.git
synced 2025-12-11 12:07:42 +00:00
(improvement)(chat)Implement a new version of multi-turn conversation.
This commit is contained in:
@@ -0,0 +1,146 @@
|
||||
package com.tencent.supersonic.chat.server.parser;
|
||||
|
||||
import com.tencent.supersonic.chat.server.persistence.repository.ChatQueryRepository;
|
||||
import com.tencent.supersonic.chat.server.pojo.ChatParseContext;
|
||||
import com.tencent.supersonic.chat.server.util.QueryReqConverter;
|
||||
import com.tencent.supersonic.common.util.ContextUtils;
|
||||
import com.tencent.supersonic.headless.api.pojo.SchemaElementMatch;
|
||||
import com.tencent.supersonic.headless.api.pojo.SchemaElementType;
|
||||
import com.tencent.supersonic.headless.api.pojo.request.QueryReq;
|
||||
import com.tencent.supersonic.headless.api.pojo.response.MapResp;
|
||||
import com.tencent.supersonic.headless.api.pojo.response.ParseResp;
|
||||
import com.tencent.supersonic.headless.server.service.ChatQueryService;
|
||||
import dev.langchain4j.data.message.AiMessage;
|
||||
import dev.langchain4j.model.chat.ChatLanguageModel;
|
||||
import dev.langchain4j.model.input.Prompt;
|
||||
import dev.langchain4j.model.input.PromptTemplate;
|
||||
import dev.langchain4j.model.output.Response;
|
||||
import lombok.Builder;
|
||||
import lombok.Data;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.slf4j.Logger;
|
||||
import org.slf4j.LoggerFactory;
|
||||
import org.springframework.core.env.Environment;
|
||||
|
||||
import java.util.Map;
|
||||
import java.util.HashMap;
|
||||
import java.util.ArrayList;
|
||||
import java.util.List;
|
||||
import java.util.stream.Collectors;
|
||||
import java.util.Collections;
|
||||
|
||||
@Slf4j
|
||||
public class MultiTurnParser implements ChatParser {
|
||||
|
||||
private static final Logger keyPipelineLog = LoggerFactory.getLogger(MultiTurnParser.class);
|
||||
|
||||
private static final PromptTemplate promptTemplate = PromptTemplate.from(
|
||||
"You are a data product manager experienced in data requirements."
|
||||
+ "Your will be provided with current and history questions asked by a user,"
|
||||
+ "along with their mapped schema elements(metric, dimension and value), "
|
||||
+ "please try understanding the semantics and rewrite a question"
|
||||
+ "(keep relevant metrics, dimensions, values and date ranges)."
|
||||
+ "Current Question: {{curtQuestion}} "
|
||||
+ "Current Mapped Schema: {{curtSchema}} "
|
||||
+ "History Question: {{histQuestion}} "
|
||||
+ "History Mapped Schema: {{histSchema}} "
|
||||
+ "Rewritten Question: ");
|
||||
|
||||
@Override
|
||||
public void parse(ChatParseContext chatParseContext, ParseResp parseResp) {
|
||||
Environment environment = ContextUtils.getBean(Environment.class);
|
||||
Boolean multiTurn = environment.getProperty("multi.turn", Boolean.class);
|
||||
if (Boolean.FALSE.equals(multiTurn)) {
|
||||
return;
|
||||
}
|
||||
|
||||
// derive mapping result of current question and parsing result of last question.
|
||||
ChatQueryService chatQueryService = ContextUtils.getBean(ChatQueryService.class);
|
||||
QueryReq queryReq = QueryReqConverter.buildText2SqlQueryReq(chatParseContext);
|
||||
MapResp currentMapResult = chatQueryService.performMapping(queryReq);
|
||||
|
||||
List<ParseResp> historyParseResults = getHistoryParseResult(chatParseContext.getChatId(), 1);
|
||||
if (historyParseResults.size() == 0) {
|
||||
return;
|
||||
}
|
||||
ParseResp lastParseResult = historyParseResults.get(0);
|
||||
Long dataId = lastParseResult.getSelectedParses().get(0).getDataSetId();
|
||||
|
||||
String curtMapStr = generateSchemaPrompt(currentMapResult.getMapInfo().getMatchedElements(dataId));
|
||||
String histMapStr = generateSchemaPrompt(lastParseResult.getSelectedParses().get(0).getElementMatches());
|
||||
String rewrittenQuery = rewriteQuery(RewriteContext.builder()
|
||||
.curtQuestion(currentMapResult.getQueryText())
|
||||
.histQuestion(lastParseResult.getQueryText())
|
||||
.curtSchema(curtMapStr)
|
||||
.histSchema(histMapStr)
|
||||
.build());
|
||||
chatParseContext.setQueryText(rewrittenQuery);
|
||||
log.info("Last Query: {} Current Query: {}, Rewritten Query: {}",
|
||||
lastParseResult.getQueryText(), currentMapResult.getQueryText(), rewrittenQuery);
|
||||
}
|
||||
|
||||
private String rewriteQuery(RewriteContext context) {
|
||||
|
||||
Map<String, Object> variables = new HashMap<>();
|
||||
variables.put("curtQuestion", context.getCurtQuestion());
|
||||
variables.put("histQuestion", context.getHistQuestion());
|
||||
variables.put("curtSchema", context.getCurtSchema());
|
||||
variables.put("histSchema", context.getHistSchema());
|
||||
|
||||
Prompt prompt = promptTemplate.apply(variables);
|
||||
keyPipelineLog.info("request prompt:{}", prompt.toSystemMessage());
|
||||
ChatLanguageModel chatLanguageModel = ContextUtils.getBean(ChatLanguageModel.class);
|
||||
Response<AiMessage> response = chatLanguageModel.generate(prompt.toSystemMessage());
|
||||
String result = response.content().text();
|
||||
keyPipelineLog.info("model response:{}", result);
|
||||
//3.format response.
|
||||
String rewriteQuery = response.content().text();
|
||||
|
||||
return rewriteQuery;
|
||||
}
|
||||
|
||||
private String generateSchemaPrompt(List<SchemaElementMatch> elementMatches) {
|
||||
List<String> metrics = new ArrayList<>();
|
||||
List<String> dimensions = new ArrayList<>();
|
||||
List<String> values = new ArrayList<>();
|
||||
|
||||
for (SchemaElementMatch match : elementMatches) {
|
||||
if (match.getElement().getType().equals(SchemaElementType.METRIC)) {
|
||||
metrics.add(match.getWord());
|
||||
} else if (match.getElement().getType().equals(SchemaElementType.DIMENSION)) {
|
||||
dimensions.add(match.getWord());
|
||||
} else if (match.getElement().getType().equals(SchemaElementType.VALUE)) {
|
||||
values.add(match.getWord());
|
||||
}
|
||||
}
|
||||
|
||||
StringBuilder prompt = new StringBuilder();
|
||||
prompt.append(String.format("'metrics:':[%s]", String.join(",", metrics)));
|
||||
prompt.append(",");
|
||||
prompt.append(String.format("'dimensions:':[%s]", String.join(",", dimensions)));
|
||||
prompt.append(",");
|
||||
prompt.append(String.format("'values:':[%s]", String.join(",", values)));
|
||||
|
||||
return prompt.toString();
|
||||
}
|
||||
|
||||
private List<ParseResp> getHistoryParseResult(int chatId, int multiNum) {
|
||||
ChatQueryRepository chatQueryRepository = ContextUtils.getBean(ChatQueryRepository.class);
|
||||
List<ParseResp> contextualParseInfoList = chatQueryRepository.getContextualParseInfo(chatId)
|
||||
.stream().filter(p -> p.getState() != ParseResp.ParseState.FAILED).collect(Collectors.toList());
|
||||
|
||||
List<ParseResp> contextualList = contextualParseInfoList.subList(0,
|
||||
Math.min(multiNum, contextualParseInfoList.size()));
|
||||
Collections.reverse(contextualList);
|
||||
return contextualList;
|
||||
}
|
||||
|
||||
@Data
|
||||
@Builder
|
||||
public static class RewriteContext {
|
||||
private String curtQuestion;
|
||||
private String histQuestion;
|
||||
private String curtSchema;
|
||||
private String histSchema;
|
||||
}
|
||||
}
|
||||
@@ -1,47 +1,20 @@
|
||||
package com.tencent.supersonic.chat.server.parser;
|
||||
|
||||
import com.tencent.supersonic.chat.server.persistence.repository.ChatQueryRepository;
|
||||
import com.tencent.supersonic.chat.server.pojo.ChatParseContext;
|
||||
import com.tencent.supersonic.chat.server.util.QueryReqConverter;
|
||||
import com.tencent.supersonic.common.util.ContextUtils;
|
||||
import com.tencent.supersonic.common.util.JsonUtil;
|
||||
import com.tencent.supersonic.headless.api.pojo.SchemaMapInfo;
|
||||
import com.tencent.supersonic.headless.api.pojo.SemanticParseInfo;
|
||||
import com.tencent.supersonic.headless.api.pojo.request.QueryReq;
|
||||
import com.tencent.supersonic.headless.api.pojo.response.ParseResp;
|
||||
import com.tencent.supersonic.headless.core.chat.mapper.SchemaMapper;
|
||||
import com.tencent.supersonic.headless.core.chat.parser.llm.LLMRequestService;
|
||||
import com.tencent.supersonic.headless.core.chat.parser.llm.RewriteQueryGeneration;
|
||||
import com.tencent.supersonic.headless.core.chat.query.llm.s2sql.LLMReq;
|
||||
import com.tencent.supersonic.headless.core.chat.query.llm.s2sql.LLMSqlQuery;
|
||||
import com.tencent.supersonic.headless.core.pojo.QueryContext;
|
||||
import com.tencent.supersonic.headless.server.service.ChatQueryService;
|
||||
|
||||
import java.util.List;
|
||||
import java.util.Collections;
|
||||
import java.util.ArrayList;
|
||||
import java.util.Map;
|
||||
import java.util.Objects;
|
||||
|
||||
import java.util.stream.Collectors;
|
||||
|
||||
|
||||
import com.tencent.supersonic.headless.server.service.impl.ChatQueryServiceImpl;
|
||||
import com.tencent.supersonic.headless.server.utils.ComponentFactory;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.apache.commons.collections.MapUtils;
|
||||
import org.apache.commons.lang3.StringUtils;
|
||||
import org.springframework.core.env.Environment;
|
||||
|
||||
import static com.tencent.supersonic.common.pojo.Constants.CONTEXT;
|
||||
|
||||
@Slf4j
|
||||
public class NL2SQLParser implements ChatParser {
|
||||
|
||||
private int contextualNum = 5;
|
||||
|
||||
private List<SchemaMapper> schemaMappers = ComponentFactory.getSchemaMappers();
|
||||
|
||||
@Override
|
||||
public void parse(ChatParseContext chatParseContext, ParseResp parseResp) {
|
||||
if (!chatParseContext.enableNL2SQL()) {
|
||||
@@ -50,15 +23,8 @@ public class NL2SQLParser implements ChatParser {
|
||||
if (checkSkip(parseResp)) {
|
||||
return;
|
||||
}
|
||||
considerMultiturn(chatParseContext, parseResp);
|
||||
QueryReq queryReq = QueryReqConverter.buildText2SqlQueryReq(chatParseContext);
|
||||
|
||||
Environment environment = ContextUtils.getBean(Environment.class);
|
||||
String multiTurn = environment.getProperty("multi.turn");
|
||||
if (StringUtils.isNotBlank(multiTurn) && Boolean.parseBoolean(multiTurn)) {
|
||||
queryReq.setMapInfo(new SchemaMapInfo());
|
||||
}
|
||||
|
||||
ChatQueryService chatQueryService = ContextUtils.getBean(ChatQueryService.class);
|
||||
ParseResp text2SqlParseResp = chatQueryService.performParsing(queryReq);
|
||||
if (!ParseResp.ParseState.FAILED.equals(text2SqlParseResp.getState())) {
|
||||
@@ -67,100 +33,6 @@ public class NL2SQLParser implements ChatParser {
|
||||
parseResp.getParseTimeCost().setSqlTime(text2SqlParseResp.getParseTimeCost().getSqlTime());
|
||||
}
|
||||
|
||||
private void considerMultiturn(ChatParseContext chatParseContext, ParseResp parseResp) {
|
||||
Environment environment = ContextUtils.getBean(Environment.class);
|
||||
RewriteQueryGeneration rewriteQueryGeneration = ContextUtils.getBean(RewriteQueryGeneration.class);
|
||||
String multiTurn = environment.getProperty("multi.turn");
|
||||
String multiNum = environment.getProperty("multi.num");
|
||||
if (StringUtils.isBlank(multiTurn) || !Boolean.parseBoolean(multiTurn)) {
|
||||
return;
|
||||
}
|
||||
log.info("multi turn text-to-sql!");
|
||||
List<ParseResp> contextualList = getContextualList(parseResp, multiNum);
|
||||
List<String> contextualQuestions = getContextualQuestionsWithLink(contextualList);
|
||||
StringBuffer currentPromptSb = new StringBuffer();
|
||||
if (contextualQuestions.size() == 0) {
|
||||
currentPromptSb.append("contextualQuestions:" + "\n");
|
||||
} else {
|
||||
currentPromptSb.append("contextualQuestions:" + "\n" + String.join("\n", contextualQuestions) + "\n");
|
||||
}
|
||||
String currentQuestion = getQueryLinks(chatParseContext);
|
||||
currentPromptSb.append("currentQuestion:" + currentQuestion + "\n");
|
||||
currentPromptSb.append("rewritingCurrentQuestion:\n");
|
||||
String rewriteQuery = rewriteQueryGeneration.generation(currentPromptSb.toString());
|
||||
log.info("rewriteQuery:{}", rewriteQuery);
|
||||
chatParseContext.setQueryText(rewriteQuery);
|
||||
}
|
||||
|
||||
private List<String> getContextualQuestionsWithLink(List<ParseResp> contextualList) {
|
||||
List<String> contextualQuestions = new ArrayList<>();
|
||||
contextualList.stream().forEach(o -> {
|
||||
Map<String, Object> map = JsonUtil.toMap(JsonUtil.toString(
|
||||
o.getSelectedParses().get(0).getProperties().get(CONTEXT)), String.class, Object.class);
|
||||
LLMReq llmReq = JsonUtil.toObject(JsonUtil.toString(map.get("llmReq")), LLMReq.class);
|
||||
List<LLMReq.ElementValue> linking = llmReq.getLinking();
|
||||
List<String> priorLinkingList = new ArrayList<>();
|
||||
for (LLMReq.ElementValue priorLinking : linking) {
|
||||
String fieldName = priorLinking.getFieldName();
|
||||
String fieldValue = priorLinking.getFieldValue();
|
||||
priorLinkingList.add("‘" + fieldValue + "‘是一个‘" + fieldName + "‘");
|
||||
}
|
||||
String linkingListStr = String.join(",", priorLinkingList);
|
||||
String questionAugmented = String.format("%s (补充信息:%s) ", o.getQueryText(), linkingListStr);
|
||||
contextualQuestions.add(questionAugmented);
|
||||
});
|
||||
return contextualQuestions;
|
||||
}
|
||||
|
||||
private List<ParseResp> getContextualList(ParseResp parseResp, String multiNum) {
|
||||
ChatQueryRepository chatQueryRepository = ContextUtils.getBean(ChatQueryRepository.class);
|
||||
List<ParseResp> contextualParseInfoList = chatQueryRepository.getContextualParseInfo(
|
||||
parseResp.getChatId()).stream().filter(o -> o.getSelectedParses().get(0)
|
||||
.getQueryMode().equals(LLMSqlQuery.QUERY_MODE)
|
||||
).collect(Collectors.toList());
|
||||
if (StringUtils.isNotBlank(multiNum) && StringUtils.isNumeric(multiNum)) {
|
||||
int num = Integer.parseInt(multiNum);
|
||||
contextualNum = Math.min(contextualNum, num);
|
||||
}
|
||||
List<ParseResp> contextualList = contextualParseInfoList.subList(0,
|
||||
Math.min(contextualNum, contextualParseInfoList.size()));
|
||||
Collections.reverse(contextualList);
|
||||
return contextualList;
|
||||
}
|
||||
|
||||
private String getQueryLinks(ChatParseContext chatParseContext) {
|
||||
QueryReq queryReq = QueryReqConverter.buildText2SqlQueryReq(chatParseContext);
|
||||
|
||||
ChatQueryServiceImpl chatQueryService = ContextUtils.getBean(ChatQueryServiceImpl.class);
|
||||
// build queryContext and chatContext
|
||||
QueryContext queryCtx = chatQueryService.buildQueryContext(queryReq);
|
||||
|
||||
// 1. mapper
|
||||
if (Objects.isNull(chatParseContext.getMapInfo())
|
||||
|| MapUtils.isEmpty(chatParseContext.getMapInfo().getDataSetElementMatches())) {
|
||||
schemaMappers.forEach(mapper -> {
|
||||
mapper.map(queryCtx);
|
||||
});
|
||||
}
|
||||
LLMRequestService requestService = ContextUtils.getBean(LLMRequestService.class);
|
||||
Long dataSetId = requestService.getDataSetId(queryCtx);
|
||||
log.info("dataSetId:{}", dataSetId);
|
||||
if (dataSetId == null) {
|
||||
return null;
|
||||
}
|
||||
List<LLMReq.ElementValue> linkingValues = requestService.getValueList(queryCtx, dataSetId);
|
||||
List<String> priorLinkingList = new ArrayList<>();
|
||||
for (LLMReq.ElementValue priorLinking : linkingValues) {
|
||||
String fieldName = priorLinking.getFieldName();
|
||||
String fieldValue = priorLinking.getFieldValue();
|
||||
priorLinkingList.add("‘" + fieldValue + "‘是一个‘" + fieldName + "‘");
|
||||
}
|
||||
String linkingListStr = String.join(",", priorLinkingList);
|
||||
String questionAugmented = String.format("%s (补充信息:%s) ", chatParseContext.getQueryText(), linkingListStr);
|
||||
log.info("questionAugmented:{}", questionAugmented);
|
||||
return questionAugmented;
|
||||
}
|
||||
|
||||
private boolean checkSkip(ParseResp parseResp) {
|
||||
List<SemanticParseInfo> selectedParses = parseResp.getSelectedParses();
|
||||
for (SemanticParseInfo semanticParseInfo : selectedParses) {
|
||||
|
||||
@@ -64,7 +64,6 @@ public class ChatServiceImpl implements ChatService {
|
||||
|
||||
@Override
|
||||
public ParseResp performParsing(ChatParseReq chatParseReq) {
|
||||
String queryText = chatParseReq.getQueryText();
|
||||
ParseResp parseResp = new ParseResp(chatParseReq.getChatId(), chatParseReq.getQueryText());
|
||||
chatManageService.createChatQuery(chatParseReq, parseResp);
|
||||
ChatParseContext chatParseContext = buildParseContext(chatParseReq);
|
||||
@@ -74,8 +73,8 @@ public class ChatServiceImpl implements ChatService {
|
||||
for (ParseResultProcessor processor : parseResultProcessors) {
|
||||
processor.process(chatParseContext, parseResp);
|
||||
}
|
||||
parseResp.setQueryText(queryText);
|
||||
chatParseReq.setQueryText(queryText);
|
||||
chatParseReq.setQueryText(chatParseContext.getQueryText());
|
||||
parseResp.setQueryText(chatParseContext.getQueryText());
|
||||
chatManageService.batchAddParse(chatParseReq, parseResp);
|
||||
return parseResp;
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user