(improvement)(headless&chat)Refactor system parameter impl

This commit is contained in:
jerryjzhang
2024-06-01 01:42:00 +08:00
parent 28960668ce
commit 0f0847824f
32 changed files with 494 additions and 432 deletions

View File

@@ -3,6 +3,7 @@ package com.tencent.supersonic.headless.core.chat.mapper;
import com.tencent.supersonic.headless.api.pojo.enums.MapModeEnum;
import com.tencent.supersonic.headless.api.pojo.response.S2Term;
import com.tencent.supersonic.headless.core.config.MapperConfig;
import com.tencent.supersonic.headless.core.pojo.QueryContext;
import com.tencent.supersonic.headless.core.chat.knowledge.helper.NatureHelper;
import lombok.extern.slf4j.Slf4j;
@@ -27,7 +28,10 @@ import java.util.stream.Collectors;
public abstract class BaseMatchStrategy<T> implements MatchStrategy<T> {
@Autowired
private MapperHelper mapperHelper;
protected MapperHelper mapperHelper;
@Autowired
protected MapperConfig mapperConfig;
@Override
public Map<MatchText, List<T>> match(QueryContext queryContext, List<S2Term> terms,

View File

@@ -5,12 +5,10 @@ import com.tencent.supersonic.common.pojo.Constants;
import com.tencent.supersonic.headless.api.pojo.SchemaElement;
import com.tencent.supersonic.headless.api.pojo.SchemaElementMatch;
import com.tencent.supersonic.headless.api.pojo.response.S2Term;
import com.tencent.supersonic.headless.core.config.OptimizationConfig;
import com.tencent.supersonic.headless.core.pojo.QueryContext;
import com.tencent.supersonic.headless.core.chat.knowledge.DatabaseMapResult;
import lombok.extern.slf4j.Slf4j;
import org.apache.commons.lang3.StringUtils;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.stereotype.Service;
import org.springframework.util.CollectionUtils;
@@ -22,6 +20,9 @@ import java.util.Map.Entry;
import java.util.Set;
import java.util.stream.Collectors;
import static com.tencent.supersonic.headless.core.config.MapperConfig.MAPPER_NAME_THRESHOLD;
import static com.tencent.supersonic.headless.core.config.MapperConfig.MAPPER_NAME_THRESHOLD_MIN;
/**
* DatabaseMatchStrategy uses SQL LIKE operator to match schema elements.
* It currently supports fuzzy matching against names and aliases.
@@ -30,10 +31,6 @@ import java.util.stream.Collectors;
@Slf4j
public class DatabaseMatchStrategy extends BaseMatchStrategy<DatabaseMapResult> {
@Autowired
private OptimizationConfig optimizationConfig;
@Autowired
private MapperHelper mapperHelper;
private List<SchemaElement> allElements;
@Override
@@ -94,9 +91,8 @@ public class DatabaseMatchStrategy extends BaseMatchStrategy<DatabaseMapResult>
}
private Double getThreshold(QueryContext queryContext) {
Double threshold = optimizationConfig.getMetricDimensionThresholdConfig();
Double minThreshold = optimizationConfig.getMetricDimensionMinThresholdConfig();
Double threshold = Double.valueOf(mapperConfig.getParameterValue(MAPPER_NAME_THRESHOLD));
Double minThreshold = Double.valueOf(mapperConfig.getParameterValue(MAPPER_NAME_THRESHOLD_MIN));
Map<Long, List<SchemaElementMatch>> modelElementMatches = queryContext.getMapInfo().getDataSetElementMatches();

View File

@@ -5,7 +5,6 @@ import com.tencent.supersonic.common.pojo.Constants;
import com.tencent.supersonic.common.util.embedding.Retrieval;
import com.tencent.supersonic.common.util.embedding.RetrieveQuery;
import com.tencent.supersonic.common.util.embedding.RetrieveQueryResult;
import com.tencent.supersonic.headless.core.config.OptimizationConfig;
import com.tencent.supersonic.headless.core.chat.knowledge.EmbeddingResult;
import com.tencent.supersonic.headless.core.chat.knowledge.MetaEmbeddingService;
import com.tencent.supersonic.headless.core.pojo.QueryContext;
@@ -22,6 +21,14 @@ import java.util.Map;
import java.util.Set;
import java.util.stream.Collectors;
import static com.tencent.supersonic.headless.core.config.MapperConfig.EMBEDDING_MAPPER_BATCH;
import static com.tencent.supersonic.headless.core.config.MapperConfig.EMBEDDING_MAPPER_MAX;
import static com.tencent.supersonic.headless.core.config.MapperConfig.EMBEDDING_MAPPER_MIN;
import static com.tencent.supersonic.headless.core.config.MapperConfig.EMBEDDING_MAPPER_NUMBER;
import static com.tencent.supersonic.headless.core.config.MapperConfig.EMBEDDING_MAPPER_ROUND_NUMBER;
import static com.tencent.supersonic.headless.core.config.MapperConfig.EMBEDDING_MAPPER_THRESHOLD;
import static com.tencent.supersonic.headless.core.config.MapperConfig.EMBEDDING_MAPPER_THRESHOLD_MIN;
/**
* EmbeddingMatchStrategy uses vector database to perform
* similarity search against the embeddings of schema elements.
@@ -30,9 +37,6 @@ import java.util.stream.Collectors;
@Slf4j
public class EmbeddingMatchStrategy extends BaseMatchStrategy<EmbeddingResult> {
@Autowired
private OptimizationConfig optimizationConfig;
@Autowired
private MetaEmbeddingService metaEmbeddingService;
@@ -48,24 +52,27 @@ public class EmbeddingMatchStrategy extends BaseMatchStrategy<EmbeddingResult> {
}
@Override
public void detectByStep(QueryContext queryContext, Set<EmbeddingResult> existResults, Set<Long> detectDataSetIds,
String detectSegment, int offset) {
public void detectByStep(QueryContext queryContext, Set<EmbeddingResult> existResults,
Set<Long> detectDataSetIds, String detectSegment, int offset) {
}
@Override
protected void detectByBatch(QueryContext queryContext, Set<EmbeddingResult> results, Set<Long> detectDataSetIds,
Set<String> detectSegments) {
protected void detectByBatch(QueryContext queryContext, Set<EmbeddingResult> results,
Set<Long> detectDataSetIds, Set<String> detectSegments) {
int embedddingMapperMin = Integer.valueOf(mapperConfig.getParameterValue(EMBEDDING_MAPPER_MIN));
int embedddingMapperMax = Integer.valueOf(mapperConfig.getParameterValue(EMBEDDING_MAPPER_MAX));
int embeddingMapperBatch = Integer.valueOf(mapperConfig.getParameterValue(EMBEDDING_MAPPER_BATCH));
List<String> queryTextsList = detectSegments.stream()
.map(detectSegment -> detectSegment.trim())
.filter(detectSegment -> StringUtils.isNotBlank(detectSegment)
&& detectSegment.length() >= optimizationConfig.getEmbeddingMapperWordMin()
&& detectSegment.length() <= optimizationConfig.getEmbeddingMapperWordMax())
&& detectSegment.length() >= embedddingMapperMin
&& detectSegment.length() <= embedddingMapperMax)
.collect(Collectors.toList());
List<List<String>> queryTextsSubList = Lists.partition(queryTextsList,
optimizationConfig.getEmbeddingMapperBatch());
embeddingMapperBatch);
for (List<String> queryTextsSub : queryTextsSubList) {
detectByQueryTextsSub(results, detectDataSetIds, queryTextsSub, queryContext);
@@ -74,15 +81,16 @@ public class EmbeddingMatchStrategy extends BaseMatchStrategy<EmbeddingResult> {
private void detectByQueryTextsSub(Set<EmbeddingResult> results, Set<Long> detectDataSetIds,
List<String> queryTextsSub, QueryContext queryContext) {
Map<Long, List<Long>> modelIdToDataSetIds = queryContext.getModelIdToDataSetIds();
int embeddingNumber = optimizationConfig.getEmbeddingMapperNumber();
double threshold = getThreshold(optimizationConfig.getEmbeddingMapperThreshold(),
optimizationConfig.getEmbeddingMapperMinThreshold(), queryContext.getMapModeEnum());
double embeddingThreshold = Double.valueOf(mapperConfig.getParameterValue(EMBEDDING_MAPPER_THRESHOLD));
double embeddingThresholdMin = Double.valueOf(mapperConfig.getParameterValue(EMBEDDING_MAPPER_THRESHOLD_MIN));
double threshold = getThreshold(embeddingThreshold, embeddingThresholdMin, queryContext.getMapModeEnum());
// step1. build query params
RetrieveQuery retrieveQuery = RetrieveQuery.builder().queryTextsList(queryTextsSub).build();
// step2. retrieveQuery by detectSegment
// step2. retrieveQuery by detectSegment
int embeddingNumber = Integer.valueOf(mapperConfig.getParameterValue(EMBEDDING_MAPPER_NUMBER));
List<RetrieveQueryResult> retrieveQueryResults = metaEmbeddingService.retrieveQuery(
retrieveQuery, embeddingNumber, modelIdToDataSetIds, detectDataSetIds);
@@ -118,7 +126,8 @@ public class EmbeddingMatchStrategy extends BaseMatchStrategy<EmbeddingResult> {
.collect(Collectors.toList());
// step4. select mapResul in one round
int roundNumber = optimizationConfig.getEmbeddingMapperRoundNumber() * queryTextsSub.size();
int embeddingRoundNumber = Integer.valueOf(mapperConfig.getParameterValue(EMBEDDING_MAPPER_ROUND_NUMBER));
int roundNumber = embeddingRoundNumber * queryTextsSub.size();
List<EmbeddingResult> oneRoundResults = collect.stream()
.sorted(Comparator.comparingDouble(EmbeddingResult::getDistance))
.limit(roundNumber)

View File

@@ -4,7 +4,6 @@ import com.tencent.supersonic.common.pojo.Constants;
import com.tencent.supersonic.headless.api.pojo.response.S2Term;
import com.tencent.supersonic.headless.core.chat.knowledge.HanlpMapResult;
import com.tencent.supersonic.headless.core.chat.knowledge.KnowledgeBaseService;
import com.tencent.supersonic.headless.core.config.OptimizationConfig;
import com.tencent.supersonic.headless.core.pojo.QueryContext;
import lombok.extern.slf4j.Slf4j;
import org.apache.commons.collections.CollectionUtils;
@@ -21,6 +20,14 @@ import java.util.Objects;
import java.util.Set;
import java.util.stream.Collectors;
import static com.tencent.supersonic.headless.core.config.MapperConfig.MAPPER_DETECTION_MAX_SIZE;
import static com.tencent.supersonic.headless.core.config.MapperConfig.MAPPER_DETECTION_SIZE;
import static com.tencent.supersonic.headless.core.config.MapperConfig.MAPPER_DIMENSION_VALUE_SIZE;
import static com.tencent.supersonic.headless.core.config.MapperConfig.MAPPER_NAME_THRESHOLD;
import static com.tencent.supersonic.headless.core.config.MapperConfig.MAPPER_NAME_THRESHOLD_MIN;
import static com.tencent.supersonic.headless.core.config.MapperConfig.MAPPER_VALUE_THRESHOLD;
import static com.tencent.supersonic.headless.core.config.MapperConfig.MAPPER_VALUE_THRESHOLD_MIN;
/**
* HanlpDictMatchStrategy uses <a href="https://www.hanlp.com/">HanLP</a> to
* match schema elements. It currently supports prefix and suffix matching
@@ -30,12 +37,6 @@ import java.util.stream.Collectors;
@Slf4j
public class HanlpDictMatchStrategy extends BaseMatchStrategy<HanlpMapResult> {
@Autowired
private MapperHelper mapperHelper;
@Autowired
private OptimizationConfig optimizationConfig;
@Autowired
private KnowledgeBaseService knowledgeBaseService;
@@ -65,7 +66,7 @@ public class HanlpDictMatchStrategy extends BaseMatchStrategy<HanlpMapResult> {
public void detectByStep(QueryContext queryContext, Set<HanlpMapResult> existResults, Set<Long> detectDataSetIds,
String detectSegment, int offset) {
// step1. pre search
Integer oneDetectionMaxSize = optimizationConfig.getOneDetectionMaxSize();
Integer oneDetectionMaxSize = Integer.valueOf(mapperConfig.getParameterValue(MAPPER_DETECTION_MAX_SIZE));
LinkedHashSet<HanlpMapResult> hanlpMapResults = knowledgeBaseService.prefixSearch(detectSegment,
oneDetectionMaxSize, queryContext.getModelIdToDataSetIds(), detectDataSetIds)
.stream().collect(Collectors.toCollection(LinkedHashSet::new));
@@ -99,12 +100,13 @@ public class HanlpDictMatchStrategy extends BaseMatchStrategy<HanlpMapResult> {
}).collect(Collectors.toCollection(LinkedHashSet::new));
// step5. take only M dimensionValue or N-M metric/dimension value per rond.
int oneDetectionValueSize = Integer.valueOf(mapperConfig.getParameterValue(MAPPER_DIMENSION_VALUE_SIZE));
List<HanlpMapResult> dimensionValues = hanlpMapResults.stream()
.filter(entry -> mapperHelper.existDimensionValues(entry.getNatures()))
.limit(optimizationConfig.getOneDetectionDimensionValueSize())
.limit(oneDetectionValueSize)
.collect(Collectors.toList());
Integer oneDetectionSize = optimizationConfig.getOneDetectionSize();
Integer oneDetectionSize = Integer.valueOf(mapperConfig.getParameterValue(MAPPER_DETECTION_SIZE));
List<HanlpMapResult> oneRoundResults = new ArrayList<>();
// add the dimensionValue if it exists
@@ -129,13 +131,14 @@ public class HanlpDictMatchStrategy extends BaseMatchStrategy<HanlpMapResult> {
}
public double getThresholdMatch(List<String> natures, QueryContext queryContext) {
Double threshold = optimizationConfig.getMetricDimensionThresholdConfig();
Double minThreshold = optimizationConfig.getMetricDimensionMinThresholdConfig();
Double threshold = Double.valueOf(mapperConfig.getParameterValue(MAPPER_NAME_THRESHOLD));
Double minThreshold = Double.valueOf(mapperConfig.getParameterValue(MAPPER_NAME_THRESHOLD_MIN));
if (mapperHelper.existDimensionValues(natures)) {
threshold = optimizationConfig.getDimensionValueThresholdConfig();
minThreshold = optimizationConfig.getDimensionValueMinThresholdConfig();
threshold = Double.valueOf(mapperConfig.getParameterValue(MAPPER_VALUE_THRESHOLD));
minThreshold = Double.valueOf(mapperConfig.getParameterValue(MAPPER_VALUE_THRESHOLD_MIN));
}
return getThreshold(threshold, minThreshold, queryContext.getMapModeEnum());
return getThreshold(threshold, minThreshold, queryContext.getMapModeEnum());
}
}

View File

@@ -2,11 +2,9 @@ package com.tencent.supersonic.headless.core.chat.mapper;
import com.hankcs.hanlp.algorithm.EditDistance;
import com.tencent.supersonic.headless.api.pojo.response.S2Term;
import com.tencent.supersonic.headless.core.config.OptimizationConfig;
import com.tencent.supersonic.headless.core.chat.knowledge.helper.NatureHelper;
import lombok.Data;
import lombok.extern.slf4j.Slf4j;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.stereotype.Service;
import java.util.Comparator;
@@ -20,9 +18,6 @@ import java.util.stream.Collectors;
@Slf4j
public class MapperHelper {
@Autowired
private OptimizationConfig optimizationConfig;
public Integer getStepIndex(Map<Integer, Integer> regOffsetToLength, Integer index) {
Integer subRegLength = regOffsetToLength.get(index);
if (Objects.nonNull(subRegLength)) {

View File

@@ -2,12 +2,16 @@ package com.tencent.supersonic.headless.core.chat.parser;
import com.tencent.supersonic.common.util.ContextUtils;
import com.tencent.supersonic.headless.api.pojo.SemanticParseInfo;
import com.tencent.supersonic.headless.core.config.OptimizationConfig;
import com.tencent.supersonic.headless.core.config.ParserConfig;
import com.tencent.supersonic.headless.core.pojo.QueryContext;
import com.tencent.supersonic.headless.core.chat.query.SemanticQuery;
import com.tencent.supersonic.headless.core.chat.query.llm.s2sql.LLMSqlQuery;
import lombok.extern.slf4j.Slf4j;
import static com.tencent.supersonic.headless.core.config.ParserConfig.PARSER_TEXT_LENGTH_THRESHOLD;
import static com.tencent.supersonic.headless.core.config.ParserConfig.PARSER_TEXT_LENGTH_THRESHOLD_LONG;
import static com.tencent.supersonic.headless.core.config.ParserConfig.PARSER_TEXT_LENGTH_THRESHOLD_SHORT;
/**
* This checker can be used by semantic parsers to check if query intent
* has already been satisfied by current candidate queries. If so, current
@@ -32,12 +36,19 @@ public class SatisfactionChecker {
private static boolean checkThreshold(String queryText, SemanticParseInfo semanticParseInfo) {
int queryTextLength = queryText.replaceAll(" ", "").length();
double degree = semanticParseInfo.getScore() / queryTextLength;
OptimizationConfig optimizationConfig = ContextUtils.getBean(OptimizationConfig.class);
if (queryTextLength > optimizationConfig.getQueryTextLengthThreshold()) {
if (degree < optimizationConfig.getLongTextThreshold()) {
ParserConfig parserConfig = ContextUtils.getBean(ParserConfig.class);
int textLengthThreshold =
Integer.valueOf(parserConfig.getParameterValue(PARSER_TEXT_LENGTH_THRESHOLD));
double longTextLengthThreshold =
Double.valueOf(parserConfig.getParameterValue(PARSER_TEXT_LENGTH_THRESHOLD_LONG));
double shortTextLengthThreshold =
Double.valueOf(parserConfig.getParameterValue(PARSER_TEXT_LENGTH_THRESHOLD_SHORT));
if (queryTextLength > textLengthThreshold) {
if (degree < longTextLengthThreshold) {
return false;
}
} else if (degree < optimizationConfig.getShortTextThreshold()) {
} else if (degree < shortTextLengthThreshold) {
return false;
}
log.info("queryMode:{}, degree:{}, parse info:{}",

View File

@@ -2,7 +2,6 @@ package com.tencent.supersonic.headless.core.chat.parser.llm;
import com.tencent.supersonic.headless.core.chat.query.llm.s2sql.LLMReq;
import com.tencent.supersonic.headless.core.chat.query.llm.s2sql.LLMReq.SqlGenType;
import com.tencent.supersonic.headless.core.chat.query.llm.s2sql.LLMResp;
import lombok.extern.slf4j.Slf4j;
import org.springframework.stereotype.Component;
@@ -16,8 +15,7 @@ public class JavaLLMProxy implements LLMProxy {
public LLMResp text2sql(LLMReq llmReq) {
SqlGenStrategy sqlGenStrategy = SqlGenStrategyFactory.get(
SqlGenType.getMode(llmReq.getSqlGenerationMode()));
SqlGenStrategy sqlGenStrategy = SqlGenStrategyFactory.get(llmReq.getSqlGenType());
String modelName = llmReq.getSchema().getDataSetName();
LLMResp result = sqlGenStrategy.generate(llmReq);
result.setQuery(llmReq.getQueryText());

View File

@@ -11,8 +11,8 @@ import com.tencent.supersonic.headless.core.chat.parser.SatisfactionChecker;
import com.tencent.supersonic.headless.core.chat.query.llm.s2sql.LLMReq;
import com.tencent.supersonic.headless.core.chat.query.llm.s2sql.LLMReq.ElementValue;
import com.tencent.supersonic.headless.core.chat.query.llm.s2sql.LLMResp;
import com.tencent.supersonic.headless.core.config.ParserConfig;
import com.tencent.supersonic.headless.core.config.LLMParserConfig;
import com.tencent.supersonic.headless.core.config.OptimizationConfig;
import com.tencent.supersonic.headless.core.pojo.QueryContext;
import com.tencent.supersonic.headless.core.utils.ComponentFactory;
import com.tencent.supersonic.headless.core.utils.S2SqlDateHelper;
@@ -31,14 +31,18 @@ import java.util.Objects;
import java.util.Set;
import java.util.stream.Collectors;
import static com.tencent.supersonic.headless.core.config.ParserConfig.PARSER_LINKING_VALUE_ENABLE;
import static com.tencent.supersonic.headless.core.config.ParserConfig.PARSER_STRATEGY_TYPE;
@Slf4j
@Service
public class LLMRequestService {
@Autowired
private LLMParserConfig llmParserConfig;
@Autowired
private OptimizationConfig optimizationConfig;
private ParserConfig parserConfig;
public boolean isSkip(QueryContext queryCtx) {
if (!queryCtx.getText2SQLType().enableLLM()) {
@@ -86,7 +90,9 @@ public class LLMRequestService {
llmReq.setSchema(llmSchema);
List<ElementValue> linking = new ArrayList<>();
if (optimizationConfig.isUseLinkingValueSwitch()) {
boolean linkingValueEnabled = Boolean.valueOf(parserConfig.getParameterValue(PARSER_LINKING_VALUE_ENABLE));
if (linkingValueEnabled) {
linking.addAll(linkingValues);
}
llmReq.setLinking(linking);
@@ -96,7 +102,7 @@ public class LLMRequestService {
currentDate = DateUtils.getBeforeDate(0);
}
llmReq.setCurrentDate(currentDate);
llmReq.setSqlGenerationMode(optimizationConfig.getSqlGenType().getName());
llmReq.setSqlGenType(LLMReq.SqlGenType.valueOf(parserConfig.getParameterValue(PARSER_STRATEGY_TYPE)));
llmReq.setLlmConfig(queryCtx.getLlmConfig());
return llmReq;
}

View File

@@ -18,6 +18,11 @@ import java.util.Map;
import java.util.concurrent.CopyOnWriteArrayList;
import java.util.stream.Collectors;
import static com.tencent.supersonic.headless.core.config.ParserConfig.PARSER_EXEMPLAR_RECALL_NUMBER;
import static com.tencent.supersonic.headless.core.config.ParserConfig.PARSER_FEW_SHOT_NUMBER;
import static com.tencent.supersonic.headless.core.config.ParserConfig.PARSER_SELF_CONSISTENCY_NUMBER;
@Service
@Slf4j
public class OnePassSCSqlGenStrategy extends SqlGenStrategy {
@@ -27,11 +32,14 @@ public class OnePassSCSqlGenStrategy extends SqlGenStrategy {
//1.retriever sqlExamples and generate exampleListPool
keyPipelineLog.info("OnePassSCSqlGenStrategy llmReq:{}", llmReq);
List<Map<String, String>> sqlExamples = exemplarManager.recallExemplars(llmReq.getQueryText(),
optimizationConfig.getText2sqlExampleNum());
int exemplarRecallNumber = Integer.valueOf(parserConfig.getParameterValue(PARSER_EXEMPLAR_RECALL_NUMBER));
int fewShotNumber = Integer.valueOf(parserConfig.getParameterValue(PARSER_FEW_SHOT_NUMBER));
int selfConsistencyNumber = Integer.valueOf(parserConfig.getParameterValue(PARSER_SELF_CONSISTENCY_NUMBER));
List<Map<String, String>> sqlExamples = exemplarManager.recallExemplars(llmReq.getQueryText(),
exemplarRecallNumber);
List<List<Map<String, String>>> exampleListPool = promptGenerator.getExampleCombos(sqlExamples,
optimizationConfig.getText2sqlFewShotsNum(), optimizationConfig.getText2sqlSelfConsistencyNum());
fewShotNumber, selfConsistencyNumber);
//2.generator linking and sql prompt by sqlExamples,and parallel generate response.
List<String> linkingSqlPromptPool = promptGenerator.generatePromptPool(llmReq, exampleListPool, true);

View File

@@ -3,7 +3,7 @@ package com.tencent.supersonic.headless.core.chat.parser.llm;
import com.tencent.supersonic.headless.api.pojo.LLMConfig;
import com.tencent.supersonic.headless.core.chat.query.llm.s2sql.LLMReq;
import com.tencent.supersonic.headless.core.chat.query.llm.s2sql.LLMResp;
import com.tencent.supersonic.headless.core.config.OptimizationConfig;
import com.tencent.supersonic.headless.core.config.ParserConfig;
import com.tencent.supersonic.headless.core.utils.S2ChatModelProvider;
import dev.langchain4j.model.chat.ChatLanguageModel;
import org.slf4j.Logger;
@@ -25,7 +25,7 @@ public abstract class SqlGenStrategy implements InitializingBean {
protected ExemplarManager exemplarManager;
@Autowired
protected OptimizationConfig optimizationConfig;
protected ParserConfig parserConfig;
@Autowired
protected PromptGenerator promptGenerator;

View File

@@ -17,6 +17,10 @@ import java.util.List;
import java.util.Map;
import java.util.concurrent.CopyOnWriteArrayList;
import static com.tencent.supersonic.headless.core.config.ParserConfig.PARSER_EXEMPLAR_RECALL_NUMBER;
import static com.tencent.supersonic.headless.core.config.ParserConfig.PARSER_FEW_SHOT_NUMBER;
import static com.tencent.supersonic.headless.core.config.ParserConfig.PARSER_SELF_CONSISTENCY_NUMBER;
@Service
public class TwoPassSCSqlGenStrategy extends SqlGenStrategy {
@@ -24,11 +28,15 @@ public class TwoPassSCSqlGenStrategy extends SqlGenStrategy {
public LLMResp generate(LLMReq llmReq) {
//1.retriever sqlExamples and generate exampleListPool
keyPipelineLog.info("TwoPassSCSqlGenStrategy llmReq:{}", llmReq);
List<Map<String, String>> sqlExamples = exemplarManager.recallExemplars(llmReq.getQueryText(),
optimizationConfig.getText2sqlExampleNum());
int exemplarRecallNumber = Integer.valueOf(parserConfig.getParameterValue(PARSER_EXEMPLAR_RECALL_NUMBER));
int fewShotNumber = Integer.valueOf(parserConfig.getParameterValue(PARSER_FEW_SHOT_NUMBER));
int selfConsistencyNumber = Integer.valueOf(parserConfig.getParameterValue(PARSER_SELF_CONSISTENCY_NUMBER));
List<Map<String, String>> sqlExamples = exemplarManager.recallExemplars(llmReq.getQueryText(),
exemplarRecallNumber);
List<List<Map<String, String>>> exampleListPool = promptGenerator.getExampleCombos(sqlExamples,
optimizationConfig.getText2sqlFewShotsNum(), optimizationConfig.getText2sqlSelfConsistencyNum());
fewShotNumber, selfConsistencyNumber);
//2.generator linking prompt,and parallel generate response.
List<String> linkingPromptPool = promptGenerator.generatePromptPool(llmReq, exampleListPool, false);

View File

@@ -10,7 +10,7 @@ import com.tencent.supersonic.headless.api.pojo.SemanticParseInfo;
import com.tencent.supersonic.headless.api.pojo.SemanticSchema;
import com.tencent.supersonic.headless.api.pojo.request.QuerySqlReq;
import com.tencent.supersonic.headless.api.pojo.request.QueryStructReq;
import com.tencent.supersonic.headless.core.config.OptimizationConfig;
import com.tencent.supersonic.headless.core.config.ParserConfig;
import com.tencent.supersonic.headless.core.utils.QueryReqBuilder;
import lombok.ToString;
import lombok.extern.slf4j.Slf4j;
@@ -21,6 +21,8 @@ import java.util.List;
import java.util.Map;
import java.util.stream.Collectors;
import static com.tencent.supersonic.headless.core.config.ParserConfig.PARSER_S2SQL_ENABLE;
@Slf4j
@ToString
public abstract class BaseSemanticQuery implements SemanticQuery, Serializable {
@@ -73,8 +75,9 @@ public abstract class BaseSemanticQuery implements SemanticQuery, Serializable {
}
protected void initS2SqlByStruct(SemanticSchema semanticSchema) {
OptimizationConfig optimizationConfig = ContextUtils.getBean(OptimizationConfig.class);
if (!optimizationConfig.isUseS2SqlSwitch()) {
ParserConfig parserConfig = ContextUtils.getBean(ParserConfig.class);
boolean s2sqlEnable = Boolean.valueOf(parserConfig.getParameterValue(PARSER_S2SQL_ENABLE));
if (!s2sqlEnable) {
return;
}
QueryStructReq queryStructReq = convertQueryStruct();

View File

@@ -22,7 +22,7 @@ public class LLMReq {
private String priorExts;
private String sqlGenerationMode;
private SqlGenType sqlGenType;
private LLMConfig llmConfig;
@@ -82,14 +82,5 @@ public class LLMReq {
return name;
}
public static SqlGenType getMode(String name) {
for (SqlGenType sqlGenType : SqlGenType.values()) {
if (sqlGenType.name.equals(name)) {
return sqlGenType;
}
}
return null;
}
}
}

View File

@@ -0,0 +1,110 @@
package com.tencent.supersonic.headless.core.config;
import com.google.common.collect.Lists;
import com.tencent.supersonic.common.pojo.ParameterConfig;
import com.tencent.supersonic.common.pojo.Parameter;
import org.springframework.stereotype.Service;
import java.util.List;
@Service("HeadlessMapperConfig")
public class MapperConfig extends ParameterConfig {
public static final Parameter MAPPER_DETECTION_SIZE =
new Parameter("s2.mapper.detection.size", "8",
"一次探测返回结果个数",
"在每次探测后, 将前后缀匹配的结果合并, 并根据相似度阈值过滤后的结果个数",
"number", "Mapper相关配置");
public static final Parameter MAPPER_DETECTION_MAX_SIZE =
new Parameter("s2.mapper.detection.max.size", "20",
"一次探测前后缀匹配结果返回个数",
"单次前后缀匹配返回的结果个数",
"number", "Mapper相关配置");
public static final Parameter MAPPER_NAME_THRESHOLD =
new Parameter("s2.mapper.name.threshold", "0.3",
"指标名、维度名文本相似度阈值",
"文本片段和匹配到的指标、维度名计算出来的编辑距离阈值, 若超出该阈值, 则舍弃",
"number", "Mapper相关配置");
public static final Parameter MAPPER_NAME_THRESHOLD_MIN =
new Parameter("s2.mapper.name.min.threshold", "0.25",
"指标名、维度名最小文本相似度阈值",
"指标名、维度名相似度阈值在动态调整中的最低值",
"number", "Mapper相关配置");
public static final Parameter MAPPER_DIMENSION_VALUE_SIZE =
new Parameter("s2.mapper.value.size", "1",
"指标名、维度名最小文本相似度阈值",
"指标名、维度名相似度阈值在动态调整中的最低值",
"number", "Mapper相关配置");
public static final Parameter MAPPER_VALUE_THRESHOLD =
new Parameter("s2.mapper.value.threshold", "0.5",
"维度值文本相似度阈值",
"文本片段和匹配到的维度值计算出来的编辑距离阈值, 若超出该阈值, 则舍弃",
"number", "Mapper相关配置");
public static final Parameter MAPPER_VALUE_THRESHOLD_MIN =
new Parameter("s2.mapper.value.min.threshold", "0.3",
"维度值最小文本相似度阈值",
"维度值相似度阈值在动态调整中的最低值",
"number", "Mapper相关配置");
public static final Parameter EMBEDDING_MAPPER_MIN =
new Parameter("s2.mapper.embedding.word.min", "4",
"用于向量召回最小的文本长度",
"为提高向量召回效率, 小于该长度的文本不进行向量语义召回",
"number", "Mapper相关配置");
public static final Parameter EMBEDDING_MAPPER_MAX =
new Parameter("s2.mapper.embedding.word.max", "5",
"用于向量召回最大的文本长度",
"为提高向量召回效率, 大于该长度的文本不进行向量语义召回",
"number", "Mapper相关配置");
public static final Parameter EMBEDDING_MAPPER_BATCH =
new Parameter("s2.mapper.embedding.batch", "50",
"批量向量召回文本请求个数",
"每次进行向量语义召回的原始文本片段个数",
"number", "Mapper相关配置");
public static final Parameter EMBEDDING_MAPPER_NUMBER =
new Parameter("s2.mapper.embedding.number", "5",
"批量向量召回文本返回结果个数",
"每个文本进行向量语义召回的文本结果个数",
"number", "Mapper相关配置");
public static final Parameter EMBEDDING_MAPPER_THRESHOLD =
new Parameter("s2.mapper.embedding.threshold", "0.99",
"向量召回相似度阈值",
"相似度小于该阈值的则舍弃",
"number", "Mapper相关配置");
public static final Parameter EMBEDDING_MAPPER_THRESHOLD_MIN =
new Parameter("s2.mapper.embedding.min.threshold", "0.9",
"向量召回最小相似度阈值",
"向量召回相似度阈值在动态调整中的最低值",
"number", "Mapper相关配置");
public static final Parameter EMBEDDING_MAPPER_ROUND_NUMBER =
new Parameter("s2.mapper.embedding.round.number", "10",
"向量召回最小相似度阈值",
"向量召回相似度阈值在动态调整中的最低值",
"number", "Mapper相关配置");
@Override
public List<Parameter> getSysParameters() {
return Lists.newArrayList(
MAPPER_DETECTION_SIZE,
MAPPER_DETECTION_MAX_SIZE,
MAPPER_NAME_THRESHOLD,
MAPPER_NAME_THRESHOLD_MIN,
MAPPER_DIMENSION_VALUE_SIZE,
MAPPER_VALUE_THRESHOLD,
MAPPER_VALUE_THRESHOLD_MIN
);
}
}

View File

@@ -1,193 +0,0 @@
package com.tencent.supersonic.headless.core.config;
import com.tencent.supersonic.common.service.SysParameterService;
import com.tencent.supersonic.headless.core.chat.query.llm.s2sql.LLMReq;
import lombok.Data;
import lombok.extern.slf4j.Slf4j;
import org.apache.commons.lang3.StringUtils;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.beans.factory.annotation.Value;
import org.springframework.context.annotation.Configuration;
@Configuration
@Data
@Slf4j
public class OptimizationConfig {
@Value("${s2.one.detection.size:8}")
private Integer oneDetectionSize;
@Value("${s2.one.detection.max.size:20}")
private Integer oneDetectionMaxSize;
@Value("${s2.one.detection.dimensionValue.size:1}")
private Integer oneDetectionDimensionValueSize;
@Value("${s2.metric.dimension.min.threshold:0.3}")
private Double metricDimensionMinThresholdConfig;
@Value("${s2.metric.dimension.threshold:0.3}")
private Double metricDimensionThresholdConfig;
@Value("${s2.dimension.value.min.threshold:0.2}")
private Double dimensionValueMinThresholdConfig;
@Value("${s2.dimension.value.threshold:0.5}")
private Double dimensionValueThresholdConfig;
@Value("${s2.long.text.threshold:0.8}")
private Double longTextThreshold;
@Value("${s2.short.text.threshold:0.5}")
private Double shortTextThreshold;
@Value("${s2.query.text.length.threshold:10}")
private Integer queryTextLengthThreshold;
@Value("${s2.embedding.mapper.word.min:4}")
private int embeddingMapperWordMin;
@Value("${s2.embedding.mapper.word.max:4}")
private int embeddingMapperWordMax;
@Value("${s2.embedding.mapper.batch:50}")
private int embeddingMapperBatch;
@Value("${s2.embedding.mapper.number:5}")
private int embeddingMapperNumber;
@Value("${s2.embedding.mapper.round.number:10}")
private int embeddingMapperRoundNumber;
@Value("${s2.embedding.mapper.min.threshold:0.6}")
private Double embeddingMapperMinThreshold;
@Value("${s2.embedding.mapper.threshold:0.99}")
private Double embeddingMapperThreshold;
@Value("${s2.parser.linking.value.switch:true}")
private boolean useLinkingValueSwitch;
@Value("${s2.parser.strategy:TWO_PASS_AUTO_COT_SELF_CONSISTENCY}")
private LLMReq.SqlGenType sqlGenType;
@Value("${s2.parser.use.switch:true}")
private boolean useS2SqlSwitch;
@Value("${s2.parser.exemplar-recall.number:15}")
private int text2sqlExampleNum;
@Value("${s2.parser.few-shot.number:5}")
private int text2sqlFewShotsNum;
@Value("${s2.parser.self-consistency.number:5}")
private int text2sqlSelfConsistencyNum;
@Value("${s2.parser.show-count:3}")
private Integer parseShowCount;
@Autowired
private SysParameterService sysParameterService;
public Integer getOneDetectionSize() {
return convertValue("s2.one.detection.size", Integer.class, oneDetectionSize);
}
public Integer getOneDetectionMaxSize() {
return convertValue("s2.one.detection.max.size", Integer.class, oneDetectionMaxSize);
}
public Double getMetricDimensionMinThresholdConfig() {
return convertValue("s2.metric.dimension.min.threshold", Double.class, metricDimensionMinThresholdConfig);
}
public Double getMetricDimensionThresholdConfig() {
return convertValue("s2.metric.dimension.threshold", Double.class, metricDimensionThresholdConfig);
}
public Double getDimensionValueMinThresholdConfig() {
return convertValue("s2.dimension.value.min.threshold", Double.class, dimensionValueMinThresholdConfig);
}
public Double getDimensionValueThresholdConfig() {
return convertValue("s2.dimension.value.threshold", Double.class, dimensionValueThresholdConfig);
}
public Double getLongTextThreshold() {
return convertValue("s2.long.text.threshold", Double.class, longTextThreshold);
}
public Double getShortTextThreshold() {
return convertValue("s2.short.text.threshold", Double.class, shortTextThreshold);
}
public Integer getQueryTextLengthThreshold() {
return convertValue("s2.query.text.length.threshold", Integer.class, queryTextLengthThreshold);
}
public Integer getEmbeddingMapperWordMin() {
return convertValue("s2.embedding.mapper.word.min", Integer.class, embeddingMapperWordMin);
}
public Integer getEmbeddingMapperWordMax() {
return convertValue("s2.embedding.mapper.word.max", Integer.class, embeddingMapperWordMax);
}
public Integer getEmbeddingMapperBatch() {
return convertValue("s2.embedding.mapper.batch", Integer.class, embeddingMapperBatch);
}
public Integer getEmbeddingMapperNumber() {
return convertValue("s2.embedding.mapper.number", Integer.class, embeddingMapperNumber);
}
public Integer getEmbeddingMapperRoundNumber() {
return convertValue("s2.embedding.mapper.round.number", Integer.class, embeddingMapperRoundNumber);
}
public Double getEmbeddingMapperMinThreshold() {
return convertValue("s2.embedding.mapper.min.threshold", Double.class, embeddingMapperMinThreshold);
}
public Double getEmbeddingMapperThreshold() {
return convertValue("s2.embedding.mapper.threshold", Double.class, embeddingMapperThreshold);
}
public boolean isUseS2SqlSwitch() {
return convertValue("s2.parser.use.switch", Boolean.class, useS2SqlSwitch);
}
public boolean isUseLinkingValueSwitch() {
return convertValue("s2.parser.linking.value.switch", Boolean.class, useLinkingValueSwitch);
}
public LLMReq.SqlGenType getSqlGenType() {
return convertValue("s2.parser.strategy", LLMReq.SqlGenType.class, sqlGenType);
}
public Integer getParseShowCount() {
return convertValue("s2.parse.show-count", Integer.class, parseShowCount);
}
public <T> T convertValue(String paramName, Class<T> targetType, T defaultValue) {
try {
String value = sysParameterService.getSysParameter().getParameterByName(paramName);
if (StringUtils.isBlank(value)) {
return defaultValue;
}
if (targetType == Double.class) {
return targetType.cast(Double.parseDouble(value));
} else if (targetType == Integer.class) {
return targetType.cast(Integer.parseInt(value));
} else if (targetType == Boolean.class) {
return targetType.cast(Boolean.parseBoolean(value));
} else if (targetType == LLMReq.SqlGenType.class) {
return targetType.cast(LLMReq.SqlGenType.valueOf(value));
}
} catch (Exception e) {
log.error("convertValue", e);
}
return defaultValue;
}
}

View File

@@ -0,0 +1,84 @@
package com.tencent.supersonic.headless.core.config;
import com.google.common.collect.Lists;
import com.tencent.supersonic.common.pojo.ParameterConfig;
import com.tencent.supersonic.common.pojo.Parameter;
import lombok.extern.slf4j.Slf4j;
import org.springframework.stereotype.Service;
import java.util.List;
@Service("HeadlessParserConfig")
@Slf4j
public class ParserConfig extends ParameterConfig {
public static final Parameter PARSER_STRATEGY_TYPE =
new Parameter("s2.parser.strategy", "ONE_PASS_AUTO_COT_SELF_CONSISTENCY",
"LLM解析生成S2SQL策略",
"ONE_PASS_AUTO_COT_SELF_CONSISTENCY: 通过思维链且投票方式一步生成sql"
+ "\nTWO_PASS_AUTO_COT_SELF_CONSISTENCY: 通过思维链且投票方式两步生成sql",
"list", "Parser相关配置", Lists.newArrayList(
"ONE_PASS_AUTO_COT_SELF_CONSISTENCY", "TWO_PASS_AUTO_COT_SELF_CONSISTENCY"));
public static final Parameter PARSER_LINKING_VALUE_ENABLE =
new Parameter("s2.parser.linking.value.enable", "true",
"是否将Mapper探测识别到的维度值提供给大模型", "为了数据安全考虑, 这里可进行开关选择",
"bool", "Parser相关配置");
public static final Parameter PARSER_TEXT_LENGTH_THRESHOLD =
new Parameter("s2.parser.text.length.threshold", "10",
"用户输入文本长短阈值", "文本超过该阈值为长文本",
"number", "Parser相关配置");
public static final Parameter PARSER_TEXT_LENGTH_THRESHOLD_SHORT =
new Parameter("s2.parser.text.threshold", "0.5",
"短文本匹配阈值",
"由于请求大模型耗时较长, 因此如果有规则类型的Query得分达到阈值,则跳过大模型的调用,"
+ "\n如果是短文本, 若query得分/文本长度>该阈值, 则跳过当前parser",
"number", "Parser相关配置");
public static final Parameter PARSER_TEXT_LENGTH_THRESHOLD_LONG =
new Parameter("s2.parser.text.threshold", "0.8",
"长文本匹配阈值", "如果是长文本, 若query得分/文本长度>该阈值, 则跳过当前parser",
"number", "Parser相关配置");
public static final Parameter PARSER_EXEMPLAR_RECALL_NUMBER =
new Parameter("s2.parser.exemplar-recall.number", "10",
"exemplar召回个数", "",
"number", "Parser相关配置");
public static final Parameter PARSER_FEW_SHOT_NUMBER =
new Parameter("s2.parser.few-shot.number", "5",
"few-shot样例个数", "样例越多效果可能越好但token消耗越大",
"number", "Parser相关配置");
public static final Parameter PARSER_SELF_CONSISTENCY_NUMBER =
new Parameter("s2.parser.self-consistency.number", "1",
"self-consistency执行个数", "执行越多效果可能越好但token消耗越大",
"number", "Parser相关配置");
public static final Parameter PARSER_SHOW_COUNT =
new Parameter("s2.parser.show.count", "3",
"解析结果展示个数", "前端展示的解析个数",
"number", "Parser相关配置");
public static final Parameter PARSER_S2SQL_ENABLE =
new Parameter("s2.parser.s2sql.switch", "true",
"", "",
"bool", "Parser相关配置");
@Override
public List<Parameter> getSysParameters() {
return Lists.newArrayList(
PARSER_STRATEGY_TYPE,
PARSER_LINKING_VALUE_ENABLE,
PARSER_TEXT_LENGTH_THRESHOLD,
PARSER_TEXT_LENGTH_THRESHOLD_SHORT,
PARSER_TEXT_LENGTH_THRESHOLD_LONG,
PARSER_FEW_SHOT_NUMBER,
PARSER_SELF_CONSISTENCY_NUMBER,
PARSER_SHOW_COUNT
);
}
}

View File

@@ -12,7 +12,7 @@ import com.tencent.supersonic.headless.api.pojo.enums.MapModeEnum;
import com.tencent.supersonic.headless.api.pojo.enums.WorkflowState;
import com.tencent.supersonic.headless.api.pojo.request.QueryFilters;
import com.tencent.supersonic.headless.core.chat.query.SemanticQuery;
import com.tencent.supersonic.headless.core.config.OptimizationConfig;
import com.tencent.supersonic.headless.core.config.ParserConfig;
import lombok.AllArgsConstructor;
import lombok.Builder;
import lombok.Data;
@@ -26,6 +26,8 @@ import java.util.Map;
import java.util.Set;
import java.util.stream.Collectors;
import static com.tencent.supersonic.headless.core.config.ParserConfig.PARSER_SHOW_COUNT;
@Data
@Builder
@NoArgsConstructor
@@ -51,8 +53,8 @@ public class QueryContext {
private LLMConfig llmConfig;
public List<SemanticQuery> getCandidateQueries() {
OptimizationConfig optimizationConfig = ContextUtils.getBean(OptimizationConfig.class);
Integer parseShowCount = optimizationConfig.getParseShowCount();
ParserConfig parserConfig = ContextUtils.getBean(ParserConfig.class);
int parseShowCount = Integer.valueOf(parserConfig.getParameterValue(PARSER_SHOW_COUNT));
candidateQueries = candidateQueries.stream()
.sorted(Comparator.comparing(semanticQuery -> semanticQuery.getParseInfo().getScore(),
Comparator.reverseOrder()))