(improvement)(Headless) support multiturn text-to-sql (#983)

This commit is contained in:
mainmain
2024-05-13 14:13:02 +08:00
committed by GitHub
parent 947a01e8ba
commit 0e28d6cbcc
15 changed files with 407 additions and 11 deletions

View File

@@ -143,7 +143,7 @@ public class LLMRequestService {
return extraInfoSb.toString();
}
protected List<ElementValue> getValueList(QueryContext queryCtx, Long dataSetId) {
public List<ElementValue> getValueList(QueryContext queryCtx, Long dataSetId) {
Map<Long, String> itemIdToName = getItemIdToName(queryCtx, dataSetId);
List<SchemaElementMatch> matchedElements = queryCtx.getMapInfo().getMatchedElements(dataSetId);
if (CollectionUtils.isEmpty(matchedElements)) {

View File

@@ -0,0 +1,32 @@
package com.tencent.supersonic.headless.core.chat.parser.llm;
import com.fasterxml.jackson.core.type.TypeReference;
import com.tencent.supersonic.common.util.JsonUtil;
import lombok.extern.slf4j.Slf4j;
import org.springframework.core.io.ClassPathResource;
import org.springframework.stereotype.Component;
import java.io.InputStream;
import java.util.List;
import java.util.ArrayList;
@Slf4j
@Component
public class RewriteExamplarLoader {
private static final String EXAMPLE_JSON_FILE = "rewrite_examplar.json";
private TypeReference<List<RewriteExample>> valueTypeRef = new TypeReference<List<RewriteExample>>() {
};
public List<RewriteExample> getRewriteExamples() {
try {
ClassPathResource resource = new ClassPathResource(EXAMPLE_JSON_FILE);
InputStream inputStream = resource.getInputStream();
return JsonUtil.INSTANCE.getObjectMapper().readValue(inputStream, valueTypeRef);
} catch (Exception e) {
return new ArrayList<>();
}
}
}

View File

@@ -0,0 +1,14 @@
package com.tencent.supersonic.headless.core.chat.parser.llm;
import lombok.Data;
@Data
public class RewriteExample {
private String contextualQuestions;
private String currentQuestion;
private String rewritingCurrentQuestion;
}

View File

@@ -0,0 +1,54 @@
package com.tencent.supersonic.headless.core.chat.parser.llm;
import com.tencent.supersonic.common.util.JsonUtil;
import dev.langchain4j.data.message.AiMessage;
import dev.langchain4j.model.chat.ChatLanguageModel;
import dev.langchain4j.model.input.Prompt;
import dev.langchain4j.model.input.PromptTemplate;
import dev.langchain4j.model.output.Response;
import lombok.extern.slf4j.Slf4j;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.stereotype.Service;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.stream.Collectors;
@Service
@Slf4j
public class RewriteQueryGeneration {
private static final Logger keyPipelineLog = LoggerFactory.getLogger("keyPipeline");
@Autowired
private ChatLanguageModel chatLanguageModel;
@Autowired
private RewriteExamplarLoader rewriteExamplarLoader;
@Autowired
private SqlPromptGenerator sqlPromptGenerator;
public String generation(String currentPromptStr) {
//1.retriever sqlExamples
List<Map<String, String>> rewriteExamples = rewriteExamplarLoader.getRewriteExamples().stream().map(o -> {
return JsonUtil.toMap(JsonUtil.toString(o), String.class, String.class);
}).collect(Collectors.toList());
//2.generator linking and sql prompt by sqlExamples,and generate response.
String promptStr = sqlPromptGenerator.generateRewritePrompt(rewriteExamples) + currentPromptStr;
Prompt prompt = PromptTemplate.from(JsonUtil.toString(promptStr)).apply(new HashMap<>());
keyPipelineLog.info("request prompt:{}", prompt.toSystemMessage());
Response<AiMessage> response = chatLanguageModel.generate(prompt.toSystemMessage());
String result = response.content().text();
keyPipelineLog.info("model response:{}", result);
//3.format response.
String rewriteQuery = response.content().text();
return rewriteQuery;
}
}

View File

@@ -129,4 +129,19 @@ public class SqlPromptGenerator {
return sqlPromptPool;
}
public String generateRewritePrompt(List<Map<String, String>> rewriteExamples) {
String instruction = "#this is a multi-turn text-to-sql scenes,you need consider the contextual "
+ "questions and semantics, rewriting current question for expressing complete semantics of "
+ "the current question based on the contextual questions.";
List<String> exampleKeys = Arrays.asList("contextualQuestions", "currentQuestion", "rewritingCurrentQuestion");
StringBuilder rewriteSb = new StringBuilder();
rewriteExamples.stream().forEach(o -> {
exampleKeys.stream().forEach(example -> {
rewriteSb.append(example + ":" + o.get(example) + "\n");
});
rewriteSb.append("\n");
});
return instruction + InputFormat.SEPERATOR + rewriteSb.toString();
}
}