(improvement)(pyllm)Use HTTP parameter llm_config in place of the default llm_config

This commit is contained in:
jerryjzhang
2024-05-20 17:40:34 +08:00
parent 53b6c03288
commit eaec7b4663
11 changed files with 106 additions and 86 deletions

View File

@@ -4,11 +4,13 @@ import com.tencent.supersonic.chat.server.persistence.repository.ChatQueryReposi
import com.tencent.supersonic.chat.server.pojo.ChatParseContext;
import com.tencent.supersonic.chat.server.util.QueryReqConverter;
import com.tencent.supersonic.common.util.ContextUtils;
import com.tencent.supersonic.headless.api.pojo.LLMConfig;
import com.tencent.supersonic.headless.api.pojo.SchemaElementMatch;
import com.tencent.supersonic.headless.api.pojo.SchemaElementType;
import com.tencent.supersonic.headless.api.pojo.request.QueryReq;
import com.tencent.supersonic.headless.api.pojo.response.MapResp;
import com.tencent.supersonic.headless.api.pojo.response.ParseResp;
import com.tencent.supersonic.headless.core.utils.S2ChatModelProvider;
import com.tencent.supersonic.headless.server.service.ChatQueryService;
import dev.langchain4j.data.message.AiMessage;
import dev.langchain4j.model.chat.ChatLanguageModel;
@@ -50,7 +52,7 @@ public class MultiTurnParser implements ChatParser {
public void parse(ChatParseContext chatParseContext, ParseResp parseResp) {
Environment environment = ContextUtils.getBean(Environment.class);
Boolean multiTurn = environment.getProperty("multi.turn", Boolean.class);
if (Boolean.FALSE.equals(multiTurn)) {
if (!Boolean.TRUE.equals(multiTurn)) {
return;
}
@@ -73,6 +75,7 @@ public class MultiTurnParser implements ChatParser {
.histQuestion(lastParseResult.getQueryText())
.curtSchema(curtMapStr)
.histSchema(histMapStr)
.llmConfig(queryReq.getLlmConfig())
.build());
chatParseContext.setQueryText(rewrittenQuery);
log.info("Last Query: {} Current Query: {}, Rewritten Query: {}",
@@ -80,7 +83,6 @@ public class MultiTurnParser implements ChatParser {
}
private String rewriteQuery(RewriteContext context) {
Map<String, Object> variables = new HashMap<>();
variables.put("curtQuestion", context.getCurtQuestion());
variables.put("histQuestion", context.getHistQuestion());
@@ -89,14 +91,13 @@ public class MultiTurnParser implements ChatParser {
Prompt prompt = promptTemplate.apply(variables);
keyPipelineLog.info("request prompt:{}", prompt.toSystemMessage());
ChatLanguageModel chatLanguageModel = ContextUtils.getBean(ChatLanguageModel.class);
ChatLanguageModel chatLanguageModel = S2ChatModelProvider.provide(context.getLlmConfig());
Response<AiMessage> response = chatLanguageModel.generate(prompt.toSystemMessage());
String result = response.content().text();
keyPipelineLog.info("model response:{}", result);
//3.format response.
String rewriteQuery = response.content().text();
return rewriteQuery;
return response.content().text();
}
private String generateSchemaPrompt(List<SchemaElementMatch> elementMatches) {
@@ -142,5 +143,6 @@ public class MultiTurnParser implements ChatParser {
private String histQuestion;
private String curtSchema;
private String histSchema;
private LLMConfig llmConfig;
}
}