mirror of
https://github.com/tencentmusic/supersonic.git
synced 2025-12-11 12:07:42 +00:00
(feature)(chat&common)Introduce ChatMemory module to support dynamic few-shot exemplars.#1097
This commit is contained in:
@@ -1,14 +1,17 @@
|
||||
package com.tencent.supersonic.headless.api.pojo.request;
|
||||
|
||||
import com.google.common.collect.Lists;
|
||||
import com.google.common.collect.Sets;
|
||||
import com.tencent.supersonic.auth.api.authentication.pojo.User;
|
||||
import com.tencent.supersonic.common.pojo.enums.Text2SQLType;
|
||||
import com.tencent.supersonic.common.config.LLMConfig;
|
||||
import com.tencent.supersonic.headless.api.pojo.QueryDataType;
|
||||
import com.tencent.supersonic.headless.api.pojo.SchemaMapInfo;
|
||||
import com.tencent.supersonic.common.pojo.SqlExemplar;
|
||||
import com.tencent.supersonic.headless.api.pojo.enums.MapModeEnum;
|
||||
import lombok.Data;
|
||||
|
||||
import java.util.List;
|
||||
import java.util.Set;
|
||||
|
||||
@Data
|
||||
@@ -24,4 +27,5 @@ public class QueryReq {
|
||||
private SchemaMapInfo mapInfo = new SchemaMapInfo();
|
||||
private QueryDataType queryDataType = QueryDataType.ALL;
|
||||
private LLMConfig llmConfig;
|
||||
private List<SqlExemplar> exemplars = Lists.newArrayList();
|
||||
}
|
||||
|
||||
@@ -8,6 +8,7 @@ import com.tencent.supersonic.common.config.LLMConfig;
|
||||
import com.tencent.supersonic.headless.api.pojo.QueryDataType;
|
||||
import com.tencent.supersonic.headless.api.pojo.SchemaMapInfo;
|
||||
import com.tencent.supersonic.headless.api.pojo.SemanticSchema;
|
||||
import com.tencent.supersonic.common.pojo.SqlExemplar;
|
||||
import com.tencent.supersonic.headless.api.pojo.enums.MapModeEnum;
|
||||
import com.tencent.supersonic.headless.api.pojo.enums.WorkflowState;
|
||||
import com.tencent.supersonic.headless.api.pojo.request.QueryFilters;
|
||||
@@ -49,6 +50,7 @@ public class QueryContext {
|
||||
private WorkflowState workflowState;
|
||||
private QueryDataType queryDataType = QueryDataType.ALL;
|
||||
private LLMConfig llmConfig;
|
||||
private List<SqlExemplar> exemplars;
|
||||
|
||||
public List<SemanticQuery> getCandidateQueries() {
|
||||
ParserConfig parserConfig = ContextUtils.getBean(ParserConfig.class);
|
||||
|
||||
@@ -1,16 +0,0 @@
|
||||
package com.tencent.supersonic.headless.chat.parser.llm;
|
||||
|
||||
import lombok.Data;
|
||||
|
||||
@Data
|
||||
public class Exemplar {
|
||||
|
||||
private String question;
|
||||
|
||||
private String questionAugmented;
|
||||
|
||||
private String dbSchema;
|
||||
|
||||
private String sql;
|
||||
|
||||
}
|
||||
@@ -1,101 +0,0 @@
|
||||
package com.tencent.supersonic.headless.chat.parser.llm;
|
||||
|
||||
|
||||
import com.fasterxml.jackson.core.type.TypeReference;
|
||||
import com.tencent.supersonic.common.config.EmbeddingConfig;
|
||||
import com.tencent.supersonic.common.service.EmbeddingService;
|
||||
import com.tencent.supersonic.common.util.JsonUtil;
|
||||
import com.tencent.supersonic.headless.chat.utils.ComponentFactory;
|
||||
import dev.langchain4j.data.document.Metadata;
|
||||
import dev.langchain4j.data.segment.TextSegment;
|
||||
import dev.langchain4j.store.embedding.Retrieval;
|
||||
import dev.langchain4j.store.embedding.RetrieveQuery;
|
||||
import dev.langchain4j.store.embedding.RetrieveQueryResult;
|
||||
import dev.langchain4j.store.embedding.TextSegmentConvert;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.apache.commons.collections.CollectionUtils;
|
||||
import org.springframework.beans.factory.annotation.Autowired;
|
||||
import org.springframework.boot.CommandLineRunner;
|
||||
import org.springframework.core.annotation.Order;
|
||||
import org.springframework.core.io.ClassPathResource;
|
||||
import org.springframework.stereotype.Component;
|
||||
|
||||
import java.io.IOException;
|
||||
import java.io.InputStream;
|
||||
import java.util.ArrayList;
|
||||
import java.util.Collections;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
import java.util.Objects;
|
||||
import java.util.stream.Collectors;
|
||||
|
||||
@Slf4j
|
||||
@Component
|
||||
@Order(0)
|
||||
public class ExemplarManager implements CommandLineRunner {
|
||||
|
||||
private static final String EXAMPLE_JSON_FILE = "s2ql_exemplar.json";
|
||||
|
||||
@Autowired
|
||||
private EmbeddingService embeddingService;
|
||||
|
||||
@Autowired
|
||||
private EmbeddingConfig embeddingConfig;
|
||||
|
||||
private TypeReference<List<Exemplar>> valueTypeRef = new TypeReference<List<Exemplar>>() {
|
||||
};
|
||||
|
||||
@Override
|
||||
public void run(String... args) {
|
||||
try {
|
||||
if (ComponentFactory.getLLMProxy() instanceof JavaLLMProxy) {
|
||||
loadDefaultExemplars();
|
||||
}
|
||||
} catch (Exception e) {
|
||||
log.error("Failed to init examples", e);
|
||||
}
|
||||
}
|
||||
|
||||
public void addExemplars(List<Exemplar> exemplars, String collectionName) {
|
||||
List<TextSegment> queries = new ArrayList<>();
|
||||
for (int i = 0; i < exemplars.size(); i++) {
|
||||
Exemplar exemplar = exemplars.get(i);
|
||||
String question = exemplar.getQuestion();
|
||||
Map<String, Object> metaDataMap = JsonUtil.toMap(JsonUtil.toString(exemplar), String.class, Object.class);
|
||||
TextSegment embeddingQuery = TextSegment.from(question, new Metadata(metaDataMap));
|
||||
TextSegmentConvert.addQueryId(embeddingQuery, String.valueOf(i));
|
||||
queries.add(embeddingQuery);
|
||||
}
|
||||
embeddingService.addQuery(collectionName, queries);
|
||||
}
|
||||
|
||||
public List<Map<String, String>> recallExemplars(String queryText, int maxResults) {
|
||||
String collectionName = embeddingConfig.getText2sqlCollectionName();
|
||||
RetrieveQuery retrieveQuery = RetrieveQuery.builder().queryTextsList(Collections.singletonList(queryText))
|
||||
.queryEmbeddings(null).build();
|
||||
|
||||
List<RetrieveQueryResult> resultList = embeddingService.retrieveQuery(collectionName, retrieveQuery,
|
||||
maxResults);
|
||||
List<Map<String, String>> result = new ArrayList<>();
|
||||
if (CollectionUtils.isEmpty(resultList)) {
|
||||
return result;
|
||||
}
|
||||
for (Retrieval retrieval : resultList.get(0).getRetrieval()) {
|
||||
if (Objects.nonNull(retrieval.getMetadata()) && !retrieval.getMetadata().isEmpty()) {
|
||||
Map<String, String> convertedMap = retrieval.getMetadata().entrySet().stream()
|
||||
.collect(Collectors.toMap(Map.Entry::getKey, entry -> String.valueOf(entry.getValue())));
|
||||
result.add(convertedMap);
|
||||
}
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
private void loadDefaultExemplars() throws IOException {
|
||||
ClassPathResource resource = new ClassPathResource(EXAMPLE_JSON_FILE);
|
||||
InputStream inputStream = resource.getInputStream();
|
||||
List<Exemplar> examples = JsonUtil.INSTANCE.getObjectMapper().readValue(inputStream, valueTypeRef);
|
||||
String collectionName = embeddingConfig.getText2sqlCollectionName();
|
||||
addExemplars(examples, collectionName);
|
||||
}
|
||||
|
||||
}
|
||||
@@ -109,6 +109,8 @@ public class LLMRequestService {
|
||||
llmReq.setSqlGenType(LLMReq.SqlGenType.valueOf(parserConfig.getParameterValue(PARSER_STRATEGY_TYPE)));
|
||||
llmReq.setLlmConfig(queryCtx.getLlmConfig());
|
||||
|
||||
llmReq.setExemplars(queryCtx.getExemplars());
|
||||
|
||||
return llmReq;
|
||||
}
|
||||
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
package com.tencent.supersonic.headless.chat.parser.llm;
|
||||
|
||||
import com.google.common.collect.Lists;
|
||||
import com.tencent.supersonic.common.pojo.SqlExemplar;
|
||||
import com.tencent.supersonic.headless.chat.query.llm.s2sql.LLMReq;
|
||||
import com.tencent.supersonic.headless.chat.query.llm.s2sql.LLMResp;
|
||||
import dev.langchain4j.data.message.AiMessage;
|
||||
@@ -42,11 +43,11 @@ public class OnePassSCSqlGenStrategy extends SqlGenStrategy {
|
||||
public LLMResp generate(LLMReq llmReq) {
|
||||
//1.recall exemplars
|
||||
keyPipelineLog.info("OnePassSCSqlGenStrategy llmReq:\n{}", llmReq);
|
||||
List<List<Map<String, String>>> exemplarsList = promptHelper.getFewShotExemplars(llmReq);
|
||||
List<List<SqlExemplar>> exemplarsList = promptHelper.getFewShotExemplars(llmReq);
|
||||
|
||||
//2.generate sql generation prompt for each self-consistency inference
|
||||
Map<Prompt, List<Map<String, String>>> prompt2Exemplar = new HashMap<>();
|
||||
for (List<Map<String, String>> exemplars : exemplarsList) {
|
||||
Map<Prompt, List<SqlExemplar>> prompt2Exemplar = new HashMap<>();
|
||||
for (List<SqlExemplar> exemplars : exemplarsList) {
|
||||
Prompt prompt = generatePrompt(llmReq, exemplars);
|
||||
prompt2Exemplar.put(prompt, exemplars);
|
||||
}
|
||||
@@ -67,25 +68,24 @@ public class OnePassSCSqlGenStrategy extends SqlGenStrategy {
|
||||
Pair<String, Map<String, Double>> sqlMapPair = OutputFormat.selfConsistencyVote(
|
||||
Lists.newArrayList(prompt2Output.values()));
|
||||
LLMResp llmResp = new LLMResp();
|
||||
llmResp.setQuery(llmReq.getQueryText());
|
||||
llmResp.setQuery(promptHelper.buildAugmentedQuestion(llmReq));
|
||||
llmResp.setDbSchema(promptHelper.buildSchemaStr(llmReq));
|
||||
llmResp.setSqlOutput(sqlMapPair.getLeft());
|
||||
//TODO: should use the same few-shot exemplars as the one chose by self-consistency vote
|
||||
llmResp.setSqlRespMap(OutputFormat.buildSqlRespMap(exemplarsList.get(0), sqlMapPair.getRight()));
|
||||
|
||||
return llmResp;
|
||||
}
|
||||
|
||||
private Prompt generatePrompt(LLMReq llmReq, List<Map<String, String>> fewshotExampleList) {
|
||||
private Prompt generatePrompt(LLMReq llmReq, List<SqlExemplar> fewshotExampleList) {
|
||||
StringBuilder exemplarsStr = new StringBuilder();
|
||||
for (Map<String, String> example : fewshotExampleList) {
|
||||
String metadata = example.get("dbSchema");
|
||||
String question = example.get("questionAugmented");
|
||||
String sql = example.get("sql");
|
||||
for (SqlExemplar exemplar : fewshotExampleList) {
|
||||
String exemplarStr = String.format("#UserQuery: %s #Schema: %s #SQL: %s\n",
|
||||
question, metadata, sql);
|
||||
exemplar.getQuestion(), exemplar.getDbSchema(), exemplar.getSql());
|
||||
exemplarsStr.append(exemplarStr);
|
||||
}
|
||||
|
||||
String dataSemanticsStr = promptHelper.buildMetadataStr(llmReq);
|
||||
String dataSemanticsStr = promptHelper.buildSchemaStr(llmReq);
|
||||
String questionAugmented = promptHelper.buildAugmentedQuestion(llmReq);
|
||||
String promptStr = String.format(INSTRUCTION, exemplarsStr, questionAugmented, dataSemanticsStr);
|
||||
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
package com.tencent.supersonic.headless.chat.parser.llm;
|
||||
|
||||
import com.tencent.supersonic.common.pojo.SqlExemplar;
|
||||
import com.tencent.supersonic.headless.chat.query.llm.s2sql.LLMSqlResp;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.apache.commons.lang3.tuple.Pair;
|
||||
@@ -108,7 +109,7 @@ public class OutputFormat {
|
||||
return results;
|
||||
}
|
||||
|
||||
public static Map<String, LLMSqlResp> buildSqlRespMap(List<Map<String, String>> sqlExamples,
|
||||
public static Map<String, LLMSqlResp> buildSqlRespMap(List<SqlExemplar> sqlExamples,
|
||||
Map<String, Double> sqlMap) {
|
||||
if (sqlMap == null) {
|
||||
return new HashMap<>();
|
||||
|
||||
@@ -1,5 +1,8 @@
|
||||
package com.tencent.supersonic.headless.chat.parser.llm;
|
||||
|
||||
import com.google.common.collect.Lists;
|
||||
import com.tencent.supersonic.common.pojo.SqlExemplar;
|
||||
import com.tencent.supersonic.common.service.ExemplarService;
|
||||
import com.tencent.supersonic.headless.chat.parser.ParserConfig;
|
||||
import com.tencent.supersonic.headless.chat.query.llm.s2sql.LLMReq;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
@@ -11,7 +14,6 @@ import org.springframework.util.CollectionUtils;
|
||||
import java.util.ArrayList;
|
||||
import java.util.Collections;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
|
||||
import static com.tencent.supersonic.headless.chat.parser.ParserConfig.PARSER_EXEMPLAR_RECALL_NUMBER;
|
||||
import static com.tencent.supersonic.headless.chat.parser.ParserConfig.PARSER_FEW_SHOT_NUMBER;
|
||||
@@ -25,20 +27,27 @@ public class PromptHelper {
|
||||
private ParserConfig parserConfig;
|
||||
|
||||
@Autowired
|
||||
private ExemplarManager exemplarManager;
|
||||
private ExemplarService exemplarService;
|
||||
|
||||
public List<List<Map<String, String>>> getFewShotExemplars(LLMReq llmReq) {
|
||||
public List<List<SqlExemplar>> getFewShotExemplars(LLMReq llmReq) {
|
||||
int exemplarRecallNumber = Integer.valueOf(parserConfig.getParameterValue(PARSER_EXEMPLAR_RECALL_NUMBER));
|
||||
int fewShotNumber = Integer.valueOf(parserConfig.getParameterValue(PARSER_FEW_SHOT_NUMBER));
|
||||
int selfConsistencyNumber = Integer.valueOf(parserConfig.getParameterValue(PARSER_SELF_CONSISTENCY_NUMBER));
|
||||
|
||||
List<Map<String, String>> exemplars = exemplarManager.recallExemplars(llmReq.getQueryText(),
|
||||
exemplarRecallNumber);
|
||||
List<List<Map<String, String>>> results = new ArrayList<>();
|
||||
List<SqlExemplar> exemplars = Lists.newArrayList();
|
||||
llmReq.getExemplars().stream().forEach(e -> {
|
||||
exemplars.add(e);
|
||||
});
|
||||
|
||||
int recallSize = exemplarRecallNumber - llmReq.getExemplars().size();
|
||||
if (recallSize > 0) {
|
||||
exemplars.addAll(exemplarService.recallExemplars(llmReq.getQueryText(), recallSize));
|
||||
}
|
||||
|
||||
List<List<SqlExemplar>> results = new ArrayList<>();
|
||||
// use random collection of exemplars for each self-consistency inference
|
||||
for (int i = 0; i < selfConsistencyNumber; i++) {
|
||||
List<Map<String, String>> shuffledList = new ArrayList<>(exemplars);
|
||||
List<SqlExemplar> shuffledList = new ArrayList<>(exemplars);
|
||||
Collections.shuffle(shuffledList);
|
||||
results.add(shuffledList.subList(0, fewShotNumber));
|
||||
}
|
||||
@@ -64,7 +73,7 @@ public class PromptHelper {
|
||||
linkingListStr, currentDataStr, termStr, priorExts);
|
||||
}
|
||||
|
||||
public String buildMetadataStr(LLMReq llmReq) {
|
||||
public String buildSchemaStr(LLMReq llmReq) {
|
||||
String tableStr = llmReq.getSchema().getDataSetName();
|
||||
StringBuilder metricStr = new StringBuilder();
|
||||
StringBuilder dimensionStr = new StringBuilder();
|
||||
|
||||
@@ -4,6 +4,7 @@ import com.fasterxml.jackson.annotation.JsonValue;
|
||||
import com.google.common.collect.Lists;
|
||||
import com.tencent.supersonic.common.config.LLMConfig;
|
||||
import com.tencent.supersonic.headless.api.pojo.SchemaElement;
|
||||
import com.tencent.supersonic.common.pojo.SqlExemplar;
|
||||
import lombok.Data;
|
||||
|
||||
import java.util.List;
|
||||
@@ -26,6 +27,9 @@ public class LLMReq {
|
||||
private SqlGenType sqlGenType;
|
||||
|
||||
private LLMConfig llmConfig;
|
||||
|
||||
private List<SqlExemplar> exemplars;
|
||||
|
||||
@Data
|
||||
public static class ElementValue {
|
||||
|
||||
|
||||
@@ -12,6 +12,8 @@ public class LLMResp {
|
||||
|
||||
private String modelName;
|
||||
|
||||
private String dbSchema;
|
||||
|
||||
private String sqlOutput;
|
||||
|
||||
private List<String> fields;
|
||||
|
||||
@@ -1,13 +1,12 @@
|
||||
package com.tencent.supersonic.headless.chat.query.llm.s2sql;
|
||||
|
||||
import com.tencent.supersonic.common.pojo.SqlExemplar;
|
||||
import lombok.AllArgsConstructor;
|
||||
import lombok.Builder;
|
||||
import lombok.Data;
|
||||
import lombok.NoArgsConstructor;
|
||||
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
|
||||
|
||||
@Data
|
||||
@Builder
|
||||
@@ -17,6 +16,6 @@ public class LLMSqlResp {
|
||||
|
||||
private double sqlWeight;
|
||||
|
||||
private List<Map<String, String>> fewShots;
|
||||
private List<SqlExemplar> fewShots;
|
||||
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user