diff --git a/chat/core/src/main/java/com/tencent/supersonic/chat/core/config/OptimizationConfig.java b/chat/core/src/main/java/com/tencent/supersonic/chat/core/config/OptimizationConfig.java index f0782f85c..0059a8eb6 100644 --- a/chat/core/src/main/java/com/tencent/supersonic/chat/core/config/OptimizationConfig.java +++ b/chat/core/src/main/java/com/tencent/supersonic/chat/core/config/OptimizationConfig.java @@ -73,9 +73,6 @@ public class OptimizationConfig { @Value("${text2sql.self.consistency.num:5}") private int text2sqlSelfConsistencyNum; - @Value("${text2sql.collection.name:text2dsl_agent_collection}") - private String text2sqlCollectionName; - @Value("${parse.show.count:3}") private Integer parseShowCount; diff --git a/chat/core/src/main/java/com/tencent/supersonic/chat/core/corrector/WhereCorrector.java b/chat/core/src/main/java/com/tencent/supersonic/chat/core/corrector/WhereCorrector.java index b1d68fb9f..07ef29067 100644 --- a/chat/core/src/main/java/com/tencent/supersonic/chat/core/corrector/WhereCorrector.java +++ b/chat/core/src/main/java/com/tencent/supersonic/chat/core/corrector/WhereCorrector.java @@ -5,7 +5,7 @@ import com.tencent.supersonic.headless.api.pojo.SchemaValueMap; import com.tencent.supersonic.chat.api.pojo.SemanticParseInfo; import com.tencent.supersonic.chat.api.pojo.SemanticSchema; import com.tencent.supersonic.chat.api.pojo.request.QueryFilters; -import com.tencent.supersonic.chat.core.parser.sql.llm.S2SqlDateHelper; +import com.tencent.supersonic.chat.core.utils.S2SqlDateHelper; import com.tencent.supersonic.chat.core.pojo.QueryContext; import com.tencent.supersonic.common.pojo.Constants; import com.tencent.supersonic.common.pojo.enums.TimeDimensionEnum; diff --git a/chat/core/src/main/java/com/tencent/supersonic/chat/core/parser/sql/llm/LLMRequestService.java b/chat/core/src/main/java/com/tencent/supersonic/chat/core/parser/sql/llm/LLMRequestService.java index 3dce92fb3..cde924968 100644 --- a/chat/core/src/main/java/com/tencent/supersonic/chat/core/parser/sql/llm/LLMRequestService.java +++ b/chat/core/src/main/java/com/tencent/supersonic/chat/core/parser/sql/llm/LLMRequestService.java @@ -1,6 +1,7 @@ package com.tencent.supersonic.chat.core.parser.sql.llm; import com.google.common.collect.Lists; +import com.tencent.supersonic.chat.core.utils.S2SqlDateHelper; import com.tencent.supersonic.headless.api.pojo.SchemaElement; import com.tencent.supersonic.chat.api.pojo.SchemaElementMatch; import com.tencent.supersonic.headless.api.pojo.SchemaElementType; diff --git a/chat/core/src/main/java/com/tencent/supersonic/chat/core/parser/sql/llm/OnePassSCSqlGeneration.java b/chat/core/src/main/java/com/tencent/supersonic/chat/core/parser/sql/llm/OnePassSCSqlGeneration.java index 13eeee5fa..1342bf0cd 100644 --- a/chat/core/src/main/java/com/tencent/supersonic/chat/core/parser/sql/llm/OnePassSCSqlGeneration.java +++ b/chat/core/src/main/java/com/tencent/supersonic/chat/core/parser/sql/llm/OnePassSCSqlGeneration.java @@ -48,7 +48,7 @@ public class OnePassSCSqlGeneration implements SqlGeneration, InitializingBean { keyPipelineLog.info("viewId:{},llmReq:{}", viewId, llmReq); List> sqlExamples = sqlExamplarLoader.retrieverSqlExamples(llmReq.getQueryText(), - optimizationConfig.getText2sqlCollectionName(), optimizationConfig.getText2sqlExampleNum()); + optimizationConfig.getText2sqlExampleNum()); List>> exampleListPool = sqlPromptGenerator.getExampleCombos(sqlExamples, optimizationConfig.getText2sqlFewShotsNum(), optimizationConfig.getText2sqlSelfConsistencyNum()); diff --git a/chat/core/src/main/java/com/tencent/supersonic/chat/core/parser/sql/llm/OnePassSqlGeneration.java b/chat/core/src/main/java/com/tencent/supersonic/chat/core/parser/sql/llm/OnePassSqlGeneration.java index a3fe2748f..a98486ae1 100644 --- a/chat/core/src/main/java/com/tencent/supersonic/chat/core/parser/sql/llm/OnePassSqlGeneration.java +++ b/chat/core/src/main/java/com/tencent/supersonic/chat/core/parser/sql/llm/OnePassSqlGeneration.java @@ -45,7 +45,7 @@ public class OnePassSqlGeneration implements SqlGeneration, InitializingBean { //1.retriever sqlExamples keyPipelineLog.info("viewId:{},llmReq:{}", viewId, llmReq); List> sqlExamples = sqlExampleLoader.retrieverSqlExamples(llmReq.getQueryText(), - optimizationConfig.getText2sqlCollectionName(), optimizationConfig.getText2sqlExampleNum()); + optimizationConfig.getText2sqlExampleNum()); //2.generator linking and sql prompt by sqlExamples,and generate response. String promptStr = sqlPromptGenerator.generatorLinkingAndSqlPrompt(llmReq, sqlExamples); diff --git a/chat/core/src/main/java/com/tencent/supersonic/chat/core/parser/sql/llm/SqlExamplarLoader.java b/chat/core/src/main/java/com/tencent/supersonic/chat/core/parser/sql/llm/SqlExamplarLoader.java index 7aa5c869a..7a586e9d3 100644 --- a/chat/core/src/main/java/com/tencent/supersonic/chat/core/parser/sql/llm/SqlExamplarLoader.java +++ b/chat/core/src/main/java/com/tencent/supersonic/chat/core/parser/sql/llm/SqlExamplarLoader.java @@ -2,6 +2,7 @@ package com.tencent.supersonic.chat.core.parser.sql.llm; import com.fasterxml.jackson.core.type.TypeReference; +import com.tencent.supersonic.common.config.EmbeddingConfig; import com.tencent.supersonic.common.util.ComponentFactory; import com.tencent.supersonic.common.util.JsonUtil; import com.tencent.supersonic.common.util.embedding.EmbeddingQuery; @@ -19,6 +20,7 @@ import java.util.Objects; import java.util.stream.Collectors; import lombok.extern.slf4j.Slf4j; import org.apache.commons.collections4.CollectionUtils; +import org.springframework.beans.factory.annotation.Autowired; import org.springframework.core.io.ClassPathResource; import org.springframework.stereotype.Component; @@ -32,6 +34,9 @@ public class SqlExamplarLoader { private TypeReference> valueTypeRef = new TypeReference>() { }; + @Autowired + private EmbeddingConfig embeddingConfig; + public List getSqlExamples() throws IOException { ClassPathResource resource = new ClassPathResource(EXAMPLE_JSON_FILE); InputStream inputStream = resource.getInputStream(); @@ -53,8 +58,8 @@ public class SqlExamplarLoader { s2EmbeddingStore.addQuery(collectionName, queries); } - public List> retrieverSqlExamples(String queryText, String collectionName, int maxResults) { - + public List> retrieverSqlExamples(String queryText, int maxResults) { + String collectionName = embeddingConfig.getText2sqlCollectionName(); RetrieveQuery retrieveQuery = RetrieveQuery.builder().queryTextsList(Collections.singletonList(queryText)) .queryEmbeddings(null).build(); diff --git a/chat/core/src/main/java/com/tencent/supersonic/chat/core/parser/sql/llm/TwoPassSCSqlGeneration.java b/chat/core/src/main/java/com/tencent/supersonic/chat/core/parser/sql/llm/TwoPassSCSqlGeneration.java index 3ba1452b5..a90af3086 100644 --- a/chat/core/src/main/java/com/tencent/supersonic/chat/core/parser/sql/llm/TwoPassSCSqlGeneration.java +++ b/chat/core/src/main/java/com/tencent/supersonic/chat/core/parser/sql/llm/TwoPassSCSqlGeneration.java @@ -44,7 +44,7 @@ public class TwoPassSCSqlGeneration implements SqlGeneration, InitializingBean { //1.retriever sqlExamples and generate exampleListPool keyPipelineLog.info("viewId:{},llmReq:{}", viewId, llmReq); List> sqlExamples = sqlExamplarLoader.retrieverSqlExamples(llmReq.getQueryText(), - optimizationConfig.getText2sqlCollectionName(), optimizationConfig.getText2sqlExampleNum()); + optimizationConfig.getText2sqlExampleNum()); List>> exampleListPool = sqlPromptGenerator.getExampleCombos(sqlExamples, optimizationConfig.getText2sqlFewShotsNum(), optimizationConfig.getText2sqlSelfConsistencyNum()); diff --git a/chat/core/src/main/java/com/tencent/supersonic/chat/core/parser/sql/llm/TwoPassSqlGeneration.java b/chat/core/src/main/java/com/tencent/supersonic/chat/core/parser/sql/llm/TwoPassSqlGeneration.java index aaaf3eb68..61fe24b60 100644 --- a/chat/core/src/main/java/com/tencent/supersonic/chat/core/parser/sql/llm/TwoPassSqlGeneration.java +++ b/chat/core/src/main/java/com/tencent/supersonic/chat/core/parser/sql/llm/TwoPassSqlGeneration.java @@ -43,7 +43,7 @@ public class TwoPassSqlGeneration implements SqlGeneration, InitializingBean { public LLMResp generation(LLMReq llmReq, Long viewId) { keyPipelineLog.info("viewId:{},llmReq:{}", viewId, llmReq); List> sqlExamples = sqlExamplarLoader.retrieverSqlExamples(llmReq.getQueryText(), - optimizationConfig.getText2sqlCollectionName(), optimizationConfig.getText2sqlExampleNum()); + optimizationConfig.getText2sqlExampleNum()); String linkingPromptStr = sqlPromptGenerator.generateLinkingPrompt(llmReq, sqlExamples); diff --git a/chat/core/src/main/java/com/tencent/supersonic/chat/core/parser/sql/llm/S2SqlDateHelper.java b/chat/core/src/main/java/com/tencent/supersonic/chat/core/utils/S2SqlDateHelper.java similarity index 98% rename from chat/core/src/main/java/com/tencent/supersonic/chat/core/parser/sql/llm/S2SqlDateHelper.java rename to chat/core/src/main/java/com/tencent/supersonic/chat/core/utils/S2SqlDateHelper.java index 5421d47c2..741f5f89c 100644 --- a/chat/core/src/main/java/com/tencent/supersonic/chat/core/parser/sql/llm/S2SqlDateHelper.java +++ b/chat/core/src/main/java/com/tencent/supersonic/chat/core/utils/S2SqlDateHelper.java @@ -1,4 +1,4 @@ -package com.tencent.supersonic.chat.core.parser.sql.llm; +package com.tencent.supersonic.chat.core.utils; import com.tencent.supersonic.chat.api.pojo.ViewSchema; import com.tencent.supersonic.chat.core.pojo.QueryContext; diff --git a/chat/core/src/test/java/com/tencent/supersonic/chat/core/parser/sql/llm/S2SqlDateHelperTest.java b/chat/core/src/test/java/com/tencent/supersonic/chat/core/utils/S2SqlDateHelperTest.java similarity index 98% rename from chat/core/src/test/java/com/tencent/supersonic/chat/core/parser/sql/llm/S2SqlDateHelperTest.java rename to chat/core/src/test/java/com/tencent/supersonic/chat/core/utils/S2SqlDateHelperTest.java index acca8a813..2066c0bf9 100644 --- a/chat/core/src/test/java/com/tencent/supersonic/chat/core/parser/sql/llm/S2SqlDateHelperTest.java +++ b/chat/core/src/test/java/com/tencent/supersonic/chat/core/utils/S2SqlDateHelperTest.java @@ -1,4 +1,5 @@ -package com.tencent.supersonic.chat.core.parser.sql.llm; +package com.tencent.supersonic.chat.core.utils; + import com.tencent.supersonic.chat.api.pojo.SemanticSchema; import com.tencent.supersonic.chat.api.pojo.ViewSchema; diff --git a/common/src/main/java/com/tencent/supersonic/common/config/EmbeddingConfig.java b/common/src/main/java/com/tencent/supersonic/common/config/EmbeddingConfig.java index b28d2931d..c8b9d6f9b 100644 --- a/common/src/main/java/com/tencent/supersonic/common/config/EmbeddingConfig.java +++ b/common/src/main/java/com/tencent/supersonic/common/config/EmbeddingConfig.java @@ -32,6 +32,9 @@ public class EmbeddingConfig { @Value("${embedding.metric.analyzeQuery.collection:solved_query_collection}") private String metricAnalyzeQueryCollection; + @Value("${text2sql.collection.name:text2dsl_agent_collection}") + private String text2sqlCollectionName; + @Value("${embedding.metric.analyzeQuery.nResult:5}") private int metricAnalyzeQueryResultNum; diff --git a/common/src/main/java/com/tencent/supersonic/common/util/embedding/InMemoryS2EmbeddingStore.java b/common/src/main/java/com/tencent/supersonic/common/util/embedding/InMemoryS2EmbeddingStore.java index d254e0637..6c063e19d 100644 --- a/common/src/main/java/com/tencent/supersonic/common/util/embedding/InMemoryS2EmbeddingStore.java +++ b/common/src/main/java/com/tencent/supersonic/common/util/embedding/InMemoryS2EmbeddingStore.java @@ -48,7 +48,9 @@ public class InMemoryS2EmbeddingStore implements S2EmbeddingStore { InMemoryEmbeddingStore embeddingStore = null; Path filePath = getPersistentPath(collectionName); try { - if (Files.exists(filePath)) { + EmbeddingConfig embeddingConfig = ContextUtils.getBean(EmbeddingConfig.class); + if (Files.exists(filePath) && !collectionName.equals(embeddingConfig.getMetaCollectionName()) + && !collectionName.equals(embeddingConfig.getText2sqlCollectionName())) { embeddingStore = InMemoryEmbeddingStore.fromFile(filePath); embeddingStore.entries = new CopyOnWriteArraySet<>(embeddingStore.entries); log.info("embeddingStore reload from file:{}", filePath); diff --git a/launchers/standalone/src/main/java/com/tencent/supersonic/EmbeddingInitListener.java b/launchers/standalone/src/main/java/com/tencent/supersonic/EmbeddingInitListener.java index 46cb9b42e..efdd52353 100644 --- a/launchers/standalone/src/main/java/com/tencent/supersonic/EmbeddingInitListener.java +++ b/launchers/standalone/src/main/java/com/tencent/supersonic/EmbeddingInitListener.java @@ -1,10 +1,10 @@ package com.tencent.supersonic; -import com.tencent.supersonic.chat.core.config.OptimizationConfig; import com.tencent.supersonic.chat.core.parser.JavaLLMProxy; import com.tencent.supersonic.chat.core.parser.sql.llm.SqlExamplarLoader; import com.tencent.supersonic.chat.core.parser.sql.llm.SqlExample; import com.tencent.supersonic.chat.core.utils.ComponentFactory; +import com.tencent.supersonic.common.config.EmbeddingConfig; import java.util.List; import lombok.extern.slf4j.Slf4j; import org.springframework.beans.factory.annotation.Autowired; @@ -20,7 +20,7 @@ public class EmbeddingInitListener implements CommandLineRunner { @Autowired private SqlExamplarLoader sqlExamplarLoader; @Autowired - private OptimizationConfig optimizationConfig; + private EmbeddingConfig embeddingConfig; @Override public void run(String... args) { @@ -31,7 +31,7 @@ public class EmbeddingInitListener implements CommandLineRunner { try { if (ComponentFactory.getLLMProxy() instanceof JavaLLMProxy) { List sqlExamples = sqlExamplarLoader.getSqlExamples(); - String collectionName = optimizationConfig.getText2sqlCollectionName(); + String collectionName = embeddingConfig.getText2sqlCollectionName(); sqlExamplarLoader.addEmbeddingStore(sqlExamples, collectionName); } } catch (Exception e) {