From 0e28d6cbccfe055b270ff560a26d800d27556a8d Mon Sep 17 00:00:00 2001 From: mainmain <57514971+mainmainer@users.noreply.github.com> Date: Mon, 13 May 2024 14:13:02 +0800 Subject: [PATCH] (improvement)(Headless) support multiturn text-to-sql (#983) --- .../chat/server/parser/NL2SQLParser.java | 131 ++++++++++++++++++ .../persistence/mapper/ChatParseMapper.java | 2 + .../repository/ChatQueryRepository.java | 2 + .../impl/ChatQueryRepositoryImpl.java | 22 ++- .../chat/server/rest/ChatController.java | 8 +- .../server/service/impl/ChatServiceImpl.java | 3 + .../main/resources/mapper/ChatParseMapper.xml | 6 + .../chat/parser/llm/LLMRequestService.java | 2 +- .../parser/llm/RewriteExamplarLoader.java | 32 +++++ .../core/chat/parser/llm/RewriteExample.java | 14 ++ .../parser/llm/RewriteQueryGeneration.java | 54 ++++++++ .../chat/parser/llm/SqlPromptGenerator.java | 15 ++ .../service/impl/ChatQueryServiceImpl.java | 2 +- .../src/main/resources/application-local.yaml | 3 + .../src/main/resources/rewrite_examplar.json | 122 ++++++++++++++++ 15 files changed, 407 insertions(+), 11 deletions(-) create mode 100644 headless/core/src/main/java/com/tencent/supersonic/headless/core/chat/parser/llm/RewriteExamplarLoader.java create mode 100644 headless/core/src/main/java/com/tencent/supersonic/headless/core/chat/parser/llm/RewriteExample.java create mode 100644 headless/core/src/main/java/com/tencent/supersonic/headless/core/chat/parser/llm/RewriteQueryGeneration.java create mode 100644 launchers/standalone/src/main/resources/rewrite_examplar.json diff --git a/chat/server/src/main/java/com/tencent/supersonic/chat/server/parser/NL2SQLParser.java b/chat/server/src/main/java/com/tencent/supersonic/chat/server/parser/NL2SQLParser.java index 5da72ec1f..b884f761a 100644 --- a/chat/server/src/main/java/com/tencent/supersonic/chat/server/parser/NL2SQLParser.java +++ b/chat/server/src/main/java/com/tencent/supersonic/chat/server/parser/NL2SQLParser.java @@ -1,18 +1,47 @@ package com.tencent.supersonic.chat.server.parser; +import com.tencent.supersonic.chat.server.persistence.repository.ChatQueryRepository; import com.tencent.supersonic.chat.server.pojo.ChatParseContext; import com.tencent.supersonic.chat.server.util.QueryReqConverter; import com.tencent.supersonic.common.util.ContextUtils; +import com.tencent.supersonic.common.util.JsonUtil; +import com.tencent.supersonic.headless.api.pojo.SchemaMapInfo; import com.tencent.supersonic.headless.api.pojo.SemanticParseInfo; import com.tencent.supersonic.headless.api.pojo.request.QueryReq; import com.tencent.supersonic.headless.api.pojo.response.ParseResp; +import com.tencent.supersonic.headless.core.chat.mapper.SchemaMapper; +import com.tencent.supersonic.headless.core.chat.parser.llm.LLMRequestService; +import com.tencent.supersonic.headless.core.chat.parser.llm.RewriteQueryGeneration; +import com.tencent.supersonic.headless.core.chat.query.llm.s2sql.LLMReq; +import com.tencent.supersonic.headless.core.chat.query.llm.s2sql.LLMSqlQuery; +import com.tencent.supersonic.headless.core.pojo.QueryContext; import com.tencent.supersonic.headless.server.service.ChatQueryService; + import java.util.List; +import java.util.Collections; +import java.util.ArrayList; +import java.util.Map; +import java.util.Objects; + +import java.util.stream.Collectors; + + +import com.tencent.supersonic.headless.server.service.impl.ChatQueryServiceImpl; +import com.tencent.supersonic.headless.server.utils.ComponentFactory; import lombok.extern.slf4j.Slf4j; +import org.apache.commons.collections.MapUtils; +import org.apache.commons.lang3.StringUtils; +import org.springframework.core.env.Environment; + +import static com.tencent.supersonic.common.pojo.Constants.CONTEXT; @Slf4j public class NL2SQLParser implements ChatParser { + private int contextualNum = 5; + + private List schemaMappers = ComponentFactory.getSchemaMappers(); + @Override public void parse(ChatParseContext chatParseContext, ParseResp parseResp) { if (!chatParseContext.enableNL2SQL()) { @@ -21,7 +50,15 @@ public class NL2SQLParser implements ChatParser { if (checkSkip(parseResp)) { return; } + considerMultiturn(chatParseContext, parseResp); QueryReq queryReq = QueryReqConverter.buildText2SqlQueryReq(chatParseContext); + + Environment environment = ContextUtils.getBean(Environment.class); + String multiTurn = environment.getProperty("multi.turn"); + if (StringUtils.isNotBlank(multiTurn) && Boolean.parseBoolean(multiTurn)) { + queryReq.setMapInfo(new SchemaMapInfo()); + } + ChatQueryService chatQueryService = ContextUtils.getBean(ChatQueryService.class); ParseResp text2SqlParseResp = chatQueryService.performParsing(queryReq); if (!ParseResp.ParseState.FAILED.equals(text2SqlParseResp.getState())) { @@ -30,6 +67,100 @@ public class NL2SQLParser implements ChatParser { parseResp.getParseTimeCost().setSqlTime(text2SqlParseResp.getParseTimeCost().getSqlTime()); } + private void considerMultiturn(ChatParseContext chatParseContext, ParseResp parseResp) { + Environment environment = ContextUtils.getBean(Environment.class); + RewriteQueryGeneration rewriteQueryGeneration = ContextUtils.getBean(RewriteQueryGeneration.class); + String multiTurn = environment.getProperty("multi.turn"); + String multiNum = environment.getProperty("multi.num"); + if (StringUtils.isBlank(multiTurn) || !Boolean.parseBoolean(multiTurn)) { + return; + } + log.info("multi turn text-to-sql!"); + List contextualList = getContextualList(parseResp, multiNum); + List contextualQuestions = getContextualQuestionsWithLink(contextualList); + StringBuffer currentPromptSb = new StringBuffer(); + if (contextualQuestions.size() == 0) { + currentPromptSb.append("contextualQuestions:" + "\n"); + } else { + currentPromptSb.append("contextualQuestions:" + "\n" + String.join("\n", contextualQuestions) + "\n"); + } + String currentQuestion = getQueryLinks(chatParseContext); + currentPromptSb.append("currentQuestion:" + currentQuestion + "\n"); + currentPromptSb.append("rewritingCurrentQuestion:\n"); + String rewriteQuery = rewriteQueryGeneration.generation(currentPromptSb.toString()); + log.info("rewriteQuery:{}", rewriteQuery); + chatParseContext.setQueryText(rewriteQuery); + } + + private List getContextualQuestionsWithLink(List contextualList) { + List contextualQuestions = new ArrayList<>(); + contextualList.stream().forEach(o -> { + Map map = JsonUtil.toMap(JsonUtil.toString( + o.getSelectedParses().get(0).getProperties().get(CONTEXT)), String.class, Object.class); + LLMReq llmReq = JsonUtil.toObject(JsonUtil.toString(map.get("llmReq")), LLMReq.class); + List linking = llmReq.getLinking(); + List priorLinkingList = new ArrayList<>(); + for (LLMReq.ElementValue priorLinking : linking) { + String fieldName = priorLinking.getFieldName(); + String fieldValue = priorLinking.getFieldValue(); + priorLinkingList.add("‘" + fieldValue + "‘是一个‘" + fieldName + "‘"); + } + String linkingListStr = String.join(",", priorLinkingList); + String questionAugmented = String.format("%s (补充信息:%s) ", o.getQueryText(), linkingListStr); + contextualQuestions.add(questionAugmented); + }); + return contextualQuestions; + } + + private List getContextualList(ParseResp parseResp, String multiNum) { + ChatQueryRepository chatQueryRepository = ContextUtils.getBean(ChatQueryRepository.class); + List contextualParseInfoList = chatQueryRepository.getContextualParseInfo( + parseResp.getChatId()).stream().filter(o -> o.getSelectedParses().get(0) + .getQueryMode().equals(LLMSqlQuery.QUERY_MODE) + ).collect(Collectors.toList()); + if (StringUtils.isNotBlank(multiNum) && StringUtils.isNumeric(multiNum)) { + int num = Integer.parseInt(multiNum); + contextualNum = Math.min(contextualNum, num); + } + List contextualList = contextualParseInfoList.subList(0, + Math.min(contextualNum, contextualParseInfoList.size())); + Collections.reverse(contextualList); + return contextualList; + } + + private String getQueryLinks(ChatParseContext chatParseContext) { + QueryReq queryReq = QueryReqConverter.buildText2SqlQueryReq(chatParseContext); + + ChatQueryServiceImpl chatQueryService = ContextUtils.getBean(ChatQueryServiceImpl.class); + // build queryContext and chatContext + QueryContext queryCtx = chatQueryService.buildQueryContext(queryReq); + + // 1. mapper + if (Objects.isNull(chatParseContext.getMapInfo()) + || MapUtils.isEmpty(chatParseContext.getMapInfo().getDataSetElementMatches())) { + schemaMappers.forEach(mapper -> { + mapper.map(queryCtx); + }); + } + LLMRequestService requestService = ContextUtils.getBean(LLMRequestService.class); + Long dataSetId = requestService.getDataSetId(queryCtx); + log.info("dataSetId:{}", dataSetId); + if (dataSetId == null) { + return null; + } + List linkingValues = requestService.getValueList(queryCtx, dataSetId); + List priorLinkingList = new ArrayList<>(); + for (LLMReq.ElementValue priorLinking : linkingValues) { + String fieldName = priorLinking.getFieldName(); + String fieldValue = priorLinking.getFieldValue(); + priorLinkingList.add("‘" + fieldValue + "‘是一个‘" + fieldName + "‘"); + } + String linkingListStr = String.join(",", priorLinkingList); + String questionAugmented = String.format("%s (补充信息:%s) ", chatParseContext.getQueryText(), linkingListStr); + log.info("questionAugmented:{}", questionAugmented); + return questionAugmented; + } + private boolean checkSkip(ParseResp parseResp) { List selectedParses = parseResp.getSelectedParses(); for (SemanticParseInfo semanticParseInfo : selectedParses) { diff --git a/chat/server/src/main/java/com/tencent/supersonic/chat/server/persistence/mapper/ChatParseMapper.java b/chat/server/src/main/java/com/tencent/supersonic/chat/server/persistence/mapper/ChatParseMapper.java index d76ea4445..dd2c3648b 100644 --- a/chat/server/src/main/java/com/tencent/supersonic/chat/server/persistence/mapper/ChatParseMapper.java +++ b/chat/server/src/main/java/com/tencent/supersonic/chat/server/persistence/mapper/ChatParseMapper.java @@ -18,4 +18,6 @@ public interface ChatParseMapper { List getParseInfoList(List questionIds); + List getContextualParseInfo(Integer chatId); + } diff --git a/chat/server/src/main/java/com/tencent/supersonic/chat/server/persistence/repository/ChatQueryRepository.java b/chat/server/src/main/java/com/tencent/supersonic/chat/server/persistence/repository/ChatQueryRepository.java index 7255ba78e..ae32afb95 100644 --- a/chat/server/src/main/java/com/tencent/supersonic/chat/server/persistence/repository/ChatQueryRepository.java +++ b/chat/server/src/main/java/com/tencent/supersonic/chat/server/persistence/repository/ChatQueryRepository.java @@ -36,4 +36,6 @@ public interface ChatQueryRepository { Boolean deleteChatQuery(Long questionId); + List getContextualParseInfo(Integer chatId); + } diff --git a/chat/server/src/main/java/com/tencent/supersonic/chat/server/persistence/repository/impl/ChatQueryRepositoryImpl.java b/chat/server/src/main/java/com/tencent/supersonic/chat/server/persistence/repository/impl/ChatQueryRepositoryImpl.java index 0c9e07a4b..e09687fa8 100644 --- a/chat/server/src/main/java/com/tencent/supersonic/chat/server/persistence/repository/impl/ChatQueryRepositoryImpl.java +++ b/chat/server/src/main/java/com/tencent/supersonic/chat/server/persistence/repository/impl/ChatQueryRepositoryImpl.java @@ -26,6 +26,7 @@ import org.springframework.beans.BeanUtils; import org.springframework.context.annotation.Primary; import org.springframework.stereotype.Repository; import org.springframework.util.CollectionUtils; + import java.util.ArrayList; import java.util.Comparator; import java.util.List; @@ -44,8 +45,8 @@ public class ChatQueryRepositoryImpl implements ChatQueryRepository { private final ShowCaseCustomMapper showCaseCustomMapper; public ChatQueryRepositoryImpl(ChatQueryDOMapper chatQueryDOMapper, - ChatParseMapper chatParseMapper, - ShowCaseCustomMapper showCaseCustomMapper) { + ChatParseMapper chatParseMapper, + ShowCaseCustomMapper showCaseCustomMapper) { this.chatQueryDOMapper = chatQueryDOMapper; this.chatParseMapper = chatParseMapper; this.showCaseCustomMapper = showCaseCustomMapper; @@ -131,7 +132,7 @@ public class ChatQueryRepositoryImpl implements ChatQueryRepository { @Override public List batchSaveParseInfo(ChatParseReq chatParseReq, - ParseResp parseResult, List candidateParses) { + ParseResp parseResult, List candidateParses) { List chatParseDOList = new ArrayList<>(); getChatParseDO(chatParseReq, parseResult.getQueryId(), candidateParses, chatParseDOList); if (!CollectionUtils.isEmpty(candidateParses)) { @@ -141,7 +142,7 @@ public class ChatQueryRepositoryImpl implements ChatQueryRepository { } public void getChatParseDO(ChatParseReq chatParseReq, Long queryId, - List parses, List chatParseDOList) { + List parses, List chatParseDOList) { for (int i = 0; i < parses.size(); i++) { ChatParseDO chatParseDO = new ChatParseDO(); chatParseDO.setChatId(Long.valueOf(chatParseReq.getChatId())); @@ -193,4 +194,17 @@ public class ChatQueryRepositoryImpl implements ChatQueryRepository { return chatQueryDOMapper.deleteByPrimaryKey(questionId); } + @Override + public List getContextualParseInfo(Integer chatId) { + List chatParseDOList = chatParseMapper.getContextualParseInfo(chatId); + List semanticParseInfoList = chatParseDOList.stream().map(parseInfo -> { + ParseResp parseResp = new ParseResp(chatId, parseInfo.getQueryText()); + List selectedParses = new ArrayList<>(); + selectedParses.add(JSONObject.parseObject(parseInfo.getParseInfo(), SemanticParseInfo.class)); + parseResp.setSelectedParses(selectedParses); + return parseResp; + }).collect(Collectors.toList()); + return semanticParseInfoList; + } + } diff --git a/chat/server/src/main/java/com/tencent/supersonic/chat/server/rest/ChatController.java b/chat/server/src/main/java/com/tencent/supersonic/chat/server/rest/ChatController.java index b8001fb86..77c8cdf37 100644 --- a/chat/server/src/main/java/com/tencent/supersonic/chat/server/rest/ChatController.java +++ b/chat/server/src/main/java/com/tencent/supersonic/chat/server/rest/ChatController.java @@ -8,6 +8,7 @@ import com.tencent.supersonic.chat.api.pojo.response.QueryResp; import com.tencent.supersonic.chat.api.pojo.response.ShowCaseResp; import com.tencent.supersonic.chat.server.persistence.dataobject.ChatDO; import com.tencent.supersonic.chat.server.service.ChatManageService; +import org.springframework.beans.factory.annotation.Autowired; import org.springframework.web.bind.annotation.GetMapping; import org.springframework.web.bind.annotation.PathVariable; import org.springframework.web.bind.annotation.PostMapping; @@ -24,11 +25,8 @@ import java.util.List; @RequestMapping({"/api/chat/manage", "/openapi/chat/manage"}) public class ChatController { - private final ChatManageService chatService; - - public ChatController(ChatManageService chatService) { - this.chatService = chatService; - } + @Autowired + private ChatManageService chatService; @PostMapping("/save") public Boolean save(@RequestParam(value = "chatName") String chatName, diff --git a/chat/server/src/main/java/com/tencent/supersonic/chat/server/service/impl/ChatServiceImpl.java b/chat/server/src/main/java/com/tencent/supersonic/chat/server/service/impl/ChatServiceImpl.java index b18900237..58607fa1c 100644 --- a/chat/server/src/main/java/com/tencent/supersonic/chat/server/service/impl/ChatServiceImpl.java +++ b/chat/server/src/main/java/com/tencent/supersonic/chat/server/service/impl/ChatServiceImpl.java @@ -64,6 +64,7 @@ public class ChatServiceImpl implements ChatService { @Override public ParseResp performParsing(ChatParseReq chatParseReq) { + String queryText = chatParseReq.getQueryText(); ParseResp parseResp = new ParseResp(chatParseReq.getChatId(), chatParseReq.getQueryText()); chatManageService.createChatQuery(chatParseReq, parseResp); ChatParseContext chatParseContext = buildParseContext(chatParseReq); @@ -73,6 +74,8 @@ public class ChatServiceImpl implements ChatService { for (ParseResultProcessor processor : parseResultProcessors) { processor.process(chatParseContext, parseResp); } + parseResp.setQueryText(queryText); + chatParseReq.setQueryText(queryText); chatManageService.batchAddParse(chatParseReq, parseResp); return parseResp; } diff --git a/chat/server/src/main/resources/mapper/ChatParseMapper.xml b/chat/server/src/main/resources/mapper/ChatParseMapper.xml index f5242e3fb..8e1e5336a 100644 --- a/chat/server/src/main/resources/mapper/ChatParseMapper.xml +++ b/chat/server/src/main/resources/mapper/ChatParseMapper.xml @@ -46,4 +46,10 @@ + + diff --git a/headless/core/src/main/java/com/tencent/supersonic/headless/core/chat/parser/llm/LLMRequestService.java b/headless/core/src/main/java/com/tencent/supersonic/headless/core/chat/parser/llm/LLMRequestService.java index 65a2828c9..0a37ea86c 100644 --- a/headless/core/src/main/java/com/tencent/supersonic/headless/core/chat/parser/llm/LLMRequestService.java +++ b/headless/core/src/main/java/com/tencent/supersonic/headless/core/chat/parser/llm/LLMRequestService.java @@ -143,7 +143,7 @@ public class LLMRequestService { return extraInfoSb.toString(); } - protected List getValueList(QueryContext queryCtx, Long dataSetId) { + public List getValueList(QueryContext queryCtx, Long dataSetId) { Map itemIdToName = getItemIdToName(queryCtx, dataSetId); List matchedElements = queryCtx.getMapInfo().getMatchedElements(dataSetId); if (CollectionUtils.isEmpty(matchedElements)) { diff --git a/headless/core/src/main/java/com/tencent/supersonic/headless/core/chat/parser/llm/RewriteExamplarLoader.java b/headless/core/src/main/java/com/tencent/supersonic/headless/core/chat/parser/llm/RewriteExamplarLoader.java new file mode 100644 index 000000000..9e8b36bcb --- /dev/null +++ b/headless/core/src/main/java/com/tencent/supersonic/headless/core/chat/parser/llm/RewriteExamplarLoader.java @@ -0,0 +1,32 @@ +package com.tencent.supersonic.headless.core.chat.parser.llm; + + +import com.fasterxml.jackson.core.type.TypeReference; +import com.tencent.supersonic.common.util.JsonUtil; +import lombok.extern.slf4j.Slf4j; +import org.springframework.core.io.ClassPathResource; +import org.springframework.stereotype.Component; + +import java.io.InputStream; +import java.util.List; +import java.util.ArrayList; + +@Slf4j +@Component +public class RewriteExamplarLoader { + + private static final String EXAMPLE_JSON_FILE = "rewrite_examplar.json"; + + private TypeReference> valueTypeRef = new TypeReference>() { + }; + + public List getRewriteExamples() { + try { + ClassPathResource resource = new ClassPathResource(EXAMPLE_JSON_FILE); + InputStream inputStream = resource.getInputStream(); + return JsonUtil.INSTANCE.getObjectMapper().readValue(inputStream, valueTypeRef); + } catch (Exception e) { + return new ArrayList<>(); + } + } +} diff --git a/headless/core/src/main/java/com/tencent/supersonic/headless/core/chat/parser/llm/RewriteExample.java b/headless/core/src/main/java/com/tencent/supersonic/headless/core/chat/parser/llm/RewriteExample.java new file mode 100644 index 000000000..d21467dca --- /dev/null +++ b/headless/core/src/main/java/com/tencent/supersonic/headless/core/chat/parser/llm/RewriteExample.java @@ -0,0 +1,14 @@ +package com.tencent.supersonic.headless.core.chat.parser.llm; + +import lombok.Data; + +@Data +public class RewriteExample { + + private String contextualQuestions; + + private String currentQuestion; + + private String rewritingCurrentQuestion; + +} diff --git a/headless/core/src/main/java/com/tencent/supersonic/headless/core/chat/parser/llm/RewriteQueryGeneration.java b/headless/core/src/main/java/com/tencent/supersonic/headless/core/chat/parser/llm/RewriteQueryGeneration.java new file mode 100644 index 000000000..0abd6cdc0 --- /dev/null +++ b/headless/core/src/main/java/com/tencent/supersonic/headless/core/chat/parser/llm/RewriteQueryGeneration.java @@ -0,0 +1,54 @@ +package com.tencent.supersonic.headless.core.chat.parser.llm; + +import com.tencent.supersonic.common.util.JsonUtil; +import dev.langchain4j.data.message.AiMessage; +import dev.langchain4j.model.chat.ChatLanguageModel; +import dev.langchain4j.model.input.Prompt; +import dev.langchain4j.model.input.PromptTemplate; +import dev.langchain4j.model.output.Response; +import lombok.extern.slf4j.Slf4j; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; +import org.springframework.beans.factory.annotation.Autowired; +import org.springframework.stereotype.Service; + +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.stream.Collectors; + +@Service +@Slf4j +public class RewriteQueryGeneration { + + private static final Logger keyPipelineLog = LoggerFactory.getLogger("keyPipeline"); + @Autowired + private ChatLanguageModel chatLanguageModel; + + @Autowired + private RewriteExamplarLoader rewriteExamplarLoader; + + @Autowired + private SqlPromptGenerator sqlPromptGenerator; + + public String generation(String currentPromptStr) { + //1.retriever sqlExamples + + List> rewriteExamples = rewriteExamplarLoader.getRewriteExamples().stream().map(o -> { + return JsonUtil.toMap(JsonUtil.toString(o), String.class, String.class); + }).collect(Collectors.toList()); + + //2.generator linking and sql prompt by sqlExamples,and generate response. + String promptStr = sqlPromptGenerator.generateRewritePrompt(rewriteExamples) + currentPromptStr; + + Prompt prompt = PromptTemplate.from(JsonUtil.toString(promptStr)).apply(new HashMap<>()); + keyPipelineLog.info("request prompt:{}", prompt.toSystemMessage()); + Response response = chatLanguageModel.generate(prompt.toSystemMessage()); + String result = response.content().text(); + keyPipelineLog.info("model response:{}", result); + //3.format response. + String rewriteQuery = response.content().text(); + + return rewriteQuery; + } +} diff --git a/headless/core/src/main/java/com/tencent/supersonic/headless/core/chat/parser/llm/SqlPromptGenerator.java b/headless/core/src/main/java/com/tencent/supersonic/headless/core/chat/parser/llm/SqlPromptGenerator.java index 889dabe46..651a07893 100644 --- a/headless/core/src/main/java/com/tencent/supersonic/headless/core/chat/parser/llm/SqlPromptGenerator.java +++ b/headless/core/src/main/java/com/tencent/supersonic/headless/core/chat/parser/llm/SqlPromptGenerator.java @@ -129,4 +129,19 @@ public class SqlPromptGenerator { return sqlPromptPool; } + public String generateRewritePrompt(List> rewriteExamples) { + String instruction = "#this is a multi-turn text-to-sql scenes,you need consider the contextual " + + "questions and semantics, rewriting current question for expressing complete semantics of " + + "the current question based on the contextual questions."; + List exampleKeys = Arrays.asList("contextualQuestions", "currentQuestion", "rewritingCurrentQuestion"); + StringBuilder rewriteSb = new StringBuilder(); + rewriteExamples.stream().forEach(o -> { + exampleKeys.stream().forEach(example -> { + rewriteSb.append(example + ":" + o.get(example) + "\n"); + }); + rewriteSb.append("\n"); + }); + return instruction + InputFormat.SEPERATOR + rewriteSb.toString(); + } + } diff --git a/headless/server/src/main/java/com/tencent/supersonic/headless/server/service/impl/ChatQueryServiceImpl.java b/headless/server/src/main/java/com/tencent/supersonic/headless/server/service/impl/ChatQueryServiceImpl.java index e27912c8f..c57edf85b 100644 --- a/headless/server/src/main/java/com/tencent/supersonic/headless/server/service/impl/ChatQueryServiceImpl.java +++ b/headless/server/src/main/java/com/tencent/supersonic/headless/server/service/impl/ChatQueryServiceImpl.java @@ -134,7 +134,7 @@ public class ChatQueryServiceImpl implements ChatQueryService { return parseResult; } - private QueryContext buildQueryContext(QueryReq queryReq) { + public QueryContext buildQueryContext(QueryReq queryReq) { SemanticSchema semanticSchema = semanticService.getSemanticSchema(); Map> modelIdToDataSetIds = dataSetService.getModelIdToDataSetIds(); diff --git a/launchers/standalone/src/main/resources/application-local.yaml b/launchers/standalone/src/main/resources/application-local.yaml index ad6d4b92d..a3fa6c3ea 100644 --- a/launchers/standalone/src/main/resources/application-local.yaml +++ b/launchers/standalone/src/main/resources/application-local.yaml @@ -102,3 +102,6 @@ inMemoryEmbeddingStore: query: optimizer: enable: true +multi: + turn: false + num: 5 diff --git a/launchers/standalone/src/main/resources/rewrite_examplar.json b/launchers/standalone/src/main/resources/rewrite_examplar.json new file mode 100644 index 000000000..03166fb35 --- /dev/null +++ b/launchers/standalone/src/main/resources/rewrite_examplar.json @@ -0,0 +1,122 @@ +[ + { + "contextualQuestions": "[“近7天纯音乐的歌曲播放量 (补充信息:’ '纯音乐'‘是一个’语种‘。)”]", + "currentQuestion": "对比翻唱版呢 (补充信息:’ '翻唱版'‘是一个’歌曲版本‘。)", + "rewritingCurrentQuestion": "对比近7天翻唱版和纯音乐的歌曲播放量" + }, + { + "contextualQuestions": "[]", + "currentQuestion": "robinlee在内容库的访问次数 (补充信息:’ 'robinlee'‘是一个’用户名‘。)", + "rewritingCurrentQuestion": "robinlee在内容库的访问次数" + }, + { + "contextualQuestions": "[\"robinlee在内容库的访问次数 (补充信息:’ 'robinlee'‘是一个’用户名‘。)\"]", + "currentQuestion": "对比jackjchen呢? (补充信息:’ 'jackjchen'‘是一个’用户名‘。)", + "rewritingCurrentQuestion": "robinlee对比jackjchen在内容库的访问次数" + }, + { + "contextualQuestions": "[\"robinlee在内容库的访问次数 (补充信息:’ 'robinlee'‘是一个’用户名‘。)\",\"对比jackjchen呢? (补充信息:’ 'jackjchen'‘是一个’用户名‘。)\"]。", + "currentQuestion": "内容库近12个月访问人数按部门", + "rewritingCurrentQuestion": "内容库近12个月访问人数按部门" + }, + { + "contextualQuestions": "[\"robinlee在内容库的访问次数 (补充信息:’ 'robinlee'‘是一个’用户名‘。)\",\"对比jackjchen呢? (补充信息:’ 'jackjchen'‘是一个’用户名‘。)\",\"内容库近12个月访问人数按部门\"]", + "currentQuestion": "访问次数呢?", + "rewritingCurrentQuestion": "内容库近12个月访问次数按部门" + }, + { + "contextualQuestions": "[]", + "currentQuestion": "近3天海田飞系MPPM结算播放份额 (补充信息:’'海田飞系'‘是一个’严选版权归属系‘)", + "rewritingCurrentQuestion": "近3天海田飞系MPPM结算播放份额" + }, + { + "contextualQuestions": "[\"近3天海田飞系MPPM结算播放份额(补充信息:’'海田飞系'‘是一个’严选版权归属系‘) \"]", + "currentQuestion": "近60天呢", + "rewritingCurrentQuestion": "近60天海田飞系MPPM结算播放份额" + }, + { + "contextualQuestions": "[\"近3天海田飞系MPPM结算播放份额(补充信息:’'海田飞系'‘是一个’严选版权归属系‘) \",\"近60天呢\"]", + "currentQuestion": "飞天系呢(补充信息:’'飞天系'‘是一个’严选版权归属系‘)", + "rewritingCurrentQuestion": "近60天飞天系MPPM结算播放份额" + }, + { + "contextualQuestions": "[“近90天袁亚伟播放量是多少 (补充信息:'袁亚伟'是一个歌手名)”]", + "currentQuestion": "平均值是多少", + "rewritingCurrentQuestion": "近90天袁亚伟播放量的平均值是多少" + }, + { + "contextualQuestions": "[“近90天袁亚伟播放量是多少 (补充信息:'袁亚伟'是一个歌手名)”,\"平均值是多少\",\"总和是多少\"]", + "currentQuestion": "总和是多少", + "rewritingCurrentQuestion": "近90天袁亚伟播放量的总和是多少" + }, + { + "contextualQuestions": "[\"播放量大于1万的歌曲有多少\"]", + "currentQuestion": "下载量大于10万的呢", + "rewritingCurrentQuestion": "下载量大于10万的歌曲有多少" + }, + { + "contextualQuestions": "[\"周杰伦2023年6月之后发布的歌曲有哪些(补充信息:'周杰伦'是一个歌手名)\"]", + "currentQuestion": "这些歌曲有哪些播放量大于500W的?", + "rewritingCurrentQuestion": "周杰伦2023年6月之后发布的歌曲,有哪些播放量大于500W的?" + }, + { + "contextualQuestions": "[“陈奕迅唱的所有的播放量大于20万的歌曲有哪些(补充信息:'陈奕迅'是一个歌手名)”]", + "currentQuestion": "大于100万的呢", + "rewritingCurrentQuestion": "陈奕迅唱的所有的播放量大于100万的歌曲有哪些" + }, + { + "contextualQuestions": "[“陈奕迅唱的所有的播放量大于20万的歌曲有哪些(补充信息:'陈奕迅'是一个歌手名)”,\"大于100万的呢\"]", + "currentQuestion": "周杰伦去年发布的歌曲有哪些(补充信息:'周杰伦'是一个歌手名)", + "rewritingCurrentQuestion": "周杰伦去年发布的歌曲有哪些" + }, + { + "contextualQuestions": "[“陈奕迅唱的所有的播放量大于20万的歌曲有哪些(补充信息:'陈奕迅'是一个歌手名)”,\"大于100万的呢\",\"周杰伦去年发布的歌曲有哪些(补充信息:'周杰伦'是一个歌手名)\"]", + "currentQuestion": "他今年发布的呢", + "rewritingCurrentQuestion": "周杰伦今年发布的歌曲有哪些" + }, + { + "contextualQuestions": "[“陈奕迅唱的所有的播放量大于20万的歌曲有哪些(补充信息:'陈奕迅'是一个歌手名)”,\"大于100万的呢\",\"周杰伦去年发布的歌曲有哪些(补充信息:'周杰伦'是一个歌手名)\",\"他今年发布的呢\"]", + "currentQuestion": "我想要近半年签约的播放量前十的歌手有哪些", + "rewritingCurrentQuestion": "我想要近半年签约的播放量前十的歌手有哪些" + }, + { + "contextualQuestions": "[]", + "currentQuestion": "最近一年发行的歌曲中,有哪些在近7天播放超过一千万的", + "rewritingCurrentQuestion": "最近一年发行的歌曲中,有哪些在近7天播放超过一千万的" + }, + { + "contextualQuestions": "[“最近一年发行的歌曲中,有哪些在近7天播放超过一千万的”]", + "currentQuestion": "今年以来呢?", + "rewritingCurrentQuestion": "今年以来发行的歌曲中,有哪些在近7天播放超过一千万的" + }, + { + "contextualQuestions": "[“最近一年发行的歌曲中,有哪些在近7天播放超过一千万的”,\"今年以来呢?\"]", + "currentQuestion": "2023年以来呢?", + "rewritingCurrentQuestion": "2023年以来发行的歌曲中,有哪些在近7天播放超过一千万的" + }, + { + "contextualQuestions": "[\"内容库近20天访问次数\"]", + "currentQuestion": "按部门看一下", + "rewritingCurrentQuestion": "内容库近20天按部门的访问次数" + }, + { + "contextualQuestions": "[\"内容库近20天访问次数\",\"按部门看一下\"]", + "currentQuestion": "按模块看一下", + "rewritingCurrentQuestion": "内容库近20天按模块的访问次数" + }, + { + "contextualQuestions": "[\"内容库近20天访问次数\",\"按部门看一下\",\"按模块看一下\"]", + "currentQuestion": "看一下技术部的 (补充信息:’'技术部'‘是一个’部门‘)", + "rewritingCurrentQuestion": "技术部在内容库近20天的访问次数" + }, + { + "contextualQuestions": "[\"内容库近20天访问次数\",\"按部门看一下\",\"按模块看一下\",\"看一下技术部的 (补充信息:’'技术部'‘是一个’部门‘)\"]", + "currentQuestion": "看一下产品部的 (补充信息:’'产品部'‘是一个’部门‘)", + "rewritingCurrentQuestion": "产品部在内容库近20天的访问次数" + }, + { + "contextualQuestions": "[\"内容库近20天访问次数\",\"按部门看一下\",\"按模块看一下\",\"看一下技术部的 (补充信息:’'技术部'‘是一个’部门‘)\",\"看一下产品部的 (补充信息:’'产品部'‘是一个’部门‘)\"]", + "currentQuestion": "对比一下技术部、产品部(补充信息:'技术部'、‘产品部’分别是一个’部门‘)", + "rewritingCurrentQuestion": "对比一下技术部、产品部在内容库近20天的访问次数" + } +]