mirror of
https://github.com/tencentmusic/supersonic.git
synced 2025-12-10 19:51:00 +00:00
(improvement)(chat)Optimize multi-turn prompts.
This commit is contained in:
@@ -24,8 +24,6 @@ import lombok.extern.slf4j.Slf4j;
|
|||||||
import org.slf4j.Logger;
|
import org.slf4j.Logger;
|
||||||
import org.slf4j.LoggerFactory;
|
import org.slf4j.LoggerFactory;
|
||||||
|
|
||||||
import java.util.Map;
|
|
||||||
import java.util.HashMap;
|
|
||||||
import java.util.ArrayList;
|
import java.util.ArrayList;
|
||||||
import java.util.List;
|
import java.util.List;
|
||||||
import java.util.stream.Collectors;
|
import java.util.stream.Collectors;
|
||||||
@@ -38,17 +36,20 @@ public class MultiTurnParser implements ChatParser {
|
|||||||
|
|
||||||
private static final Logger keyPipelineLog = LoggerFactory.getLogger("keyPipeline");
|
private static final Logger keyPipelineLog = LoggerFactory.getLogger("keyPipeline");
|
||||||
|
|
||||||
private static final PromptTemplate promptTemplate = PromptTemplate.from(
|
private static final String instruction = ""
|
||||||
"You are a data product manager experienced in data requirements."
|
+ "#Role: You are a data product manager experienced in data requirements.\n"
|
||||||
+ "Your will be provided with current and history questions asked by a user,"
|
+ "#Task: Your will be provided with current and history questions asked by a user,"
|
||||||
+ "along with their mapped schema elements(metric, dimension and value), "
|
+ "along with their mapped schema elements(metric, dimension and value),"
|
||||||
+ "please try understanding the semantics and rewrite a question"
|
+ "please try understanding the semantics and rewrite a question.\n"
|
||||||
+ "(keep relevant entities, metrics, dimensions, values and date ranges)."
|
+ "#Rules: "
|
||||||
+ "Current Question: {{curtQuestion}} "
|
+ "1.ALWAYS keep relevant entities, metrics, dimensions, values and date ranges. "
|
||||||
+ "Current Mapped Schema: {{curtSchema}} "
|
+ "2.ONLY respond with the rewritten question.\n"
|
||||||
+ "History Question: {{histQuestion}} "
|
+ "#Current Question: %s\n"
|
||||||
+ "History Mapped Schema: {{histSchema}} "
|
+ "#Current Mapped Schema: %s\n"
|
||||||
+ "Rewritten Question: ");
|
+ "#History Question: %s\n"
|
||||||
|
+ "#History Mapped Schema: %s\n"
|
||||||
|
+ "#History SQL: %s\n"
|
||||||
|
+ "#Rewritten Question: ";
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public void parse(ChatParseContext chatParseContext, ParseResp parseResp) {
|
public void parse(ChatParseContext chatParseContext, ParseResp parseResp) {
|
||||||
@@ -76,11 +77,13 @@ public class MultiTurnParser implements ChatParser {
|
|||||||
|
|
||||||
String curtMapStr = generateSchemaPrompt(currentMapResult.getMapInfo().getMatchedElements(dataId));
|
String curtMapStr = generateSchemaPrompt(currentMapResult.getMapInfo().getMatchedElements(dataId));
|
||||||
String histMapStr = generateSchemaPrompt(lastParseResult.getSelectedParses().get(0).getElementMatches());
|
String histMapStr = generateSchemaPrompt(lastParseResult.getSelectedParses().get(0).getElementMatches());
|
||||||
|
String histSQL = lastParseResult.getSelectedParses().get(0).getSqlInfo().getCorrectS2SQL();
|
||||||
String rewrittenQuery = rewriteQuery(RewriteContext.builder()
|
String rewrittenQuery = rewriteQuery(RewriteContext.builder()
|
||||||
.curtQuestion(currentMapResult.getQueryText())
|
.curtQuestion(currentMapResult.getQueryText())
|
||||||
.histQuestion(lastParseResult.getQueryText())
|
.histQuestion(lastParseResult.getQueryText())
|
||||||
.curtSchema(curtMapStr)
|
.curtSchema(curtMapStr)
|
||||||
.histSchema(histMapStr)
|
.histSchema(histMapStr)
|
||||||
|
.histSQL(histSQL)
|
||||||
.llmConfig(queryReq.getLlmConfig())
|
.llmConfig(queryReq.getLlmConfig())
|
||||||
.build());
|
.build());
|
||||||
chatParseContext.setQueryText(rewrittenQuery);
|
chatParseContext.setQueryText(rewrittenQuery);
|
||||||
@@ -89,14 +92,10 @@ public class MultiTurnParser implements ChatParser {
|
|||||||
}
|
}
|
||||||
|
|
||||||
private String rewriteQuery(RewriteContext context) {
|
private String rewriteQuery(RewriteContext context) {
|
||||||
Map<String, Object> variables = new HashMap<>();
|
String promptStr = String.format(instruction, context.getCurtQuestion(), context.getCurtSchema(),
|
||||||
variables.put("curtQuestion", context.getCurtQuestion());
|
context.getHistQuestion(), context.getHistSchema(), context.getHistSQL());
|
||||||
variables.put("histQuestion", context.getHistQuestion());
|
Prompt prompt = PromptTemplate.from(promptStr).apply(Collections.EMPTY_MAP);
|
||||||
variables.put("curtSchema", context.getCurtSchema());
|
keyPipelineLog.info("MultiTurnParser reqPrompt:{}", promptStr);
|
||||||
variables.put("histSchema", context.getHistSchema());
|
|
||||||
|
|
||||||
Prompt prompt = promptTemplate.apply(variables);
|
|
||||||
keyPipelineLog.info("MultiTurnParser reqPrompt:{}", prompt.toSystemMessage());
|
|
||||||
|
|
||||||
ChatLanguageModel chatLanguageModel = S2ChatModelProvider.provide(context.getLlmConfig());
|
ChatLanguageModel chatLanguageModel = S2ChatModelProvider.provide(context.getLlmConfig());
|
||||||
Response<AiMessage> response = chatLanguageModel.generate(prompt.toSystemMessage());
|
Response<AiMessage> response = chatLanguageModel.generate(prompt.toSystemMessage());
|
||||||
@@ -149,6 +148,7 @@ public class MultiTurnParser implements ChatParser {
|
|||||||
private String histQuestion;
|
private String histQuestion;
|
||||||
private String curtSchema;
|
private String curtSchema;
|
||||||
private String histSchema;
|
private String histSchema;
|
||||||
|
private String histSQL;
|
||||||
private LLMConfig llmConfig;
|
private LLMConfig llmConfig;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -65,10 +65,10 @@ public class OnePassSCSqlGenStrategy extends SqlGenStrategy {
|
|||||||
+ "#Task: You will be provided a natural language query asked by business users,"
|
+ "#Task: You will be provided a natural language query asked by business users,"
|
||||||
+ "please convert it to a SQL query so that relevant answer could be returned to the user "
|
+ "please convert it to a SQL query so that relevant answer could be returned to the user "
|
||||||
+ "by executing the SQL query against underlying database.\n"
|
+ "by executing the SQL query against underlying database.\n"
|
||||||
+ "#Rules:\n"
|
+ "#Rules:"
|
||||||
+ "1.ALWAYS use `数据日期` as the date field.\n"
|
+ "1.ALWAYS use `数据日期` as the date field."
|
||||||
+ "2.ALWAYS use `datediff()` as the date function.\n"
|
+ "2.ALWAYS use `datediff()` as the date function."
|
||||||
+ "3.DO NOT specify date filter in the where clause if not explicitly mentioned in the query.\n"
|
+ "3.DO NOT specify date filter in the where clause if not explicitly mentioned in the query."
|
||||||
+ "4.ONLY respond with the converted SQL statement.\n"
|
+ "4.ONLY respond with the converted SQL statement.\n"
|
||||||
+ "#Exemplars:\n%s"
|
+ "#Exemplars:\n%s"
|
||||||
+ "#UserQuery: %s "
|
+ "#UserQuery: %s "
|
||||||
|
|||||||
@@ -11,6 +11,7 @@ com.tencent.supersonic.headless.chat.parser.SemanticParser=\
|
|||||||
|
|
||||||
com.tencent.supersonic.chat.server.parser.ChatParser=\
|
com.tencent.supersonic.chat.server.parser.ChatParser=\
|
||||||
com.tencent.supersonic.chat.server.parser.NL2PluginParser, \
|
com.tencent.supersonic.chat.server.parser.NL2PluginParser, \
|
||||||
|
com.tencent.supersonic.chat.server.parser.MultiTurnParser, \
|
||||||
com.tencent.supersonic.chat.server.parser.NL2SQLParser
|
com.tencent.supersonic.chat.server.parser.NL2SQLParser
|
||||||
|
|
||||||
com.tencent.supersonic.chat.server.executor.ChatExecutor=\
|
com.tencent.supersonic.chat.server.executor.ChatExecutor=\
|
||||||
|
|||||||
5
pom.xml
5
pom.xml
@@ -212,7 +212,10 @@
|
|||||||
<plugin>
|
<plugin>
|
||||||
<groupId>org.apache.maven.plugins</groupId>
|
<groupId>org.apache.maven.plugins</groupId>
|
||||||
<artifactId>maven-resources-plugin</artifactId>
|
<artifactId>maven-resources-plugin</artifactId>
|
||||||
<version>3.1.0</version>
|
<version>3.2.0</version>
|
||||||
|
<configuration>
|
||||||
|
<propertiesEncoding>${project.build.sourceEncoding}</propertiesEncoding>
|
||||||
|
</configuration>
|
||||||
</plugin>
|
</plugin>
|
||||||
<plugin>
|
<plugin>
|
||||||
<groupId>org.apache.maven.plugins</groupId>
|
<groupId>org.apache.maven.plugins</groupId>
|
||||||
|
|||||||
Reference in New Issue
Block a user