From cbafff0935628649086e7ce6176570ca8e842195 Mon Sep 17 00:00:00 2001 From: jerryjzhang Date: Sun, 19 May 2024 16:43:05 +0800 Subject: [PATCH] (improvement)(chat)Implement a new version of multi-turn conversation. --- README.md | 6 +- README_CN.md | 6 +- .../chat/server/parser/MultiTurnParser.java | 146 ++++++++++++++++++ .../chat/server/parser/NL2SQLParser.java | 128 --------------- .../server/service/impl/ChatServiceImpl.java | 5 +- .../parser/llm/RewriteExamplarLoader.java | 32 ---- .../core/chat/parser/llm/RewriteExample.java | 14 -- .../parser/llm/RewriteQueryGeneration.java | 54 ------- .../chat/parser/llm/SqlPromptGenerator.java | 15 -- .../main/resources/META-INF/spring.factories | 1 + 10 files changed, 155 insertions(+), 252 deletions(-) create mode 100644 chat/server/src/main/java/com/tencent/supersonic/chat/server/parser/MultiTurnParser.java delete mode 100644 headless/core/src/main/java/com/tencent/supersonic/headless/core/chat/parser/llm/RewriteExamplarLoader.java delete mode 100644 headless/core/src/main/java/com/tencent/supersonic/headless/core/chat/parser/llm/RewriteExample.java delete mode 100644 headless/core/src/main/java/com/tencent/supersonic/headless/core/chat/parser/llm/RewriteQueryGeneration.java diff --git a/README.md b/README.md index cfd86519b..f39673055 100644 --- a/README.md +++ b/README.md @@ -32,9 +32,9 @@ With these ideas in mind, we develop SuperSonic as a practical reference impleme - Built-in Chat BI interface for *business users* to enter natural language queries - Built-in Headless BI interface for *analytics engineers* to build semantic data models -- Built-in rule-based semantic parser to improve efficiency in certain scenarios -- Support input auto-completion as well as query recommendation -- Support four-level permission control: domain-level, model-level, column-level and row-level +- Built-in rule-based semantic parser to improve efficiency in certain scenarios (e.g. demonstration, integration testing) +- Built-in support for input auto-completion, multi-turn conversation as well as post-query recommendation +- Built-in support for three-level data access control: dataset-level, column-level and row-level ## Extensible Components diff --git a/README_CN.md b/README_CN.md index 9c23f0ca3..4e96f036c 100644 --- a/README_CN.md +++ b/README_CN.md @@ -28,9 +28,9 @@ - 内置Chat BI界面以便*业务用户*输入数据查询。 - 内置Headless BI界面以便*分析工程师*构建语义模型。 -- 内置基于规则的语义解析器,在特定场景可以提升运行效率。 -- 支持文本输入的联想和查询问题的推荐。 -- 支持四级权限控制:主题域级、模型级、列级、行级。 +- 内置基于规则的语义解析器,在特定场景(比如DEMO演示、集成测试)可以提升推理效率。 +- 支持文本输入联想、多轮对话、查询后问题推荐等高级特征。 +- 支持三级权限控制:数据集级、列级、行级。 ## 易于扩展的组件 diff --git a/chat/server/src/main/java/com/tencent/supersonic/chat/server/parser/MultiTurnParser.java b/chat/server/src/main/java/com/tencent/supersonic/chat/server/parser/MultiTurnParser.java new file mode 100644 index 000000000..257e49f5a --- /dev/null +++ b/chat/server/src/main/java/com/tencent/supersonic/chat/server/parser/MultiTurnParser.java @@ -0,0 +1,146 @@ +package com.tencent.supersonic.chat.server.parser; + +import com.tencent.supersonic.chat.server.persistence.repository.ChatQueryRepository; +import com.tencent.supersonic.chat.server.pojo.ChatParseContext; +import com.tencent.supersonic.chat.server.util.QueryReqConverter; +import com.tencent.supersonic.common.util.ContextUtils; +import com.tencent.supersonic.headless.api.pojo.SchemaElementMatch; +import com.tencent.supersonic.headless.api.pojo.SchemaElementType; +import com.tencent.supersonic.headless.api.pojo.request.QueryReq; +import com.tencent.supersonic.headless.api.pojo.response.MapResp; +import com.tencent.supersonic.headless.api.pojo.response.ParseResp; +import com.tencent.supersonic.headless.server.service.ChatQueryService; +import dev.langchain4j.data.message.AiMessage; +import dev.langchain4j.model.chat.ChatLanguageModel; +import dev.langchain4j.model.input.Prompt; +import dev.langchain4j.model.input.PromptTemplate; +import dev.langchain4j.model.output.Response; +import lombok.Builder; +import lombok.Data; +import lombok.extern.slf4j.Slf4j; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; +import org.springframework.core.env.Environment; + +import java.util.Map; +import java.util.HashMap; +import java.util.ArrayList; +import java.util.List; +import java.util.stream.Collectors; +import java.util.Collections; + +@Slf4j +public class MultiTurnParser implements ChatParser { + + private static final Logger keyPipelineLog = LoggerFactory.getLogger(MultiTurnParser.class); + + private static final PromptTemplate promptTemplate = PromptTemplate.from( + "You are a data product manager experienced in data requirements." + + "Your will be provided with current and history questions asked by a user," + + "along with their mapped schema elements(metric, dimension and value), " + + "please try understanding the semantics and rewrite a question" + + "(keep relevant metrics, dimensions, values and date ranges)." + + "Current Question: {{curtQuestion}} " + + "Current Mapped Schema: {{curtSchema}} " + + "History Question: {{histQuestion}} " + + "History Mapped Schema: {{histSchema}} " + + "Rewritten Question: "); + + @Override + public void parse(ChatParseContext chatParseContext, ParseResp parseResp) { + Environment environment = ContextUtils.getBean(Environment.class); + Boolean multiTurn = environment.getProperty("multi.turn", Boolean.class); + if (Boolean.FALSE.equals(multiTurn)) { + return; + } + + // derive mapping result of current question and parsing result of last question. + ChatQueryService chatQueryService = ContextUtils.getBean(ChatQueryService.class); + QueryReq queryReq = QueryReqConverter.buildText2SqlQueryReq(chatParseContext); + MapResp currentMapResult = chatQueryService.performMapping(queryReq); + + List historyParseResults = getHistoryParseResult(chatParseContext.getChatId(), 1); + if (historyParseResults.size() == 0) { + return; + } + ParseResp lastParseResult = historyParseResults.get(0); + Long dataId = lastParseResult.getSelectedParses().get(0).getDataSetId(); + + String curtMapStr = generateSchemaPrompt(currentMapResult.getMapInfo().getMatchedElements(dataId)); + String histMapStr = generateSchemaPrompt(lastParseResult.getSelectedParses().get(0).getElementMatches()); + String rewrittenQuery = rewriteQuery(RewriteContext.builder() + .curtQuestion(currentMapResult.getQueryText()) + .histQuestion(lastParseResult.getQueryText()) + .curtSchema(curtMapStr) + .histSchema(histMapStr) + .build()); + chatParseContext.setQueryText(rewrittenQuery); + log.info("Last Query: {} Current Query: {}, Rewritten Query: {}", + lastParseResult.getQueryText(), currentMapResult.getQueryText(), rewrittenQuery); + } + + private String rewriteQuery(RewriteContext context) { + + Map variables = new HashMap<>(); + variables.put("curtQuestion", context.getCurtQuestion()); + variables.put("histQuestion", context.getHistQuestion()); + variables.put("curtSchema", context.getCurtSchema()); + variables.put("histSchema", context.getHistSchema()); + + Prompt prompt = promptTemplate.apply(variables); + keyPipelineLog.info("request prompt:{}", prompt.toSystemMessage()); + ChatLanguageModel chatLanguageModel = ContextUtils.getBean(ChatLanguageModel.class); + Response response = chatLanguageModel.generate(prompt.toSystemMessage()); + String result = response.content().text(); + keyPipelineLog.info("model response:{}", result); + //3.format response. + String rewriteQuery = response.content().text(); + + return rewriteQuery; + } + + private String generateSchemaPrompt(List elementMatches) { + List metrics = new ArrayList<>(); + List dimensions = new ArrayList<>(); + List values = new ArrayList<>(); + + for (SchemaElementMatch match : elementMatches) { + if (match.getElement().getType().equals(SchemaElementType.METRIC)) { + metrics.add(match.getWord()); + } else if (match.getElement().getType().equals(SchemaElementType.DIMENSION)) { + dimensions.add(match.getWord()); + } else if (match.getElement().getType().equals(SchemaElementType.VALUE)) { + values.add(match.getWord()); + } + } + + StringBuilder prompt = new StringBuilder(); + prompt.append(String.format("'metrics:':[%s]", String.join(",", metrics))); + prompt.append(","); + prompt.append(String.format("'dimensions:':[%s]", String.join(",", dimensions))); + prompt.append(","); + prompt.append(String.format("'values:':[%s]", String.join(",", values))); + + return prompt.toString(); + } + + private List getHistoryParseResult(int chatId, int multiNum) { + ChatQueryRepository chatQueryRepository = ContextUtils.getBean(ChatQueryRepository.class); + List contextualParseInfoList = chatQueryRepository.getContextualParseInfo(chatId) + .stream().filter(p -> p.getState() != ParseResp.ParseState.FAILED).collect(Collectors.toList()); + + List contextualList = contextualParseInfoList.subList(0, + Math.min(multiNum, contextualParseInfoList.size())); + Collections.reverse(contextualList); + return contextualList; + } + + @Data + @Builder + public static class RewriteContext { + private String curtQuestion; + private String histQuestion; + private String curtSchema; + private String histSchema; + } +} diff --git a/chat/server/src/main/java/com/tencent/supersonic/chat/server/parser/NL2SQLParser.java b/chat/server/src/main/java/com/tencent/supersonic/chat/server/parser/NL2SQLParser.java index b884f761a..6ef246c18 100644 --- a/chat/server/src/main/java/com/tencent/supersonic/chat/server/parser/NL2SQLParser.java +++ b/chat/server/src/main/java/com/tencent/supersonic/chat/server/parser/NL2SQLParser.java @@ -1,47 +1,20 @@ package com.tencent.supersonic.chat.server.parser; -import com.tencent.supersonic.chat.server.persistence.repository.ChatQueryRepository; import com.tencent.supersonic.chat.server.pojo.ChatParseContext; import com.tencent.supersonic.chat.server.util.QueryReqConverter; import com.tencent.supersonic.common.util.ContextUtils; -import com.tencent.supersonic.common.util.JsonUtil; -import com.tencent.supersonic.headless.api.pojo.SchemaMapInfo; import com.tencent.supersonic.headless.api.pojo.SemanticParseInfo; import com.tencent.supersonic.headless.api.pojo.request.QueryReq; import com.tencent.supersonic.headless.api.pojo.response.ParseResp; -import com.tencent.supersonic.headless.core.chat.mapper.SchemaMapper; -import com.tencent.supersonic.headless.core.chat.parser.llm.LLMRequestService; -import com.tencent.supersonic.headless.core.chat.parser.llm.RewriteQueryGeneration; -import com.tencent.supersonic.headless.core.chat.query.llm.s2sql.LLMReq; -import com.tencent.supersonic.headless.core.chat.query.llm.s2sql.LLMSqlQuery; -import com.tencent.supersonic.headless.core.pojo.QueryContext; import com.tencent.supersonic.headless.server.service.ChatQueryService; import java.util.List; -import java.util.Collections; -import java.util.ArrayList; -import java.util.Map; -import java.util.Objects; -import java.util.stream.Collectors; - - -import com.tencent.supersonic.headless.server.service.impl.ChatQueryServiceImpl; -import com.tencent.supersonic.headless.server.utils.ComponentFactory; import lombok.extern.slf4j.Slf4j; -import org.apache.commons.collections.MapUtils; -import org.apache.commons.lang3.StringUtils; -import org.springframework.core.env.Environment; - -import static com.tencent.supersonic.common.pojo.Constants.CONTEXT; @Slf4j public class NL2SQLParser implements ChatParser { - private int contextualNum = 5; - - private List schemaMappers = ComponentFactory.getSchemaMappers(); - @Override public void parse(ChatParseContext chatParseContext, ParseResp parseResp) { if (!chatParseContext.enableNL2SQL()) { @@ -50,15 +23,8 @@ public class NL2SQLParser implements ChatParser { if (checkSkip(parseResp)) { return; } - considerMultiturn(chatParseContext, parseResp); QueryReq queryReq = QueryReqConverter.buildText2SqlQueryReq(chatParseContext); - Environment environment = ContextUtils.getBean(Environment.class); - String multiTurn = environment.getProperty("multi.turn"); - if (StringUtils.isNotBlank(multiTurn) && Boolean.parseBoolean(multiTurn)) { - queryReq.setMapInfo(new SchemaMapInfo()); - } - ChatQueryService chatQueryService = ContextUtils.getBean(ChatQueryService.class); ParseResp text2SqlParseResp = chatQueryService.performParsing(queryReq); if (!ParseResp.ParseState.FAILED.equals(text2SqlParseResp.getState())) { @@ -67,100 +33,6 @@ public class NL2SQLParser implements ChatParser { parseResp.getParseTimeCost().setSqlTime(text2SqlParseResp.getParseTimeCost().getSqlTime()); } - private void considerMultiturn(ChatParseContext chatParseContext, ParseResp parseResp) { - Environment environment = ContextUtils.getBean(Environment.class); - RewriteQueryGeneration rewriteQueryGeneration = ContextUtils.getBean(RewriteQueryGeneration.class); - String multiTurn = environment.getProperty("multi.turn"); - String multiNum = environment.getProperty("multi.num"); - if (StringUtils.isBlank(multiTurn) || !Boolean.parseBoolean(multiTurn)) { - return; - } - log.info("multi turn text-to-sql!"); - List contextualList = getContextualList(parseResp, multiNum); - List contextualQuestions = getContextualQuestionsWithLink(contextualList); - StringBuffer currentPromptSb = new StringBuffer(); - if (contextualQuestions.size() == 0) { - currentPromptSb.append("contextualQuestions:" + "\n"); - } else { - currentPromptSb.append("contextualQuestions:" + "\n" + String.join("\n", contextualQuestions) + "\n"); - } - String currentQuestion = getQueryLinks(chatParseContext); - currentPromptSb.append("currentQuestion:" + currentQuestion + "\n"); - currentPromptSb.append("rewritingCurrentQuestion:\n"); - String rewriteQuery = rewriteQueryGeneration.generation(currentPromptSb.toString()); - log.info("rewriteQuery:{}", rewriteQuery); - chatParseContext.setQueryText(rewriteQuery); - } - - private List getContextualQuestionsWithLink(List contextualList) { - List contextualQuestions = new ArrayList<>(); - contextualList.stream().forEach(o -> { - Map map = JsonUtil.toMap(JsonUtil.toString( - o.getSelectedParses().get(0).getProperties().get(CONTEXT)), String.class, Object.class); - LLMReq llmReq = JsonUtil.toObject(JsonUtil.toString(map.get("llmReq")), LLMReq.class); - List linking = llmReq.getLinking(); - List priorLinkingList = new ArrayList<>(); - for (LLMReq.ElementValue priorLinking : linking) { - String fieldName = priorLinking.getFieldName(); - String fieldValue = priorLinking.getFieldValue(); - priorLinkingList.add("‘" + fieldValue + "‘是一个‘" + fieldName + "‘"); - } - String linkingListStr = String.join(",", priorLinkingList); - String questionAugmented = String.format("%s (补充信息:%s) ", o.getQueryText(), linkingListStr); - contextualQuestions.add(questionAugmented); - }); - return contextualQuestions; - } - - private List getContextualList(ParseResp parseResp, String multiNum) { - ChatQueryRepository chatQueryRepository = ContextUtils.getBean(ChatQueryRepository.class); - List contextualParseInfoList = chatQueryRepository.getContextualParseInfo( - parseResp.getChatId()).stream().filter(o -> o.getSelectedParses().get(0) - .getQueryMode().equals(LLMSqlQuery.QUERY_MODE) - ).collect(Collectors.toList()); - if (StringUtils.isNotBlank(multiNum) && StringUtils.isNumeric(multiNum)) { - int num = Integer.parseInt(multiNum); - contextualNum = Math.min(contextualNum, num); - } - List contextualList = contextualParseInfoList.subList(0, - Math.min(contextualNum, contextualParseInfoList.size())); - Collections.reverse(contextualList); - return contextualList; - } - - private String getQueryLinks(ChatParseContext chatParseContext) { - QueryReq queryReq = QueryReqConverter.buildText2SqlQueryReq(chatParseContext); - - ChatQueryServiceImpl chatQueryService = ContextUtils.getBean(ChatQueryServiceImpl.class); - // build queryContext and chatContext - QueryContext queryCtx = chatQueryService.buildQueryContext(queryReq); - - // 1. mapper - if (Objects.isNull(chatParseContext.getMapInfo()) - || MapUtils.isEmpty(chatParseContext.getMapInfo().getDataSetElementMatches())) { - schemaMappers.forEach(mapper -> { - mapper.map(queryCtx); - }); - } - LLMRequestService requestService = ContextUtils.getBean(LLMRequestService.class); - Long dataSetId = requestService.getDataSetId(queryCtx); - log.info("dataSetId:{}", dataSetId); - if (dataSetId == null) { - return null; - } - List linkingValues = requestService.getValueList(queryCtx, dataSetId); - List priorLinkingList = new ArrayList<>(); - for (LLMReq.ElementValue priorLinking : linkingValues) { - String fieldName = priorLinking.getFieldName(); - String fieldValue = priorLinking.getFieldValue(); - priorLinkingList.add("‘" + fieldValue + "‘是一个‘" + fieldName + "‘"); - } - String linkingListStr = String.join(",", priorLinkingList); - String questionAugmented = String.format("%s (补充信息:%s) ", chatParseContext.getQueryText(), linkingListStr); - log.info("questionAugmented:{}", questionAugmented); - return questionAugmented; - } - private boolean checkSkip(ParseResp parseResp) { List selectedParses = parseResp.getSelectedParses(); for (SemanticParseInfo semanticParseInfo : selectedParses) { diff --git a/chat/server/src/main/java/com/tencent/supersonic/chat/server/service/impl/ChatServiceImpl.java b/chat/server/src/main/java/com/tencent/supersonic/chat/server/service/impl/ChatServiceImpl.java index 58607fa1c..e3d07b949 100644 --- a/chat/server/src/main/java/com/tencent/supersonic/chat/server/service/impl/ChatServiceImpl.java +++ b/chat/server/src/main/java/com/tencent/supersonic/chat/server/service/impl/ChatServiceImpl.java @@ -64,7 +64,6 @@ public class ChatServiceImpl implements ChatService { @Override public ParseResp performParsing(ChatParseReq chatParseReq) { - String queryText = chatParseReq.getQueryText(); ParseResp parseResp = new ParseResp(chatParseReq.getChatId(), chatParseReq.getQueryText()); chatManageService.createChatQuery(chatParseReq, parseResp); ChatParseContext chatParseContext = buildParseContext(chatParseReq); @@ -74,8 +73,8 @@ public class ChatServiceImpl implements ChatService { for (ParseResultProcessor processor : parseResultProcessors) { processor.process(chatParseContext, parseResp); } - parseResp.setQueryText(queryText); - chatParseReq.setQueryText(queryText); + chatParseReq.setQueryText(chatParseContext.getQueryText()); + parseResp.setQueryText(chatParseContext.getQueryText()); chatManageService.batchAddParse(chatParseReq, parseResp); return parseResp; } diff --git a/headless/core/src/main/java/com/tencent/supersonic/headless/core/chat/parser/llm/RewriteExamplarLoader.java b/headless/core/src/main/java/com/tencent/supersonic/headless/core/chat/parser/llm/RewriteExamplarLoader.java deleted file mode 100644 index 9e8b36bcb..000000000 --- a/headless/core/src/main/java/com/tencent/supersonic/headless/core/chat/parser/llm/RewriteExamplarLoader.java +++ /dev/null @@ -1,32 +0,0 @@ -package com.tencent.supersonic.headless.core.chat.parser.llm; - - -import com.fasterxml.jackson.core.type.TypeReference; -import com.tencent.supersonic.common.util.JsonUtil; -import lombok.extern.slf4j.Slf4j; -import org.springframework.core.io.ClassPathResource; -import org.springframework.stereotype.Component; - -import java.io.InputStream; -import java.util.List; -import java.util.ArrayList; - -@Slf4j -@Component -public class RewriteExamplarLoader { - - private static final String EXAMPLE_JSON_FILE = "rewrite_examplar.json"; - - private TypeReference> valueTypeRef = new TypeReference>() { - }; - - public List getRewriteExamples() { - try { - ClassPathResource resource = new ClassPathResource(EXAMPLE_JSON_FILE); - InputStream inputStream = resource.getInputStream(); - return JsonUtil.INSTANCE.getObjectMapper().readValue(inputStream, valueTypeRef); - } catch (Exception e) { - return new ArrayList<>(); - } - } -} diff --git a/headless/core/src/main/java/com/tencent/supersonic/headless/core/chat/parser/llm/RewriteExample.java b/headless/core/src/main/java/com/tencent/supersonic/headless/core/chat/parser/llm/RewriteExample.java deleted file mode 100644 index d21467dca..000000000 --- a/headless/core/src/main/java/com/tencent/supersonic/headless/core/chat/parser/llm/RewriteExample.java +++ /dev/null @@ -1,14 +0,0 @@ -package com.tencent.supersonic.headless.core.chat.parser.llm; - -import lombok.Data; - -@Data -public class RewriteExample { - - private String contextualQuestions; - - private String currentQuestion; - - private String rewritingCurrentQuestion; - -} diff --git a/headless/core/src/main/java/com/tencent/supersonic/headless/core/chat/parser/llm/RewriteQueryGeneration.java b/headless/core/src/main/java/com/tencent/supersonic/headless/core/chat/parser/llm/RewriteQueryGeneration.java deleted file mode 100644 index 0abd6cdc0..000000000 --- a/headless/core/src/main/java/com/tencent/supersonic/headless/core/chat/parser/llm/RewriteQueryGeneration.java +++ /dev/null @@ -1,54 +0,0 @@ -package com.tencent.supersonic.headless.core.chat.parser.llm; - -import com.tencent.supersonic.common.util.JsonUtil; -import dev.langchain4j.data.message.AiMessage; -import dev.langchain4j.model.chat.ChatLanguageModel; -import dev.langchain4j.model.input.Prompt; -import dev.langchain4j.model.input.PromptTemplate; -import dev.langchain4j.model.output.Response; -import lombok.extern.slf4j.Slf4j; -import org.slf4j.Logger; -import org.slf4j.LoggerFactory; -import org.springframework.beans.factory.annotation.Autowired; -import org.springframework.stereotype.Service; - -import java.util.HashMap; -import java.util.List; -import java.util.Map; -import java.util.stream.Collectors; - -@Service -@Slf4j -public class RewriteQueryGeneration { - - private static final Logger keyPipelineLog = LoggerFactory.getLogger("keyPipeline"); - @Autowired - private ChatLanguageModel chatLanguageModel; - - @Autowired - private RewriteExamplarLoader rewriteExamplarLoader; - - @Autowired - private SqlPromptGenerator sqlPromptGenerator; - - public String generation(String currentPromptStr) { - //1.retriever sqlExamples - - List> rewriteExamples = rewriteExamplarLoader.getRewriteExamples().stream().map(o -> { - return JsonUtil.toMap(JsonUtil.toString(o), String.class, String.class); - }).collect(Collectors.toList()); - - //2.generator linking and sql prompt by sqlExamples,and generate response. - String promptStr = sqlPromptGenerator.generateRewritePrompt(rewriteExamples) + currentPromptStr; - - Prompt prompt = PromptTemplate.from(JsonUtil.toString(promptStr)).apply(new HashMap<>()); - keyPipelineLog.info("request prompt:{}", prompt.toSystemMessage()); - Response response = chatLanguageModel.generate(prompt.toSystemMessage()); - String result = response.content().text(); - keyPipelineLog.info("model response:{}", result); - //3.format response. - String rewriteQuery = response.content().text(); - - return rewriteQuery; - } -} diff --git a/headless/core/src/main/java/com/tencent/supersonic/headless/core/chat/parser/llm/SqlPromptGenerator.java b/headless/core/src/main/java/com/tencent/supersonic/headless/core/chat/parser/llm/SqlPromptGenerator.java index 651a07893..889dabe46 100644 --- a/headless/core/src/main/java/com/tencent/supersonic/headless/core/chat/parser/llm/SqlPromptGenerator.java +++ b/headless/core/src/main/java/com/tencent/supersonic/headless/core/chat/parser/llm/SqlPromptGenerator.java @@ -129,19 +129,4 @@ public class SqlPromptGenerator { return sqlPromptPool; } - public String generateRewritePrompt(List> rewriteExamples) { - String instruction = "#this is a multi-turn text-to-sql scenes,you need consider the contextual " - + "questions and semantics, rewriting current question for expressing complete semantics of " - + "the current question based on the contextual questions."; - List exampleKeys = Arrays.asList("contextualQuestions", "currentQuestion", "rewritingCurrentQuestion"); - StringBuilder rewriteSb = new StringBuilder(); - rewriteExamples.stream().forEach(o -> { - exampleKeys.stream().forEach(example -> { - rewriteSb.append(example + ":" + o.get(example) + "\n"); - }); - rewriteSb.append("\n"); - }); - return instruction + InputFormat.SEPERATOR + rewriteSb.toString(); - } - } diff --git a/launchers/standalone/src/main/resources/META-INF/spring.factories b/launchers/standalone/src/main/resources/META-INF/spring.factories index 46142a5a8..e0750b1e4 100644 --- a/launchers/standalone/src/main/resources/META-INF/spring.factories +++ b/launchers/standalone/src/main/resources/META-INF/spring.factories @@ -47,6 +47,7 @@ com.tencent.supersonic.auth.api.authentication.adaptor.UserAdaptor=\ com.tencent.supersonic.chat.server.parser.ChatParser=\ com.tencent.supersonic.chat.server.parser.NL2PluginParser, \ + com.tencent.supersonic.chat.server.parser.MultiTurnParser,\ com.tencent.supersonic.chat.server.parser.NL2SQLParser com.tencent.supersonic.chat.server.executor.ChatExecutor=\