(improvement)(chat)Implement a new version of multi-turn conversation.

This commit is contained in:
jerryjzhang
2024-05-19 16:43:05 +08:00
parent 710f120e38
commit cbafff0935
10 changed files with 155 additions and 252 deletions

View File

@@ -32,9 +32,9 @@ With these ideas in mind, we develop SuperSonic as a practical reference impleme
- Built-in Chat BI interface for *business users* to enter natural language queries
- Built-in Headless BI interface for *analytics engineers* to build semantic data models
- Built-in rule-based semantic parser to improve efficiency in certain scenarios
- Support input auto-completion as well as query recommendation
- Support four-level permission control: domain-level, model-level, column-level and row-level
- Built-in rule-based semantic parser to improve efficiency in certain scenarios (e.g. demonstration, integration testing)
- Built-in support for input auto-completion, multi-turn conversation as well as post-query recommendation
- Built-in support for three-level data access control: dataset-level, column-level and row-level
## Extensible Components

View File

@@ -28,9 +28,9 @@
- 内置Chat BI界面以便*业务用户*输入数据查询。
- 内置Headless BI界面以便*分析工程师*构建语义模型。
- 内置基于规则的语义解析器,在特定场景可以提升运行效率。
- 支持文本输入联想查询问题推荐。
- 支持级权限控制:主题域级、模型级、列级、行级。
- 内置基于规则的语义解析器,在特定场景比如DEMO演示、集成测试可以提升推理效率。
- 支持文本输入联想、多轮对话、查询问题推荐等高级特征
- 支持级权限控制:数据集级、列级、行级。
## 易于扩展的组件

View File

@@ -0,0 +1,146 @@
package com.tencent.supersonic.chat.server.parser;
import com.tencent.supersonic.chat.server.persistence.repository.ChatQueryRepository;
import com.tencent.supersonic.chat.server.pojo.ChatParseContext;
import com.tencent.supersonic.chat.server.util.QueryReqConverter;
import com.tencent.supersonic.common.util.ContextUtils;
import com.tencent.supersonic.headless.api.pojo.SchemaElementMatch;
import com.tencent.supersonic.headless.api.pojo.SchemaElementType;
import com.tencent.supersonic.headless.api.pojo.request.QueryReq;
import com.tencent.supersonic.headless.api.pojo.response.MapResp;
import com.tencent.supersonic.headless.api.pojo.response.ParseResp;
import com.tencent.supersonic.headless.server.service.ChatQueryService;
import dev.langchain4j.data.message.AiMessage;
import dev.langchain4j.model.chat.ChatLanguageModel;
import dev.langchain4j.model.input.Prompt;
import dev.langchain4j.model.input.PromptTemplate;
import dev.langchain4j.model.output.Response;
import lombok.Builder;
import lombok.Data;
import lombok.extern.slf4j.Slf4j;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.core.env.Environment;
import java.util.Map;
import java.util.HashMap;
import java.util.ArrayList;
import java.util.List;
import java.util.stream.Collectors;
import java.util.Collections;
@Slf4j
public class MultiTurnParser implements ChatParser {
private static final Logger keyPipelineLog = LoggerFactory.getLogger(MultiTurnParser.class);
private static final PromptTemplate promptTemplate = PromptTemplate.from(
"You are a data product manager experienced in data requirements."
+ "Your will be provided with current and history questions asked by a user,"
+ "along with their mapped schema elements(metric, dimension and value), "
+ "please try understanding the semantics and rewrite a question"
+ "(keep relevant metrics, dimensions, values and date ranges)."
+ "Current Question: {{curtQuestion}} "
+ "Current Mapped Schema: {{curtSchema}} "
+ "History Question: {{histQuestion}} "
+ "History Mapped Schema: {{histSchema}} "
+ "Rewritten Question: ");
@Override
public void parse(ChatParseContext chatParseContext, ParseResp parseResp) {
Environment environment = ContextUtils.getBean(Environment.class);
Boolean multiTurn = environment.getProperty("multi.turn", Boolean.class);
if (Boolean.FALSE.equals(multiTurn)) {
return;
}
// derive mapping result of current question and parsing result of last question.
ChatQueryService chatQueryService = ContextUtils.getBean(ChatQueryService.class);
QueryReq queryReq = QueryReqConverter.buildText2SqlQueryReq(chatParseContext);
MapResp currentMapResult = chatQueryService.performMapping(queryReq);
List<ParseResp> historyParseResults = getHistoryParseResult(chatParseContext.getChatId(), 1);
if (historyParseResults.size() == 0) {
return;
}
ParseResp lastParseResult = historyParseResults.get(0);
Long dataId = lastParseResult.getSelectedParses().get(0).getDataSetId();
String curtMapStr = generateSchemaPrompt(currentMapResult.getMapInfo().getMatchedElements(dataId));
String histMapStr = generateSchemaPrompt(lastParseResult.getSelectedParses().get(0).getElementMatches());
String rewrittenQuery = rewriteQuery(RewriteContext.builder()
.curtQuestion(currentMapResult.getQueryText())
.histQuestion(lastParseResult.getQueryText())
.curtSchema(curtMapStr)
.histSchema(histMapStr)
.build());
chatParseContext.setQueryText(rewrittenQuery);
log.info("Last Query: {} Current Query: {}, Rewritten Query: {}",
lastParseResult.getQueryText(), currentMapResult.getQueryText(), rewrittenQuery);
}
private String rewriteQuery(RewriteContext context) {
Map<String, Object> variables = new HashMap<>();
variables.put("curtQuestion", context.getCurtQuestion());
variables.put("histQuestion", context.getHistQuestion());
variables.put("curtSchema", context.getCurtSchema());
variables.put("histSchema", context.getHistSchema());
Prompt prompt = promptTemplate.apply(variables);
keyPipelineLog.info("request prompt:{}", prompt.toSystemMessage());
ChatLanguageModel chatLanguageModel = ContextUtils.getBean(ChatLanguageModel.class);
Response<AiMessage> response = chatLanguageModel.generate(prompt.toSystemMessage());
String result = response.content().text();
keyPipelineLog.info("model response:{}", result);
//3.format response.
String rewriteQuery = response.content().text();
return rewriteQuery;
}
private String generateSchemaPrompt(List<SchemaElementMatch> elementMatches) {
List<String> metrics = new ArrayList<>();
List<String> dimensions = new ArrayList<>();
List<String> values = new ArrayList<>();
for (SchemaElementMatch match : elementMatches) {
if (match.getElement().getType().equals(SchemaElementType.METRIC)) {
metrics.add(match.getWord());
} else if (match.getElement().getType().equals(SchemaElementType.DIMENSION)) {
dimensions.add(match.getWord());
} else if (match.getElement().getType().equals(SchemaElementType.VALUE)) {
values.add(match.getWord());
}
}
StringBuilder prompt = new StringBuilder();
prompt.append(String.format("'metrics:':[%s]", String.join(",", metrics)));
prompt.append(",");
prompt.append(String.format("'dimensions:':[%s]", String.join(",", dimensions)));
prompt.append(",");
prompt.append(String.format("'values:':[%s]", String.join(",", values)));
return prompt.toString();
}
private List<ParseResp> getHistoryParseResult(int chatId, int multiNum) {
ChatQueryRepository chatQueryRepository = ContextUtils.getBean(ChatQueryRepository.class);
List<ParseResp> contextualParseInfoList = chatQueryRepository.getContextualParseInfo(chatId)
.stream().filter(p -> p.getState() != ParseResp.ParseState.FAILED).collect(Collectors.toList());
List<ParseResp> contextualList = contextualParseInfoList.subList(0,
Math.min(multiNum, contextualParseInfoList.size()));
Collections.reverse(contextualList);
return contextualList;
}
@Data
@Builder
public static class RewriteContext {
private String curtQuestion;
private String histQuestion;
private String curtSchema;
private String histSchema;
}
}

View File

@@ -1,47 +1,20 @@
package com.tencent.supersonic.chat.server.parser;
import com.tencent.supersonic.chat.server.persistence.repository.ChatQueryRepository;
import com.tencent.supersonic.chat.server.pojo.ChatParseContext;
import com.tencent.supersonic.chat.server.util.QueryReqConverter;
import com.tencent.supersonic.common.util.ContextUtils;
import com.tencent.supersonic.common.util.JsonUtil;
import com.tencent.supersonic.headless.api.pojo.SchemaMapInfo;
import com.tencent.supersonic.headless.api.pojo.SemanticParseInfo;
import com.tencent.supersonic.headless.api.pojo.request.QueryReq;
import com.tencent.supersonic.headless.api.pojo.response.ParseResp;
import com.tencent.supersonic.headless.core.chat.mapper.SchemaMapper;
import com.tencent.supersonic.headless.core.chat.parser.llm.LLMRequestService;
import com.tencent.supersonic.headless.core.chat.parser.llm.RewriteQueryGeneration;
import com.tencent.supersonic.headless.core.chat.query.llm.s2sql.LLMReq;
import com.tencent.supersonic.headless.core.chat.query.llm.s2sql.LLMSqlQuery;
import com.tencent.supersonic.headless.core.pojo.QueryContext;
import com.tencent.supersonic.headless.server.service.ChatQueryService;
import java.util.List;
import java.util.Collections;
import java.util.ArrayList;
import java.util.Map;
import java.util.Objects;
import java.util.stream.Collectors;
import com.tencent.supersonic.headless.server.service.impl.ChatQueryServiceImpl;
import com.tencent.supersonic.headless.server.utils.ComponentFactory;
import lombok.extern.slf4j.Slf4j;
import org.apache.commons.collections.MapUtils;
import org.apache.commons.lang3.StringUtils;
import org.springframework.core.env.Environment;
import static com.tencent.supersonic.common.pojo.Constants.CONTEXT;
@Slf4j
public class NL2SQLParser implements ChatParser {
private int contextualNum = 5;
private List<SchemaMapper> schemaMappers = ComponentFactory.getSchemaMappers();
@Override
public void parse(ChatParseContext chatParseContext, ParseResp parseResp) {
if (!chatParseContext.enableNL2SQL()) {
@@ -50,15 +23,8 @@ public class NL2SQLParser implements ChatParser {
if (checkSkip(parseResp)) {
return;
}
considerMultiturn(chatParseContext, parseResp);
QueryReq queryReq = QueryReqConverter.buildText2SqlQueryReq(chatParseContext);
Environment environment = ContextUtils.getBean(Environment.class);
String multiTurn = environment.getProperty("multi.turn");
if (StringUtils.isNotBlank(multiTurn) && Boolean.parseBoolean(multiTurn)) {
queryReq.setMapInfo(new SchemaMapInfo());
}
ChatQueryService chatQueryService = ContextUtils.getBean(ChatQueryService.class);
ParseResp text2SqlParseResp = chatQueryService.performParsing(queryReq);
if (!ParseResp.ParseState.FAILED.equals(text2SqlParseResp.getState())) {
@@ -67,100 +33,6 @@ public class NL2SQLParser implements ChatParser {
parseResp.getParseTimeCost().setSqlTime(text2SqlParseResp.getParseTimeCost().getSqlTime());
}
private void considerMultiturn(ChatParseContext chatParseContext, ParseResp parseResp) {
Environment environment = ContextUtils.getBean(Environment.class);
RewriteQueryGeneration rewriteQueryGeneration = ContextUtils.getBean(RewriteQueryGeneration.class);
String multiTurn = environment.getProperty("multi.turn");
String multiNum = environment.getProperty("multi.num");
if (StringUtils.isBlank(multiTurn) || !Boolean.parseBoolean(multiTurn)) {
return;
}
log.info("multi turn text-to-sql!");
List<ParseResp> contextualList = getContextualList(parseResp, multiNum);
List<String> contextualQuestions = getContextualQuestionsWithLink(contextualList);
StringBuffer currentPromptSb = new StringBuffer();
if (contextualQuestions.size() == 0) {
currentPromptSb.append("contextualQuestions:" + "\n");
} else {
currentPromptSb.append("contextualQuestions:" + "\n" + String.join("\n", contextualQuestions) + "\n");
}
String currentQuestion = getQueryLinks(chatParseContext);
currentPromptSb.append("currentQuestion:" + currentQuestion + "\n");
currentPromptSb.append("rewritingCurrentQuestion:\n");
String rewriteQuery = rewriteQueryGeneration.generation(currentPromptSb.toString());
log.info("rewriteQuery:{}", rewriteQuery);
chatParseContext.setQueryText(rewriteQuery);
}
private List<String> getContextualQuestionsWithLink(List<ParseResp> contextualList) {
List<String> contextualQuestions = new ArrayList<>();
contextualList.stream().forEach(o -> {
Map<String, Object> map = JsonUtil.toMap(JsonUtil.toString(
o.getSelectedParses().get(0).getProperties().get(CONTEXT)), String.class, Object.class);
LLMReq llmReq = JsonUtil.toObject(JsonUtil.toString(map.get("llmReq")), LLMReq.class);
List<LLMReq.ElementValue> linking = llmReq.getLinking();
List<String> priorLinkingList = new ArrayList<>();
for (LLMReq.ElementValue priorLinking : linking) {
String fieldName = priorLinking.getFieldName();
String fieldValue = priorLinking.getFieldValue();
priorLinkingList.add("" + fieldValue + "‘是一个‘" + fieldName + "");
}
String linkingListStr = String.join("", priorLinkingList);
String questionAugmented = String.format("%s (补充信息:%s) ", o.getQueryText(), linkingListStr);
contextualQuestions.add(questionAugmented);
});
return contextualQuestions;
}
private List<ParseResp> getContextualList(ParseResp parseResp, String multiNum) {
ChatQueryRepository chatQueryRepository = ContextUtils.getBean(ChatQueryRepository.class);
List<ParseResp> contextualParseInfoList = chatQueryRepository.getContextualParseInfo(
parseResp.getChatId()).stream().filter(o -> o.getSelectedParses().get(0)
.getQueryMode().equals(LLMSqlQuery.QUERY_MODE)
).collect(Collectors.toList());
if (StringUtils.isNotBlank(multiNum) && StringUtils.isNumeric(multiNum)) {
int num = Integer.parseInt(multiNum);
contextualNum = Math.min(contextualNum, num);
}
List<ParseResp> contextualList = contextualParseInfoList.subList(0,
Math.min(contextualNum, contextualParseInfoList.size()));
Collections.reverse(contextualList);
return contextualList;
}
private String getQueryLinks(ChatParseContext chatParseContext) {
QueryReq queryReq = QueryReqConverter.buildText2SqlQueryReq(chatParseContext);
ChatQueryServiceImpl chatQueryService = ContextUtils.getBean(ChatQueryServiceImpl.class);
// build queryContext and chatContext
QueryContext queryCtx = chatQueryService.buildQueryContext(queryReq);
// 1. mapper
if (Objects.isNull(chatParseContext.getMapInfo())
|| MapUtils.isEmpty(chatParseContext.getMapInfo().getDataSetElementMatches())) {
schemaMappers.forEach(mapper -> {
mapper.map(queryCtx);
});
}
LLMRequestService requestService = ContextUtils.getBean(LLMRequestService.class);
Long dataSetId = requestService.getDataSetId(queryCtx);
log.info("dataSetId:{}", dataSetId);
if (dataSetId == null) {
return null;
}
List<LLMReq.ElementValue> linkingValues = requestService.getValueList(queryCtx, dataSetId);
List<String> priorLinkingList = new ArrayList<>();
for (LLMReq.ElementValue priorLinking : linkingValues) {
String fieldName = priorLinking.getFieldName();
String fieldValue = priorLinking.getFieldValue();
priorLinkingList.add("" + fieldValue + "‘是一个‘" + fieldName + "");
}
String linkingListStr = String.join("", priorLinkingList);
String questionAugmented = String.format("%s (补充信息:%s) ", chatParseContext.getQueryText(), linkingListStr);
log.info("questionAugmented:{}", questionAugmented);
return questionAugmented;
}
private boolean checkSkip(ParseResp parseResp) {
List<SemanticParseInfo> selectedParses = parseResp.getSelectedParses();
for (SemanticParseInfo semanticParseInfo : selectedParses) {

View File

@@ -64,7 +64,6 @@ public class ChatServiceImpl implements ChatService {
@Override
public ParseResp performParsing(ChatParseReq chatParseReq) {
String queryText = chatParseReq.getQueryText();
ParseResp parseResp = new ParseResp(chatParseReq.getChatId(), chatParseReq.getQueryText());
chatManageService.createChatQuery(chatParseReq, parseResp);
ChatParseContext chatParseContext = buildParseContext(chatParseReq);
@@ -74,8 +73,8 @@ public class ChatServiceImpl implements ChatService {
for (ParseResultProcessor processor : parseResultProcessors) {
processor.process(chatParseContext, parseResp);
}
parseResp.setQueryText(queryText);
chatParseReq.setQueryText(queryText);
chatParseReq.setQueryText(chatParseContext.getQueryText());
parseResp.setQueryText(chatParseContext.getQueryText());
chatManageService.batchAddParse(chatParseReq, parseResp);
return parseResp;
}

View File

@@ -1,32 +0,0 @@
package com.tencent.supersonic.headless.core.chat.parser.llm;
import com.fasterxml.jackson.core.type.TypeReference;
import com.tencent.supersonic.common.util.JsonUtil;
import lombok.extern.slf4j.Slf4j;
import org.springframework.core.io.ClassPathResource;
import org.springframework.stereotype.Component;
import java.io.InputStream;
import java.util.List;
import java.util.ArrayList;
@Slf4j
@Component
public class RewriteExamplarLoader {
private static final String EXAMPLE_JSON_FILE = "rewrite_examplar.json";
private TypeReference<List<RewriteExample>> valueTypeRef = new TypeReference<List<RewriteExample>>() {
};
public List<RewriteExample> getRewriteExamples() {
try {
ClassPathResource resource = new ClassPathResource(EXAMPLE_JSON_FILE);
InputStream inputStream = resource.getInputStream();
return JsonUtil.INSTANCE.getObjectMapper().readValue(inputStream, valueTypeRef);
} catch (Exception e) {
return new ArrayList<>();
}
}
}

View File

@@ -1,14 +0,0 @@
package com.tencent.supersonic.headless.core.chat.parser.llm;
import lombok.Data;
@Data
public class RewriteExample {
private String contextualQuestions;
private String currentQuestion;
private String rewritingCurrentQuestion;
}

View File

@@ -1,54 +0,0 @@
package com.tencent.supersonic.headless.core.chat.parser.llm;
import com.tencent.supersonic.common.util.JsonUtil;
import dev.langchain4j.data.message.AiMessage;
import dev.langchain4j.model.chat.ChatLanguageModel;
import dev.langchain4j.model.input.Prompt;
import dev.langchain4j.model.input.PromptTemplate;
import dev.langchain4j.model.output.Response;
import lombok.extern.slf4j.Slf4j;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.stereotype.Service;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.stream.Collectors;
@Service
@Slf4j
public class RewriteQueryGeneration {
private static final Logger keyPipelineLog = LoggerFactory.getLogger("keyPipeline");
@Autowired
private ChatLanguageModel chatLanguageModel;
@Autowired
private RewriteExamplarLoader rewriteExamplarLoader;
@Autowired
private SqlPromptGenerator sqlPromptGenerator;
public String generation(String currentPromptStr) {
//1.retriever sqlExamples
List<Map<String, String>> rewriteExamples = rewriteExamplarLoader.getRewriteExamples().stream().map(o -> {
return JsonUtil.toMap(JsonUtil.toString(o), String.class, String.class);
}).collect(Collectors.toList());
//2.generator linking and sql prompt by sqlExamples,and generate response.
String promptStr = sqlPromptGenerator.generateRewritePrompt(rewriteExamples) + currentPromptStr;
Prompt prompt = PromptTemplate.from(JsonUtil.toString(promptStr)).apply(new HashMap<>());
keyPipelineLog.info("request prompt:{}", prompt.toSystemMessage());
Response<AiMessage> response = chatLanguageModel.generate(prompt.toSystemMessage());
String result = response.content().text();
keyPipelineLog.info("model response:{}", result);
//3.format response.
String rewriteQuery = response.content().text();
return rewriteQuery;
}
}

View File

@@ -129,19 +129,4 @@ public class SqlPromptGenerator {
return sqlPromptPool;
}
public String generateRewritePrompt(List<Map<String, String>> rewriteExamples) {
String instruction = "#this is a multi-turn text-to-sql scenes,you need consider the contextual "
+ "questions and semantics, rewriting current question for expressing complete semantics of "
+ "the current question based on the contextual questions.";
List<String> exampleKeys = Arrays.asList("contextualQuestions", "currentQuestion", "rewritingCurrentQuestion");
StringBuilder rewriteSb = new StringBuilder();
rewriteExamples.stream().forEach(o -> {
exampleKeys.stream().forEach(example -> {
rewriteSb.append(example + ":" + o.get(example) + "\n");
});
rewriteSb.append("\n");
});
return instruction + InputFormat.SEPERATOR + rewriteSb.toString();
}
}

View File

@@ -47,6 +47,7 @@ com.tencent.supersonic.auth.api.authentication.adaptor.UserAdaptor=\
com.tencent.supersonic.chat.server.parser.ChatParser=\
com.tencent.supersonic.chat.server.parser.NL2PluginParser, \
com.tencent.supersonic.chat.server.parser.MultiTurnParser,\
com.tencent.supersonic.chat.server.parser.NL2SQLParser
com.tencent.supersonic.chat.server.executor.ChatExecutor=\