(improvement)(Headless) Embedding data related to metadata is not restored from persistent files. (#748)

This commit is contained in:
lexluo09
2024-02-23 18:22:28 +08:00
committed by GitHub
parent 01bc4dcacf
commit e610dd8246
13 changed files with 25 additions and 16 deletions

View File

@@ -73,9 +73,6 @@ public class OptimizationConfig {
@Value("${text2sql.self.consistency.num:5}")
private int text2sqlSelfConsistencyNum;
@Value("${text2sql.collection.name:text2dsl_agent_collection}")
private String text2sqlCollectionName;
@Value("${parse.show.count:3}")
private Integer parseShowCount;

View File

@@ -5,7 +5,7 @@ import com.tencent.supersonic.headless.api.pojo.SchemaValueMap;
import com.tencent.supersonic.chat.api.pojo.SemanticParseInfo;
import com.tencent.supersonic.chat.api.pojo.SemanticSchema;
import com.tencent.supersonic.chat.api.pojo.request.QueryFilters;
import com.tencent.supersonic.chat.core.parser.sql.llm.S2SqlDateHelper;
import com.tencent.supersonic.chat.core.utils.S2SqlDateHelper;
import com.tencent.supersonic.chat.core.pojo.QueryContext;
import com.tencent.supersonic.common.pojo.Constants;
import com.tencent.supersonic.common.pojo.enums.TimeDimensionEnum;

View File

@@ -1,6 +1,7 @@
package com.tencent.supersonic.chat.core.parser.sql.llm;
import com.google.common.collect.Lists;
import com.tencent.supersonic.chat.core.utils.S2SqlDateHelper;
import com.tencent.supersonic.headless.api.pojo.SchemaElement;
import com.tencent.supersonic.chat.api.pojo.SchemaElementMatch;
import com.tencent.supersonic.headless.api.pojo.SchemaElementType;

View File

@@ -48,7 +48,7 @@ public class OnePassSCSqlGeneration implements SqlGeneration, InitializingBean {
keyPipelineLog.info("viewId:{},llmReq:{}", viewId, llmReq);
List<Map<String, String>> sqlExamples = sqlExamplarLoader.retrieverSqlExamples(llmReq.getQueryText(),
optimizationConfig.getText2sqlCollectionName(), optimizationConfig.getText2sqlExampleNum());
optimizationConfig.getText2sqlExampleNum());
List<List<Map<String, String>>> exampleListPool = sqlPromptGenerator.getExampleCombos(sqlExamples,
optimizationConfig.getText2sqlFewShotsNum(), optimizationConfig.getText2sqlSelfConsistencyNum());

View File

@@ -45,7 +45,7 @@ public class OnePassSqlGeneration implements SqlGeneration, InitializingBean {
//1.retriever sqlExamples
keyPipelineLog.info("viewId:{},llmReq:{}", viewId, llmReq);
List<Map<String, String>> sqlExamples = sqlExampleLoader.retrieverSqlExamples(llmReq.getQueryText(),
optimizationConfig.getText2sqlCollectionName(), optimizationConfig.getText2sqlExampleNum());
optimizationConfig.getText2sqlExampleNum());
//2.generator linking and sql prompt by sqlExamples,and generate response.
String promptStr = sqlPromptGenerator.generatorLinkingAndSqlPrompt(llmReq, sqlExamples);

View File

@@ -2,6 +2,7 @@ package com.tencent.supersonic.chat.core.parser.sql.llm;
import com.fasterxml.jackson.core.type.TypeReference;
import com.tencent.supersonic.common.config.EmbeddingConfig;
import com.tencent.supersonic.common.util.ComponentFactory;
import com.tencent.supersonic.common.util.JsonUtil;
import com.tencent.supersonic.common.util.embedding.EmbeddingQuery;
@@ -19,6 +20,7 @@ import java.util.Objects;
import java.util.stream.Collectors;
import lombok.extern.slf4j.Slf4j;
import org.apache.commons.collections4.CollectionUtils;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.core.io.ClassPathResource;
import org.springframework.stereotype.Component;
@@ -32,6 +34,9 @@ public class SqlExamplarLoader {
private TypeReference<List<SqlExample>> valueTypeRef = new TypeReference<List<SqlExample>>() {
};
@Autowired
private EmbeddingConfig embeddingConfig;
public List<SqlExample> getSqlExamples() throws IOException {
ClassPathResource resource = new ClassPathResource(EXAMPLE_JSON_FILE);
InputStream inputStream = resource.getInputStream();
@@ -53,8 +58,8 @@ public class SqlExamplarLoader {
s2EmbeddingStore.addQuery(collectionName, queries);
}
public List<Map<String, String>> retrieverSqlExamples(String queryText, String collectionName, int maxResults) {
public List<Map<String, String>> retrieverSqlExamples(String queryText, int maxResults) {
String collectionName = embeddingConfig.getText2sqlCollectionName();
RetrieveQuery retrieveQuery = RetrieveQuery.builder().queryTextsList(Collections.singletonList(queryText))
.queryEmbeddings(null).build();

View File

@@ -44,7 +44,7 @@ public class TwoPassSCSqlGeneration implements SqlGeneration, InitializingBean {
//1.retriever sqlExamples and generate exampleListPool
keyPipelineLog.info("viewId:{},llmReq:{}", viewId, llmReq);
List<Map<String, String>> sqlExamples = sqlExamplarLoader.retrieverSqlExamples(llmReq.getQueryText(),
optimizationConfig.getText2sqlCollectionName(), optimizationConfig.getText2sqlExampleNum());
optimizationConfig.getText2sqlExampleNum());
List<List<Map<String, String>>> exampleListPool = sqlPromptGenerator.getExampleCombos(sqlExamples,
optimizationConfig.getText2sqlFewShotsNum(), optimizationConfig.getText2sqlSelfConsistencyNum());

View File

@@ -43,7 +43,7 @@ public class TwoPassSqlGeneration implements SqlGeneration, InitializingBean {
public LLMResp generation(LLMReq llmReq, Long viewId) {
keyPipelineLog.info("viewId:{},llmReq:{}", viewId, llmReq);
List<Map<String, String>> sqlExamples = sqlExamplarLoader.retrieverSqlExamples(llmReq.getQueryText(),
optimizationConfig.getText2sqlCollectionName(), optimizationConfig.getText2sqlExampleNum());
optimizationConfig.getText2sqlExampleNum());
String linkingPromptStr = sqlPromptGenerator.generateLinkingPrompt(llmReq, sqlExamples);

View File

@@ -1,4 +1,4 @@
package com.tencent.supersonic.chat.core.parser.sql.llm;
package com.tencent.supersonic.chat.core.utils;
import com.tencent.supersonic.chat.api.pojo.ViewSchema;
import com.tencent.supersonic.chat.core.pojo.QueryContext;

View File

@@ -1,4 +1,5 @@
package com.tencent.supersonic.chat.core.parser.sql.llm;
package com.tencent.supersonic.chat.core.utils;
import com.tencent.supersonic.chat.api.pojo.SemanticSchema;
import com.tencent.supersonic.chat.api.pojo.ViewSchema;

View File

@@ -32,6 +32,9 @@ public class EmbeddingConfig {
@Value("${embedding.metric.analyzeQuery.collection:solved_query_collection}")
private String metricAnalyzeQueryCollection;
@Value("${text2sql.collection.name:text2dsl_agent_collection}")
private String text2sqlCollectionName;
@Value("${embedding.metric.analyzeQuery.nResult:5}")
private int metricAnalyzeQueryResultNum;

View File

@@ -48,7 +48,9 @@ public class InMemoryS2EmbeddingStore implements S2EmbeddingStore {
InMemoryEmbeddingStore<EmbeddingQuery> embeddingStore = null;
Path filePath = getPersistentPath(collectionName);
try {
if (Files.exists(filePath)) {
EmbeddingConfig embeddingConfig = ContextUtils.getBean(EmbeddingConfig.class);
if (Files.exists(filePath) && !collectionName.equals(embeddingConfig.getMetaCollectionName())
&& !collectionName.equals(embeddingConfig.getText2sqlCollectionName())) {
embeddingStore = InMemoryEmbeddingStore.fromFile(filePath);
embeddingStore.entries = new CopyOnWriteArraySet<>(embeddingStore.entries);
log.info("embeddingStore reload from file:{}", filePath);

View File

@@ -1,10 +1,10 @@
package com.tencent.supersonic;
import com.tencent.supersonic.chat.core.config.OptimizationConfig;
import com.tencent.supersonic.chat.core.parser.JavaLLMProxy;
import com.tencent.supersonic.chat.core.parser.sql.llm.SqlExamplarLoader;
import com.tencent.supersonic.chat.core.parser.sql.llm.SqlExample;
import com.tencent.supersonic.chat.core.utils.ComponentFactory;
import com.tencent.supersonic.common.config.EmbeddingConfig;
import java.util.List;
import lombok.extern.slf4j.Slf4j;
import org.springframework.beans.factory.annotation.Autowired;
@@ -20,7 +20,7 @@ public class EmbeddingInitListener implements CommandLineRunner {
@Autowired
private SqlExamplarLoader sqlExamplarLoader;
@Autowired
private OptimizationConfig optimizationConfig;
private EmbeddingConfig embeddingConfig;
@Override
public void run(String... args) {
@@ -31,7 +31,7 @@ public class EmbeddingInitListener implements CommandLineRunner {
try {
if (ComponentFactory.getLLMProxy() instanceof JavaLLMProxy) {
List<SqlExample> sqlExamples = sqlExamplarLoader.getSqlExamples();
String collectionName = optimizationConfig.getText2sqlCollectionName();
String collectionName = embeddingConfig.getText2sqlCollectionName();
sqlExamplarLoader.addEmbeddingStore(sqlExamples, collectionName);
}
} catch (Exception e) {