mirror of
https://github.com/tencentmusic/supersonic.git
synced 2025-12-10 11:07:06 +00:00
(improvement)(common)Rename SqlExemplar to Text2SQLExemplar.
This commit is contained in:
@@ -6,7 +6,7 @@ import com.tencent.supersonic.chat.server.pojo.ExecuteContext;
|
||||
import com.tencent.supersonic.chat.server.service.MemoryService;
|
||||
import com.tencent.supersonic.chat.server.util.ResultFormatter;
|
||||
import com.tencent.supersonic.common.pojo.QueryColumn;
|
||||
import com.tencent.supersonic.common.pojo.SqlExemplar;
|
||||
import com.tencent.supersonic.common.pojo.Text2SQLExemplar;
|
||||
import com.tencent.supersonic.common.util.ContextUtils;
|
||||
import com.tencent.supersonic.common.util.JsonUtil;
|
||||
import com.tencent.supersonic.headless.api.pojo.SemanticParseInfo;
|
||||
@@ -41,9 +41,9 @@ public class SqlExecutor implements ChatQueryExecutor {
|
||||
|
||||
if (queryResult.getQueryState().equals(QueryState.SUCCESS)
|
||||
&& queryResult.getQueryMode().equals(LLMSqlQuery.QUERY_MODE)) {
|
||||
SqlExemplar exemplar = JsonUtil.toObject(JsonUtil.toString(
|
||||
Text2SQLExemplar exemplar = JsonUtil.toObject(JsonUtil.toString(
|
||||
executeContext.getParseInfo().getProperties()
|
||||
.get(SqlExemplar.PROPERTY_KEY)), SqlExemplar.class);
|
||||
.get(Text2SQLExemplar.PROPERTY_KEY)), Text2SQLExemplar.class);
|
||||
|
||||
MemoryService memoryService = ContextUtils.getBean(MemoryService.class);
|
||||
memoryService.createMemory(ChatMemoryDO.builder()
|
||||
|
||||
@@ -32,7 +32,8 @@ public class MemoryReviewTask {
|
||||
+ "please take a review and give your opinion.\n"
|
||||
+ "#Rules: "
|
||||
+ "1.ALWAYS follow the output format: `opinion=(POSITIVE|NEGATIVE),comment=(your comment)`."
|
||||
+ "2.ALWAYS recognize `数据日期` as the date field.\n"
|
||||
+ "2.ALWAYS recognize `数据日期` as the date field."
|
||||
+ "3.IGNORE `数据日期` if not expressed in the `Question`."
|
||||
+ "#Question: %s\n"
|
||||
+ "#Schema: %s\n"
|
||||
+ "#SideInfo: %s\n"
|
||||
@@ -57,12 +58,12 @@ public class MemoryReviewTask {
|
||||
m.getSideInfo(), m.getS2sql());
|
||||
Prompt prompt = PromptTemplate.from(promptStr).apply(Collections.EMPTY_MAP);
|
||||
|
||||
keyPipelineLog.info("MemoryReviewTask reqPrompt:{}", promptStr);
|
||||
keyPipelineLog.info("MemoryReviewTask reqPrompt:\n{}", promptStr);
|
||||
ChatLanguageModel chatLanguageModel = ModelProvider.getChatModel(
|
||||
chatAgent.getModelConfig());
|
||||
if (Objects.nonNull(chatLanguageModel)) {
|
||||
String response = chatLanguageModel.generate(prompt.toUserMessage()).content().text();
|
||||
keyPipelineLog.info("MemoryReviewTask modelResp:{}", response);
|
||||
keyPipelineLog.info("MemoryReviewTask modelResp:\n{}", response);
|
||||
|
||||
Matcher matcher = OUTPUT_PATTERN.matcher(response);
|
||||
if (matcher.find()) {
|
||||
|
||||
@@ -7,7 +7,7 @@ import com.tencent.supersonic.chat.server.pojo.ParseContext;
|
||||
import com.tencent.supersonic.chat.server.service.ChatContextService;
|
||||
import com.tencent.supersonic.chat.server.util.QueryReqConverter;
|
||||
import com.tencent.supersonic.common.config.EmbeddingConfig;
|
||||
import com.tencent.supersonic.common.pojo.SqlExemplar;
|
||||
import com.tencent.supersonic.common.pojo.Text2SQLExemplar;
|
||||
import com.tencent.supersonic.common.service.impl.ExemplarServiceImpl;
|
||||
import com.tencent.supersonic.common.util.ContextUtils;
|
||||
import com.tencent.supersonic.headless.api.pojo.SchemaElement;
|
||||
@@ -207,7 +207,7 @@ public class NL2SQLParser implements ChatQueryParser {
|
||||
}
|
||||
|
||||
private String rewriteErrorMessage(ChatLanguageModel chatLanguageModel, String userQuestion,
|
||||
String errMsg, List<SqlExemplar> similarExemplars,
|
||||
String errMsg, List<Text2SQLExemplar> similarExemplars,
|
||||
List<String> agentExamples) {
|
||||
Map<String, Object> variables = new HashMap<>();
|
||||
variables.put("user_question", userQuestion);
|
||||
@@ -276,7 +276,7 @@ public class NL2SQLParser implements ChatQueryParser {
|
||||
ExemplarServiceImpl exemplarManager = ContextUtils.getBean(ExemplarServiceImpl.class);
|
||||
EmbeddingConfig embeddingConfig = ContextUtils.getBean(EmbeddingConfig.class);
|
||||
String memoryCollectionName = embeddingConfig.getMemoryCollectionName(agentId);
|
||||
List<SqlExemplar> exemplars = exemplarManager.recallExemplars(memoryCollectionName,
|
||||
List<Text2SQLExemplar> exemplars = exemplarManager.recallExemplars(memoryCollectionName,
|
||||
queryNLReq.getQueryText(), 5);
|
||||
queryNLReq.getDynamicExemplars().addAll(exemplars);
|
||||
}
|
||||
|
||||
@@ -7,7 +7,7 @@ import com.tencent.supersonic.chat.server.persistence.dataobject.ChatQueryDO;
|
||||
import com.tencent.supersonic.chat.server.persistence.repository.ChatQueryRepository;
|
||||
import com.tencent.supersonic.chat.server.pojo.ParseContext;
|
||||
import com.tencent.supersonic.common.config.EmbeddingConfig;
|
||||
import com.tencent.supersonic.common.pojo.SqlExemplar;
|
||||
import com.tencent.supersonic.common.pojo.Text2SQLExemplar;
|
||||
import com.tencent.supersonic.common.service.ExemplarService;
|
||||
import com.tencent.supersonic.common.util.ContextUtils;
|
||||
import com.tencent.supersonic.headless.api.pojo.response.ParseResp;
|
||||
@@ -43,7 +43,7 @@ public class QueryRecommendProcessor implements ParseResultProcessor {
|
||||
ExemplarService exemplarService = ContextUtils.getBean(ExemplarService.class);
|
||||
EmbeddingConfig embeddingConfig = ContextUtils.getBean(EmbeddingConfig.class);
|
||||
String memoryCollectionName = embeddingConfig.getMemoryCollectionName(agentId);
|
||||
List<SqlExemplar> exemplars = exemplarService.recallExemplars(memoryCollectionName, queryText, 5);
|
||||
List<Text2SQLExemplar> exemplars = exemplarService.recallExemplars(memoryCollectionName, queryText, 5);
|
||||
return exemplars.stream().map(sqlExemplar ->
|
||||
SimilarQueryRecallResp.builder().queryText(sqlExemplar.getQuestion()).build())
|
||||
.collect(Collectors.toList());
|
||||
|
||||
@@ -12,7 +12,7 @@ import com.tencent.supersonic.chat.server.persistence.dataobject.ChatMemoryDO;
|
||||
import com.tencent.supersonic.chat.server.persistence.repository.ChatMemoryRepository;
|
||||
import com.tencent.supersonic.chat.server.service.MemoryService;
|
||||
import com.tencent.supersonic.common.config.EmbeddingConfig;
|
||||
import com.tencent.supersonic.common.pojo.SqlExemplar;
|
||||
import com.tencent.supersonic.common.pojo.Text2SQLExemplar;
|
||||
import com.tencent.supersonic.common.service.ExemplarService;
|
||||
import com.tencent.supersonic.common.util.BeanMapper;
|
||||
import java.util.List;
|
||||
@@ -100,7 +100,7 @@ public class MemoryServiceImpl implements MemoryService {
|
||||
public void enableMemory(ChatMemoryDO memory) {
|
||||
memory.setStatus(MemoryStatus.ENABLED);
|
||||
exemplarService.storeExemplar(embeddingConfig.getMemoryCollectionName(memory.getAgentId()),
|
||||
SqlExemplar.builder()
|
||||
Text2SQLExemplar.builder()
|
||||
.question(memory.getQuestion())
|
||||
.sideInfo(memory.getSideInfo())
|
||||
.dbSchema(memory.getDbSchema())
|
||||
@@ -112,7 +112,7 @@ public class MemoryServiceImpl implements MemoryService {
|
||||
public void disableMemory(ChatMemoryDO memory) {
|
||||
memory.setStatus(MemoryStatus.DISABLED);
|
||||
exemplarService.removeExemplar(embeddingConfig.getMemoryCollectionName(memory.getAgentId()),
|
||||
SqlExemplar.builder()
|
||||
Text2SQLExemplar.builder()
|
||||
.question(memory.getQuestion())
|
||||
.sideInfo(memory.getSideInfo())
|
||||
.dbSchema(memory.getDbSchema())
|
||||
|
||||
@@ -9,7 +9,7 @@ import lombok.NoArgsConstructor;
|
||||
@Builder
|
||||
@NoArgsConstructor
|
||||
@AllArgsConstructor
|
||||
public class SqlExemplar {
|
||||
public class Text2SQLExemplar {
|
||||
|
||||
public static final String PROPERTY_KEY = "sql_exemplar";
|
||||
|
||||
@@ -1,16 +1,16 @@
|
||||
package com.tencent.supersonic.common.service;
|
||||
|
||||
import com.tencent.supersonic.common.pojo.SqlExemplar;
|
||||
import com.tencent.supersonic.common.pojo.Text2SQLExemplar;
|
||||
|
||||
import java.util.List;
|
||||
|
||||
public interface ExemplarService {
|
||||
void storeExemplar(String collection, SqlExemplar exemplar);
|
||||
void storeExemplar(String collection, Text2SQLExemplar exemplar);
|
||||
|
||||
void removeExemplar(String collection, SqlExemplar exemplar);
|
||||
void removeExemplar(String collection, Text2SQLExemplar exemplar);
|
||||
|
||||
List<SqlExemplar> recallExemplars(String collection, String query, int num);
|
||||
List<Text2SQLExemplar> recallExemplars(String collection, String query, int num);
|
||||
|
||||
List<SqlExemplar> recallExemplars(String query, int num);
|
||||
List<Text2SQLExemplar> recallExemplars(String query, int num);
|
||||
|
||||
}
|
||||
|
||||
@@ -4,10 +4,10 @@ import com.fasterxml.jackson.core.type.TypeReference;
|
||||
import com.fasterxml.jackson.databind.ObjectMapper;
|
||||
import com.google.common.collect.Lists;
|
||||
import com.tencent.supersonic.common.config.EmbeddingConfig;
|
||||
import com.tencent.supersonic.common.pojo.Text2SQLExemplar;
|
||||
import com.tencent.supersonic.common.service.EmbeddingService;
|
||||
import com.tencent.supersonic.common.service.ExemplarService;
|
||||
import com.tencent.supersonic.common.util.JsonUtil;
|
||||
import com.tencent.supersonic.common.pojo.SqlExemplar;
|
||||
import dev.langchain4j.data.document.Metadata;
|
||||
import dev.langchain4j.data.segment.TextSegment;
|
||||
import dev.langchain4j.store.embedding.RetrieveQuery;
|
||||
@@ -31,7 +31,7 @@ public class ExemplarServiceImpl implements ExemplarService, CommandLineRunner {
|
||||
|
||||
private static final String SYS_EXEMPLAR_FILE = "s2-exemplar.json";
|
||||
|
||||
private TypeReference<List<SqlExemplar>> valueTypeRef = new TypeReference<List<SqlExemplar>>() {
|
||||
private TypeReference<List<Text2SQLExemplar>> valueTypeRef = new TypeReference<List<Text2SQLExemplar>>() {
|
||||
};
|
||||
|
||||
private final ObjectMapper objectMapper = JsonUtil.INSTANCE.getObjectMapper();
|
||||
@@ -42,7 +42,7 @@ public class ExemplarServiceImpl implements ExemplarService, CommandLineRunner {
|
||||
@Autowired
|
||||
private EmbeddingService embeddingService;
|
||||
|
||||
public void storeExemplar(String collection, SqlExemplar exemplar) {
|
||||
public void storeExemplar(String collection, Text2SQLExemplar exemplar) {
|
||||
Metadata metadata = Metadata.from(JsonUtil.toMap(JsonUtil.toString(exemplar),
|
||||
String.class, Object.class));
|
||||
TextSegment segment = TextSegment.from(exemplar.getQuestion(), metadata);
|
||||
@@ -51,7 +51,7 @@ public class ExemplarServiceImpl implements ExemplarService, CommandLineRunner {
|
||||
embeddingService.addQuery(collection, Lists.newArrayList(segment));
|
||||
}
|
||||
|
||||
public void removeExemplar(String collection, SqlExemplar exemplar) {
|
||||
public void removeExemplar(String collection, Text2SQLExemplar exemplar) {
|
||||
Metadata metadata = Metadata.from(JsonUtil.toMap(JsonUtil.toString(exemplar),
|
||||
String.class, Object.class));
|
||||
TextSegment segment = TextSegment.from(exemplar.getQuestion(), metadata);
|
||||
@@ -59,20 +59,20 @@ public class ExemplarServiceImpl implements ExemplarService, CommandLineRunner {
|
||||
embeddingService.deleteQuery(collection, Lists.newArrayList(segment));
|
||||
}
|
||||
|
||||
public List<SqlExemplar> recallExemplars(String query, int num) {
|
||||
public List<Text2SQLExemplar> recallExemplars(String query, int num) {
|
||||
String collection = embeddingConfig.getText2sqlCollectionName();
|
||||
return recallExemplars(collection, query, num);
|
||||
}
|
||||
|
||||
public List<SqlExemplar> recallExemplars(String collection, String query, int num) {
|
||||
List<SqlExemplar> exemplars = Lists.newArrayList();
|
||||
public List<Text2SQLExemplar> recallExemplars(String collection, String query, int num) {
|
||||
List<Text2SQLExemplar> exemplars = Lists.newArrayList();
|
||||
RetrieveQuery retrieveQuery = RetrieveQuery.builder()
|
||||
.queryTextsList(Lists.newArrayList(query))
|
||||
.build();
|
||||
List<RetrieveQueryResult> results = embeddingService.retrieveQuery(collection, retrieveQuery, num);
|
||||
results.stream().forEach(ret -> {
|
||||
ret.getRetrieval().stream().forEach(r -> {
|
||||
exemplars.add(JsonUtil.mapToObject(r.getMetadata(), SqlExemplar.class));
|
||||
exemplars.add(JsonUtil.mapToObject(r.getMetadata(), Text2SQLExemplar.class));
|
||||
});
|
||||
});
|
||||
|
||||
@@ -91,7 +91,7 @@ public class ExemplarServiceImpl implements ExemplarService, CommandLineRunner {
|
||||
private void loadSysExemplars() throws IOException {
|
||||
ClassPathResource resource = new ClassPathResource(SYS_EXEMPLAR_FILE);
|
||||
InputStream inputStream = resource.getInputStream();
|
||||
List<SqlExemplar> exemplars = objectMapper.readValue(inputStream, valueTypeRef);
|
||||
List<Text2SQLExemplar> exemplars = objectMapper.readValue(inputStream, valueTypeRef);
|
||||
String collection = embeddingConfig.getText2sqlCollectionName();
|
||||
exemplars.stream().forEach(e -> storeExemplar(collection, e));
|
||||
}
|
||||
|
||||
@@ -5,7 +5,7 @@ import com.google.common.collect.Sets;
|
||||
import com.tencent.supersonic.auth.api.authentication.pojo.User;
|
||||
import com.tencent.supersonic.common.config.PromptConfig;
|
||||
import com.tencent.supersonic.common.pojo.ChatModelConfig;
|
||||
import com.tencent.supersonic.common.pojo.SqlExemplar;
|
||||
import com.tencent.supersonic.common.pojo.Text2SQLExemplar;
|
||||
import com.tencent.supersonic.common.pojo.enums.Text2SQLType;
|
||||
import com.tencent.supersonic.headless.api.pojo.QueryDataType;
|
||||
import com.tencent.supersonic.headless.api.pojo.SchemaMapInfo;
|
||||
@@ -29,6 +29,6 @@ public class QueryNLReq {
|
||||
private QueryDataType queryDataType = QueryDataType.ALL;
|
||||
private ChatModelConfig modelConfig;
|
||||
private PromptConfig promptConfig;
|
||||
private List<SqlExemplar> dynamicExemplars = Lists.newArrayList();
|
||||
private List<Text2SQLExemplar> dynamicExemplars = Lists.newArrayList();
|
||||
private SemanticParseInfo contextParseInfo;
|
||||
}
|
||||
|
||||
@@ -4,7 +4,7 @@ import com.fasterxml.jackson.annotation.JsonIgnore;
|
||||
import com.tencent.supersonic.auth.api.authentication.pojo.User;
|
||||
import com.tencent.supersonic.common.config.PromptConfig;
|
||||
import com.tencent.supersonic.common.pojo.ChatModelConfig;
|
||||
import com.tencent.supersonic.common.pojo.SqlExemplar;
|
||||
import com.tencent.supersonic.common.pojo.Text2SQLExemplar;
|
||||
import com.tencent.supersonic.common.pojo.enums.Text2SQLType;
|
||||
import com.tencent.supersonic.common.util.ContextUtils;
|
||||
import com.tencent.supersonic.headless.api.pojo.QueryDataType;
|
||||
@@ -52,7 +52,7 @@ public class ChatQueryContext {
|
||||
private QueryDataType queryDataType = QueryDataType.ALL;
|
||||
private ChatModelConfig modelConfig;
|
||||
private PromptConfig promptConfig;
|
||||
private List<SqlExemplar> dynamicExemplars;
|
||||
private List<Text2SQLExemplar> dynamicExemplars;
|
||||
private SemanticParseInfo contextParseInfo;
|
||||
|
||||
public List<SemanticQuery> getCandidateQueries() {
|
||||
|
||||
@@ -2,7 +2,7 @@ package com.tencent.supersonic.headless.chat.parser.llm;
|
||||
|
||||
import com.tencent.supersonic.common.pojo.Constants;
|
||||
import com.tencent.supersonic.common.jsqlparser.SqlValidHelper;
|
||||
import com.tencent.supersonic.common.pojo.SqlExemplar;
|
||||
import com.tencent.supersonic.common.pojo.Text2SQLExemplar;
|
||||
import com.tencent.supersonic.headless.api.pojo.SemanticParseInfo;
|
||||
import com.tencent.supersonic.headless.chat.query.QueryManager;
|
||||
import com.tencent.supersonic.headless.chat.query.llm.LLMSemanticQuery;
|
||||
@@ -36,13 +36,13 @@ public class LLMResponseService {
|
||||
Map<String, Object> properties = new HashMap<>();
|
||||
properties.put(Constants.CONTEXT, parseResult);
|
||||
properties.put("type", "internal");
|
||||
SqlExemplar exemplar = SqlExemplar.builder()
|
||||
Text2SQLExemplar exemplar = Text2SQLExemplar.builder()
|
||||
.question(queryCtx.getQueryText())
|
||||
.sideInfo(parseResult.getLlmResp().getSideInfo())
|
||||
.dbSchema(parseResult.getLlmResp().getSchema())
|
||||
.sql(parseResult.getLlmResp().getSqlOutput())
|
||||
.build();
|
||||
properties.put(SqlExemplar.PROPERTY_KEY, exemplar);
|
||||
properties.put(Text2SQLExemplar.PROPERTY_KEY, exemplar);
|
||||
parseInfo.setProperties(properties);
|
||||
parseInfo.setScore(queryCtx.getQueryText().length() * (1 + weight));
|
||||
parseInfo.setQueryMode(semanticQuery.getQueryMode());
|
||||
|
||||
@@ -2,7 +2,7 @@ package com.tencent.supersonic.headless.chat.parser.llm;
|
||||
|
||||
import com.google.common.collect.Lists;
|
||||
import com.tencent.supersonic.common.config.PromptConfig;
|
||||
import com.tencent.supersonic.common.pojo.SqlExemplar;
|
||||
import com.tencent.supersonic.common.pojo.Text2SQLExemplar;
|
||||
import com.tencent.supersonic.headless.chat.query.llm.s2sql.LLMReq;
|
||||
import com.tencent.supersonic.headless.chat.query.llm.s2sql.LLMResp;
|
||||
import dev.langchain4j.data.message.AiMessage;
|
||||
@@ -45,11 +45,11 @@ public class OnePassSCSqlGenStrategy extends SqlGenStrategy {
|
||||
llmResp.setQuery(llmReq.getQueryText());
|
||||
//1.recall exemplars
|
||||
keyPipelineLog.info("OnePassSCSqlGenStrategy llmReq:\n{}", llmReq);
|
||||
List<List<SqlExemplar>> exemplarsList = promptHelper.getFewShotExemplars(llmReq);
|
||||
List<List<Text2SQLExemplar>> exemplarsList = promptHelper.getFewShotExemplars(llmReq);
|
||||
|
||||
//2.generate sql generation prompt for each self-consistency inference
|
||||
Map<Prompt, List<SqlExemplar>> prompt2Exemplar = new HashMap<>();
|
||||
for (List<SqlExemplar> exemplars : exemplarsList) {
|
||||
Map<Prompt, List<Text2SQLExemplar>> prompt2Exemplar = new HashMap<>();
|
||||
for (List<Text2SQLExemplar> exemplars : exemplarsList) {
|
||||
llmReq.setDynamicExemplars(exemplars);
|
||||
Prompt prompt = generatePrompt(llmReq, llmResp);
|
||||
prompt2Exemplar.put(prompt, exemplars);
|
||||
@@ -61,9 +61,9 @@ public class OnePassSCSqlGenStrategy extends SqlGenStrategy {
|
||||
keyPipelineLog.info("OnePassSCSqlGenStrategy reqPrompt:\n{}", prompt.toUserMessage());
|
||||
ChatLanguageModel chatLanguageModel = getChatLanguageModel(llmReq.getModelConfig());
|
||||
Response<AiMessage> response = chatLanguageModel.generate(prompt.toUserMessage());
|
||||
String result = response.content().text();
|
||||
output2Prompt.put(result, prompt);
|
||||
keyPipelineLog.info("OnePassSCSqlGenStrategy modelResp:\n{}", result);
|
||||
String sqlOutput = StringUtils.normalizeSpace(response.content().text());
|
||||
output2Prompt.put(sqlOutput, prompt);
|
||||
keyPipelineLog.info("OnePassSCSqlGenStrategy modelResp:\n{}", sqlOutput);
|
||||
}
|
||||
);
|
||||
|
||||
@@ -71,7 +71,7 @@ public class OnePassSCSqlGenStrategy extends SqlGenStrategy {
|
||||
Pair<String, Map<String, Double>> sqlMapPair = ResponseHelper.selfConsistencyVote(
|
||||
Lists.newArrayList(output2Prompt.keySet()));
|
||||
llmResp.setSqlOutput(sqlMapPair.getLeft());
|
||||
List<SqlExemplar> usedExemplars = prompt2Exemplar.get(output2Prompt.get(sqlMapPair.getLeft()));
|
||||
List<Text2SQLExemplar> usedExemplars = prompt2Exemplar.get(output2Prompt.get(sqlMapPair.getLeft()));
|
||||
llmResp.setSqlRespMap(ResponseHelper.buildSqlRespMap(usedExemplars, sqlMapPair.getRight()));
|
||||
|
||||
return llmResp;
|
||||
@@ -79,7 +79,7 @@ public class OnePassSCSqlGenStrategy extends SqlGenStrategy {
|
||||
|
||||
private Prompt generatePrompt(LLMReq llmReq, LLMResp llmResp) {
|
||||
StringBuilder exemplars = new StringBuilder();
|
||||
for (SqlExemplar exemplar : llmReq.getDynamicExemplars()) {
|
||||
for (Text2SQLExemplar exemplar : llmReq.getDynamicExemplars()) {
|
||||
String exemplarStr = String.format("#Question:%s #Schema:%s #SideInfo:%s #SQL:%s\n",
|
||||
exemplar.getQuestion(), exemplar.getDbSchema(),
|
||||
exemplar.getSideInfo(), exemplar.getSql());
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
package com.tencent.supersonic.headless.chat.parser.llm;
|
||||
|
||||
import com.google.common.collect.Lists;
|
||||
import com.tencent.supersonic.common.pojo.SqlExemplar;
|
||||
import com.tencent.supersonic.common.pojo.Text2SQLExemplar;
|
||||
import com.tencent.supersonic.common.service.ExemplarService;
|
||||
import com.tencent.supersonic.headless.chat.parser.ParserConfig;
|
||||
import com.tencent.supersonic.headless.chat.query.llm.s2sql.LLMReq;
|
||||
@@ -29,12 +29,12 @@ public class PromptHelper {
|
||||
@Autowired
|
||||
private ExemplarService exemplarService;
|
||||
|
||||
public List<List<SqlExemplar>> getFewShotExemplars(LLMReq llmReq) {
|
||||
public List<List<Text2SQLExemplar>> getFewShotExemplars(LLMReq llmReq) {
|
||||
int exemplarRecallNumber = Integer.valueOf(parserConfig.getParameterValue(PARSER_EXEMPLAR_RECALL_NUMBER));
|
||||
int fewShotNumber = Integer.valueOf(parserConfig.getParameterValue(PARSER_FEW_SHOT_NUMBER));
|
||||
int selfConsistencyNumber = Integer.valueOf(parserConfig.getParameterValue(PARSER_SELF_CONSISTENCY_NUMBER));
|
||||
|
||||
List<SqlExemplar> exemplars = Lists.newArrayList();
|
||||
List<Text2SQLExemplar> exemplars = Lists.newArrayList();
|
||||
llmReq.getDynamicExemplars().stream().forEach(e -> {
|
||||
exemplars.add(e);
|
||||
});
|
||||
@@ -44,10 +44,10 @@ public class PromptHelper {
|
||||
exemplars.addAll(exemplarService.recallExemplars(llmReq.getQueryText(), recallSize));
|
||||
}
|
||||
|
||||
List<List<SqlExemplar>> results = new ArrayList<>();
|
||||
List<List<Text2SQLExemplar>> results = new ArrayList<>();
|
||||
// use random collection of exemplars for each self-consistency inference
|
||||
for (int i = 0; i < selfConsistencyNumber; i++) {
|
||||
List<SqlExemplar> shuffledList = new ArrayList<>(exemplars);
|
||||
List<Text2SQLExemplar> shuffledList = new ArrayList<>(exemplars);
|
||||
Collections.shuffle(shuffledList);
|
||||
results.add(shuffledList.subList(0, fewShotNumber));
|
||||
}
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
package com.tencent.supersonic.headless.chat.parser.llm;
|
||||
|
||||
import com.tencent.supersonic.common.pojo.SqlExemplar;
|
||||
import com.tencent.supersonic.common.pojo.Text2SQLExemplar;
|
||||
import com.tencent.supersonic.headless.chat.query.llm.s2sql.LLMSqlResp;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.apache.commons.lang3.tuple.Pair;
|
||||
@@ -54,7 +54,7 @@ public class ResponseHelper {
|
||||
return Pair.of(inputMax, votePercentage);
|
||||
}
|
||||
|
||||
public static Map<String, LLMSqlResp> buildSqlRespMap(List<SqlExemplar> sqlExamples,
|
||||
public static Map<String, LLMSqlResp> buildSqlRespMap(List<Text2SQLExemplar> sqlExamples,
|
||||
Map<String, Double> sqlMap) {
|
||||
if (sqlMap == null) {
|
||||
return new HashMap<>();
|
||||
|
||||
@@ -4,7 +4,7 @@ import com.fasterxml.jackson.annotation.JsonValue;
|
||||
import com.google.common.collect.Lists;
|
||||
import com.tencent.supersonic.common.config.PromptConfig;
|
||||
import com.tencent.supersonic.common.pojo.ChatModelConfig;
|
||||
import com.tencent.supersonic.common.pojo.SqlExemplar;
|
||||
import com.tencent.supersonic.common.pojo.Text2SQLExemplar;
|
||||
import com.tencent.supersonic.headless.api.pojo.SchemaElement;
|
||||
import lombok.Data;
|
||||
|
||||
@@ -30,7 +30,7 @@ public class LLMReq {
|
||||
private ChatModelConfig modelConfig;
|
||||
private PromptConfig promptConfig;
|
||||
|
||||
private List<SqlExemplar> dynamicExemplars;
|
||||
private List<Text2SQLExemplar> dynamicExemplars;
|
||||
|
||||
|
||||
@Data
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
package com.tencent.supersonic.headless.chat.query.llm.s2sql;
|
||||
|
||||
import com.tencent.supersonic.common.pojo.SqlExemplar;
|
||||
import com.tencent.supersonic.common.pojo.Text2SQLExemplar;
|
||||
import lombok.AllArgsConstructor;
|
||||
import lombok.Builder;
|
||||
import lombok.Data;
|
||||
@@ -16,6 +16,6 @@ public class LLMSqlResp {
|
||||
|
||||
private double sqlWeight;
|
||||
|
||||
private List<SqlExemplar> fewShots;
|
||||
private List<Text2SQLExemplar> fewShots;
|
||||
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user