mirror of
https://github.com/tencentmusic/supersonic.git
synced 2025-12-10 11:07:06 +00:00
(improvement)(common) Fixed the compatibility issue with qwen. (#1193)
This commit is contained in:
@@ -1,55 +1,54 @@
|
|||||||
package com.tencent.supersonic.chat.server.parser;
|
package com.tencent.supersonic.chat.server.parser;
|
||||||
|
|
||||||
|
import static com.tencent.supersonic.chat.server.parser.ParserConfig.PARSER_MULTI_TURN_ENABLE;
|
||||||
|
|
||||||
import com.tencent.supersonic.chat.server.agent.MultiTurnConfig;
|
import com.tencent.supersonic.chat.server.agent.MultiTurnConfig;
|
||||||
import com.tencent.supersonic.chat.server.persistence.repository.ChatQueryRepository;
|
import com.tencent.supersonic.chat.server.persistence.repository.ChatQueryRepository;
|
||||||
import com.tencent.supersonic.chat.server.pojo.ChatParseContext;
|
import com.tencent.supersonic.chat.server.pojo.ChatParseContext;
|
||||||
import com.tencent.supersonic.chat.server.util.QueryReqConverter;
|
import com.tencent.supersonic.chat.server.util.QueryReqConverter;
|
||||||
import com.tencent.supersonic.common.util.ContextUtils;
|
|
||||||
import com.tencent.supersonic.common.config.LLMConfig;
|
import com.tencent.supersonic.common.config.LLMConfig;
|
||||||
|
import com.tencent.supersonic.common.util.ContextUtils;
|
||||||
|
import com.tencent.supersonic.common.util.S2ChatModelProvider;
|
||||||
import com.tencent.supersonic.headless.api.pojo.SchemaElementMatch;
|
import com.tencent.supersonic.headless.api.pojo.SchemaElementMatch;
|
||||||
import com.tencent.supersonic.headless.api.pojo.SchemaElementType;
|
import com.tencent.supersonic.headless.api.pojo.SchemaElementType;
|
||||||
import com.tencent.supersonic.headless.api.pojo.request.QueryReq;
|
import com.tencent.supersonic.headless.api.pojo.request.QueryReq;
|
||||||
import com.tencent.supersonic.headless.api.pojo.response.MapResp;
|
import com.tencent.supersonic.headless.api.pojo.response.MapResp;
|
||||||
import com.tencent.supersonic.headless.api.pojo.response.ParseResp;
|
import com.tencent.supersonic.headless.api.pojo.response.ParseResp;
|
||||||
import com.tencent.supersonic.common.util.S2ChatModelProvider;
|
|
||||||
import com.tencent.supersonic.headless.server.service.ChatQueryService;
|
import com.tencent.supersonic.headless.server.service.ChatQueryService;
|
||||||
import dev.langchain4j.data.message.AiMessage;
|
import dev.langchain4j.data.message.AiMessage;
|
||||||
import dev.langchain4j.model.chat.ChatLanguageModel;
|
import dev.langchain4j.model.chat.ChatLanguageModel;
|
||||||
import dev.langchain4j.model.input.Prompt;
|
import dev.langchain4j.model.input.Prompt;
|
||||||
import dev.langchain4j.model.input.PromptTemplate;
|
import dev.langchain4j.model.input.PromptTemplate;
|
||||||
import dev.langchain4j.model.output.Response;
|
import dev.langchain4j.model.output.Response;
|
||||||
|
import java.util.ArrayList;
|
||||||
|
import java.util.Collections;
|
||||||
|
import java.util.List;
|
||||||
|
import java.util.stream.Collectors;
|
||||||
import lombok.Builder;
|
import lombok.Builder;
|
||||||
import lombok.Data;
|
import lombok.Data;
|
||||||
import lombok.extern.slf4j.Slf4j;
|
import lombok.extern.slf4j.Slf4j;
|
||||||
import org.slf4j.Logger;
|
import org.slf4j.Logger;
|
||||||
import org.slf4j.LoggerFactory;
|
import org.slf4j.LoggerFactory;
|
||||||
|
|
||||||
import java.util.ArrayList;
|
|
||||||
import java.util.List;
|
|
||||||
import java.util.stream.Collectors;
|
|
||||||
import java.util.Collections;
|
|
||||||
|
|
||||||
import static com.tencent.supersonic.chat.server.parser.ParserConfig.PARSER_MULTI_TURN_ENABLE;
|
|
||||||
|
|
||||||
@Slf4j
|
@Slf4j
|
||||||
public class MultiTurnParser implements ChatParser {
|
public class MultiTurnParser implements ChatParser {
|
||||||
|
|
||||||
private static final Logger keyPipelineLog = LoggerFactory.getLogger("keyPipeline");
|
private static final Logger keyPipelineLog = LoggerFactory.getLogger("keyPipeline");
|
||||||
|
|
||||||
private static final String instruction = ""
|
private static final String instruction = ""
|
||||||
+ "#Role: You are a data product manager experienced in data requirements.\n"
|
+ "#Role: You are a data product manager experienced in data requirements.\n"
|
||||||
+ "#Task: Your will be provided with current and history questions asked by a user,"
|
+ "#Task: Your will be provided with current and history questions asked by a user,"
|
||||||
+ "along with their mapped schema elements(metric, dimension and value),"
|
+ "along with their mapped schema elements(metric, dimension and value),"
|
||||||
+ "please try understanding the semantics and rewrite a question.\n"
|
+ "please try understanding the semantics and rewrite a question.\n"
|
||||||
+ "#Rules: "
|
+ "#Rules: "
|
||||||
+ "1.ALWAYS keep relevant entities, metrics, dimensions, values and date ranges. "
|
+ "1.ALWAYS keep relevant entities, metrics, dimensions, values and date ranges. "
|
||||||
+ "2.ONLY respond with the rewritten question.\n"
|
+ "2.ONLY respond with the rewritten question.\n"
|
||||||
+ "#Current Question: %s\n"
|
+ "#Current Question: %s\n"
|
||||||
+ "#Current Mapped Schema: %s\n"
|
+ "#Current Mapped Schema: %s\n"
|
||||||
+ "#History Question: %s\n"
|
+ "#History Question: %s\n"
|
||||||
+ "#History Mapped Schema: %s\n"
|
+ "#History Mapped Schema: %s\n"
|
||||||
+ "#History SQL: %s\n"
|
+ "#History SQL: %s\n"
|
||||||
+ "#Rewritten Question: ";
|
+ "#Rewritten Question: ";
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public void parse(ChatParseContext chatParseContext, ParseResp parseResp) {
|
public void parse(ChatParseContext chatParseContext, ParseResp parseResp) {
|
||||||
@@ -79,13 +78,13 @@ public class MultiTurnParser implements ChatParser {
|
|||||||
String histMapStr = generateSchemaPrompt(lastParseResult.getSelectedParses().get(0).getElementMatches());
|
String histMapStr = generateSchemaPrompt(lastParseResult.getSelectedParses().get(0).getElementMatches());
|
||||||
String histSQL = lastParseResult.getSelectedParses().get(0).getSqlInfo().getCorrectS2SQL();
|
String histSQL = lastParseResult.getSelectedParses().get(0).getSqlInfo().getCorrectS2SQL();
|
||||||
String rewrittenQuery = rewriteQuery(RewriteContext.builder()
|
String rewrittenQuery = rewriteQuery(RewriteContext.builder()
|
||||||
.curtQuestion(currentMapResult.getQueryText())
|
.curtQuestion(currentMapResult.getQueryText())
|
||||||
.histQuestion(lastParseResult.getQueryText())
|
.histQuestion(lastParseResult.getQueryText())
|
||||||
.curtSchema(curtMapStr)
|
.curtSchema(curtMapStr)
|
||||||
.histSchema(histMapStr)
|
.histSchema(histMapStr)
|
||||||
.histSQL(histSQL)
|
.histSQL(histSQL)
|
||||||
.llmConfig(queryReq.getLlmConfig())
|
.llmConfig(queryReq.getLlmConfig())
|
||||||
.build());
|
.build());
|
||||||
chatParseContext.setQueryText(rewrittenQuery);
|
chatParseContext.setQueryText(rewrittenQuery);
|
||||||
log.info("Last Query: {} Current Query: {}, Rewritten Query: {}",
|
log.info("Last Query: {} Current Query: {}, Rewritten Query: {}",
|
||||||
lastParseResult.getQueryText(), currentMapResult.getQueryText(), rewrittenQuery);
|
lastParseResult.getQueryText(), currentMapResult.getQueryText(), rewrittenQuery);
|
||||||
@@ -98,7 +97,7 @@ public class MultiTurnParser implements ChatParser {
|
|||||||
keyPipelineLog.info("MultiTurnParser reqPrompt:{}", promptStr);
|
keyPipelineLog.info("MultiTurnParser reqPrompt:{}", promptStr);
|
||||||
|
|
||||||
ChatLanguageModel chatLanguageModel = S2ChatModelProvider.provide(context.getLlmConfig());
|
ChatLanguageModel chatLanguageModel = S2ChatModelProvider.provide(context.getLlmConfig());
|
||||||
Response<AiMessage> response = chatLanguageModel.generate(prompt.toSystemMessage());
|
Response<AiMessage> response = chatLanguageModel.generate(prompt.toUserMessage());
|
||||||
|
|
||||||
String result = response.content().text();
|
String result = response.content().text();
|
||||||
keyPipelineLog.info("MultiTurnParser modelResp:{}", result);
|
keyPipelineLog.info("MultiTurnParser modelResp:{}", result);
|
||||||
@@ -144,6 +143,7 @@ public class MultiTurnParser implements ChatParser {
|
|||||||
@Data
|
@Data
|
||||||
@Builder
|
@Builder
|
||||||
public static class RewriteContext {
|
public static class RewriteContext {
|
||||||
|
|
||||||
private String curtQuestion;
|
private String curtQuestion;
|
||||||
private String histQuestion;
|
private String histQuestion;
|
||||||
private String curtSchema;
|
private String curtSchema;
|
||||||
|
|||||||
@@ -228,6 +228,12 @@
|
|||||||
<dependency>
|
<dependency>
|
||||||
<groupId>dev.langchain4j</groupId>
|
<groupId>dev.langchain4j</groupId>
|
||||||
<artifactId>langchain4j-dashscope</artifactId>
|
<artifactId>langchain4j-dashscope</artifactId>
|
||||||
|
<exclusions>
|
||||||
|
<exclusion>
|
||||||
|
<groupId>org.slf4j</groupId>
|
||||||
|
<artifactId>slf4j-simple</artifactId>
|
||||||
|
</exclusion>
|
||||||
|
</exclusions>
|
||||||
</dependency>
|
</dependency>
|
||||||
<!--langchain4j-->
|
<!--langchain4j-->
|
||||||
<dependency>
|
<dependency>
|
||||||
|
|||||||
@@ -39,9 +39,9 @@ public class OnePassSCSqlGenStrategy extends SqlGenStrategy {
|
|||||||
//3.perform multiple self-consistency inferences parallelly
|
//3.perform multiple self-consistency inferences parallelly
|
||||||
Map<Prompt, String> prompt2Output = new ConcurrentHashMap<>();
|
Map<Prompt, String> prompt2Output = new ConcurrentHashMap<>();
|
||||||
prompt2Exemplar.keySet().parallelStream().forEach(prompt -> {
|
prompt2Exemplar.keySet().parallelStream().forEach(prompt -> {
|
||||||
keyPipelineLog.info("OnePassSCSqlGenStrategy reqPrompt:\n{}", prompt.toSystemMessage());
|
keyPipelineLog.info("OnePassSCSqlGenStrategy reqPrompt:\n{}", prompt.toUserMessage());
|
||||||
ChatLanguageModel chatLanguageModel = getChatLanguageModel(llmReq.getLlmConfig());
|
ChatLanguageModel chatLanguageModel = getChatLanguageModel(llmReq.getLlmConfig());
|
||||||
Response<AiMessage> response = chatLanguageModel.generate(prompt.toSystemMessage());
|
Response<AiMessage> response = chatLanguageModel.generate(prompt.toUserMessage());
|
||||||
String result = response.content().text();
|
String result = response.content().text();
|
||||||
prompt2Output.put(prompt, result);
|
prompt2Output.put(prompt, result);
|
||||||
keyPipelineLog.info("OnePassSCSqlGenStrategy modelResp:\n{}", result);
|
keyPipelineLog.info("OnePassSCSqlGenStrategy modelResp:\n{}", result);
|
||||||
|
|||||||
@@ -99,7 +99,14 @@ langchain4j:
|
|||||||
model-name: ${OPENAI_MODEL_NAME:gpt-3.5-turbo}
|
model-name: ${OPENAI_MODEL_NAME:gpt-3.5-turbo}
|
||||||
temperature: ${OPENAI_TEMPERATURE:0.0}
|
temperature: ${OPENAI_TEMPERATURE:0.0}
|
||||||
timeout: ${OPENAI_TIMEOUT:PT60S}
|
timeout: ${OPENAI_TIMEOUT:PT60S}
|
||||||
# java.lang.RuntimeException: dev.ai4j.openai4j.OpenAiHttpException: Too many requests
|
|
||||||
# embedding-model:
|
# embedding-model:
|
||||||
# base-url: ${OPENAI_API_BASE:https://api.openai.com/v1}
|
# base-url: ${OPENAI_API_BASE:https://api.openai.com/v1}
|
||||||
# api-key: ${OPENAI_API_KEY:demo}
|
# api-key: ${OPENAI_API_KEY:demo}
|
||||||
|
# dashscope:
|
||||||
|
# chat-model:
|
||||||
|
# api-key: ${OPENAI_API_KEY:demo}
|
||||||
|
# model-name: qwen-max-1201
|
||||||
|
# embedding-model:
|
||||||
|
# api-key: ${OPENAI_API_KEY:demo}
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
6
pom.xml
6
pom.xml
@@ -199,6 +199,12 @@
|
|||||||
<groupId>dev.langchain4j</groupId>
|
<groupId>dev.langchain4j</groupId>
|
||||||
<artifactId>langchain4j-dashscope</artifactId>
|
<artifactId>langchain4j-dashscope</artifactId>
|
||||||
<version>${langchain4j.version}</version>
|
<version>${langchain4j.version}</version>
|
||||||
|
<exclusions>
|
||||||
|
<exclusion>
|
||||||
|
<groupId>org.slf4j</groupId>
|
||||||
|
<artifactId>slf4j-simple</artifactId>
|
||||||
|
</exclusion>
|
||||||
|
</exclusions>
|
||||||
</dependency>
|
</dependency>
|
||||||
|
|
||||||
<dependency>
|
<dependency>
|
||||||
|
|||||||
Reference in New Issue
Block a user