mirror of
https://github.com/tencentmusic/supersonic.git
synced 2025-12-10 11:07:06 +00:00
(improvement)(common) Fixed the compatibility issue with qwen. (#1193)
This commit is contained in:
@@ -1,55 +1,54 @@
|
||||
package com.tencent.supersonic.chat.server.parser;
|
||||
|
||||
import static com.tencent.supersonic.chat.server.parser.ParserConfig.PARSER_MULTI_TURN_ENABLE;
|
||||
|
||||
import com.tencent.supersonic.chat.server.agent.MultiTurnConfig;
|
||||
import com.tencent.supersonic.chat.server.persistence.repository.ChatQueryRepository;
|
||||
import com.tencent.supersonic.chat.server.pojo.ChatParseContext;
|
||||
import com.tencent.supersonic.chat.server.util.QueryReqConverter;
|
||||
import com.tencent.supersonic.common.util.ContextUtils;
|
||||
import com.tencent.supersonic.common.config.LLMConfig;
|
||||
import com.tencent.supersonic.common.util.ContextUtils;
|
||||
import com.tencent.supersonic.common.util.S2ChatModelProvider;
|
||||
import com.tencent.supersonic.headless.api.pojo.SchemaElementMatch;
|
||||
import com.tencent.supersonic.headless.api.pojo.SchemaElementType;
|
||||
import com.tencent.supersonic.headless.api.pojo.request.QueryReq;
|
||||
import com.tencent.supersonic.headless.api.pojo.response.MapResp;
|
||||
import com.tencent.supersonic.headless.api.pojo.response.ParseResp;
|
||||
import com.tencent.supersonic.common.util.S2ChatModelProvider;
|
||||
import com.tencent.supersonic.headless.server.service.ChatQueryService;
|
||||
import dev.langchain4j.data.message.AiMessage;
|
||||
import dev.langchain4j.model.chat.ChatLanguageModel;
|
||||
import dev.langchain4j.model.input.Prompt;
|
||||
import dev.langchain4j.model.input.PromptTemplate;
|
||||
import dev.langchain4j.model.output.Response;
|
||||
import java.util.ArrayList;
|
||||
import java.util.Collections;
|
||||
import java.util.List;
|
||||
import java.util.stream.Collectors;
|
||||
import lombok.Builder;
|
||||
import lombok.Data;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.slf4j.Logger;
|
||||
import org.slf4j.LoggerFactory;
|
||||
|
||||
import java.util.ArrayList;
|
||||
import java.util.List;
|
||||
import java.util.stream.Collectors;
|
||||
import java.util.Collections;
|
||||
|
||||
import static com.tencent.supersonic.chat.server.parser.ParserConfig.PARSER_MULTI_TURN_ENABLE;
|
||||
|
||||
@Slf4j
|
||||
public class MultiTurnParser implements ChatParser {
|
||||
|
||||
private static final Logger keyPipelineLog = LoggerFactory.getLogger("keyPipeline");
|
||||
|
||||
private static final String instruction = ""
|
||||
+ "#Role: You are a data product manager experienced in data requirements.\n"
|
||||
+ "#Task: Your will be provided with current and history questions asked by a user,"
|
||||
+ "along with their mapped schema elements(metric, dimension and value),"
|
||||
+ "please try understanding the semantics and rewrite a question.\n"
|
||||
+ "#Rules: "
|
||||
+ "1.ALWAYS keep relevant entities, metrics, dimensions, values and date ranges. "
|
||||
+ "2.ONLY respond with the rewritten question.\n"
|
||||
+ "#Current Question: %s\n"
|
||||
+ "#Current Mapped Schema: %s\n"
|
||||
+ "#History Question: %s\n"
|
||||
+ "#History Mapped Schema: %s\n"
|
||||
+ "#History SQL: %s\n"
|
||||
+ "#Rewritten Question: ";
|
||||
+ "#Role: You are a data product manager experienced in data requirements.\n"
|
||||
+ "#Task: Your will be provided with current and history questions asked by a user,"
|
||||
+ "along with their mapped schema elements(metric, dimension and value),"
|
||||
+ "please try understanding the semantics and rewrite a question.\n"
|
||||
+ "#Rules: "
|
||||
+ "1.ALWAYS keep relevant entities, metrics, dimensions, values and date ranges. "
|
||||
+ "2.ONLY respond with the rewritten question.\n"
|
||||
+ "#Current Question: %s\n"
|
||||
+ "#Current Mapped Schema: %s\n"
|
||||
+ "#History Question: %s\n"
|
||||
+ "#History Mapped Schema: %s\n"
|
||||
+ "#History SQL: %s\n"
|
||||
+ "#Rewritten Question: ";
|
||||
|
||||
@Override
|
||||
public void parse(ChatParseContext chatParseContext, ParseResp parseResp) {
|
||||
@@ -79,13 +78,13 @@ public class MultiTurnParser implements ChatParser {
|
||||
String histMapStr = generateSchemaPrompt(lastParseResult.getSelectedParses().get(0).getElementMatches());
|
||||
String histSQL = lastParseResult.getSelectedParses().get(0).getSqlInfo().getCorrectS2SQL();
|
||||
String rewrittenQuery = rewriteQuery(RewriteContext.builder()
|
||||
.curtQuestion(currentMapResult.getQueryText())
|
||||
.histQuestion(lastParseResult.getQueryText())
|
||||
.curtSchema(curtMapStr)
|
||||
.histSchema(histMapStr)
|
||||
.histSQL(histSQL)
|
||||
.llmConfig(queryReq.getLlmConfig())
|
||||
.build());
|
||||
.curtQuestion(currentMapResult.getQueryText())
|
||||
.histQuestion(lastParseResult.getQueryText())
|
||||
.curtSchema(curtMapStr)
|
||||
.histSchema(histMapStr)
|
||||
.histSQL(histSQL)
|
||||
.llmConfig(queryReq.getLlmConfig())
|
||||
.build());
|
||||
chatParseContext.setQueryText(rewrittenQuery);
|
||||
log.info("Last Query: {} Current Query: {}, Rewritten Query: {}",
|
||||
lastParseResult.getQueryText(), currentMapResult.getQueryText(), rewrittenQuery);
|
||||
@@ -98,7 +97,7 @@ public class MultiTurnParser implements ChatParser {
|
||||
keyPipelineLog.info("MultiTurnParser reqPrompt:{}", promptStr);
|
||||
|
||||
ChatLanguageModel chatLanguageModel = S2ChatModelProvider.provide(context.getLlmConfig());
|
||||
Response<AiMessage> response = chatLanguageModel.generate(prompt.toSystemMessage());
|
||||
Response<AiMessage> response = chatLanguageModel.generate(prompt.toUserMessage());
|
||||
|
||||
String result = response.content().text();
|
||||
keyPipelineLog.info("MultiTurnParser modelResp:{}", result);
|
||||
@@ -144,6 +143,7 @@ public class MultiTurnParser implements ChatParser {
|
||||
@Data
|
||||
@Builder
|
||||
public static class RewriteContext {
|
||||
|
||||
private String curtQuestion;
|
||||
private String histQuestion;
|
||||
private String curtSchema;
|
||||
|
||||
@@ -228,6 +228,12 @@
|
||||
<dependency>
|
||||
<groupId>dev.langchain4j</groupId>
|
||||
<artifactId>langchain4j-dashscope</artifactId>
|
||||
<exclusions>
|
||||
<exclusion>
|
||||
<groupId>org.slf4j</groupId>
|
||||
<artifactId>slf4j-simple</artifactId>
|
||||
</exclusion>
|
||||
</exclusions>
|
||||
</dependency>
|
||||
<!--langchain4j-->
|
||||
<dependency>
|
||||
|
||||
@@ -39,9 +39,9 @@ public class OnePassSCSqlGenStrategy extends SqlGenStrategy {
|
||||
//3.perform multiple self-consistency inferences parallelly
|
||||
Map<Prompt, String> prompt2Output = new ConcurrentHashMap<>();
|
||||
prompt2Exemplar.keySet().parallelStream().forEach(prompt -> {
|
||||
keyPipelineLog.info("OnePassSCSqlGenStrategy reqPrompt:\n{}", prompt.toSystemMessage());
|
||||
keyPipelineLog.info("OnePassSCSqlGenStrategy reqPrompt:\n{}", prompt.toUserMessage());
|
||||
ChatLanguageModel chatLanguageModel = getChatLanguageModel(llmReq.getLlmConfig());
|
||||
Response<AiMessage> response = chatLanguageModel.generate(prompt.toSystemMessage());
|
||||
Response<AiMessage> response = chatLanguageModel.generate(prompt.toUserMessage());
|
||||
String result = response.content().text();
|
||||
prompt2Output.put(prompt, result);
|
||||
keyPipelineLog.info("OnePassSCSqlGenStrategy modelResp:\n{}", result);
|
||||
|
||||
@@ -99,7 +99,14 @@ langchain4j:
|
||||
model-name: ${OPENAI_MODEL_NAME:gpt-3.5-turbo}
|
||||
temperature: ${OPENAI_TEMPERATURE:0.0}
|
||||
timeout: ${OPENAI_TIMEOUT:PT60S}
|
||||
# java.lang.RuntimeException: dev.ai4j.openai4j.OpenAiHttpException: Too many requests
|
||||
# embedding-model:
|
||||
# base-url: ${OPENAI_API_BASE:https://api.openai.com/v1}
|
||||
# api-key: ${OPENAI_API_KEY:demo}
|
||||
# api-key: ${OPENAI_API_KEY:demo}
|
||||
# dashscope:
|
||||
# chat-model:
|
||||
# api-key: ${OPENAI_API_KEY:demo}
|
||||
# model-name: qwen-max-1201
|
||||
# embedding-model:
|
||||
# api-key: ${OPENAI_API_KEY:demo}
|
||||
|
||||
|
||||
|
||||
6
pom.xml
6
pom.xml
@@ -199,6 +199,12 @@
|
||||
<groupId>dev.langchain4j</groupId>
|
||||
<artifactId>langchain4j-dashscope</artifactId>
|
||||
<version>${langchain4j.version}</version>
|
||||
<exclusions>
|
||||
<exclusion>
|
||||
<groupId>org.slf4j</groupId>
|
||||
<artifactId>slf4j-simple</artifactId>
|
||||
</exclusion>
|
||||
</exclusions>
|
||||
</dependency>
|
||||
|
||||
<dependency>
|
||||
|
||||
Reference in New Issue
Block a user