From 1df1fe5ad6d09978134290588730d246c26a2bf4 Mon Sep 17 00:00:00 2001 From: jerryjzhang Date: Thu, 20 Jun 2024 15:03:00 +0800 Subject: [PATCH] (improvement)(chat)Optimize multi-turn prompts. --- .../chat/server/parser/MultiTurnParser.java | 42 +++++++++---------- .../parser/llm/OnePassSCSqlGenStrategy.java | 8 ++-- .../test/resources/META-INF/spring.factories | 1 + pom.xml | 5 ++- 4 files changed, 30 insertions(+), 26 deletions(-) diff --git a/chat/server/src/main/java/com/tencent/supersonic/chat/server/parser/MultiTurnParser.java b/chat/server/src/main/java/com/tencent/supersonic/chat/server/parser/MultiTurnParser.java index 227f0ad72..a90b727d3 100644 --- a/chat/server/src/main/java/com/tencent/supersonic/chat/server/parser/MultiTurnParser.java +++ b/chat/server/src/main/java/com/tencent/supersonic/chat/server/parser/MultiTurnParser.java @@ -24,8 +24,6 @@ import lombok.extern.slf4j.Slf4j; import org.slf4j.Logger; import org.slf4j.LoggerFactory; -import java.util.Map; -import java.util.HashMap; import java.util.ArrayList; import java.util.List; import java.util.stream.Collectors; @@ -38,17 +36,20 @@ public class MultiTurnParser implements ChatParser { private static final Logger keyPipelineLog = LoggerFactory.getLogger("keyPipeline"); - private static final PromptTemplate promptTemplate = PromptTemplate.from( - "You are a data product manager experienced in data requirements." - + "Your will be provided with current and history questions asked by a user," - + "along with their mapped schema elements(metric, dimension and value), " - + "please try understanding the semantics and rewrite a question" - + "(keep relevant entities, metrics, dimensions, values and date ranges)." - + "Current Question: {{curtQuestion}} " - + "Current Mapped Schema: {{curtSchema}} " - + "History Question: {{histQuestion}} " - + "History Mapped Schema: {{histSchema}} " - + "Rewritten Question: "); + private static final String instruction = "" + + "#Role: You are a data product manager experienced in data requirements.\n" + + "#Task: Your will be provided with current and history questions asked by a user," + + "along with their mapped schema elements(metric, dimension and value)," + + "please try understanding the semantics and rewrite a question.\n" + + "#Rules: " + + "1.ALWAYS keep relevant entities, metrics, dimensions, values and date ranges. " + + "2.ONLY respond with the rewritten question.\n" + + "#Current Question: %s\n" + + "#Current Mapped Schema: %s\n" + + "#History Question: %s\n" + + "#History Mapped Schema: %s\n" + + "#History SQL: %s\n" + + "#Rewritten Question: "; @Override public void parse(ChatParseContext chatParseContext, ParseResp parseResp) { @@ -76,11 +77,13 @@ public class MultiTurnParser implements ChatParser { String curtMapStr = generateSchemaPrompt(currentMapResult.getMapInfo().getMatchedElements(dataId)); String histMapStr = generateSchemaPrompt(lastParseResult.getSelectedParses().get(0).getElementMatches()); + String histSQL = lastParseResult.getSelectedParses().get(0).getSqlInfo().getCorrectS2SQL(); String rewrittenQuery = rewriteQuery(RewriteContext.builder() .curtQuestion(currentMapResult.getQueryText()) .histQuestion(lastParseResult.getQueryText()) .curtSchema(curtMapStr) .histSchema(histMapStr) + .histSQL(histSQL) .llmConfig(queryReq.getLlmConfig()) .build()); chatParseContext.setQueryText(rewrittenQuery); @@ -89,14 +92,10 @@ public class MultiTurnParser implements ChatParser { } private String rewriteQuery(RewriteContext context) { - Map variables = new HashMap<>(); - variables.put("curtQuestion", context.getCurtQuestion()); - variables.put("histQuestion", context.getHistQuestion()); - variables.put("curtSchema", context.getCurtSchema()); - variables.put("histSchema", context.getHistSchema()); - - Prompt prompt = promptTemplate.apply(variables); - keyPipelineLog.info("MultiTurnParser reqPrompt:{}", prompt.toSystemMessage()); + String promptStr = String.format(instruction, context.getCurtQuestion(), context.getCurtSchema(), + context.getHistQuestion(), context.getHistSchema(), context.getHistSQL()); + Prompt prompt = PromptTemplate.from(promptStr).apply(Collections.EMPTY_MAP); + keyPipelineLog.info("MultiTurnParser reqPrompt:{}", promptStr); ChatLanguageModel chatLanguageModel = S2ChatModelProvider.provide(context.getLlmConfig()); Response response = chatLanguageModel.generate(prompt.toSystemMessage()); @@ -149,6 +148,7 @@ public class MultiTurnParser implements ChatParser { private String histQuestion; private String curtSchema; private String histSchema; + private String histSQL; private LLMConfig llmConfig; } } diff --git a/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/parser/llm/OnePassSCSqlGenStrategy.java b/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/parser/llm/OnePassSCSqlGenStrategy.java index 15b04680e..1c73244a7 100644 --- a/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/parser/llm/OnePassSCSqlGenStrategy.java +++ b/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/parser/llm/OnePassSCSqlGenStrategy.java @@ -65,10 +65,10 @@ public class OnePassSCSqlGenStrategy extends SqlGenStrategy { + "#Task: You will be provided a natural language query asked by business users," + "please convert it to a SQL query so that relevant answer could be returned to the user " + "by executing the SQL query against underlying database.\n" - + "#Rules:\n" - + "1.ALWAYS use `数据日期` as the date field.\n" - + "2.ALWAYS use `datediff()` as the date function.\n" - + "3.DO NOT specify date filter in the where clause if not explicitly mentioned in the query.\n" + + "#Rules:" + + "1.ALWAYS use `数据日期` as the date field." + + "2.ALWAYS use `datediff()` as the date function." + + "3.DO NOT specify date filter in the where clause if not explicitly mentioned in the query." + "4.ONLY respond with the converted SQL statement.\n" + "#Exemplars:\n%s" + "#UserQuery: %s " diff --git a/launchers/standalone/src/test/resources/META-INF/spring.factories b/launchers/standalone/src/test/resources/META-INF/spring.factories index 4d61b8c89..5794db3e0 100644 --- a/launchers/standalone/src/test/resources/META-INF/spring.factories +++ b/launchers/standalone/src/test/resources/META-INF/spring.factories @@ -11,6 +11,7 @@ com.tencent.supersonic.headless.chat.parser.SemanticParser=\ com.tencent.supersonic.chat.server.parser.ChatParser=\ com.tencent.supersonic.chat.server.parser.NL2PluginParser, \ + com.tencent.supersonic.chat.server.parser.MultiTurnParser, \ com.tencent.supersonic.chat.server.parser.NL2SQLParser com.tencent.supersonic.chat.server.executor.ChatExecutor=\ diff --git a/pom.xml b/pom.xml index 2d02b8b94..86896325b 100644 --- a/pom.xml +++ b/pom.xml @@ -212,7 +212,10 @@ org.apache.maven.plugins maven-resources-plugin - 3.1.0 + 3.2.0 + + ${project.build.sourceEncoding} + org.apache.maven.plugins