From 2425067091d7d2a6a2b16ac2afd135a33edcba45 Mon Sep 17 00:00:00 2001 From: jerryjzhang Date: Thu, 18 Jul 2024 14:19:56 +0800 Subject: [PATCH] (improvement)(common)Rename `SqlExemplar` to `Text2SQLExemplar`. --- .../chat/server/executor/SqlExecutor.java | 6 +++--- .../chat/server/memory/MemoryReviewTask.java | 7 ++++--- .../chat/server/parser/NL2SQLParser.java | 6 +++--- .../parse/QueryRecommendProcessor.java | 4 ++-- .../server/service/impl/MemoryServiceImpl.java | 6 +++--- ...{SqlExemplar.java => Text2SQLExemplar.java} | 2 +- .../common/service/ExemplarService.java | 10 +++++----- .../service/impl/ExemplarServiceImpl.java | 18 +++++++++--------- .../headless/api/pojo/request/QueryNLReq.java | 4 ++-- .../headless/chat/ChatQueryContext.java | 4 ++-- .../chat/parser/llm/LLMResponseService.java | 6 +++--- .../parser/llm/OnePassSCSqlGenStrategy.java | 18 +++++++++--------- .../headless/chat/parser/llm/PromptHelper.java | 10 +++++----- .../chat/parser/llm/ResponseHelper.java | 4 ++-- .../headless/chat/query/llm/s2sql/LLMReq.java | 4 ++-- .../chat/query/llm/s2sql/LLMSqlResp.java | 4 ++-- 16 files changed, 57 insertions(+), 56 deletions(-) rename common/src/main/java/com/tencent/supersonic/common/pojo/{SqlExemplar.java => Text2SQLExemplar.java} (92%) diff --git a/chat/server/src/main/java/com/tencent/supersonic/chat/server/executor/SqlExecutor.java b/chat/server/src/main/java/com/tencent/supersonic/chat/server/executor/SqlExecutor.java index 1a69171d8..c7d38123b 100644 --- a/chat/server/src/main/java/com/tencent/supersonic/chat/server/executor/SqlExecutor.java +++ b/chat/server/src/main/java/com/tencent/supersonic/chat/server/executor/SqlExecutor.java @@ -6,7 +6,7 @@ import com.tencent.supersonic.chat.server.pojo.ExecuteContext; import com.tencent.supersonic.chat.server.service.MemoryService; import com.tencent.supersonic.chat.server.util.ResultFormatter; import com.tencent.supersonic.common.pojo.QueryColumn; -import com.tencent.supersonic.common.pojo.SqlExemplar; +import com.tencent.supersonic.common.pojo.Text2SQLExemplar; import com.tencent.supersonic.common.util.ContextUtils; import com.tencent.supersonic.common.util.JsonUtil; import com.tencent.supersonic.headless.api.pojo.SemanticParseInfo; @@ -41,9 +41,9 @@ public class SqlExecutor implements ChatQueryExecutor { if (queryResult.getQueryState().equals(QueryState.SUCCESS) && queryResult.getQueryMode().equals(LLMSqlQuery.QUERY_MODE)) { - SqlExemplar exemplar = JsonUtil.toObject(JsonUtil.toString( + Text2SQLExemplar exemplar = JsonUtil.toObject(JsonUtil.toString( executeContext.getParseInfo().getProperties() - .get(SqlExemplar.PROPERTY_KEY)), SqlExemplar.class); + .get(Text2SQLExemplar.PROPERTY_KEY)), Text2SQLExemplar.class); MemoryService memoryService = ContextUtils.getBean(MemoryService.class); memoryService.createMemory(ChatMemoryDO.builder() diff --git a/chat/server/src/main/java/com/tencent/supersonic/chat/server/memory/MemoryReviewTask.java b/chat/server/src/main/java/com/tencent/supersonic/chat/server/memory/MemoryReviewTask.java index cbedf8991..326201424 100644 --- a/chat/server/src/main/java/com/tencent/supersonic/chat/server/memory/MemoryReviewTask.java +++ b/chat/server/src/main/java/com/tencent/supersonic/chat/server/memory/MemoryReviewTask.java @@ -32,7 +32,8 @@ public class MemoryReviewTask { + "please take a review and give your opinion.\n" + "#Rules: " + "1.ALWAYS follow the output format: `opinion=(POSITIVE|NEGATIVE),comment=(your comment)`." - + "2.ALWAYS recognize `数据日期` as the date field.\n" + + "2.ALWAYS recognize `数据日期` as the date field." + + "3.IGNORE `数据日期` if not expressed in the `Question`." + "#Question: %s\n" + "#Schema: %s\n" + "#SideInfo: %s\n" @@ -57,12 +58,12 @@ public class MemoryReviewTask { m.getSideInfo(), m.getS2sql()); Prompt prompt = PromptTemplate.from(promptStr).apply(Collections.EMPTY_MAP); - keyPipelineLog.info("MemoryReviewTask reqPrompt:{}", promptStr); + keyPipelineLog.info("MemoryReviewTask reqPrompt:\n{}", promptStr); ChatLanguageModel chatLanguageModel = ModelProvider.getChatModel( chatAgent.getModelConfig()); if (Objects.nonNull(chatLanguageModel)) { String response = chatLanguageModel.generate(prompt.toUserMessage()).content().text(); - keyPipelineLog.info("MemoryReviewTask modelResp:{}", response); + keyPipelineLog.info("MemoryReviewTask modelResp:\n{}", response); Matcher matcher = OUTPUT_PATTERN.matcher(response); if (matcher.find()) { diff --git a/chat/server/src/main/java/com/tencent/supersonic/chat/server/parser/NL2SQLParser.java b/chat/server/src/main/java/com/tencent/supersonic/chat/server/parser/NL2SQLParser.java index 1f5d60b36..1b5273f77 100644 --- a/chat/server/src/main/java/com/tencent/supersonic/chat/server/parser/NL2SQLParser.java +++ b/chat/server/src/main/java/com/tencent/supersonic/chat/server/parser/NL2SQLParser.java @@ -7,7 +7,7 @@ import com.tencent.supersonic.chat.server.pojo.ParseContext; import com.tencent.supersonic.chat.server.service.ChatContextService; import com.tencent.supersonic.chat.server.util.QueryReqConverter; import com.tencent.supersonic.common.config.EmbeddingConfig; -import com.tencent.supersonic.common.pojo.SqlExemplar; +import com.tencent.supersonic.common.pojo.Text2SQLExemplar; import com.tencent.supersonic.common.service.impl.ExemplarServiceImpl; import com.tencent.supersonic.common.util.ContextUtils; import com.tencent.supersonic.headless.api.pojo.SchemaElement; @@ -207,7 +207,7 @@ public class NL2SQLParser implements ChatQueryParser { } private String rewriteErrorMessage(ChatLanguageModel chatLanguageModel, String userQuestion, - String errMsg, List similarExemplars, + String errMsg, List similarExemplars, List agentExamples) { Map variables = new HashMap<>(); variables.put("user_question", userQuestion); @@ -276,7 +276,7 @@ public class NL2SQLParser implements ChatQueryParser { ExemplarServiceImpl exemplarManager = ContextUtils.getBean(ExemplarServiceImpl.class); EmbeddingConfig embeddingConfig = ContextUtils.getBean(EmbeddingConfig.class); String memoryCollectionName = embeddingConfig.getMemoryCollectionName(agentId); - List exemplars = exemplarManager.recallExemplars(memoryCollectionName, + List exemplars = exemplarManager.recallExemplars(memoryCollectionName, queryNLReq.getQueryText(), 5); queryNLReq.getDynamicExemplars().addAll(exemplars); } diff --git a/chat/server/src/main/java/com/tencent/supersonic/chat/server/processor/parse/QueryRecommendProcessor.java b/chat/server/src/main/java/com/tencent/supersonic/chat/server/processor/parse/QueryRecommendProcessor.java index e04921d50..b8d155ad6 100644 --- a/chat/server/src/main/java/com/tencent/supersonic/chat/server/processor/parse/QueryRecommendProcessor.java +++ b/chat/server/src/main/java/com/tencent/supersonic/chat/server/processor/parse/QueryRecommendProcessor.java @@ -7,7 +7,7 @@ import com.tencent.supersonic.chat.server.persistence.dataobject.ChatQueryDO; import com.tencent.supersonic.chat.server.persistence.repository.ChatQueryRepository; import com.tencent.supersonic.chat.server.pojo.ParseContext; import com.tencent.supersonic.common.config.EmbeddingConfig; -import com.tencent.supersonic.common.pojo.SqlExemplar; +import com.tencent.supersonic.common.pojo.Text2SQLExemplar; import com.tencent.supersonic.common.service.ExemplarService; import com.tencent.supersonic.common.util.ContextUtils; import com.tencent.supersonic.headless.api.pojo.response.ParseResp; @@ -43,7 +43,7 @@ public class QueryRecommendProcessor implements ParseResultProcessor { ExemplarService exemplarService = ContextUtils.getBean(ExemplarService.class); EmbeddingConfig embeddingConfig = ContextUtils.getBean(EmbeddingConfig.class); String memoryCollectionName = embeddingConfig.getMemoryCollectionName(agentId); - List exemplars = exemplarService.recallExemplars(memoryCollectionName, queryText, 5); + List exemplars = exemplarService.recallExemplars(memoryCollectionName, queryText, 5); return exemplars.stream().map(sqlExemplar -> SimilarQueryRecallResp.builder().queryText(sqlExemplar.getQuestion()).build()) .collect(Collectors.toList()); diff --git a/chat/server/src/main/java/com/tencent/supersonic/chat/server/service/impl/MemoryServiceImpl.java b/chat/server/src/main/java/com/tencent/supersonic/chat/server/service/impl/MemoryServiceImpl.java index e5972fc96..182a7f4e4 100644 --- a/chat/server/src/main/java/com/tencent/supersonic/chat/server/service/impl/MemoryServiceImpl.java +++ b/chat/server/src/main/java/com/tencent/supersonic/chat/server/service/impl/MemoryServiceImpl.java @@ -12,7 +12,7 @@ import com.tencent.supersonic.chat.server.persistence.dataobject.ChatMemoryDO; import com.tencent.supersonic.chat.server.persistence.repository.ChatMemoryRepository; import com.tencent.supersonic.chat.server.service.MemoryService; import com.tencent.supersonic.common.config.EmbeddingConfig; -import com.tencent.supersonic.common.pojo.SqlExemplar; +import com.tencent.supersonic.common.pojo.Text2SQLExemplar; import com.tencent.supersonic.common.service.ExemplarService; import com.tencent.supersonic.common.util.BeanMapper; import java.util.List; @@ -100,7 +100,7 @@ public class MemoryServiceImpl implements MemoryService { public void enableMemory(ChatMemoryDO memory) { memory.setStatus(MemoryStatus.ENABLED); exemplarService.storeExemplar(embeddingConfig.getMemoryCollectionName(memory.getAgentId()), - SqlExemplar.builder() + Text2SQLExemplar.builder() .question(memory.getQuestion()) .sideInfo(memory.getSideInfo()) .dbSchema(memory.getDbSchema()) @@ -112,7 +112,7 @@ public class MemoryServiceImpl implements MemoryService { public void disableMemory(ChatMemoryDO memory) { memory.setStatus(MemoryStatus.DISABLED); exemplarService.removeExemplar(embeddingConfig.getMemoryCollectionName(memory.getAgentId()), - SqlExemplar.builder() + Text2SQLExemplar.builder() .question(memory.getQuestion()) .sideInfo(memory.getSideInfo()) .dbSchema(memory.getDbSchema()) diff --git a/common/src/main/java/com/tencent/supersonic/common/pojo/SqlExemplar.java b/common/src/main/java/com/tencent/supersonic/common/pojo/Text2SQLExemplar.java similarity index 92% rename from common/src/main/java/com/tencent/supersonic/common/pojo/SqlExemplar.java rename to common/src/main/java/com/tencent/supersonic/common/pojo/Text2SQLExemplar.java index f6bf2f162..5fe7b27da 100644 --- a/common/src/main/java/com/tencent/supersonic/common/pojo/SqlExemplar.java +++ b/common/src/main/java/com/tencent/supersonic/common/pojo/Text2SQLExemplar.java @@ -9,7 +9,7 @@ import lombok.NoArgsConstructor; @Builder @NoArgsConstructor @AllArgsConstructor -public class SqlExemplar { +public class Text2SQLExemplar { public static final String PROPERTY_KEY = "sql_exemplar"; diff --git a/common/src/main/java/com/tencent/supersonic/common/service/ExemplarService.java b/common/src/main/java/com/tencent/supersonic/common/service/ExemplarService.java index e7ae8554b..69a21fe16 100644 --- a/common/src/main/java/com/tencent/supersonic/common/service/ExemplarService.java +++ b/common/src/main/java/com/tencent/supersonic/common/service/ExemplarService.java @@ -1,16 +1,16 @@ package com.tencent.supersonic.common.service; -import com.tencent.supersonic.common.pojo.SqlExemplar; +import com.tencent.supersonic.common.pojo.Text2SQLExemplar; import java.util.List; public interface ExemplarService { - void storeExemplar(String collection, SqlExemplar exemplar); + void storeExemplar(String collection, Text2SQLExemplar exemplar); - void removeExemplar(String collection, SqlExemplar exemplar); + void removeExemplar(String collection, Text2SQLExemplar exemplar); - List recallExemplars(String collection, String query, int num); + List recallExemplars(String collection, String query, int num); - List recallExemplars(String query, int num); + List recallExemplars(String query, int num); } diff --git a/common/src/main/java/com/tencent/supersonic/common/service/impl/ExemplarServiceImpl.java b/common/src/main/java/com/tencent/supersonic/common/service/impl/ExemplarServiceImpl.java index e8adf9b1b..5bac8bf5e 100644 --- a/common/src/main/java/com/tencent/supersonic/common/service/impl/ExemplarServiceImpl.java +++ b/common/src/main/java/com/tencent/supersonic/common/service/impl/ExemplarServiceImpl.java @@ -4,10 +4,10 @@ import com.fasterxml.jackson.core.type.TypeReference; import com.fasterxml.jackson.databind.ObjectMapper; import com.google.common.collect.Lists; import com.tencent.supersonic.common.config.EmbeddingConfig; +import com.tencent.supersonic.common.pojo.Text2SQLExemplar; import com.tencent.supersonic.common.service.EmbeddingService; import com.tencent.supersonic.common.service.ExemplarService; import com.tencent.supersonic.common.util.JsonUtil; -import com.tencent.supersonic.common.pojo.SqlExemplar; import dev.langchain4j.data.document.Metadata; import dev.langchain4j.data.segment.TextSegment; import dev.langchain4j.store.embedding.RetrieveQuery; @@ -31,7 +31,7 @@ public class ExemplarServiceImpl implements ExemplarService, CommandLineRunner { private static final String SYS_EXEMPLAR_FILE = "s2-exemplar.json"; - private TypeReference> valueTypeRef = new TypeReference>() { + private TypeReference> valueTypeRef = new TypeReference>() { }; private final ObjectMapper objectMapper = JsonUtil.INSTANCE.getObjectMapper(); @@ -42,7 +42,7 @@ public class ExemplarServiceImpl implements ExemplarService, CommandLineRunner { @Autowired private EmbeddingService embeddingService; - public void storeExemplar(String collection, SqlExemplar exemplar) { + public void storeExemplar(String collection, Text2SQLExemplar exemplar) { Metadata metadata = Metadata.from(JsonUtil.toMap(JsonUtil.toString(exemplar), String.class, Object.class)); TextSegment segment = TextSegment.from(exemplar.getQuestion(), metadata); @@ -51,7 +51,7 @@ public class ExemplarServiceImpl implements ExemplarService, CommandLineRunner { embeddingService.addQuery(collection, Lists.newArrayList(segment)); } - public void removeExemplar(String collection, SqlExemplar exemplar) { + public void removeExemplar(String collection, Text2SQLExemplar exemplar) { Metadata metadata = Metadata.from(JsonUtil.toMap(JsonUtil.toString(exemplar), String.class, Object.class)); TextSegment segment = TextSegment.from(exemplar.getQuestion(), metadata); @@ -59,20 +59,20 @@ public class ExemplarServiceImpl implements ExemplarService, CommandLineRunner { embeddingService.deleteQuery(collection, Lists.newArrayList(segment)); } - public List recallExemplars(String query, int num) { + public List recallExemplars(String query, int num) { String collection = embeddingConfig.getText2sqlCollectionName(); return recallExemplars(collection, query, num); } - public List recallExemplars(String collection, String query, int num) { - List exemplars = Lists.newArrayList(); + public List recallExemplars(String collection, String query, int num) { + List exemplars = Lists.newArrayList(); RetrieveQuery retrieveQuery = RetrieveQuery.builder() .queryTextsList(Lists.newArrayList(query)) .build(); List results = embeddingService.retrieveQuery(collection, retrieveQuery, num); results.stream().forEach(ret -> { ret.getRetrieval().stream().forEach(r -> { - exemplars.add(JsonUtil.mapToObject(r.getMetadata(), SqlExemplar.class)); + exemplars.add(JsonUtil.mapToObject(r.getMetadata(), Text2SQLExemplar.class)); }); }); @@ -91,7 +91,7 @@ public class ExemplarServiceImpl implements ExemplarService, CommandLineRunner { private void loadSysExemplars() throws IOException { ClassPathResource resource = new ClassPathResource(SYS_EXEMPLAR_FILE); InputStream inputStream = resource.getInputStream(); - List exemplars = objectMapper.readValue(inputStream, valueTypeRef); + List exemplars = objectMapper.readValue(inputStream, valueTypeRef); String collection = embeddingConfig.getText2sqlCollectionName(); exemplars.stream().forEach(e -> storeExemplar(collection, e)); } diff --git a/headless/api/src/main/java/com/tencent/supersonic/headless/api/pojo/request/QueryNLReq.java b/headless/api/src/main/java/com/tencent/supersonic/headless/api/pojo/request/QueryNLReq.java index 557fd03f3..353a03c0c 100644 --- a/headless/api/src/main/java/com/tencent/supersonic/headless/api/pojo/request/QueryNLReq.java +++ b/headless/api/src/main/java/com/tencent/supersonic/headless/api/pojo/request/QueryNLReq.java @@ -5,7 +5,7 @@ import com.google.common.collect.Sets; import com.tencent.supersonic.auth.api.authentication.pojo.User; import com.tencent.supersonic.common.config.PromptConfig; import com.tencent.supersonic.common.pojo.ChatModelConfig; -import com.tencent.supersonic.common.pojo.SqlExemplar; +import com.tencent.supersonic.common.pojo.Text2SQLExemplar; import com.tencent.supersonic.common.pojo.enums.Text2SQLType; import com.tencent.supersonic.headless.api.pojo.QueryDataType; import com.tencent.supersonic.headless.api.pojo.SchemaMapInfo; @@ -29,6 +29,6 @@ public class QueryNLReq { private QueryDataType queryDataType = QueryDataType.ALL; private ChatModelConfig modelConfig; private PromptConfig promptConfig; - private List dynamicExemplars = Lists.newArrayList(); + private List dynamicExemplars = Lists.newArrayList(); private SemanticParseInfo contextParseInfo; } diff --git a/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/ChatQueryContext.java b/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/ChatQueryContext.java index c35abc0c0..43932745d 100644 --- a/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/ChatQueryContext.java +++ b/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/ChatQueryContext.java @@ -4,7 +4,7 @@ import com.fasterxml.jackson.annotation.JsonIgnore; import com.tencent.supersonic.auth.api.authentication.pojo.User; import com.tencent.supersonic.common.config.PromptConfig; import com.tencent.supersonic.common.pojo.ChatModelConfig; -import com.tencent.supersonic.common.pojo.SqlExemplar; +import com.tencent.supersonic.common.pojo.Text2SQLExemplar; import com.tencent.supersonic.common.pojo.enums.Text2SQLType; import com.tencent.supersonic.common.util.ContextUtils; import com.tencent.supersonic.headless.api.pojo.QueryDataType; @@ -52,7 +52,7 @@ public class ChatQueryContext { private QueryDataType queryDataType = QueryDataType.ALL; private ChatModelConfig modelConfig; private PromptConfig promptConfig; - private List dynamicExemplars; + private List dynamicExemplars; private SemanticParseInfo contextParseInfo; public List getCandidateQueries() { diff --git a/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/parser/llm/LLMResponseService.java b/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/parser/llm/LLMResponseService.java index 7cf6b17d5..03dc92a6b 100644 --- a/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/parser/llm/LLMResponseService.java +++ b/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/parser/llm/LLMResponseService.java @@ -2,7 +2,7 @@ package com.tencent.supersonic.headless.chat.parser.llm; import com.tencent.supersonic.common.pojo.Constants; import com.tencent.supersonic.common.jsqlparser.SqlValidHelper; -import com.tencent.supersonic.common.pojo.SqlExemplar; +import com.tencent.supersonic.common.pojo.Text2SQLExemplar; import com.tencent.supersonic.headless.api.pojo.SemanticParseInfo; import com.tencent.supersonic.headless.chat.query.QueryManager; import com.tencent.supersonic.headless.chat.query.llm.LLMSemanticQuery; @@ -36,13 +36,13 @@ public class LLMResponseService { Map properties = new HashMap<>(); properties.put(Constants.CONTEXT, parseResult); properties.put("type", "internal"); - SqlExemplar exemplar = SqlExemplar.builder() + Text2SQLExemplar exemplar = Text2SQLExemplar.builder() .question(queryCtx.getQueryText()) .sideInfo(parseResult.getLlmResp().getSideInfo()) .dbSchema(parseResult.getLlmResp().getSchema()) .sql(parseResult.getLlmResp().getSqlOutput()) .build(); - properties.put(SqlExemplar.PROPERTY_KEY, exemplar); + properties.put(Text2SQLExemplar.PROPERTY_KEY, exemplar); parseInfo.setProperties(properties); parseInfo.setScore(queryCtx.getQueryText().length() * (1 + weight)); parseInfo.setQueryMode(semanticQuery.getQueryMode()); diff --git a/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/parser/llm/OnePassSCSqlGenStrategy.java b/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/parser/llm/OnePassSCSqlGenStrategy.java index e8063350c..b1ad55329 100644 --- a/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/parser/llm/OnePassSCSqlGenStrategy.java +++ b/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/parser/llm/OnePassSCSqlGenStrategy.java @@ -2,7 +2,7 @@ package com.tencent.supersonic.headless.chat.parser.llm; import com.google.common.collect.Lists; import com.tencent.supersonic.common.config.PromptConfig; -import com.tencent.supersonic.common.pojo.SqlExemplar; +import com.tencent.supersonic.common.pojo.Text2SQLExemplar; import com.tencent.supersonic.headless.chat.query.llm.s2sql.LLMReq; import com.tencent.supersonic.headless.chat.query.llm.s2sql.LLMResp; import dev.langchain4j.data.message.AiMessage; @@ -45,11 +45,11 @@ public class OnePassSCSqlGenStrategy extends SqlGenStrategy { llmResp.setQuery(llmReq.getQueryText()); //1.recall exemplars keyPipelineLog.info("OnePassSCSqlGenStrategy llmReq:\n{}", llmReq); - List> exemplarsList = promptHelper.getFewShotExemplars(llmReq); + List> exemplarsList = promptHelper.getFewShotExemplars(llmReq); //2.generate sql generation prompt for each self-consistency inference - Map> prompt2Exemplar = new HashMap<>(); - for (List exemplars : exemplarsList) { + Map> prompt2Exemplar = new HashMap<>(); + for (List exemplars : exemplarsList) { llmReq.setDynamicExemplars(exemplars); Prompt prompt = generatePrompt(llmReq, llmResp); prompt2Exemplar.put(prompt, exemplars); @@ -61,9 +61,9 @@ public class OnePassSCSqlGenStrategy extends SqlGenStrategy { keyPipelineLog.info("OnePassSCSqlGenStrategy reqPrompt:\n{}", prompt.toUserMessage()); ChatLanguageModel chatLanguageModel = getChatLanguageModel(llmReq.getModelConfig()); Response response = chatLanguageModel.generate(prompt.toUserMessage()); - String result = response.content().text(); - output2Prompt.put(result, prompt); - keyPipelineLog.info("OnePassSCSqlGenStrategy modelResp:\n{}", result); + String sqlOutput = StringUtils.normalizeSpace(response.content().text()); + output2Prompt.put(sqlOutput, prompt); + keyPipelineLog.info("OnePassSCSqlGenStrategy modelResp:\n{}", sqlOutput); } ); @@ -71,7 +71,7 @@ public class OnePassSCSqlGenStrategy extends SqlGenStrategy { Pair> sqlMapPair = ResponseHelper.selfConsistencyVote( Lists.newArrayList(output2Prompt.keySet())); llmResp.setSqlOutput(sqlMapPair.getLeft()); - List usedExemplars = prompt2Exemplar.get(output2Prompt.get(sqlMapPair.getLeft())); + List usedExemplars = prompt2Exemplar.get(output2Prompt.get(sqlMapPair.getLeft())); llmResp.setSqlRespMap(ResponseHelper.buildSqlRespMap(usedExemplars, sqlMapPair.getRight())); return llmResp; @@ -79,7 +79,7 @@ public class OnePassSCSqlGenStrategy extends SqlGenStrategy { private Prompt generatePrompt(LLMReq llmReq, LLMResp llmResp) { StringBuilder exemplars = new StringBuilder(); - for (SqlExemplar exemplar : llmReq.getDynamicExemplars()) { + for (Text2SQLExemplar exemplar : llmReq.getDynamicExemplars()) { String exemplarStr = String.format("#Question:%s #Schema:%s #SideInfo:%s #SQL:%s\n", exemplar.getQuestion(), exemplar.getDbSchema(), exemplar.getSideInfo(), exemplar.getSql()); diff --git a/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/parser/llm/PromptHelper.java b/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/parser/llm/PromptHelper.java index b06d3ecb8..c3ebaf15c 100644 --- a/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/parser/llm/PromptHelper.java +++ b/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/parser/llm/PromptHelper.java @@ -1,7 +1,7 @@ package com.tencent.supersonic.headless.chat.parser.llm; import com.google.common.collect.Lists; -import com.tencent.supersonic.common.pojo.SqlExemplar; +import com.tencent.supersonic.common.pojo.Text2SQLExemplar; import com.tencent.supersonic.common.service.ExemplarService; import com.tencent.supersonic.headless.chat.parser.ParserConfig; import com.tencent.supersonic.headless.chat.query.llm.s2sql.LLMReq; @@ -29,12 +29,12 @@ public class PromptHelper { @Autowired private ExemplarService exemplarService; - public List> getFewShotExemplars(LLMReq llmReq) { + public List> getFewShotExemplars(LLMReq llmReq) { int exemplarRecallNumber = Integer.valueOf(parserConfig.getParameterValue(PARSER_EXEMPLAR_RECALL_NUMBER)); int fewShotNumber = Integer.valueOf(parserConfig.getParameterValue(PARSER_FEW_SHOT_NUMBER)); int selfConsistencyNumber = Integer.valueOf(parserConfig.getParameterValue(PARSER_SELF_CONSISTENCY_NUMBER)); - List exemplars = Lists.newArrayList(); + List exemplars = Lists.newArrayList(); llmReq.getDynamicExemplars().stream().forEach(e -> { exemplars.add(e); }); @@ -44,10 +44,10 @@ public class PromptHelper { exemplars.addAll(exemplarService.recallExemplars(llmReq.getQueryText(), recallSize)); } - List> results = new ArrayList<>(); + List> results = new ArrayList<>(); // use random collection of exemplars for each self-consistency inference for (int i = 0; i < selfConsistencyNumber; i++) { - List shuffledList = new ArrayList<>(exemplars); + List shuffledList = new ArrayList<>(exemplars); Collections.shuffle(shuffledList); results.add(shuffledList.subList(0, fewShotNumber)); } diff --git a/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/parser/llm/ResponseHelper.java b/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/parser/llm/ResponseHelper.java index 0b6d18655..40db0567b 100644 --- a/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/parser/llm/ResponseHelper.java +++ b/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/parser/llm/ResponseHelper.java @@ -1,6 +1,6 @@ package com.tencent.supersonic.headless.chat.parser.llm; -import com.tencent.supersonic.common.pojo.SqlExemplar; +import com.tencent.supersonic.common.pojo.Text2SQLExemplar; import com.tencent.supersonic.headless.chat.query.llm.s2sql.LLMSqlResp; import lombok.extern.slf4j.Slf4j; import org.apache.commons.lang3.tuple.Pair; @@ -54,7 +54,7 @@ public class ResponseHelper { return Pair.of(inputMax, votePercentage); } - public static Map buildSqlRespMap(List sqlExamples, + public static Map buildSqlRespMap(List sqlExamples, Map sqlMap) { if (sqlMap == null) { return new HashMap<>(); diff --git a/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/query/llm/s2sql/LLMReq.java b/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/query/llm/s2sql/LLMReq.java index 901557216..335cbe905 100644 --- a/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/query/llm/s2sql/LLMReq.java +++ b/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/query/llm/s2sql/LLMReq.java @@ -4,7 +4,7 @@ import com.fasterxml.jackson.annotation.JsonValue; import com.google.common.collect.Lists; import com.tencent.supersonic.common.config.PromptConfig; import com.tencent.supersonic.common.pojo.ChatModelConfig; -import com.tencent.supersonic.common.pojo.SqlExemplar; +import com.tencent.supersonic.common.pojo.Text2SQLExemplar; import com.tencent.supersonic.headless.api.pojo.SchemaElement; import lombok.Data; @@ -30,7 +30,7 @@ public class LLMReq { private ChatModelConfig modelConfig; private PromptConfig promptConfig; - private List dynamicExemplars; + private List dynamicExemplars; @Data diff --git a/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/query/llm/s2sql/LLMSqlResp.java b/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/query/llm/s2sql/LLMSqlResp.java index f4efe51bc..f0ab841e2 100644 --- a/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/query/llm/s2sql/LLMSqlResp.java +++ b/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/query/llm/s2sql/LLMSqlResp.java @@ -1,6 +1,6 @@ package com.tencent.supersonic.headless.chat.query.llm.s2sql; -import com.tencent.supersonic.common.pojo.SqlExemplar; +import com.tencent.supersonic.common.pojo.Text2SQLExemplar; import lombok.AllArgsConstructor; import lombok.Builder; import lombok.Data; @@ -16,6 +16,6 @@ public class LLMSqlResp { private double sqlWeight; - private List fewShots; + private List fewShots; }