(improvement)(common) Fixed the compatibility issue with qwen. (#1193)

This commit is contained in:
lexluo09
2024-06-22 16:59:15 +08:00
committed by GitHub
parent 782d4ead9e
commit 32e2c1e39d
5 changed files with 53 additions and 34 deletions

View File

@@ -1,55 +1,54 @@
package com.tencent.supersonic.chat.server.parser; package com.tencent.supersonic.chat.server.parser;
import static com.tencent.supersonic.chat.server.parser.ParserConfig.PARSER_MULTI_TURN_ENABLE;
import com.tencent.supersonic.chat.server.agent.MultiTurnConfig; import com.tencent.supersonic.chat.server.agent.MultiTurnConfig;
import com.tencent.supersonic.chat.server.persistence.repository.ChatQueryRepository; import com.tencent.supersonic.chat.server.persistence.repository.ChatQueryRepository;
import com.tencent.supersonic.chat.server.pojo.ChatParseContext; import com.tencent.supersonic.chat.server.pojo.ChatParseContext;
import com.tencent.supersonic.chat.server.util.QueryReqConverter; import com.tencent.supersonic.chat.server.util.QueryReqConverter;
import com.tencent.supersonic.common.util.ContextUtils;
import com.tencent.supersonic.common.config.LLMConfig; import com.tencent.supersonic.common.config.LLMConfig;
import com.tencent.supersonic.common.util.ContextUtils;
import com.tencent.supersonic.common.util.S2ChatModelProvider;
import com.tencent.supersonic.headless.api.pojo.SchemaElementMatch; import com.tencent.supersonic.headless.api.pojo.SchemaElementMatch;
import com.tencent.supersonic.headless.api.pojo.SchemaElementType; import com.tencent.supersonic.headless.api.pojo.SchemaElementType;
import com.tencent.supersonic.headless.api.pojo.request.QueryReq; import com.tencent.supersonic.headless.api.pojo.request.QueryReq;
import com.tencent.supersonic.headless.api.pojo.response.MapResp; import com.tencent.supersonic.headless.api.pojo.response.MapResp;
import com.tencent.supersonic.headless.api.pojo.response.ParseResp; import com.tencent.supersonic.headless.api.pojo.response.ParseResp;
import com.tencent.supersonic.common.util.S2ChatModelProvider;
import com.tencent.supersonic.headless.server.service.ChatQueryService; import com.tencent.supersonic.headless.server.service.ChatQueryService;
import dev.langchain4j.data.message.AiMessage; import dev.langchain4j.data.message.AiMessage;
import dev.langchain4j.model.chat.ChatLanguageModel; import dev.langchain4j.model.chat.ChatLanguageModel;
import dev.langchain4j.model.input.Prompt; import dev.langchain4j.model.input.Prompt;
import dev.langchain4j.model.input.PromptTemplate; import dev.langchain4j.model.input.PromptTemplate;
import dev.langchain4j.model.output.Response; import dev.langchain4j.model.output.Response;
import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
import java.util.stream.Collectors;
import lombok.Builder; import lombok.Builder;
import lombok.Data; import lombok.Data;
import lombok.extern.slf4j.Slf4j; import lombok.extern.slf4j.Slf4j;
import org.slf4j.Logger; import org.slf4j.Logger;
import org.slf4j.LoggerFactory; import org.slf4j.LoggerFactory;
import java.util.ArrayList;
import java.util.List;
import java.util.stream.Collectors;
import java.util.Collections;
import static com.tencent.supersonic.chat.server.parser.ParserConfig.PARSER_MULTI_TURN_ENABLE;
@Slf4j @Slf4j
public class MultiTurnParser implements ChatParser { public class MultiTurnParser implements ChatParser {
private static final Logger keyPipelineLog = LoggerFactory.getLogger("keyPipeline"); private static final Logger keyPipelineLog = LoggerFactory.getLogger("keyPipeline");
private static final String instruction = "" private static final String instruction = ""
+ "#Role: You are a data product manager experienced in data requirements.\n" + "#Role: You are a data product manager experienced in data requirements.\n"
+ "#Task: Your will be provided with current and history questions asked by a user," + "#Task: Your will be provided with current and history questions asked by a user,"
+ "along with their mapped schema elements(metric, dimension and value)," + "along with their mapped schema elements(metric, dimension and value),"
+ "please try understanding the semantics and rewrite a question.\n" + "please try understanding the semantics and rewrite a question.\n"
+ "#Rules: " + "#Rules: "
+ "1.ALWAYS keep relevant entities, metrics, dimensions, values and date ranges. " + "1.ALWAYS keep relevant entities, metrics, dimensions, values and date ranges. "
+ "2.ONLY respond with the rewritten question.\n" + "2.ONLY respond with the rewritten question.\n"
+ "#Current Question: %s\n" + "#Current Question: %s\n"
+ "#Current Mapped Schema: %s\n" + "#Current Mapped Schema: %s\n"
+ "#History Question: %s\n" + "#History Question: %s\n"
+ "#History Mapped Schema: %s\n" + "#History Mapped Schema: %s\n"
+ "#History SQL: %s\n" + "#History SQL: %s\n"
+ "#Rewritten Question: "; + "#Rewritten Question: ";
@Override @Override
public void parse(ChatParseContext chatParseContext, ParseResp parseResp) { public void parse(ChatParseContext chatParseContext, ParseResp parseResp) {
@@ -79,13 +78,13 @@ public class MultiTurnParser implements ChatParser {
String histMapStr = generateSchemaPrompt(lastParseResult.getSelectedParses().get(0).getElementMatches()); String histMapStr = generateSchemaPrompt(lastParseResult.getSelectedParses().get(0).getElementMatches());
String histSQL = lastParseResult.getSelectedParses().get(0).getSqlInfo().getCorrectS2SQL(); String histSQL = lastParseResult.getSelectedParses().get(0).getSqlInfo().getCorrectS2SQL();
String rewrittenQuery = rewriteQuery(RewriteContext.builder() String rewrittenQuery = rewriteQuery(RewriteContext.builder()
.curtQuestion(currentMapResult.getQueryText()) .curtQuestion(currentMapResult.getQueryText())
.histQuestion(lastParseResult.getQueryText()) .histQuestion(lastParseResult.getQueryText())
.curtSchema(curtMapStr) .curtSchema(curtMapStr)
.histSchema(histMapStr) .histSchema(histMapStr)
.histSQL(histSQL) .histSQL(histSQL)
.llmConfig(queryReq.getLlmConfig()) .llmConfig(queryReq.getLlmConfig())
.build()); .build());
chatParseContext.setQueryText(rewrittenQuery); chatParseContext.setQueryText(rewrittenQuery);
log.info("Last Query: {} Current Query: {}, Rewritten Query: {}", log.info("Last Query: {} Current Query: {}, Rewritten Query: {}",
lastParseResult.getQueryText(), currentMapResult.getQueryText(), rewrittenQuery); lastParseResult.getQueryText(), currentMapResult.getQueryText(), rewrittenQuery);
@@ -98,7 +97,7 @@ public class MultiTurnParser implements ChatParser {
keyPipelineLog.info("MultiTurnParser reqPrompt:{}", promptStr); keyPipelineLog.info("MultiTurnParser reqPrompt:{}", promptStr);
ChatLanguageModel chatLanguageModel = S2ChatModelProvider.provide(context.getLlmConfig()); ChatLanguageModel chatLanguageModel = S2ChatModelProvider.provide(context.getLlmConfig());
Response<AiMessage> response = chatLanguageModel.generate(prompt.toSystemMessage()); Response<AiMessage> response = chatLanguageModel.generate(prompt.toUserMessage());
String result = response.content().text(); String result = response.content().text();
keyPipelineLog.info("MultiTurnParser modelResp:{}", result); keyPipelineLog.info("MultiTurnParser modelResp:{}", result);
@@ -144,6 +143,7 @@ public class MultiTurnParser implements ChatParser {
@Data @Data
@Builder @Builder
public static class RewriteContext { public static class RewriteContext {
private String curtQuestion; private String curtQuestion;
private String histQuestion; private String histQuestion;
private String curtSchema; private String curtSchema;

View File

@@ -228,6 +228,12 @@
<dependency> <dependency>
<groupId>dev.langchain4j</groupId> <groupId>dev.langchain4j</groupId>
<artifactId>langchain4j-dashscope</artifactId> <artifactId>langchain4j-dashscope</artifactId>
<exclusions>
<exclusion>
<groupId>org.slf4j</groupId>
<artifactId>slf4j-simple</artifactId>
</exclusion>
</exclusions>
</dependency> </dependency>
<!--langchain4j--> <!--langchain4j-->
<dependency> <dependency>

View File

@@ -39,9 +39,9 @@ public class OnePassSCSqlGenStrategy extends SqlGenStrategy {
//3.perform multiple self-consistency inferences parallelly //3.perform multiple self-consistency inferences parallelly
Map<Prompt, String> prompt2Output = new ConcurrentHashMap<>(); Map<Prompt, String> prompt2Output = new ConcurrentHashMap<>();
prompt2Exemplar.keySet().parallelStream().forEach(prompt -> { prompt2Exemplar.keySet().parallelStream().forEach(prompt -> {
keyPipelineLog.info("OnePassSCSqlGenStrategy reqPrompt:\n{}", prompt.toSystemMessage()); keyPipelineLog.info("OnePassSCSqlGenStrategy reqPrompt:\n{}", prompt.toUserMessage());
ChatLanguageModel chatLanguageModel = getChatLanguageModel(llmReq.getLlmConfig()); ChatLanguageModel chatLanguageModel = getChatLanguageModel(llmReq.getLlmConfig());
Response<AiMessage> response = chatLanguageModel.generate(prompt.toSystemMessage()); Response<AiMessage> response = chatLanguageModel.generate(prompt.toUserMessage());
String result = response.content().text(); String result = response.content().text();
prompt2Output.put(prompt, result); prompt2Output.put(prompt, result);
keyPipelineLog.info("OnePassSCSqlGenStrategy modelResp:\n{}", result); keyPipelineLog.info("OnePassSCSqlGenStrategy modelResp:\n{}", result);

View File

@@ -99,7 +99,14 @@ langchain4j:
model-name: ${OPENAI_MODEL_NAME:gpt-3.5-turbo} model-name: ${OPENAI_MODEL_NAME:gpt-3.5-turbo}
temperature: ${OPENAI_TEMPERATURE:0.0} temperature: ${OPENAI_TEMPERATURE:0.0}
timeout: ${OPENAI_TIMEOUT:PT60S} timeout: ${OPENAI_TIMEOUT:PT60S}
# java.lang.RuntimeException: dev.ai4j.openai4j.OpenAiHttpException: Too many requests
# embedding-model: # embedding-model:
# base-url: ${OPENAI_API_BASE:https://api.openai.com/v1} # base-url: ${OPENAI_API_BASE:https://api.openai.com/v1}
# api-key: ${OPENAI_API_KEY:demo} # api-key: ${OPENAI_API_KEY:demo}
# dashscope:
# chat-model:
# api-key: ${OPENAI_API_KEY:demo}
# model-name: qwen-max-1201
# embedding-model:
# api-key: ${OPENAI_API_KEY:demo}

View File

@@ -199,6 +199,12 @@
<groupId>dev.langchain4j</groupId> <groupId>dev.langchain4j</groupId>
<artifactId>langchain4j-dashscope</artifactId> <artifactId>langchain4j-dashscope</artifactId>
<version>${langchain4j.version}</version> <version>${langchain4j.version}</version>
<exclusions>
<exclusion>
<groupId>org.slf4j</groupId>
<artifactId>slf4j-simple</artifactId>
</exclusion>
</exclusions>
</dependency> </dependency>
<dependency> <dependency>