From 67b69768df268e2a20b7c1062be76b4935eb52a6 Mon Sep 17 00:00:00 2001 From: jerryjzhang Date: Fri, 24 May 2024 15:08:18 +0800 Subject: [PATCH] (improvement)(headless)Refactor LLMParser impl naming and structure. --- .../embedding/EmbeddingRecallRecognizer.java | 2 +- ...Service.java => KnowledgeBaseService.java} | 2 +- .../chat/mapper/HanlpDictMatchStrategy.java | 8 +-- .../core/chat/mapper/SearchMatchStrategy.java | 8 +-- .../core/chat/parser/JavaLLMProxy.java | 49 ------------------- .../llm/{SqlExample.java => Exemplar.java} | 2 +- ...amplarLoader.java => ExemplarManager.java} | 20 ++++---- .../core/chat/parser/llm/JavaLLMProxy.java | 28 +++++++++++ .../core/chat/parser/{ => llm}/LLMProxy.java | 7 +-- .../chat/parser/llm/LLMRequestService.java | 12 ++--- .../core/chat/parser/llm/LLMSqlParser.java | 2 +- ...tion.java => OnePassSCSqlGenStrategy.java} | 15 +++--- ...ration.java => OnePassSqlGenStrategy.java} | 13 +++-- .../core/chat/parser/llm/OutputFormat.java | 3 -- ...mptGenerator.java => PromptGenerator.java} | 2 +- .../chat/parser/{ => llm}/PythonLLMProxy.java | 23 ++------- ...SqlGeneration.java => SqlGenStrategy.java} | 13 +++-- .../parser/llm/SqlGenStrategyFactory.java | 19 +++++++ .../core/chat/parser/llm/SqlGeneration.java | 19 ------- .../chat/parser/llm/SqlGenerationFactory.java | 21 -------- ...tion.java => TwoPassSCSqlGenStrategy.java} | 18 +++---- ...ration.java => TwoPassSqlGenStrategy.java} | 16 +++--- .../core/chat/query/llm/s2sql/LLMReq.java | 14 +++--- .../core/config/OptimizationConfig.java | 12 ++--- .../headless/core/utils/ComponentFactory.java | 4 +- .../listener/ApplicationStartedListener.java | 8 +-- .../listener/FullMetaEmbeddingListener.java | 2 +- .../server/listener/SqlEmbeddingListener.java | 12 ++--- .../service/impl/ChatQueryServiceImpl.java | 6 +-- .../service/impl/DictTaskServiceImpl.java | 4 +- .../service/impl/SearchServiceImpl.java | 6 +-- ...{s2ql_examplar.json => s2ql_exemplar.json} | 0 ...{s2ql_examplar.json => s2ql_exemplar.json} | 0 33 files changed, 158 insertions(+), 212 deletions(-) rename headless/core/src/main/java/com/tencent/supersonic/headless/core/chat/knowledge/{KnowledgeService.java => KnowledgeBaseService.java} (98%) delete mode 100644 headless/core/src/main/java/com/tencent/supersonic/headless/core/chat/parser/JavaLLMProxy.java rename headless/core/src/main/java/com/tencent/supersonic/headless/core/chat/parser/llm/{SqlExample.java => Exemplar.java} (92%) rename headless/core/src/main/java/com/tencent/supersonic/headless/core/chat/parser/llm/{SqlExamplarLoader.java => ExemplarManager.java} (81%) create mode 100644 headless/core/src/main/java/com/tencent/supersonic/headless/core/chat/parser/llm/JavaLLMProxy.java rename headless/core/src/main/java/com/tencent/supersonic/headless/core/chat/parser/{ => llm}/LLMProxy.java (61%) rename headless/core/src/main/java/com/tencent/supersonic/headless/core/chat/parser/llm/{OnePassSCSqlGeneration.java => OnePassSCSqlGenStrategy.java} (79%) rename headless/core/src/main/java/com/tencent/supersonic/headless/core/chat/parser/llm/{OnePassSqlGeneration.java => OnePassSqlGenStrategy.java} (77%) rename headless/core/src/main/java/com/tencent/supersonic/headless/core/chat/parser/llm/{SqlPromptGenerator.java => PromptGenerator.java} (99%) rename headless/core/src/main/java/com/tencent/supersonic/headless/core/chat/parser/{ => llm}/PythonLLMProxy.java (73%) rename headless/core/src/main/java/com/tencent/supersonic/headless/core/chat/parser/llm/{BaseSqlGeneration.java => SqlGenStrategy.java} (65%) create mode 100644 headless/core/src/main/java/com/tencent/supersonic/headless/core/chat/parser/llm/SqlGenStrategyFactory.java delete mode 100644 headless/core/src/main/java/com/tencent/supersonic/headless/core/chat/parser/llm/SqlGeneration.java delete mode 100644 headless/core/src/main/java/com/tencent/supersonic/headless/core/chat/parser/llm/SqlGenerationFactory.java rename headless/core/src/main/java/com/tencent/supersonic/headless/core/chat/parser/llm/{TwoPassSCSqlGeneration.java => TwoPassSCSqlGenStrategy.java} (80%) rename headless/core/src/main/java/com/tencent/supersonic/headless/core/chat/parser/llm/{TwoPassSqlGeneration.java => TwoPassSqlGenStrategy.java} (77%) rename launchers/standalone/src/main/resources/{s2ql_examplar.json => s2ql_exemplar.json} (100%) rename launchers/standalone/src/test/resources/{s2ql_examplar.json => s2ql_exemplar.json} (100%) diff --git a/chat/server/src/main/java/com/tencent/supersonic/chat/server/plugin/recognize/embedding/EmbeddingRecallRecognizer.java b/chat/server/src/main/java/com/tencent/supersonic/chat/server/plugin/recognize/embedding/EmbeddingRecallRecognizer.java index 286862cd9..ada94dd44 100644 --- a/chat/server/src/main/java/com/tencent/supersonic/chat/server/plugin/recognize/embedding/EmbeddingRecallRecognizer.java +++ b/chat/server/src/main/java/com/tencent/supersonic/chat/server/plugin/recognize/embedding/EmbeddingRecallRecognizer.java @@ -11,7 +11,7 @@ import com.tencent.supersonic.common.config.EmbeddingConfig; import com.tencent.supersonic.common.util.ContextUtils; import com.tencent.supersonic.common.util.embedding.Retrieval; import com.tencent.supersonic.common.util.embedding.RetrieveQueryResult; -import com.tencent.supersonic.headless.core.chat.parser.PythonLLMProxy; +import com.tencent.supersonic.headless.core.chat.parser.llm.PythonLLMProxy; import com.tencent.supersonic.headless.core.utils.ComponentFactory; import lombok.extern.slf4j.Slf4j; import org.apache.commons.lang3.StringUtils; diff --git a/headless/core/src/main/java/com/tencent/supersonic/headless/core/chat/knowledge/KnowledgeService.java b/headless/core/src/main/java/com/tencent/supersonic/headless/core/chat/knowledge/KnowledgeBaseService.java similarity index 98% rename from headless/core/src/main/java/com/tencent/supersonic/headless/core/chat/knowledge/KnowledgeService.java rename to headless/core/src/main/java/com/tencent/supersonic/headless/core/chat/knowledge/KnowledgeBaseService.java index f14e45d3b..7602f1839 100644 --- a/headless/core/src/main/java/com/tencent/supersonic/headless/core/chat/knowledge/KnowledgeService.java +++ b/headless/core/src/main/java/com/tencent/supersonic/headless/core/chat/knowledge/KnowledgeBaseService.java @@ -13,7 +13,7 @@ import java.util.stream.Collectors; @Service @Slf4j -public class KnowledgeService { +public class KnowledgeBaseService { public void updateSemanticKnowledge(List natures) { diff --git a/headless/core/src/main/java/com/tencent/supersonic/headless/core/chat/mapper/HanlpDictMatchStrategy.java b/headless/core/src/main/java/com/tencent/supersonic/headless/core/chat/mapper/HanlpDictMatchStrategy.java index ad9b8ee81..64720f1b4 100644 --- a/headless/core/src/main/java/com/tencent/supersonic/headless/core/chat/mapper/HanlpDictMatchStrategy.java +++ b/headless/core/src/main/java/com/tencent/supersonic/headless/core/chat/mapper/HanlpDictMatchStrategy.java @@ -3,7 +3,7 @@ package com.tencent.supersonic.headless.core.chat.mapper; import com.tencent.supersonic.common.pojo.Constants; import com.tencent.supersonic.headless.api.pojo.response.S2Term; import com.tencent.supersonic.headless.core.chat.knowledge.HanlpMapResult; -import com.tencent.supersonic.headless.core.chat.knowledge.KnowledgeService; +import com.tencent.supersonic.headless.core.chat.knowledge.KnowledgeBaseService; import com.tencent.supersonic.headless.core.config.OptimizationConfig; import com.tencent.supersonic.headless.core.pojo.QueryContext; import lombok.extern.slf4j.Slf4j; @@ -36,7 +36,7 @@ public class HanlpDictMatchStrategy extends BaseMatchStrategy { private OptimizationConfig optimizationConfig; @Autowired - private KnowledgeService knowledgeService; + private KnowledgeBaseService knowledgeBaseService; @Override public Map> match(QueryContext queryContext, List terms, @@ -65,11 +65,11 @@ public class HanlpDictMatchStrategy extends BaseMatchStrategy { String detectSegment, int offset) { // step1. pre search Integer oneDetectionMaxSize = optimizationConfig.getOneDetectionMaxSize(); - LinkedHashSet hanlpMapResults = knowledgeService.prefixSearch(detectSegment, + LinkedHashSet hanlpMapResults = knowledgeBaseService.prefixSearch(detectSegment, oneDetectionMaxSize, queryContext.getModelIdToDataSetIds(), detectDataSetIds) .stream().collect(Collectors.toCollection(LinkedHashSet::new)); // step2. suffix search - LinkedHashSet suffixHanlpMapResults = knowledgeService.suffixSearch(detectSegment, + LinkedHashSet suffixHanlpMapResults = knowledgeBaseService.suffixSearch(detectSegment, oneDetectionMaxSize, queryContext.getModelIdToDataSetIds(), detectDataSetIds) .stream().collect(Collectors.toCollection(LinkedHashSet::new)); diff --git a/headless/core/src/main/java/com/tencent/supersonic/headless/core/chat/mapper/SearchMatchStrategy.java b/headless/core/src/main/java/com/tencent/supersonic/headless/core/chat/mapper/SearchMatchStrategy.java index f5d99e26e..2d28789e7 100644 --- a/headless/core/src/main/java/com/tencent/supersonic/headless/core/chat/mapper/SearchMatchStrategy.java +++ b/headless/core/src/main/java/com/tencent/supersonic/headless/core/chat/mapper/SearchMatchStrategy.java @@ -5,7 +5,7 @@ import com.tencent.supersonic.common.pojo.enums.DictWordType; import com.tencent.supersonic.headless.api.pojo.response.S2Term; import com.tencent.supersonic.headless.core.pojo.QueryContext; import com.tencent.supersonic.headless.core.chat.knowledge.HanlpMapResult; -import com.tencent.supersonic.headless.core.chat.knowledge.KnowledgeService; +import com.tencent.supersonic.headless.core.chat.knowledge.KnowledgeBaseService; import com.tencent.supersonic.headless.core.chat.knowledge.SearchService; import org.apache.commons.collections.CollectionUtils; import org.apache.commons.lang3.StringUtils; @@ -29,7 +29,7 @@ public class SearchMatchStrategy extends BaseMatchStrategy { private static final int SEARCH_SIZE = 3; @Autowired - private KnowledgeService knowledgeService; + private KnowledgeBaseService knowledgeBaseService; @Override public Map> match(QueryContext queryContext, List originals, @@ -57,9 +57,9 @@ public class SearchMatchStrategy extends BaseMatchStrategy { String detectSegment = text.substring(detectIndex); if (StringUtils.isNotEmpty(detectSegment)) { - List hanlpMapResults = knowledgeService.prefixSearch(detectSegment, + List hanlpMapResults = knowledgeBaseService.prefixSearch(detectSegment, SearchService.SEARCH_SIZE, queryContext.getModelIdToDataSetIds(), detectDataSetIds); - List suffixHanlpMapResults = knowledgeService.suffixSearch( + List suffixHanlpMapResults = knowledgeBaseService.suffixSearch( detectSegment, SEARCH_SIZE, queryContext.getModelIdToDataSetIds(), detectDataSetIds); hanlpMapResults.addAll(suffixHanlpMapResults); // remove entity name where search diff --git a/headless/core/src/main/java/com/tencent/supersonic/headless/core/chat/parser/JavaLLMProxy.java b/headless/core/src/main/java/com/tencent/supersonic/headless/core/chat/parser/JavaLLMProxy.java deleted file mode 100644 index ade4a0bf5..000000000 --- a/headless/core/src/main/java/com/tencent/supersonic/headless/core/chat/parser/JavaLLMProxy.java +++ /dev/null @@ -1,49 +0,0 @@ -package com.tencent.supersonic.headless.core.chat.parser; - - -import com.tencent.supersonic.common.util.ContextUtils; -import com.tencent.supersonic.headless.core.chat.parser.llm.SqlGeneration; -import com.tencent.supersonic.headless.core.chat.parser.llm.SqlGenerationFactory; -import com.tencent.supersonic.headless.core.chat.query.llm.s2sql.LLMReq; -import com.tencent.supersonic.headless.core.chat.query.llm.s2sql.LLMReq.SqlGenerationMode; -import com.tencent.supersonic.headless.core.chat.query.llm.s2sql.LLMResp; -import com.tencent.supersonic.headless.core.pojo.QueryContext; -import dev.langchain4j.model.chat.ChatLanguageModel; -import lombok.extern.slf4j.Slf4j; -import org.slf4j.Logger; -import org.slf4j.LoggerFactory; -import org.springframework.stereotype.Component; - -import java.util.Objects; - -/** - * LLMProxy based on langchain4j Java version. - */ -@Slf4j -@Component -public class JavaLLMProxy implements LLMProxy { - - private static final Logger keyPipelineLog = LoggerFactory.getLogger("keyPipeline"); - - @Override - public boolean isSkip(QueryContext queryContext) { - ChatLanguageModel chatLanguageModel = ContextUtils.getBean(ChatLanguageModel.class); - if (Objects.isNull(chatLanguageModel)) { - log.warn("chatLanguageModel is null, skip :{}", JavaLLMProxy.class.getName()); - return true; - } - return false; - } - - public LLMResp query2sql(LLMReq llmReq, Long dataSetId) { - - SqlGeneration sqlGeneration = SqlGenerationFactory.get( - SqlGenerationMode.getMode(llmReq.getSqlGenerationMode())); - String modelName = llmReq.getSchema().getDataSetName(); - LLMResp result = sqlGeneration.generation(llmReq, dataSetId); - result.setQuery(llmReq.getQueryText()); - result.setModelName(modelName); - return result; - } - -} diff --git a/headless/core/src/main/java/com/tencent/supersonic/headless/core/chat/parser/llm/SqlExample.java b/headless/core/src/main/java/com/tencent/supersonic/headless/core/chat/parser/llm/Exemplar.java similarity index 92% rename from headless/core/src/main/java/com/tencent/supersonic/headless/core/chat/parser/llm/SqlExample.java rename to headless/core/src/main/java/com/tencent/supersonic/headless/core/chat/parser/llm/Exemplar.java index 20eec3372..480c5fb5d 100644 --- a/headless/core/src/main/java/com/tencent/supersonic/headless/core/chat/parser/llm/SqlExample.java +++ b/headless/core/src/main/java/com/tencent/supersonic/headless/core/chat/parser/llm/Exemplar.java @@ -3,7 +3,7 @@ package com.tencent.supersonic.headless.core.chat.parser.llm; import lombok.Data; @Data -public class SqlExample { +public class Exemplar { private String question; diff --git a/headless/core/src/main/java/com/tencent/supersonic/headless/core/chat/parser/llm/SqlExamplarLoader.java b/headless/core/src/main/java/com/tencent/supersonic/headless/core/chat/parser/llm/ExemplarManager.java similarity index 81% rename from headless/core/src/main/java/com/tencent/supersonic/headless/core/chat/parser/llm/SqlExamplarLoader.java rename to headless/core/src/main/java/com/tencent/supersonic/headless/core/chat/parser/llm/ExemplarManager.java index aa0a1d149..d6579036d 100644 --- a/headless/core/src/main/java/com/tencent/supersonic/headless/core/chat/parser/llm/SqlExamplarLoader.java +++ b/headless/core/src/main/java/com/tencent/supersonic/headless/core/chat/parser/llm/ExemplarManager.java @@ -27,29 +27,29 @@ import java.util.stream.Collectors; @Slf4j @Component -public class SqlExamplarLoader { +public class ExemplarManager { - private static final String EXAMPLE_JSON_FILE = "s2ql_examplar.json"; + private static final String EXAMPLE_JSON_FILE = "s2ql_exemplar.json"; private S2EmbeddingStore s2EmbeddingStore = ComponentFactory.getS2EmbeddingStore(); - private TypeReference> valueTypeRef = new TypeReference>() { + private TypeReference> valueTypeRef = new TypeReference>() { }; @Autowired private EmbeddingConfig embeddingConfig; - public List getSqlExamples() throws IOException { + public List getExemplars() throws IOException { ClassPathResource resource = new ClassPathResource(EXAMPLE_JSON_FILE); InputStream inputStream = resource.getInputStream(); return JsonUtil.INSTANCE.getObjectMapper().readValue(inputStream, valueTypeRef); } - public void addEmbeddingStore(List sqlExamples, String collectionName) { + public void addExemplars(List exemplars, String collectionName) { List queries = new ArrayList<>(); - for (int i = 0; i < sqlExamples.size(); i++) { - SqlExample sqlExample = sqlExamples.get(i); - String question = sqlExample.getQuestion(); - Map metaDataMap = JsonUtil.toMap(JsonUtil.toString(sqlExample), String.class, Object.class); + for (int i = 0; i < exemplars.size(); i++) { + Exemplar exemplar = exemplars.get(i); + String question = exemplar.getQuestion(); + Map metaDataMap = JsonUtil.toMap(JsonUtil.toString(exemplar), String.class, Object.class); EmbeddingQuery embeddingQuery = new EmbeddingQuery(); embeddingQuery.setQueryId(String.valueOf(i)); embeddingQuery.setQuery(question); @@ -59,7 +59,7 @@ public class SqlExamplarLoader { s2EmbeddingStore.addQuery(collectionName, queries); } - public List> retrieverSqlExamples(String queryText, int maxResults) { + public List> recallExemplars(String queryText, int maxResults) { String collectionName = embeddingConfig.getText2sqlCollectionName(); RetrieveQuery retrieveQuery = RetrieveQuery.builder().queryTextsList(Collections.singletonList(queryText)) .queryEmbeddings(null).build(); diff --git a/headless/core/src/main/java/com/tencent/supersonic/headless/core/chat/parser/llm/JavaLLMProxy.java b/headless/core/src/main/java/com/tencent/supersonic/headless/core/chat/parser/llm/JavaLLMProxy.java new file mode 100644 index 000000000..69395432d --- /dev/null +++ b/headless/core/src/main/java/com/tencent/supersonic/headless/core/chat/parser/llm/JavaLLMProxy.java @@ -0,0 +1,28 @@ +package com.tencent.supersonic.headless.core.chat.parser.llm; + + +import com.tencent.supersonic.headless.core.chat.query.llm.s2sql.LLMReq; +import com.tencent.supersonic.headless.core.chat.query.llm.s2sql.LLMReq.SqlGenType; +import com.tencent.supersonic.headless.core.chat.query.llm.s2sql.LLMResp; +import lombok.extern.slf4j.Slf4j; +import org.springframework.stereotype.Component; + +/** + * LLMProxy based on langchain4j Java version. + */ +@Slf4j +@Component +public class JavaLLMProxy implements LLMProxy { + + public LLMResp text2sql(LLMReq llmReq) { + + SqlGenStrategy sqlGenStrategy = SqlGenStrategyFactory.get( + SqlGenType.getMode(llmReq.getSqlGenerationMode())); + String modelName = llmReq.getSchema().getDataSetName(); + LLMResp result = sqlGenStrategy.generate(llmReq); + result.setQuery(llmReq.getQueryText()); + result.setModelName(modelName); + return result; + } + +} diff --git a/headless/core/src/main/java/com/tencent/supersonic/headless/core/chat/parser/LLMProxy.java b/headless/core/src/main/java/com/tencent/supersonic/headless/core/chat/parser/llm/LLMProxy.java similarity index 61% rename from headless/core/src/main/java/com/tencent/supersonic/headless/core/chat/parser/LLMProxy.java rename to headless/core/src/main/java/com/tencent/supersonic/headless/core/chat/parser/llm/LLMProxy.java index f2910bd21..63779c55a 100644 --- a/headless/core/src/main/java/com/tencent/supersonic/headless/core/chat/parser/LLMProxy.java +++ b/headless/core/src/main/java/com/tencent/supersonic/headless/core/chat/parser/llm/LLMProxy.java @@ -1,7 +1,6 @@ -package com.tencent.supersonic.headless.core.chat.parser; +package com.tencent.supersonic.headless.core.chat.parser.llm; -import com.tencent.supersonic.headless.core.pojo.QueryContext; import com.tencent.supersonic.headless.core.chat.query.llm.s2sql.LLMReq; import com.tencent.supersonic.headless.core.chat.query.llm.s2sql.LLMResp; @@ -12,8 +11,6 @@ import com.tencent.supersonic.headless.core.chat.query.llm.s2sql.LLMResp; */ public interface LLMProxy { - boolean isSkip(QueryContext queryContext); - - LLMResp query2sql(LLMReq llmReq, Long dataSetId); + LLMResp text2sql(LLMReq llmReq); } diff --git a/headless/core/src/main/java/com/tencent/supersonic/headless/core/chat/parser/llm/LLMRequestService.java b/headless/core/src/main/java/com/tencent/supersonic/headless/core/chat/parser/llm/LLMRequestService.java index 1cd3ddae8..0018c1204 100644 --- a/headless/core/src/main/java/com/tencent/supersonic/headless/core/chat/parser/llm/LLMRequestService.java +++ b/headless/core/src/main/java/com/tencent/supersonic/headless/core/chat/parser/llm/LLMRequestService.java @@ -45,13 +45,12 @@ public class LLMRequestService { log.info("not enable llm, skip"); return true; } - if (ComponentFactory.getLLMProxy().isSkip(queryCtx)) { - return true; - } + if (SatisfactionChecker.isSkip(queryCtx)) { log.info("skip {}, queryText:{}", LLMSqlParser.class, queryCtx.getQueryText()); return true; } + return false; } @@ -72,6 +71,7 @@ public class LLMRequestService { llmReq.setFilterCondition(filterCondition); LLMReq.LLMSchema llmSchema = new LLMReq.LLMSchema(); + llmSchema.setDataSetId(dataSetId); llmSchema.setDataSetName(dataSetIdToName.get(dataSetId)); llmSchema.setDomainName(dataSetIdToName.get(dataSetId)); @@ -95,13 +95,13 @@ public class LLMRequestService { currentDate = DateUtils.getBeforeDate(0); } llmReq.setCurrentDate(currentDate); - llmReq.setSqlGenerationMode(optimizationConfig.getSqlGenerationMode().getName()); + llmReq.setSqlGenerationMode(optimizationConfig.getSqlGenType().getName()); llmReq.setLlmConfig(queryCtx.getLlmConfig()); return llmReq; } - public LLMResp requestLLM(LLMReq llmReq, Long dataSetId) { - return ComponentFactory.getLLMProxy().query2sql(llmReq, dataSetId); + public LLMResp invokeLLM(LLMReq llmReq) { + return ComponentFactory.getLLMProxy().text2sql(llmReq); } protected List getFieldNameList(QueryContext queryCtx, Long dataSetId, diff --git a/headless/core/src/main/java/com/tencent/supersonic/headless/core/chat/parser/llm/LLMSqlParser.java b/headless/core/src/main/java/com/tencent/supersonic/headless/core/chat/parser/llm/LLMSqlParser.java index d84a65712..cc4507445 100644 --- a/headless/core/src/main/java/com/tencent/supersonic/headless/core/chat/parser/llm/LLMSqlParser.java +++ b/headless/core/src/main/java/com/tencent/supersonic/headless/core/chat/parser/llm/LLMSqlParser.java @@ -43,7 +43,7 @@ public class LLMSqlParser implements SemanticParser { List linkingValues = requestService.getValueList(queryCtx, dataSetId); SemanticSchema semanticSchema = queryCtx.getSemanticSchema(); LLMReq llmReq = requestService.getLlmReq(queryCtx, dataSetId, semanticSchema, linkingValues); - LLMResp llmResp = requestService.requestLLM(llmReq, dataSetId); + LLMResp llmResp = requestService.invokeLLM(llmReq); if (Objects.isNull(llmResp)) { return; } diff --git a/headless/core/src/main/java/com/tencent/supersonic/headless/core/chat/parser/llm/OnePassSCSqlGeneration.java b/headless/core/src/main/java/com/tencent/supersonic/headless/core/chat/parser/llm/OnePassSCSqlGenStrategy.java similarity index 79% rename from headless/core/src/main/java/com/tencent/supersonic/headless/core/chat/parser/llm/OnePassSCSqlGeneration.java rename to headless/core/src/main/java/com/tencent/supersonic/headless/core/chat/parser/llm/OnePassSCSqlGenStrategy.java index 6444bd73e..3e32f278a 100644 --- a/headless/core/src/main/java/com/tencent/supersonic/headless/core/chat/parser/llm/OnePassSCSqlGeneration.java +++ b/headless/core/src/main/java/com/tencent/supersonic/headless/core/chat/parser/llm/OnePassSCSqlGenStrategy.java @@ -2,7 +2,6 @@ package com.tencent.supersonic.headless.core.chat.parser.llm; import com.tencent.supersonic.common.util.JsonUtil; import com.tencent.supersonic.headless.core.chat.query.llm.s2sql.LLMReq; -import com.tencent.supersonic.headless.core.chat.query.llm.s2sql.LLMReq.SqlGenerationMode; import com.tencent.supersonic.headless.core.chat.query.llm.s2sql.LLMResp; import dev.langchain4j.data.message.AiMessage; import dev.langchain4j.model.chat.ChatLanguageModel; @@ -21,21 +20,21 @@ import java.util.stream.Collectors; @Service @Slf4j -public class OnePassSCSqlGeneration extends BaseSqlGeneration { +public class OnePassSCSqlGenStrategy extends SqlGenStrategy { @Override - public LLMResp generation(LLMReq llmReq, Long dataSetId) { + public LLMResp generate(LLMReq llmReq) { //1.retriever sqlExamples and generate exampleListPool - keyPipelineLog.info("dataSetId:{},llmReq:{}", dataSetId, llmReq); + keyPipelineLog.info("llmReq:{}", llmReq); - List> sqlExamples = sqlExamplarLoader.retrieverSqlExamples(llmReq.getQueryText(), + List> sqlExamples = exemplarManager.recallExemplars(llmReq.getQueryText(), optimizationConfig.getText2sqlExampleNum()); - List>> exampleListPool = sqlPromptGenerator.getExampleCombos(sqlExamples, + List>> exampleListPool = promptGenerator.getExampleCombos(sqlExamples, optimizationConfig.getText2sqlFewShotsNum(), optimizationConfig.getText2sqlSelfConsistencyNum()); //2.generator linking and sql prompt by sqlExamples,and parallel generate response. - List linkingSqlPromptPool = sqlPromptGenerator.generatePromptPool(llmReq, exampleListPool, true); + List linkingSqlPromptPool = promptGenerator.generatePromptPool(llmReq, exampleListPool, true); List llmResults = new CopyOnWriteArrayList<>(); linkingSqlPromptPool.parallelStream().forEach(linkingSqlPrompt -> { Prompt prompt = PromptTemplate.from(JsonUtil.toString(linkingSqlPrompt)) @@ -67,6 +66,6 @@ public class OnePassSCSqlGeneration extends BaseSqlGeneration { @Override public void afterPropertiesSet() { - SqlGenerationFactory.addSqlGenerationForFactory(SqlGenerationMode.ONE_PASS_AUTO_COT_SELF_CONSISTENCY, this); + SqlGenStrategyFactory.addSqlGenerationForFactory(LLMReq.SqlGenType.ONE_PASS_AUTO_COT_SELF_CONSISTENCY, this); } } diff --git a/headless/core/src/main/java/com/tencent/supersonic/headless/core/chat/parser/llm/OnePassSqlGeneration.java b/headless/core/src/main/java/com/tencent/supersonic/headless/core/chat/parser/llm/OnePassSqlGenStrategy.java similarity index 77% rename from headless/core/src/main/java/com/tencent/supersonic/headless/core/chat/parser/llm/OnePassSqlGeneration.java rename to headless/core/src/main/java/com/tencent/supersonic/headless/core/chat/parser/llm/OnePassSqlGenStrategy.java index c1c8328f4..77160e047 100644 --- a/headless/core/src/main/java/com/tencent/supersonic/headless/core/chat/parser/llm/OnePassSqlGeneration.java +++ b/headless/core/src/main/java/com/tencent/supersonic/headless/core/chat/parser/llm/OnePassSqlGenStrategy.java @@ -2,7 +2,6 @@ package com.tencent.supersonic.headless.core.chat.parser.llm; import com.tencent.supersonic.common.util.JsonUtil; import com.tencent.supersonic.headless.core.chat.query.llm.s2sql.LLMReq; -import com.tencent.supersonic.headless.core.chat.query.llm.s2sql.LLMReq.SqlGenerationMode; import com.tencent.supersonic.headless.core.chat.query.llm.s2sql.LLMResp; import com.tencent.supersonic.headless.core.chat.query.llm.s2sql.LLMSqlResp; import dev.langchain4j.data.message.AiMessage; @@ -19,17 +18,17 @@ import java.util.Map; @Service @Slf4j -public class OnePassSqlGeneration extends BaseSqlGeneration { +public class OnePassSqlGenStrategy extends SqlGenStrategy { @Override - public LLMResp generation(LLMReq llmReq, Long dataSetId) { + public LLMResp generate(LLMReq llmReq) { //1.retriever sqlExamples - keyPipelineLog.info("dataSetId:{},llmReq:{}", dataSetId, llmReq); - List> sqlExamples = sqlExamplarLoader.retrieverSqlExamples(llmReq.getQueryText(), + keyPipelineLog.info("llmReq:{}", llmReq); + List> sqlExamples = exemplarManager.recallExemplars(llmReq.getQueryText(), optimizationConfig.getText2sqlExampleNum()); //2.generator linking and sql prompt by sqlExamples,and generate response. - String promptStr = sqlPromptGenerator.generatorLinkingAndSqlPrompt(llmReq, sqlExamples); + String promptStr = promptGenerator.generatorLinkingAndSqlPrompt(llmReq, sqlExamples); Prompt prompt = PromptTemplate.from(JsonUtil.toString(promptStr)).apply(new HashMap<>()); keyPipelineLog.info("request prompt:{}", prompt.toSystemMessage()); @@ -52,6 +51,6 @@ public class OnePassSqlGeneration extends BaseSqlGeneration { @Override public void afterPropertiesSet() { - SqlGenerationFactory.addSqlGenerationForFactory(SqlGenerationMode.ONE_PASS_AUTO_COT, this); + SqlGenStrategyFactory.addSqlGenerationForFactory(LLMReq.SqlGenType.ONE_PASS_AUTO_COT, this); } } diff --git a/headless/core/src/main/java/com/tencent/supersonic/headless/core/chat/parser/llm/OutputFormat.java b/headless/core/src/main/java/com/tencent/supersonic/headless/core/chat/parser/llm/OutputFormat.java index 8b59effef..082a8b989 100644 --- a/headless/core/src/main/java/com/tencent/supersonic/headless/core/chat/parser/llm/OutputFormat.java +++ b/headless/core/src/main/java/com/tencent/supersonic/headless/core/chat/parser/llm/OutputFormat.java @@ -13,9 +13,6 @@ import java.util.regex.Matcher; import java.util.regex.Pattern; import java.util.stream.Collectors; -/*** - * output format - */ @Slf4j public class OutputFormat { diff --git a/headless/core/src/main/java/com/tencent/supersonic/headless/core/chat/parser/llm/SqlPromptGenerator.java b/headless/core/src/main/java/com/tencent/supersonic/headless/core/chat/parser/llm/PromptGenerator.java similarity index 99% rename from headless/core/src/main/java/com/tencent/supersonic/headless/core/chat/parser/llm/SqlPromptGenerator.java rename to headless/core/src/main/java/com/tencent/supersonic/headless/core/chat/parser/llm/PromptGenerator.java index 889dabe46..36c86fd5f 100644 --- a/headless/core/src/main/java/com/tencent/supersonic/headless/core/chat/parser/llm/SqlPromptGenerator.java +++ b/headless/core/src/main/java/com/tencent/supersonic/headless/core/chat/parser/llm/PromptGenerator.java @@ -14,7 +14,7 @@ import java.util.Collections; @Component @Slf4j -public class SqlPromptGenerator { +public class PromptGenerator { public String generatorLinkingAndSqlPrompt(LLMReq llmReq, List> exampleList) { String instruction = diff --git a/headless/core/src/main/java/com/tencent/supersonic/headless/core/chat/parser/PythonLLMProxy.java b/headless/core/src/main/java/com/tencent/supersonic/headless/core/chat/parser/llm/PythonLLMProxy.java similarity index 73% rename from headless/core/src/main/java/com/tencent/supersonic/headless/core/chat/parser/PythonLLMProxy.java rename to headless/core/src/main/java/com/tencent/supersonic/headless/core/chat/parser/llm/PythonLLMProxy.java index a429319ec..be0c3a1f2 100644 --- a/headless/core/src/main/java/com/tencent/supersonic/headless/core/chat/parser/PythonLLMProxy.java +++ b/headless/core/src/main/java/com/tencent/supersonic/headless/core/chat/parser/llm/PythonLLMProxy.java @@ -1,15 +1,12 @@ -package com.tencent.supersonic.headless.core.chat.parser; +package com.tencent.supersonic.headless.core.chat.parser.llm; import com.tencent.supersonic.common.util.ContextUtils; import com.tencent.supersonic.common.util.JsonUtil; import com.tencent.supersonic.headless.core.config.LLMParserConfig; -import com.tencent.supersonic.headless.core.chat.parser.llm.OutputFormat; -import com.tencent.supersonic.headless.core.pojo.QueryContext; import com.tencent.supersonic.headless.core.chat.query.llm.s2sql.LLMReq; import com.tencent.supersonic.headless.core.chat.query.llm.s2sql.LLMResp; import lombok.extern.slf4j.Slf4j; import org.apache.commons.collections4.MapUtils; -import org.apache.commons.lang3.StringUtils; import org.slf4j.Logger; import org.slf4j.LoggerFactory; import org.springframework.http.HttpEntity; @@ -30,22 +27,12 @@ import java.util.ArrayList; @Component public class PythonLLMProxy implements LLMProxy { - private static final Logger keyPipelineLog = LoggerFactory.getLogger(PythonLLMProxy.class); + private static final Logger keyPipelineLog = LoggerFactory.getLogger("keyPipeline"); - @Override - public boolean isSkip(QueryContext queryContext) { - LLMParserConfig llmParserConfig = ContextUtils.getBean(LLMParserConfig.class); - if (StringUtils.isEmpty(llmParserConfig.getUrl())) { - log.warn("llmParserUrl is empty, skip :{}", PythonLLMProxy.class.getName()); - return true; - } - return false; - } - - public LLMResp query2sql(LLMReq llmReq, Long dataSetId) { + public LLMResp text2sql(LLMReq llmReq) { long startTime = System.currentTimeMillis(); - log.info("requestLLM request, dataSetId:{},llmReq:{}", dataSetId, llmReq); - keyPipelineLog.info("dataSetId:{},llmReq:{}", dataSetId, llmReq); + log.info("requestLLM request, llmReq:{}", llmReq); + keyPipelineLog.info("llmReq:{}", llmReq); try { LLMParserConfig llmParserConfig = ContextUtils.getBean(LLMParserConfig.class); diff --git a/headless/core/src/main/java/com/tencent/supersonic/headless/core/chat/parser/llm/BaseSqlGeneration.java b/headless/core/src/main/java/com/tencent/supersonic/headless/core/chat/parser/llm/SqlGenStrategy.java similarity index 65% rename from headless/core/src/main/java/com/tencent/supersonic/headless/core/chat/parser/llm/BaseSqlGeneration.java rename to headless/core/src/main/java/com/tencent/supersonic/headless/core/chat/parser/llm/SqlGenStrategy.java index 55441b1b1..c645e4cab 100644 --- a/headless/core/src/main/java/com/tencent/supersonic/headless/core/chat/parser/llm/BaseSqlGeneration.java +++ b/headless/core/src/main/java/com/tencent/supersonic/headless/core/chat/parser/llm/SqlGenStrategy.java @@ -1,6 +1,8 @@ package com.tencent.supersonic.headless.core.chat.parser.llm; import com.tencent.supersonic.headless.api.pojo.LLMConfig; +import com.tencent.supersonic.headless.core.chat.query.llm.s2sql.LLMReq; +import com.tencent.supersonic.headless.core.chat.query.llm.s2sql.LLMResp; import com.tencent.supersonic.headless.core.config.OptimizationConfig; import com.tencent.supersonic.headless.core.utils.S2ChatModelProvider; import dev.langchain4j.model.chat.ChatLanguageModel; @@ -10,22 +12,27 @@ import org.springframework.beans.factory.InitializingBean; import org.springframework.beans.factory.annotation.Autowired; import org.springframework.stereotype.Service; +/** + * SqlGenStrategy abstracts generation step so that + * different LLM prompting strategies can be implemented. + */ @Service -public abstract class BaseSqlGeneration implements SqlGeneration, InitializingBean { +public abstract class SqlGenStrategy implements InitializingBean { protected static final Logger keyPipelineLog = LoggerFactory.getLogger("keyPipeline"); @Autowired - protected SqlExamplarLoader sqlExamplarLoader; + protected ExemplarManager exemplarManager; @Autowired protected OptimizationConfig optimizationConfig; @Autowired - protected SqlPromptGenerator sqlPromptGenerator; + protected PromptGenerator promptGenerator; protected ChatLanguageModel getChatLanguageModel(LLMConfig llmConfig) { return S2ChatModelProvider.provide(llmConfig); } + abstract LLMResp generate(LLMReq llmReq); } diff --git a/headless/core/src/main/java/com/tencent/supersonic/headless/core/chat/parser/llm/SqlGenStrategyFactory.java b/headless/core/src/main/java/com/tencent/supersonic/headless/core/chat/parser/llm/SqlGenStrategyFactory.java new file mode 100644 index 000000000..1036ac0c8 --- /dev/null +++ b/headless/core/src/main/java/com/tencent/supersonic/headless/core/chat/parser/llm/SqlGenStrategyFactory.java @@ -0,0 +1,19 @@ +package com.tencent.supersonic.headless.core.chat.parser.llm; + +import com.tencent.supersonic.headless.core.chat.query.llm.s2sql.LLMReq; + +import java.util.Map; +import java.util.concurrent.ConcurrentHashMap; + +public class SqlGenStrategyFactory { + + private static Map sqlGenStrategyMap = new ConcurrentHashMap<>(); + + public static SqlGenStrategy get(LLMReq.SqlGenType strategyType) { + return sqlGenStrategyMap.get(strategyType); + } + + public static void addSqlGenerationForFactory(LLMReq.SqlGenType strategy, SqlGenStrategy sqlGenStrategy) { + sqlGenStrategyMap.put(strategy, sqlGenStrategy); + } +} \ No newline at end of file diff --git a/headless/core/src/main/java/com/tencent/supersonic/headless/core/chat/parser/llm/SqlGeneration.java b/headless/core/src/main/java/com/tencent/supersonic/headless/core/chat/parser/llm/SqlGeneration.java deleted file mode 100644 index d3f502548..000000000 --- a/headless/core/src/main/java/com/tencent/supersonic/headless/core/chat/parser/llm/SqlGeneration.java +++ /dev/null @@ -1,19 +0,0 @@ -package com.tencent.supersonic.headless.core.chat.parser.llm; - -import com.tencent.supersonic.headless.core.chat.query.llm.s2sql.LLMReq; -import com.tencent.supersonic.headless.core.chat.query.llm.s2sql.LLMResp; - -/** - * Sql Generation interface, generating SQL using a large model. - */ -public interface SqlGeneration { - - /*** - * generate llmResp (sql, schemaLink, prompt, etc.) through LLMReq. - * @param llmReq - * @param dataSetId - * @return - */ - LLMResp generation(LLMReq llmReq, Long dataSetId); - -} diff --git a/headless/core/src/main/java/com/tencent/supersonic/headless/core/chat/parser/llm/SqlGenerationFactory.java b/headless/core/src/main/java/com/tencent/supersonic/headless/core/chat/parser/llm/SqlGenerationFactory.java deleted file mode 100644 index 6fd07789b..000000000 --- a/headless/core/src/main/java/com/tencent/supersonic/headless/core/chat/parser/llm/SqlGenerationFactory.java +++ /dev/null @@ -1,21 +0,0 @@ -package com.tencent.supersonic.headless.core.chat.parser.llm; - - - -import com.tencent.supersonic.headless.core.chat.query.llm.s2sql.LLMReq.SqlGenerationMode; - -import java.util.Map; -import java.util.concurrent.ConcurrentHashMap; - -public class SqlGenerationFactory { - - private static Map sqlGenerationMap = new ConcurrentHashMap<>(); - - public static SqlGeneration get(SqlGenerationMode strategyType) { - return sqlGenerationMap.get(strategyType); - } - - public static void addSqlGenerationForFactory(SqlGenerationMode strategy, SqlGeneration sqlGeneration) { - sqlGenerationMap.put(strategy, sqlGeneration); - } -} \ No newline at end of file diff --git a/headless/core/src/main/java/com/tencent/supersonic/headless/core/chat/parser/llm/TwoPassSCSqlGeneration.java b/headless/core/src/main/java/com/tencent/supersonic/headless/core/chat/parser/llm/TwoPassSCSqlGenStrategy.java similarity index 80% rename from headless/core/src/main/java/com/tencent/supersonic/headless/core/chat/parser/llm/TwoPassSCSqlGeneration.java rename to headless/core/src/main/java/com/tencent/supersonic/headless/core/chat/parser/llm/TwoPassSCSqlGenStrategy.java index 78e3f559c..0b51fd1ad 100644 --- a/headless/core/src/main/java/com/tencent/supersonic/headless/core/chat/parser/llm/TwoPassSCSqlGeneration.java +++ b/headless/core/src/main/java/com/tencent/supersonic/headless/core/chat/parser/llm/TwoPassSCSqlGenStrategy.java @@ -2,7 +2,7 @@ package com.tencent.supersonic.headless.core.chat.parser.llm; import com.tencent.supersonic.common.util.JsonUtil; import com.tencent.supersonic.headless.core.chat.query.llm.s2sql.LLMReq; -import com.tencent.supersonic.headless.core.chat.query.llm.s2sql.LLMReq.SqlGenerationMode; +import com.tencent.supersonic.headless.core.chat.query.llm.s2sql.LLMReq.SqlGenType; import com.tencent.supersonic.headless.core.chat.query.llm.s2sql.LLMResp; import dev.langchain4j.data.message.AiMessage; import dev.langchain4j.model.chat.ChatLanguageModel; @@ -18,20 +18,20 @@ import java.util.Map; import java.util.concurrent.CopyOnWriteArrayList; @Service -public class TwoPassSCSqlGeneration extends BaseSqlGeneration { +public class TwoPassSCSqlGenStrategy extends SqlGenStrategy { @Override - public LLMResp generation(LLMReq llmReq, Long dataSetId) { + public LLMResp generate(LLMReq llmReq) { //1.retriever sqlExamples and generate exampleListPool - keyPipelineLog.info("dataSetId:{},llmReq:{}", dataSetId, llmReq); - List> sqlExamples = sqlExamplarLoader.retrieverSqlExamples(llmReq.getQueryText(), + keyPipelineLog.info("llmReq:{}", llmReq); + List> sqlExamples = exemplarManager.recallExemplars(llmReq.getQueryText(), optimizationConfig.getText2sqlExampleNum()); - List>> exampleListPool = sqlPromptGenerator.getExampleCombos(sqlExamples, + List>> exampleListPool = promptGenerator.getExampleCombos(sqlExamples, optimizationConfig.getText2sqlFewShotsNum(), optimizationConfig.getText2sqlSelfConsistencyNum()); //2.generator linking prompt,and parallel generate response. - List linkingPromptPool = sqlPromptGenerator.generatePromptPool(llmReq, exampleListPool, false); + List linkingPromptPool = promptGenerator.generatePromptPool(llmReq, exampleListPool, false); List linkingResults = new CopyOnWriteArrayList<>(); ChatLanguageModel chatLanguageModel = getChatLanguageModel(llmReq.getLlmConfig()); linkingPromptPool.parallelStream().forEach( @@ -47,7 +47,7 @@ public class TwoPassSCSqlGeneration extends BaseSqlGeneration { List sortedList = OutputFormat.formatList(linkingResults); Pair> linkingMap = OutputFormat.selfConsistencyVote(sortedList); //3.generator sql prompt,and parallel generate response. - List sqlPromptPool = sqlPromptGenerator.generateSqlPromptPool(llmReq, sortedList, exampleListPool); + List sqlPromptPool = promptGenerator.generateSqlPromptPool(llmReq, sortedList, exampleListPool); List sqlTaskPool = new CopyOnWriteArrayList<>(); sqlPromptPool.parallelStream().forEach(sqlPrompt -> { Prompt linkingPrompt = PromptTemplate.from(JsonUtil.toString(sqlPrompt)).apply(new HashMap<>()); @@ -69,6 +69,6 @@ public class TwoPassSCSqlGeneration extends BaseSqlGeneration { @Override public void afterPropertiesSet() { - SqlGenerationFactory.addSqlGenerationForFactory(SqlGenerationMode.TWO_PASS_AUTO_COT_SELF_CONSISTENCY, this); + SqlGenStrategyFactory.addSqlGenerationForFactory(SqlGenType.TWO_PASS_AUTO_COT_SELF_CONSISTENCY, this); } } diff --git a/headless/core/src/main/java/com/tencent/supersonic/headless/core/chat/parser/llm/TwoPassSqlGeneration.java b/headless/core/src/main/java/com/tencent/supersonic/headless/core/chat/parser/llm/TwoPassSqlGenStrategy.java similarity index 77% rename from headless/core/src/main/java/com/tencent/supersonic/headless/core/chat/parser/llm/TwoPassSqlGeneration.java rename to headless/core/src/main/java/com/tencent/supersonic/headless/core/chat/parser/llm/TwoPassSqlGenStrategy.java index 148c6b3c9..2310b67f2 100644 --- a/headless/core/src/main/java/com/tencent/supersonic/headless/core/chat/parser/llm/TwoPassSqlGeneration.java +++ b/headless/core/src/main/java/com/tencent/supersonic/headless/core/chat/parser/llm/TwoPassSqlGenStrategy.java @@ -2,7 +2,7 @@ package com.tencent.supersonic.headless.core.chat.parser.llm; import com.tencent.supersonic.common.util.JsonUtil; import com.tencent.supersonic.headless.core.chat.query.llm.s2sql.LLMReq; -import com.tencent.supersonic.headless.core.chat.query.llm.s2sql.LLMReq.SqlGenerationMode; +import com.tencent.supersonic.headless.core.chat.query.llm.s2sql.LLMReq.SqlGenType; import com.tencent.supersonic.headless.core.chat.query.llm.s2sql.LLMResp; import dev.langchain4j.data.message.AiMessage; import dev.langchain4j.model.chat.ChatLanguageModel; @@ -18,15 +18,15 @@ import java.util.Map; @Service @Slf4j -public class TwoPassSqlGeneration extends BaseSqlGeneration { +public class TwoPassSqlGenStrategy extends SqlGenStrategy { @Override - public LLMResp generation(LLMReq llmReq, Long dataSetId) { - keyPipelineLog.info("dataSetId:{},llmReq:{}", dataSetId, llmReq); - List> sqlExamples = sqlExamplarLoader.retrieverSqlExamples(llmReq.getQueryText(), + public LLMResp generate(LLMReq llmReq) { + keyPipelineLog.info("llmReq:{}", llmReq); + List> sqlExamples = exemplarManager.recallExemplars(llmReq.getQueryText(), optimizationConfig.getText2sqlExampleNum()); - String linkingPromptStr = sqlPromptGenerator.generateLinkingPrompt(llmReq, sqlExamples); + String linkingPromptStr = promptGenerator.generateLinkingPrompt(llmReq, sqlExamples); Prompt prompt = PromptTemplate.from(JsonUtil.toString(linkingPromptStr)).apply(new HashMap<>()); keyPipelineLog.info("step one request prompt:{}", prompt.toSystemMessage()); @@ -34,7 +34,7 @@ public class TwoPassSqlGeneration extends BaseSqlGeneration { Response response = chatLanguageModel.generate(prompt.toSystemMessage()); keyPipelineLog.info("step one model response:{}", response.content().text()); String schemaLinkStr = OutputFormat.getSchemaLink(response.content().text()); - String generateSqlPrompt = sqlPromptGenerator.generateSqlPrompt(llmReq, schemaLinkStr, sqlExamples); + String generateSqlPrompt = promptGenerator.generateSqlPrompt(llmReq, schemaLinkStr, sqlExamples); Prompt sqlPrompt = PromptTemplate.from(JsonUtil.toString(generateSqlPrompt)).apply(new HashMap<>()); keyPipelineLog.info("step two request prompt:{}", sqlPrompt.toSystemMessage()); @@ -53,6 +53,6 @@ public class TwoPassSqlGeneration extends BaseSqlGeneration { @Override public void afterPropertiesSet() { - SqlGenerationFactory.addSqlGenerationForFactory(SqlGenerationMode.TWO_PASS_AUTO_COT, this); + SqlGenStrategyFactory.addSqlGenerationForFactory(SqlGenType.TWO_PASS_AUTO_COT, this); } } diff --git a/headless/core/src/main/java/com/tencent/supersonic/headless/core/chat/query/llm/s2sql/LLMReq.java b/headless/core/src/main/java/com/tencent/supersonic/headless/core/chat/query/llm/s2sql/LLMReq.java index f5b909f5c..f739b6c9a 100644 --- a/headless/core/src/main/java/com/tencent/supersonic/headless/core/chat/query/llm/s2sql/LLMReq.java +++ b/headless/core/src/main/java/com/tencent/supersonic/headless/core/chat/query/llm/s2sql/LLMReq.java @@ -41,6 +41,8 @@ public class LLMReq { private String dataSetName; + private Long dataSetId; + private List fieldNameList; } @@ -51,7 +53,7 @@ public class LLMReq { private String tableName; } - public enum SqlGenerationMode { + public enum SqlGenType { ONE_PASS_AUTO_COT("1_pass_auto_cot"), @@ -64,7 +66,7 @@ public class LLMReq { private String name; - SqlGenerationMode(String name) { + SqlGenType(String name) { this.name = name; } @@ -73,10 +75,10 @@ public class LLMReq { return name; } - public static SqlGenerationMode getMode(String name) { - for (SqlGenerationMode sqlGenerationMode : SqlGenerationMode.values()) { - if (sqlGenerationMode.name.equals(name)) { - return sqlGenerationMode; + public static SqlGenType getMode(String name) { + for (SqlGenType sqlGenType : SqlGenType.values()) { + if (sqlGenType.name.equals(name)) { + return sqlGenType; } } return null; diff --git a/headless/core/src/main/java/com/tencent/supersonic/headless/core/config/OptimizationConfig.java b/headless/core/src/main/java/com/tencent/supersonic/headless/core/config/OptimizationConfig.java index 9ce610b16..6f6e537e4 100644 --- a/headless/core/src/main/java/com/tencent/supersonic/headless/core/config/OptimizationConfig.java +++ b/headless/core/src/main/java/com/tencent/supersonic/headless/core/config/OptimizationConfig.java @@ -1,7 +1,7 @@ package com.tencent.supersonic.headless.core.config; import com.tencent.supersonic.common.service.SysParameterService; -import com.tencent.supersonic.headless.core.chat.query.llm.s2sql.LLMReq.SqlGenerationMode; +import com.tencent.supersonic.headless.core.chat.query.llm.s2sql.LLMReq; import lombok.Data; import lombok.extern.slf4j.Slf4j; import org.apache.commons.lang3.StringUtils; @@ -65,7 +65,7 @@ public class OptimizationConfig { private boolean useLinkingValueSwitch; @Value("${s2SQL.generation:TWO_PASS_AUTO_COT}") - private SqlGenerationMode sqlGenerationMode; + private LLMReq.SqlGenType sqlGenType; @Value("${s2SQL.use.switch:true}") private boolean useS2SqlSwitch; @@ -157,8 +157,8 @@ public class OptimizationConfig { return convertValue("s2SQL.linking.value.switch", Boolean.class, useLinkingValueSwitch); } - public SqlGenerationMode getSqlGenerationMode() { - return convertValue("s2SQL.generation", SqlGenerationMode.class, sqlGenerationMode); + public LLMReq.SqlGenType getSqlGenType() { + return convertValue("s2SQL.generation", LLMReq.SqlGenType.class, sqlGenType); } public Integer getParseShowCount() { @@ -177,8 +177,8 @@ public class OptimizationConfig { return targetType.cast(Integer.parseInt(value)); } else if (targetType == Boolean.class) { return targetType.cast(Boolean.parseBoolean(value)); - } else if (targetType == SqlGenerationMode.class) { - return targetType.cast(SqlGenerationMode.valueOf(value)); + } else if (targetType == LLMReq.SqlGenType.class) { + return targetType.cast(LLMReq.SqlGenType.valueOf(value)); } } catch (Exception e) { log.error("convertValue", e); diff --git a/headless/core/src/main/java/com/tencent/supersonic/headless/core/utils/ComponentFactory.java b/headless/core/src/main/java/com/tencent/supersonic/headless/core/utils/ComponentFactory.java index 80b76c828..bbb4562a2 100644 --- a/headless/core/src/main/java/com/tencent/supersonic/headless/core/utils/ComponentFactory.java +++ b/headless/core/src/main/java/com/tencent/supersonic/headless/core/utils/ComponentFactory.java @@ -2,8 +2,8 @@ package com.tencent.supersonic.headless.core.utils; import com.tencent.supersonic.common.util.ContextUtils; import com.tencent.supersonic.headless.core.cache.QueryCache; -import com.tencent.supersonic.headless.core.chat.parser.JavaLLMProxy; -import com.tencent.supersonic.headless.core.chat.parser.LLMProxy; +import com.tencent.supersonic.headless.core.chat.parser.llm.JavaLLMProxy; +import com.tencent.supersonic.headless.core.chat.parser.llm.LLMProxy; import com.tencent.supersonic.headless.core.chat.parser.llm.DataSetResolver; import com.tencent.supersonic.headless.core.executor.QueryExecutor; import com.tencent.supersonic.headless.core.parser.SqlParser; diff --git a/headless/server/src/main/java/com/tencent/supersonic/headless/server/listener/ApplicationStartedListener.java b/headless/server/src/main/java/com/tencent/supersonic/headless/server/listener/ApplicationStartedListener.java index 7edc9cd24..d416cbf9a 100644 --- a/headless/server/src/main/java/com/tencent/supersonic/headless/server/listener/ApplicationStartedListener.java +++ b/headless/server/src/main/java/com/tencent/supersonic/headless/server/listener/ApplicationStartedListener.java @@ -2,7 +2,7 @@ package com.tencent.supersonic.headless.server.listener; import com.tencent.supersonic.headless.core.chat.knowledge.DictWord; -import com.tencent.supersonic.headless.core.chat.knowledge.KnowledgeService; +import com.tencent.supersonic.headless.core.chat.knowledge.KnowledgeBaseService; import com.tencent.supersonic.headless.server.service.impl.WordService; import lombok.extern.slf4j.Slf4j; import org.apache.commons.collections.CollectionUtils; @@ -21,7 +21,7 @@ import java.util.concurrent.CompletableFuture; public class ApplicationStartedListener implements CommandLineRunner { @Autowired - private KnowledgeService knowledgeService; + private KnowledgeBaseService knowledgeBaseService; @Autowired private WordService wordService; @@ -37,7 +37,7 @@ public class ApplicationStartedListener implements CommandLineRunner { List dictWords = wordService.getAllDictWords(); wordService.setPreDictWords(dictWords); - knowledgeService.reloadAllData(dictWords); + knowledgeBaseService.reloadAllData(dictWords); log.debug("ApplicationStartedInit end"); isOk = true; @@ -72,7 +72,7 @@ public class ApplicationStartedListener implements CommandLineRunner { } log.info("dictWords has changed"); wordService.setPreDictWords(dictWords); - knowledgeService.updateOnlineKnowledge(wordService.getAllDictWords()); + knowledgeBaseService.updateOnlineKnowledge(wordService.getAllDictWords()); } catch (Exception e) { log.error("reloadKnowledge error", e); } diff --git a/headless/server/src/main/java/com/tencent/supersonic/headless/server/listener/FullMetaEmbeddingListener.java b/headless/server/src/main/java/com/tencent/supersonic/headless/server/listener/FullMetaEmbeddingListener.java index 36ebe6d78..1fa30b0b7 100644 --- a/headless/server/src/main/java/com/tencent/supersonic/headless/server/listener/FullMetaEmbeddingListener.java +++ b/headless/server/src/main/java/com/tencent/supersonic/headless/server/listener/FullMetaEmbeddingListener.java @@ -1,6 +1,6 @@ package com.tencent.supersonic.headless.server.listener; -import com.tencent.supersonic.headless.core.chat.parser.JavaLLMProxy; +import com.tencent.supersonic.headless.core.chat.parser.llm.JavaLLMProxy; import com.tencent.supersonic.headless.core.utils.ComponentFactory; import com.tencent.supersonic.headless.server.schedule.EmbeddingTask; import lombok.extern.slf4j.Slf4j; diff --git a/headless/server/src/main/java/com/tencent/supersonic/headless/server/listener/SqlEmbeddingListener.java b/headless/server/src/main/java/com/tencent/supersonic/headless/server/listener/SqlEmbeddingListener.java index ca9249e42..afc94f2bc 100644 --- a/headless/server/src/main/java/com/tencent/supersonic/headless/server/listener/SqlEmbeddingListener.java +++ b/headless/server/src/main/java/com/tencent/supersonic/headless/server/listener/SqlEmbeddingListener.java @@ -1,9 +1,9 @@ package com.tencent.supersonic.headless.server.listener; import com.tencent.supersonic.common.config.EmbeddingConfig; -import com.tencent.supersonic.headless.core.chat.parser.JavaLLMProxy; -import com.tencent.supersonic.headless.core.chat.parser.llm.SqlExamplarLoader; -import com.tencent.supersonic.headless.core.chat.parser.llm.SqlExample; +import com.tencent.supersonic.headless.core.chat.parser.llm.JavaLLMProxy; +import com.tencent.supersonic.headless.core.chat.parser.llm.ExemplarManager; +import com.tencent.supersonic.headless.core.chat.parser.llm.Exemplar; import com.tencent.supersonic.headless.core.utils.ComponentFactory; import lombok.extern.slf4j.Slf4j; import org.springframework.beans.factory.annotation.Autowired; @@ -19,7 +19,7 @@ import java.util.List; public class SqlEmbeddingListener implements CommandLineRunner { @Autowired - private SqlExamplarLoader sqlExamplarLoader; + private ExemplarManager exemplarManager; @Autowired private EmbeddingConfig embeddingConfig; @@ -31,9 +31,9 @@ public class SqlEmbeddingListener implements CommandLineRunner { public void initSqlExamples() { try { if (ComponentFactory.getLLMProxy() instanceof JavaLLMProxy) { - List sqlExamples = sqlExamplarLoader.getSqlExamples(); + List exemplars = exemplarManager.getExemplars(); String collectionName = embeddingConfig.getText2sqlCollectionName(); - sqlExamplarLoader.addEmbeddingStore(sqlExamples, collectionName); + exemplarManager.addExemplars(exemplars, collectionName); } } catch (Exception e) { log.error("initSqlExamples error", e); diff --git a/headless/server/src/main/java/com/tencent/supersonic/headless/server/service/impl/ChatQueryServiceImpl.java b/headless/server/src/main/java/com/tencent/supersonic/headless/server/service/impl/ChatQueryServiceImpl.java index c57edf85b..11082dd97 100644 --- a/headless/server/src/main/java/com/tencent/supersonic/headless/server/service/impl/ChatQueryServiceImpl.java +++ b/headless/server/src/main/java/com/tencent/supersonic/headless/server/service/impl/ChatQueryServiceImpl.java @@ -42,7 +42,7 @@ import com.tencent.supersonic.headless.api.pojo.response.SemanticQueryResp; import com.tencent.supersonic.headless.core.chat.corrector.GrammarCorrector; import com.tencent.supersonic.headless.core.chat.corrector.SchemaCorrector; import com.tencent.supersonic.headless.core.chat.knowledge.HanlpMapResult; -import com.tencent.supersonic.headless.core.chat.knowledge.KnowledgeService; +import com.tencent.supersonic.headless.core.chat.knowledge.KnowledgeBaseService; import com.tencent.supersonic.headless.core.chat.knowledge.SearchService; import com.tencent.supersonic.headless.core.chat.knowledge.helper.HanlpHelper; import com.tencent.supersonic.headless.core.chat.knowledge.helper.NatureHelper; @@ -96,7 +96,7 @@ public class ChatQueryServiceImpl implements ChatQueryService { @Autowired private ChatContextService chatContextService; @Autowired - private KnowledgeService knowledgeService; + private KnowledgeBaseService knowledgeBaseService; @Autowired private QueryService queryService; @Autowired @@ -557,7 +557,7 @@ public class ChatQueryServiceImpl implements ChatQueryService { Map> modelIdToDataSetIds = new HashMap<>(); modelIdToDataSetIds.put(dimensionValueReq.getModelId(), new ArrayList<>(dataSetIds)); //search from prefixSearch - List hanlpMapResultList = knowledgeService.prefixSearch(dimensionValueReq.getValue(), + List hanlpMapResultList = knowledgeBaseService.prefixSearch(dimensionValueReq.getValue(), 2000, modelIdToDataSetIds, dataSetIds); HanlpHelper.transLetterOriginal(hanlpMapResultList); return hanlpMapResultList.stream() diff --git a/headless/server/src/main/java/com/tencent/supersonic/headless/server/service/impl/DictTaskServiceImpl.java b/headless/server/src/main/java/com/tencent/supersonic/headless/server/service/impl/DictTaskServiceImpl.java index 17a1ec5ec..1151dabbc 100644 --- a/headless/server/src/main/java/com/tencent/supersonic/headless/server/service/impl/DictTaskServiceImpl.java +++ b/headless/server/src/main/java/com/tencent/supersonic/headless/server/service/impl/DictTaskServiceImpl.java @@ -9,7 +9,7 @@ import com.tencent.supersonic.headless.api.pojo.request.DictSingleTaskReq; import com.tencent.supersonic.headless.api.pojo.response.DictItemResp; import com.tencent.supersonic.headless.api.pojo.response.DictTaskResp; import com.tencent.supersonic.headless.core.file.FileHandler; -import com.tencent.supersonic.headless.core.chat.knowledge.KnowledgeService; +import com.tencent.supersonic.headless.core.chat.knowledge.KnowledgeBaseService; import com.tencent.supersonic.headless.core.chat.knowledge.helper.HanlpHelper; import com.tencent.supersonic.headless.server.persistence.dataobject.DictTaskDO; import com.tencent.supersonic.headless.server.persistence.repository.DictRepository; @@ -46,7 +46,7 @@ public class DictTaskServiceImpl implements DictTaskService { DictUtils dictConverter, DictUtils dictUtils, FileHandler fileHandler, - KnowledgeService knowledgeService) { + KnowledgeBaseService knowledgeBaseService) { this.dictRepository = dictRepository; this.dictConverter = dictConverter; this.dictUtils = dictUtils; diff --git a/headless/server/src/main/java/com/tencent/supersonic/headless/server/service/impl/SearchServiceImpl.java b/headless/server/src/main/java/com/tencent/supersonic/headless/server/service/impl/SearchServiceImpl.java index 1d88d7302..781dc5783 100644 --- a/headless/server/src/main/java/com/tencent/supersonic/headless/server/service/impl/SearchServiceImpl.java +++ b/headless/server/src/main/java/com/tencent/supersonic/headless/server/service/impl/SearchServiceImpl.java @@ -18,7 +18,7 @@ import com.tencent.supersonic.headless.core.pojo.QueryContext; import com.tencent.supersonic.headless.core.chat.knowledge.DataSetInfoStat; import com.tencent.supersonic.headless.core.chat.knowledge.DictWord; import com.tencent.supersonic.headless.core.chat.knowledge.HanlpMapResult; -import com.tencent.supersonic.headless.core.chat.knowledge.KnowledgeService; +import com.tencent.supersonic.headless.core.chat.knowledge.KnowledgeBaseService; import com.tencent.supersonic.headless.core.chat.knowledge.helper.HanlpHelper; import com.tencent.supersonic.headless.core.chat.knowledge.helper.NatureHelper; import com.tencent.supersonic.headless.server.service.ChatContextService; @@ -58,7 +58,7 @@ public class SearchServiceImpl implements SearchService { @Autowired private ChatContextService chatContextService; @Autowired - private KnowledgeService knowledgeService; + private KnowledgeBaseService knowledgeBaseService; @Autowired private DataSetService dataSetService; @@ -73,7 +73,7 @@ public class SearchServiceImpl implements SearchService { Map> modelIdToDataSetIds = dataSetService.getModelIdToDataSetIds(new ArrayList<>(dataSetIdToName.keySet()), User.getFakeUser()); // 2.detect by segment - List originals = knowledgeService.getTerms(queryText, modelIdToDataSetIds); + List originals = knowledgeBaseService.getTerms(queryText, modelIdToDataSetIds); log.info("hanlp parse result: {}", originals); Set dataSetIds = queryReq.getDataSetIds(); diff --git a/launchers/standalone/src/main/resources/s2ql_examplar.json b/launchers/standalone/src/main/resources/s2ql_exemplar.json similarity index 100% rename from launchers/standalone/src/main/resources/s2ql_examplar.json rename to launchers/standalone/src/main/resources/s2ql_exemplar.json diff --git a/launchers/standalone/src/test/resources/s2ql_examplar.json b/launchers/standalone/src/test/resources/s2ql_exemplar.json similarity index 100% rename from launchers/standalone/src/test/resources/s2ql_examplar.json rename to launchers/standalone/src/test/resources/s2ql_exemplar.json