mirror of
https://github.com/tencentmusic/supersonic.git
synced 2025-12-14 13:47:09 +00:00
(improvement)(chat)Implement a new version of multi-turn conversation.
This commit is contained in:
@@ -1,32 +0,0 @@
|
||||
package com.tencent.supersonic.headless.core.chat.parser.llm;
|
||||
|
||||
|
||||
import com.fasterxml.jackson.core.type.TypeReference;
|
||||
import com.tencent.supersonic.common.util.JsonUtil;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.springframework.core.io.ClassPathResource;
|
||||
import org.springframework.stereotype.Component;
|
||||
|
||||
import java.io.InputStream;
|
||||
import java.util.List;
|
||||
import java.util.ArrayList;
|
||||
|
||||
@Slf4j
|
||||
@Component
|
||||
public class RewriteExamplarLoader {
|
||||
|
||||
private static final String EXAMPLE_JSON_FILE = "rewrite_examplar.json";
|
||||
|
||||
private TypeReference<List<RewriteExample>> valueTypeRef = new TypeReference<List<RewriteExample>>() {
|
||||
};
|
||||
|
||||
public List<RewriteExample> getRewriteExamples() {
|
||||
try {
|
||||
ClassPathResource resource = new ClassPathResource(EXAMPLE_JSON_FILE);
|
||||
InputStream inputStream = resource.getInputStream();
|
||||
return JsonUtil.INSTANCE.getObjectMapper().readValue(inputStream, valueTypeRef);
|
||||
} catch (Exception e) {
|
||||
return new ArrayList<>();
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -1,14 +0,0 @@
|
||||
package com.tencent.supersonic.headless.core.chat.parser.llm;
|
||||
|
||||
import lombok.Data;
|
||||
|
||||
@Data
|
||||
public class RewriteExample {
|
||||
|
||||
private String contextualQuestions;
|
||||
|
||||
private String currentQuestion;
|
||||
|
||||
private String rewritingCurrentQuestion;
|
||||
|
||||
}
|
||||
@@ -1,54 +0,0 @@
|
||||
package com.tencent.supersonic.headless.core.chat.parser.llm;
|
||||
|
||||
import com.tencent.supersonic.common.util.JsonUtil;
|
||||
import dev.langchain4j.data.message.AiMessage;
|
||||
import dev.langchain4j.model.chat.ChatLanguageModel;
|
||||
import dev.langchain4j.model.input.Prompt;
|
||||
import dev.langchain4j.model.input.PromptTemplate;
|
||||
import dev.langchain4j.model.output.Response;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.slf4j.Logger;
|
||||
import org.slf4j.LoggerFactory;
|
||||
import org.springframework.beans.factory.annotation.Autowired;
|
||||
import org.springframework.stereotype.Service;
|
||||
|
||||
import java.util.HashMap;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
import java.util.stream.Collectors;
|
||||
|
||||
@Service
|
||||
@Slf4j
|
||||
public class RewriteQueryGeneration {
|
||||
|
||||
private static final Logger keyPipelineLog = LoggerFactory.getLogger("keyPipeline");
|
||||
@Autowired
|
||||
private ChatLanguageModel chatLanguageModel;
|
||||
|
||||
@Autowired
|
||||
private RewriteExamplarLoader rewriteExamplarLoader;
|
||||
|
||||
@Autowired
|
||||
private SqlPromptGenerator sqlPromptGenerator;
|
||||
|
||||
public String generation(String currentPromptStr) {
|
||||
//1.retriever sqlExamples
|
||||
|
||||
List<Map<String, String>> rewriteExamples = rewriteExamplarLoader.getRewriteExamples().stream().map(o -> {
|
||||
return JsonUtil.toMap(JsonUtil.toString(o), String.class, String.class);
|
||||
}).collect(Collectors.toList());
|
||||
|
||||
//2.generator linking and sql prompt by sqlExamples,and generate response.
|
||||
String promptStr = sqlPromptGenerator.generateRewritePrompt(rewriteExamples) + currentPromptStr;
|
||||
|
||||
Prompt prompt = PromptTemplate.from(JsonUtil.toString(promptStr)).apply(new HashMap<>());
|
||||
keyPipelineLog.info("request prompt:{}", prompt.toSystemMessage());
|
||||
Response<AiMessage> response = chatLanguageModel.generate(prompt.toSystemMessage());
|
||||
String result = response.content().text();
|
||||
keyPipelineLog.info("model response:{}", result);
|
||||
//3.format response.
|
||||
String rewriteQuery = response.content().text();
|
||||
|
||||
return rewriteQuery;
|
||||
}
|
||||
}
|
||||
@@ -129,19 +129,4 @@ public class SqlPromptGenerator {
|
||||
return sqlPromptPool;
|
||||
}
|
||||
|
||||
public String generateRewritePrompt(List<Map<String, String>> rewriteExamples) {
|
||||
String instruction = "#this is a multi-turn text-to-sql scenes,you need consider the contextual "
|
||||
+ "questions and semantics, rewriting current question for expressing complete semantics of "
|
||||
+ "the current question based on the contextual questions.";
|
||||
List<String> exampleKeys = Arrays.asList("contextualQuestions", "currentQuestion", "rewritingCurrentQuestion");
|
||||
StringBuilder rewriteSb = new StringBuilder();
|
||||
rewriteExamples.stream().forEach(o -> {
|
||||
exampleKeys.stream().forEach(example -> {
|
||||
rewriteSb.append(example + ":" + o.get(example) + "\n");
|
||||
});
|
||||
rewriteSb.append("\n");
|
||||
});
|
||||
return instruction + InputFormat.SEPERATOR + rewriteSb.toString();
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user