(improvement)(headless&chat)Refactor system parameter impl

This commit is contained in:
jerryjzhang
2024-06-01 01:42:00 +08:00
parent 28960668ce
commit 0f0847824f
32 changed files with 494 additions and 432 deletions

View File

@@ -2,8 +2,8 @@ package com.tencent.supersonic.auth.api.authentication.utils;
import com.tencent.supersonic.auth.api.authentication.pojo.User;
import com.tencent.supersonic.auth.api.authentication.service.UserStrategy;
import com.tencent.supersonic.common.pojo.SysParameter;
import com.tencent.supersonic.common.service.SysParameterService;
import com.tencent.supersonic.common.pojo.SystemConfig;
import com.tencent.supersonic.common.service.SystemConfigService;
import com.tencent.supersonic.common.util.ContextUtils;
import org.springframework.util.CollectionUtils;
@@ -20,8 +20,8 @@ public final class UserHolder {
public static User findUser(HttpServletRequest request, HttpServletResponse response) {
User user = REPO.findUser(request, response);
SysParameterService sysParameterService = ContextUtils.getBean(SysParameterService.class);
SysParameter sysParameter = sysParameterService.getSysParameter();
SystemConfigService sysParameterService = ContextUtils.getBean(SystemConfigService.class);
SystemConfig sysParameter = sysParameterService.getSysParameter();
if (!CollectionUtils.isEmpty(sysParameter.getAdmins())
&& sysParameter.getAdmins().contains(user.getName())) {
user.setIsAdmin(1);

View File

@@ -6,8 +6,8 @@ import com.tencent.supersonic.auth.api.authentication.request.UserReq;
import com.tencent.supersonic.auth.api.authentication.service.UserService;
import com.tencent.supersonic.auth.api.authentication.utils.UserHolder;
import com.tencent.supersonic.auth.authentication.utils.ComponentFactory;
import com.tencent.supersonic.common.pojo.SysParameter;
import com.tencent.supersonic.common.service.SysParameterService;
import com.tencent.supersonic.common.pojo.SystemConfig;
import com.tencent.supersonic.common.service.SystemConfigService;
import org.springframework.stereotype.Service;
import org.springframework.util.CollectionUtils;
import javax.servlet.http.HttpServletRequest;
@@ -18,9 +18,9 @@ import java.util.Set;
@Service
public class UserServiceImpl implements UserService {
private SysParameterService sysParameterService;
private SystemConfigService sysParameterService;
public UserServiceImpl(SysParameterService sysParameterService) {
public UserServiceImpl(SystemConfigService sysParameterService) {
this.sysParameterService = sysParameterService;
}
@@ -28,7 +28,7 @@ public class UserServiceImpl implements UserService {
public User getCurrentUser(HttpServletRequest httpServletRequest, HttpServletResponse httpServletResponse) {
User user = UserHolder.findUser(httpServletRequest, httpServletResponse);
if (user != null) {
SysParameter sysParameter = sysParameterService.getSysParameter();
SystemConfig sysParameter = sysParameterService.getSysParameter();
if (!CollectionUtils.isEmpty(sysParameter.getAdmins())
&& sysParameter.getAdmins().contains(user.getName())) {
user.setIsAdmin(1);

View File

@@ -23,7 +23,6 @@ import lombok.Data;
import lombok.extern.slf4j.Slf4j;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.core.env.Environment;
import java.util.Map;
import java.util.HashMap;
@@ -32,6 +31,8 @@ import java.util.List;
import java.util.stream.Collectors;
import java.util.Collections;
import static com.tencent.supersonic.chat.server.parser.ParserConfig.PARSER_MULTI_TURN_ENABLE;
@Slf4j
public class MultiTurnParser implements ChatParser {
@@ -51,9 +52,9 @@ public class MultiTurnParser implements ChatParser {
@Override
public void parse(ChatParseContext chatParseContext, ParseResp parseResp) {
Environment environment = ContextUtils.getBean(Environment.class);
ParserConfig parserConfig = ContextUtils.getBean(ParserConfig.class);
MultiTurnConfig agentMultiTurnConfig = chatParseContext.getAgent().getMultiTurnConfig();
Boolean globalMultiTurnConfig = environment.getProperty("s2.parser.multi-turn.enable", Boolean.class);
Boolean globalMultiTurnConfig = Boolean.valueOf(parserConfig.getParameterValue(PARSER_MULTI_TURN_ENABLE));
Boolean multiTurnConfig = agentMultiTurnConfig != null
? agentMultiTurnConfig.isEnableMultiTurn() : globalMultiTurnConfig;

View File

@@ -17,8 +17,8 @@ public class NL2PluginParser implements ChatParser {
public void parse(ChatParseContext chatParseContext, ParseResp parseResp) {
pluginRecognizers.forEach(pluginRecognizer -> {
pluginRecognizer.recognize(chatParseContext, parseResp);
log.info("{} context:{} result:{}", pluginRecognizer.getClass().getSimpleName(),
JsonUtil.toString(chatParseContext), JsonUtil.toString(parseResp));
log.info("{} recallResult:{}", pluginRecognizer.getClass().getSimpleName(),
JsonUtil.toString(parseResp));
});
}

View File

@@ -0,0 +1,27 @@
package com.tencent.supersonic.chat.server.parser;
import com.google.common.collect.Lists;
import com.tencent.supersonic.common.pojo.ParameterConfig;
import com.tencent.supersonic.common.pojo.Parameter;
import lombok.extern.slf4j.Slf4j;
import org.springframework.stereotype.Service;
import java.util.List;
@Service("ChatParserConfig")
@Slf4j
public class ParserConfig extends ParameterConfig {
public static final Parameter PARSER_MULTI_TURN_ENABLE =
new Parameter("s2.parser.multi-turn.enable", "false",
"是否开启多轮对话", "开启多轮对话将消耗更多token",
"bool", "Parser相关配置");
@Override
public List<Parameter> getSysParameters() {
return Lists.newArrayList(
PARSER_MULTI_TURN_ENABLE
);
}
}

View File

@@ -10,15 +10,17 @@ import java.util.List;
@NoArgsConstructor
public class Parameter {
private String name;
private String value;
private String defaultValue;
private String comment;
private String description;
private String dataType;
private String module;
private List<Object> candidateValues;
public Parameter(String name, String value, String comment, String description, String dataType, String module) {
public Parameter(String name, String defaultValue, String comment,
String description, String dataType, String module) {
this.name = name;
this.value = value;
this.defaultValue = defaultValue;
this.comment = comment;
this.description = description;
this.dataType = dataType;

View File

@@ -0,0 +1,49 @@
package com.tencent.supersonic.common.pojo;
import com.tencent.supersonic.common.service.SystemConfigService;
import lombok.extern.slf4j.Slf4j;
import org.apache.commons.lang3.StringUtils;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.core.env.Environment;
import org.springframework.stereotype.Service;
import java.util.List;
@Service
@Slf4j
public abstract class ParameterConfig {
@Autowired
private SystemConfigService sysConfigService;
@Autowired
private Environment environment;
protected abstract List<Parameter> getSysParameters();
/**
* Parameter value will be derived by following orders:
* 1. `system config` set with user interface
* 2. `system property` set with application.yaml
* 3. `default value` set with parameter declaration
* @param parameter
* @return parameter value
*/
public String getParameterValue(Parameter parameter) {
String paramName = parameter.getName();
String value = sysConfigService.getSysParameter().getParameterByName(paramName);
try {
if (StringUtils.isBlank(value)) {
if (environment.containsProperty(paramName)) {
value = environment.getProperty(paramName);
} else {
value = parameter.getDefaultValue();
}
}
} catch (Exception e) {
log.error("Failed to get parameter value for {} with exception: {}", paramName, e);
}
return value;
}
}

View File

@@ -1,111 +0,0 @@
package com.tencent.supersonic.common.pojo;
import com.google.common.collect.Lists;
import lombok.Data;
import org.apache.commons.lang3.StringUtils;
import org.springframework.util.CollectionUtils;
import java.util.Arrays;
import java.util.List;
import java.util.Map;
import java.util.stream.Collectors;
@Data
public class SysParameter {
private Integer id;
private List<String> admins;
private List<Parameter> parameters;
public String getAdmin() {
if (CollectionUtils.isEmpty(admins)) {
return "";
}
return StringUtils.join(admins, ",");
}
public String getParameterByName(String name) {
if (StringUtils.isBlank(name)) {
return "";
}
Map<String, String> nameToValue = parameters.stream()
.collect(Collectors.toMap(Parameter::getName, Parameter::getValue, (k1, k2) -> k1));
return nameToValue.get(name);
}
public void setAdminList(String admin) {
if (StringUtils.isNotBlank(admin)) {
admins = Arrays.asList(admin.split(","));
} else {
admins = Lists.newArrayList();
}
}
public void init() {
parameters = Lists.newArrayList();
admins = Lists.newArrayList("admin");
//detect config
parameters.add(new Parameter("s2.one.detection.size", "8",
"一次探测返回结果个数", "在每次探测后, 将前后缀匹配的结果合并, 并根据相似度阈值过滤后的结果个数",
"number", "Mapper相关配置"));
parameters.add(new Parameter("s2.one.detection.max.size", "20",
"一次探测前后缀匹配结果返回个数", "单次前后缀匹配返回的结果个数", "number", "Mapper相关配置"));
//mapper config
parameters.add(new Parameter("s2.metric.dimension.threshold", "0.3",
"指标名、维度名文本相似度阈值", "文本片段和匹配到的指标、维度名计算出来的编辑距离阈值, 若超出该阈值, 则舍弃",
"number", "Mapper相关配置"));
parameters.add(new Parameter("s2.metric.dimension.min.threshold", "0.25",
"指标名、维度名最小文本相似度阈值", "指标名、维度名相似度阈值在动态调整中的最低值",
"number", "Mapper相关配置"));
parameters.add(new Parameter("s2.dimension.value.threshold", "0.5",
"维度值文本相似度阈值", "文本片段和匹配到的维度值计算出来的编辑距离阈值, 若超出该阈值, 则舍弃",
"number", "Mapper相关配置"));
parameters.add(new Parameter("s2.dimension.value.min.threshold", "0.3",
"维度值最小文本相似度阈值", "维度值相似度阈值在动态调整中的最低值",
"number", "Mapper相关配置"));
//embedding mapper config
parameters.add(new Parameter("s2.embedding.mapper.word.min",
"4", "用于向量召回最小的文本长度", "为提高向量召回效率, 小于该长度的文本不进行向量语义召回", "number", "Mapper相关配置"));
parameters.add(new Parameter("s2.embedding.mapper.word.max", "5",
"用于向量召回最大的文本长度", "为提高向量召回效率, 大于该长度的文本不进行向量语义召回", "number", "Mapper相关配置"));
parameters.add(new Parameter("s2.embedding.mapper.batch", "50",
"批量向量召回文本请求个数", "每次进行向量语义召回的原始文本片段个数", "number", "Mapper相关配置"));
parameters.add(new Parameter("s2.embedding.mapper.number", "5",
"批量向量召回文本返回结果个数", "每个文本进行向量语义召回的文本结果个数", "number", "Mapper相关配置"));
parameters.add(new Parameter("s2.embedding.mapper.threshold",
"0.99", "向量召回相似度阈值", "相似度小于该阈值的则舍弃", "number", "Mapper相关配置"));
parameters.add(new Parameter("s2.embedding.mapper.min.threshold",
"0.9", "向量召回最小相似度阈值", "向量召回相似度阈值在动态调整中的最低值", "number", "Mapper相关配置"));
//parser config
Parameter s2SQLParameter = new Parameter("s2.parser.strategy",
"TWO_PASS_AUTO_COT_SELF_CONSISTENCY",
"LLM解析生成S2SQL策略",
"ONE_PASS_AUTO_COT_SELF_CONSISTENCY: 通过思维链且投票方式一步生成sql"
+ "\nTWO_PASS_AUTO_COT_SELF_CONSISTENCY: 通过思维链且投票方式两步生成sql", "list", "Parser相关配置");
s2SQLParameter.setCandidateValues(Lists.newArrayList(
"ONE_PASS_AUTO_COT_SELF_CONSISTENCY", "TWO_PASS_AUTO_COT_SELF_CONSISTENCY"));
parameters.add(s2SQLParameter);
parameters.add(new Parameter("s2.s2SQL.linking.value.switch", "true",
"是否将Mapper探测识别到的维度值提供给大模型", "为了数据安全考虑, 这里可进行开关选择",
"bool", "Parser相关配置"));
parameters.add(new Parameter("s2.query.text.length.threshold", "10",
"用户输入文本长短阈值", "文本超过该阈值为长文本", "number", "Parser相关配置"));
parameters.add(new Parameter("s2.short.text.threshold", "0.5",
"短文本匹配阈值", "由于请求大模型耗时较长, 因此如果有规则类型的Query得分达到阈值,则跳过大模型的调用,\n如果是短文本, 若query得分/文本长度>该阈值, 则跳过当前parser",
"number", "Parser相关配置"));
parameters.add(new Parameter("s2.long.text.threshold", "0.8",
"长文本匹配阈值", "如果是长文本, 若query得分/文本长度>该阈值, 则跳过当前parser",
"number", "Parser相关配置"));
parameters.add(new Parameter("s2.parse.show-count", "3",
"解析结果个数", "前端展示的解析个数",
"number", "Parser相关配置"));
}
}

View File

@@ -0,0 +1,59 @@
package com.tencent.supersonic.common.pojo;
import com.google.common.collect.Lists;
import com.tencent.supersonic.common.util.ContextUtils;
import lombok.Data;
import org.apache.commons.lang3.StringUtils;
import org.springframework.util.CollectionUtils;
import java.util.Arrays;
import java.util.Collection;
import java.util.List;
import java.util.Map;
import java.util.stream.Collectors;
@Data
public class SystemConfig {
private Integer id;
private List<String> admins;
private List<Parameter> parameters;
public String getAdmin() {
if (CollectionUtils.isEmpty(admins)) {
return "";
}
return StringUtils.join(admins, ",");
}
public String getParameterByName(String name) {
if (StringUtils.isBlank(name)) {
return "";
}
Map<String, String> nameToValue = parameters.stream()
.collect(Collectors.toMap(Parameter::getName, Parameter::getDefaultValue, (k1, k2) -> k1));
return nameToValue.get(name);
}
public void setAdminList(String admin) {
if (StringUtils.isNotBlank(admin)) {
admins = Arrays.asList(admin.split(","));
} else {
admins = Lists.newArrayList();
}
}
public void init() {
parameters = Lists.newArrayList();
admins = Lists.newArrayList("admin");
Collection<ParameterConfig> configurableParameters =
ContextUtils.getBeansOfType(ParameterConfig.class).values();
for (ParameterConfig configParameters : configurableParameters) {
parameters.addAll(configParameters.getSysParameters());
}
}
}

View File

@@ -1,7 +1,7 @@
package com.tencent.supersonic.common.rest;
import com.tencent.supersonic.common.pojo.SysParameter;
import com.tencent.supersonic.common.service.SysParameterService;
import com.tencent.supersonic.common.pojo.SystemConfig;
import com.tencent.supersonic.common.service.SystemConfigService;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.web.bind.annotation.GetMapping;
import org.springframework.web.bind.annotation.PostMapping;
@@ -11,20 +11,20 @@ import org.springframework.web.bind.annotation.RestController;
@RestController
@RequestMapping({"/api/semantic/parameter"})
public class SysParameterController {
public class SystemConfigController {
@Autowired
private SysParameterService sysParameterService;
private SystemConfigService sysConfigService;
@PostMapping
public Boolean save(@RequestBody SysParameter sysParameter) {
sysParameterService.save(sysParameter);
public Boolean save(@RequestBody SystemConfig sysParameter) {
sysConfigService.save(sysParameter);
return true;
}
@GetMapping
public SysParameter get() {
return sysParameterService.getSysParameter();
public SystemConfig get() {
return sysConfigService.getSysParameter();
}
}

View File

@@ -2,12 +2,12 @@ package com.tencent.supersonic.common.service;
import com.baomidou.mybatisplus.extension.service.IService;
import com.tencent.supersonic.common.persistence.dataobject.SysParameterDO;
import com.tencent.supersonic.common.pojo.SysParameter;
import com.tencent.supersonic.common.pojo.SystemConfig;
public interface SysParameterService extends IService<SysParameterDO> {
public interface SystemConfigService extends IService<SysParameterDO> {
SysParameter getSysParameter();
SystemConfig getSysParameter();
void save(SysParameter sysParameter);
void save(SystemConfig sysConfig);
}

View File

@@ -6,22 +6,22 @@ import com.fasterxml.jackson.core.type.TypeReference;
import com.tencent.supersonic.common.persistence.dataobject.SysParameterDO;
import com.tencent.supersonic.common.persistence.mapper.SysParameterMapper;
import com.tencent.supersonic.common.pojo.Parameter;
import com.tencent.supersonic.common.pojo.SysParameter;
import com.tencent.supersonic.common.service.SysParameterService;
import com.tencent.supersonic.common.pojo.SystemConfig;
import com.tencent.supersonic.common.service.SystemConfigService;
import com.tencent.supersonic.common.util.JsonUtil;
import java.util.List;
import org.springframework.stereotype.Service;
import org.springframework.util.CollectionUtils;
@Service
public class SysParameterServiceImpl
extends ServiceImpl<SysParameterMapper, SysParameterDO> implements SysParameterService {
public class SystemConfigServiceImpl
extends ServiceImpl<SysParameterMapper, SysParameterDO> implements SystemConfigService {
@Override
public SysParameter getSysParameter() {
public SystemConfig getSysParameter() {
List<SysParameterDO> list = list();
if (CollectionUtils.isEmpty(list)) {
SysParameter sysParameter = new SysParameter();
SystemConfig sysParameter = new SystemConfig();
sysParameter.setId(1);
sysParameter.init();
save(sysParameter);
@@ -31,13 +31,13 @@ public class SysParameterServiceImpl
}
@Override
public void save(SysParameter sysParameter) {
SysParameterDO sysParameterDO = convert(sysParameter);
public void save(SystemConfig sysConfig) {
SysParameterDO sysParameterDO = convert(sysConfig);
saveOrUpdate(sysParameterDO);
}
private SysParameter convert(SysParameterDO sysParameterDO) {
SysParameter sysParameter = new SysParameter();
private SystemConfig convert(SysParameterDO sysParameterDO) {
SystemConfig sysParameter = new SystemConfig();
sysParameter.setId(sysParameterDO.getId());
List<Parameter> parameters = JsonUtil.toObject(sysParameterDO.getParameters(),
new TypeReference<List<Parameter>>() {
@@ -47,7 +47,7 @@ public class SysParameterServiceImpl
return sysParameter;
}
private SysParameterDO convert(SysParameter sysParameter) {
private SysParameterDO convert(SystemConfig sysParameter) {
SysParameterDO sysParameterDO = new SysParameterDO();
sysParameterDO.setId(sysParameter.getId());
sysParameterDO.setParameters(JSONObject.toJSONString(sysParameter.getParameters()));

View File

@@ -3,6 +3,7 @@ package com.tencent.supersonic.headless.core.chat.mapper;
import com.tencent.supersonic.headless.api.pojo.enums.MapModeEnum;
import com.tencent.supersonic.headless.api.pojo.response.S2Term;
import com.tencent.supersonic.headless.core.config.MapperConfig;
import com.tencent.supersonic.headless.core.pojo.QueryContext;
import com.tencent.supersonic.headless.core.chat.knowledge.helper.NatureHelper;
import lombok.extern.slf4j.Slf4j;
@@ -27,7 +28,10 @@ import java.util.stream.Collectors;
public abstract class BaseMatchStrategy<T> implements MatchStrategy<T> {
@Autowired
private MapperHelper mapperHelper;
protected MapperHelper mapperHelper;
@Autowired
protected MapperConfig mapperConfig;
@Override
public Map<MatchText, List<T>> match(QueryContext queryContext, List<S2Term> terms,

View File

@@ -5,12 +5,10 @@ import com.tencent.supersonic.common.pojo.Constants;
import com.tencent.supersonic.headless.api.pojo.SchemaElement;
import com.tencent.supersonic.headless.api.pojo.SchemaElementMatch;
import com.tencent.supersonic.headless.api.pojo.response.S2Term;
import com.tencent.supersonic.headless.core.config.OptimizationConfig;
import com.tencent.supersonic.headless.core.pojo.QueryContext;
import com.tencent.supersonic.headless.core.chat.knowledge.DatabaseMapResult;
import lombok.extern.slf4j.Slf4j;
import org.apache.commons.lang3.StringUtils;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.stereotype.Service;
import org.springframework.util.CollectionUtils;
@@ -22,6 +20,9 @@ import java.util.Map.Entry;
import java.util.Set;
import java.util.stream.Collectors;
import static com.tencent.supersonic.headless.core.config.MapperConfig.MAPPER_NAME_THRESHOLD;
import static com.tencent.supersonic.headless.core.config.MapperConfig.MAPPER_NAME_THRESHOLD_MIN;
/**
* DatabaseMatchStrategy uses SQL LIKE operator to match schema elements.
* It currently supports fuzzy matching against names and aliases.
@@ -30,10 +31,6 @@ import java.util.stream.Collectors;
@Slf4j
public class DatabaseMatchStrategy extends BaseMatchStrategy<DatabaseMapResult> {
@Autowired
private OptimizationConfig optimizationConfig;
@Autowired
private MapperHelper mapperHelper;
private List<SchemaElement> allElements;
@Override
@@ -94,9 +91,8 @@ public class DatabaseMatchStrategy extends BaseMatchStrategy<DatabaseMapResult>
}
private Double getThreshold(QueryContext queryContext) {
Double threshold = optimizationConfig.getMetricDimensionThresholdConfig();
Double minThreshold = optimizationConfig.getMetricDimensionMinThresholdConfig();
Double threshold = Double.valueOf(mapperConfig.getParameterValue(MAPPER_NAME_THRESHOLD));
Double minThreshold = Double.valueOf(mapperConfig.getParameterValue(MAPPER_NAME_THRESHOLD_MIN));
Map<Long, List<SchemaElementMatch>> modelElementMatches = queryContext.getMapInfo().getDataSetElementMatches();

View File

@@ -5,7 +5,6 @@ import com.tencent.supersonic.common.pojo.Constants;
import com.tencent.supersonic.common.util.embedding.Retrieval;
import com.tencent.supersonic.common.util.embedding.RetrieveQuery;
import com.tencent.supersonic.common.util.embedding.RetrieveQueryResult;
import com.tencent.supersonic.headless.core.config.OptimizationConfig;
import com.tencent.supersonic.headless.core.chat.knowledge.EmbeddingResult;
import com.tencent.supersonic.headless.core.chat.knowledge.MetaEmbeddingService;
import com.tencent.supersonic.headless.core.pojo.QueryContext;
@@ -22,6 +21,14 @@ import java.util.Map;
import java.util.Set;
import java.util.stream.Collectors;
import static com.tencent.supersonic.headless.core.config.MapperConfig.EMBEDDING_MAPPER_BATCH;
import static com.tencent.supersonic.headless.core.config.MapperConfig.EMBEDDING_MAPPER_MAX;
import static com.tencent.supersonic.headless.core.config.MapperConfig.EMBEDDING_MAPPER_MIN;
import static com.tencent.supersonic.headless.core.config.MapperConfig.EMBEDDING_MAPPER_NUMBER;
import static com.tencent.supersonic.headless.core.config.MapperConfig.EMBEDDING_MAPPER_ROUND_NUMBER;
import static com.tencent.supersonic.headless.core.config.MapperConfig.EMBEDDING_MAPPER_THRESHOLD;
import static com.tencent.supersonic.headless.core.config.MapperConfig.EMBEDDING_MAPPER_THRESHOLD_MIN;
/**
* EmbeddingMatchStrategy uses vector database to perform
* similarity search against the embeddings of schema elements.
@@ -30,9 +37,6 @@ import java.util.stream.Collectors;
@Slf4j
public class EmbeddingMatchStrategy extends BaseMatchStrategy<EmbeddingResult> {
@Autowired
private OptimizationConfig optimizationConfig;
@Autowired
private MetaEmbeddingService metaEmbeddingService;
@@ -48,24 +52,27 @@ public class EmbeddingMatchStrategy extends BaseMatchStrategy<EmbeddingResult> {
}
@Override
public void detectByStep(QueryContext queryContext, Set<EmbeddingResult> existResults, Set<Long> detectDataSetIds,
String detectSegment, int offset) {
public void detectByStep(QueryContext queryContext, Set<EmbeddingResult> existResults,
Set<Long> detectDataSetIds, String detectSegment, int offset) {
}
@Override
protected void detectByBatch(QueryContext queryContext, Set<EmbeddingResult> results, Set<Long> detectDataSetIds,
Set<String> detectSegments) {
protected void detectByBatch(QueryContext queryContext, Set<EmbeddingResult> results,
Set<Long> detectDataSetIds, Set<String> detectSegments) {
int embedddingMapperMin = Integer.valueOf(mapperConfig.getParameterValue(EMBEDDING_MAPPER_MIN));
int embedddingMapperMax = Integer.valueOf(mapperConfig.getParameterValue(EMBEDDING_MAPPER_MAX));
int embeddingMapperBatch = Integer.valueOf(mapperConfig.getParameterValue(EMBEDDING_MAPPER_BATCH));
List<String> queryTextsList = detectSegments.stream()
.map(detectSegment -> detectSegment.trim())
.filter(detectSegment -> StringUtils.isNotBlank(detectSegment)
&& detectSegment.length() >= optimizationConfig.getEmbeddingMapperWordMin()
&& detectSegment.length() <= optimizationConfig.getEmbeddingMapperWordMax())
&& detectSegment.length() >= embedddingMapperMin
&& detectSegment.length() <= embedddingMapperMax)
.collect(Collectors.toList());
List<List<String>> queryTextsSubList = Lists.partition(queryTextsList,
optimizationConfig.getEmbeddingMapperBatch());
embeddingMapperBatch);
for (List<String> queryTextsSub : queryTextsSubList) {
detectByQueryTextsSub(results, detectDataSetIds, queryTextsSub, queryContext);
@@ -74,15 +81,16 @@ public class EmbeddingMatchStrategy extends BaseMatchStrategy<EmbeddingResult> {
private void detectByQueryTextsSub(Set<EmbeddingResult> results, Set<Long> detectDataSetIds,
List<String> queryTextsSub, QueryContext queryContext) {
Map<Long, List<Long>> modelIdToDataSetIds = queryContext.getModelIdToDataSetIds();
int embeddingNumber = optimizationConfig.getEmbeddingMapperNumber();
double threshold = getThreshold(optimizationConfig.getEmbeddingMapperThreshold(),
optimizationConfig.getEmbeddingMapperMinThreshold(), queryContext.getMapModeEnum());
double embeddingThreshold = Double.valueOf(mapperConfig.getParameterValue(EMBEDDING_MAPPER_THRESHOLD));
double embeddingThresholdMin = Double.valueOf(mapperConfig.getParameterValue(EMBEDDING_MAPPER_THRESHOLD_MIN));
double threshold = getThreshold(embeddingThreshold, embeddingThresholdMin, queryContext.getMapModeEnum());
// step1. build query params
RetrieveQuery retrieveQuery = RetrieveQuery.builder().queryTextsList(queryTextsSub).build();
// step2. retrieveQuery by detectSegment
// step2. retrieveQuery by detectSegment
int embeddingNumber = Integer.valueOf(mapperConfig.getParameterValue(EMBEDDING_MAPPER_NUMBER));
List<RetrieveQueryResult> retrieveQueryResults = metaEmbeddingService.retrieveQuery(
retrieveQuery, embeddingNumber, modelIdToDataSetIds, detectDataSetIds);
@@ -118,7 +126,8 @@ public class EmbeddingMatchStrategy extends BaseMatchStrategy<EmbeddingResult> {
.collect(Collectors.toList());
// step4. select mapResul in one round
int roundNumber = optimizationConfig.getEmbeddingMapperRoundNumber() * queryTextsSub.size();
int embeddingRoundNumber = Integer.valueOf(mapperConfig.getParameterValue(EMBEDDING_MAPPER_ROUND_NUMBER));
int roundNumber = embeddingRoundNumber * queryTextsSub.size();
List<EmbeddingResult> oneRoundResults = collect.stream()
.sorted(Comparator.comparingDouble(EmbeddingResult::getDistance))
.limit(roundNumber)

View File

@@ -4,7 +4,6 @@ import com.tencent.supersonic.common.pojo.Constants;
import com.tencent.supersonic.headless.api.pojo.response.S2Term;
import com.tencent.supersonic.headless.core.chat.knowledge.HanlpMapResult;
import com.tencent.supersonic.headless.core.chat.knowledge.KnowledgeBaseService;
import com.tencent.supersonic.headless.core.config.OptimizationConfig;
import com.tencent.supersonic.headless.core.pojo.QueryContext;
import lombok.extern.slf4j.Slf4j;
import org.apache.commons.collections.CollectionUtils;
@@ -21,6 +20,14 @@ import java.util.Objects;
import java.util.Set;
import java.util.stream.Collectors;
import static com.tencent.supersonic.headless.core.config.MapperConfig.MAPPER_DETECTION_MAX_SIZE;
import static com.tencent.supersonic.headless.core.config.MapperConfig.MAPPER_DETECTION_SIZE;
import static com.tencent.supersonic.headless.core.config.MapperConfig.MAPPER_DIMENSION_VALUE_SIZE;
import static com.tencent.supersonic.headless.core.config.MapperConfig.MAPPER_NAME_THRESHOLD;
import static com.tencent.supersonic.headless.core.config.MapperConfig.MAPPER_NAME_THRESHOLD_MIN;
import static com.tencent.supersonic.headless.core.config.MapperConfig.MAPPER_VALUE_THRESHOLD;
import static com.tencent.supersonic.headless.core.config.MapperConfig.MAPPER_VALUE_THRESHOLD_MIN;
/**
* HanlpDictMatchStrategy uses <a href="https://www.hanlp.com/">HanLP</a> to
* match schema elements. It currently supports prefix and suffix matching
@@ -30,12 +37,6 @@ import java.util.stream.Collectors;
@Slf4j
public class HanlpDictMatchStrategy extends BaseMatchStrategy<HanlpMapResult> {
@Autowired
private MapperHelper mapperHelper;
@Autowired
private OptimizationConfig optimizationConfig;
@Autowired
private KnowledgeBaseService knowledgeBaseService;
@@ -65,7 +66,7 @@ public class HanlpDictMatchStrategy extends BaseMatchStrategy<HanlpMapResult> {
public void detectByStep(QueryContext queryContext, Set<HanlpMapResult> existResults, Set<Long> detectDataSetIds,
String detectSegment, int offset) {
// step1. pre search
Integer oneDetectionMaxSize = optimizationConfig.getOneDetectionMaxSize();
Integer oneDetectionMaxSize = Integer.valueOf(mapperConfig.getParameterValue(MAPPER_DETECTION_MAX_SIZE));
LinkedHashSet<HanlpMapResult> hanlpMapResults = knowledgeBaseService.prefixSearch(detectSegment,
oneDetectionMaxSize, queryContext.getModelIdToDataSetIds(), detectDataSetIds)
.stream().collect(Collectors.toCollection(LinkedHashSet::new));
@@ -99,12 +100,13 @@ public class HanlpDictMatchStrategy extends BaseMatchStrategy<HanlpMapResult> {
}).collect(Collectors.toCollection(LinkedHashSet::new));
// step5. take only M dimensionValue or N-M metric/dimension value per rond.
int oneDetectionValueSize = Integer.valueOf(mapperConfig.getParameterValue(MAPPER_DIMENSION_VALUE_SIZE));
List<HanlpMapResult> dimensionValues = hanlpMapResults.stream()
.filter(entry -> mapperHelper.existDimensionValues(entry.getNatures()))
.limit(optimizationConfig.getOneDetectionDimensionValueSize())
.limit(oneDetectionValueSize)
.collect(Collectors.toList());
Integer oneDetectionSize = optimizationConfig.getOneDetectionSize();
Integer oneDetectionSize = Integer.valueOf(mapperConfig.getParameterValue(MAPPER_DETECTION_SIZE));
List<HanlpMapResult> oneRoundResults = new ArrayList<>();
// add the dimensionValue if it exists
@@ -129,13 +131,14 @@ public class HanlpDictMatchStrategy extends BaseMatchStrategy<HanlpMapResult> {
}
public double getThresholdMatch(List<String> natures, QueryContext queryContext) {
Double threshold = optimizationConfig.getMetricDimensionThresholdConfig();
Double minThreshold = optimizationConfig.getMetricDimensionMinThresholdConfig();
Double threshold = Double.valueOf(mapperConfig.getParameterValue(MAPPER_NAME_THRESHOLD));
Double minThreshold = Double.valueOf(mapperConfig.getParameterValue(MAPPER_NAME_THRESHOLD_MIN));
if (mapperHelper.existDimensionValues(natures)) {
threshold = optimizationConfig.getDimensionValueThresholdConfig();
minThreshold = optimizationConfig.getDimensionValueMinThresholdConfig();
threshold = Double.valueOf(mapperConfig.getParameterValue(MAPPER_VALUE_THRESHOLD));
minThreshold = Double.valueOf(mapperConfig.getParameterValue(MAPPER_VALUE_THRESHOLD_MIN));
}
return getThreshold(threshold, minThreshold, queryContext.getMapModeEnum());
return getThreshold(threshold, minThreshold, queryContext.getMapModeEnum());
}
}

View File

@@ -2,11 +2,9 @@ package com.tencent.supersonic.headless.core.chat.mapper;
import com.hankcs.hanlp.algorithm.EditDistance;
import com.tencent.supersonic.headless.api.pojo.response.S2Term;
import com.tencent.supersonic.headless.core.config.OptimizationConfig;
import com.tencent.supersonic.headless.core.chat.knowledge.helper.NatureHelper;
import lombok.Data;
import lombok.extern.slf4j.Slf4j;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.stereotype.Service;
import java.util.Comparator;
@@ -20,9 +18,6 @@ import java.util.stream.Collectors;
@Slf4j
public class MapperHelper {
@Autowired
private OptimizationConfig optimizationConfig;
public Integer getStepIndex(Map<Integer, Integer> regOffsetToLength, Integer index) {
Integer subRegLength = regOffsetToLength.get(index);
if (Objects.nonNull(subRegLength)) {

View File

@@ -2,12 +2,16 @@ package com.tencent.supersonic.headless.core.chat.parser;
import com.tencent.supersonic.common.util.ContextUtils;
import com.tencent.supersonic.headless.api.pojo.SemanticParseInfo;
import com.tencent.supersonic.headless.core.config.OptimizationConfig;
import com.tencent.supersonic.headless.core.config.ParserConfig;
import com.tencent.supersonic.headless.core.pojo.QueryContext;
import com.tencent.supersonic.headless.core.chat.query.SemanticQuery;
import com.tencent.supersonic.headless.core.chat.query.llm.s2sql.LLMSqlQuery;
import lombok.extern.slf4j.Slf4j;
import static com.tencent.supersonic.headless.core.config.ParserConfig.PARSER_TEXT_LENGTH_THRESHOLD;
import static com.tencent.supersonic.headless.core.config.ParserConfig.PARSER_TEXT_LENGTH_THRESHOLD_LONG;
import static com.tencent.supersonic.headless.core.config.ParserConfig.PARSER_TEXT_LENGTH_THRESHOLD_SHORT;
/**
* This checker can be used by semantic parsers to check if query intent
* has already been satisfied by current candidate queries. If so, current
@@ -32,12 +36,19 @@ public class SatisfactionChecker {
private static boolean checkThreshold(String queryText, SemanticParseInfo semanticParseInfo) {
int queryTextLength = queryText.replaceAll(" ", "").length();
double degree = semanticParseInfo.getScore() / queryTextLength;
OptimizationConfig optimizationConfig = ContextUtils.getBean(OptimizationConfig.class);
if (queryTextLength > optimizationConfig.getQueryTextLengthThreshold()) {
if (degree < optimizationConfig.getLongTextThreshold()) {
ParserConfig parserConfig = ContextUtils.getBean(ParserConfig.class);
int textLengthThreshold =
Integer.valueOf(parserConfig.getParameterValue(PARSER_TEXT_LENGTH_THRESHOLD));
double longTextLengthThreshold =
Double.valueOf(parserConfig.getParameterValue(PARSER_TEXT_LENGTH_THRESHOLD_LONG));
double shortTextLengthThreshold =
Double.valueOf(parserConfig.getParameterValue(PARSER_TEXT_LENGTH_THRESHOLD_SHORT));
if (queryTextLength > textLengthThreshold) {
if (degree < longTextLengthThreshold) {
return false;
}
} else if (degree < optimizationConfig.getShortTextThreshold()) {
} else if (degree < shortTextLengthThreshold) {
return false;
}
log.info("queryMode:{}, degree:{}, parse info:{}",

View File

@@ -2,7 +2,6 @@ package com.tencent.supersonic.headless.core.chat.parser.llm;
import com.tencent.supersonic.headless.core.chat.query.llm.s2sql.LLMReq;
import com.tencent.supersonic.headless.core.chat.query.llm.s2sql.LLMReq.SqlGenType;
import com.tencent.supersonic.headless.core.chat.query.llm.s2sql.LLMResp;
import lombok.extern.slf4j.Slf4j;
import org.springframework.stereotype.Component;
@@ -16,8 +15,7 @@ public class JavaLLMProxy implements LLMProxy {
public LLMResp text2sql(LLMReq llmReq) {
SqlGenStrategy sqlGenStrategy = SqlGenStrategyFactory.get(
SqlGenType.getMode(llmReq.getSqlGenerationMode()));
SqlGenStrategy sqlGenStrategy = SqlGenStrategyFactory.get(llmReq.getSqlGenType());
String modelName = llmReq.getSchema().getDataSetName();
LLMResp result = sqlGenStrategy.generate(llmReq);
result.setQuery(llmReq.getQueryText());

View File

@@ -11,8 +11,8 @@ import com.tencent.supersonic.headless.core.chat.parser.SatisfactionChecker;
import com.tencent.supersonic.headless.core.chat.query.llm.s2sql.LLMReq;
import com.tencent.supersonic.headless.core.chat.query.llm.s2sql.LLMReq.ElementValue;
import com.tencent.supersonic.headless.core.chat.query.llm.s2sql.LLMResp;
import com.tencent.supersonic.headless.core.config.ParserConfig;
import com.tencent.supersonic.headless.core.config.LLMParserConfig;
import com.tencent.supersonic.headless.core.config.OptimizationConfig;
import com.tencent.supersonic.headless.core.pojo.QueryContext;
import com.tencent.supersonic.headless.core.utils.ComponentFactory;
import com.tencent.supersonic.headless.core.utils.S2SqlDateHelper;
@@ -31,14 +31,18 @@ import java.util.Objects;
import java.util.Set;
import java.util.stream.Collectors;
import static com.tencent.supersonic.headless.core.config.ParserConfig.PARSER_LINKING_VALUE_ENABLE;
import static com.tencent.supersonic.headless.core.config.ParserConfig.PARSER_STRATEGY_TYPE;
@Slf4j
@Service
public class LLMRequestService {
@Autowired
private LLMParserConfig llmParserConfig;
@Autowired
private OptimizationConfig optimizationConfig;
private ParserConfig parserConfig;
public boolean isSkip(QueryContext queryCtx) {
if (!queryCtx.getText2SQLType().enableLLM()) {
@@ -86,7 +90,9 @@ public class LLMRequestService {
llmReq.setSchema(llmSchema);
List<ElementValue> linking = new ArrayList<>();
if (optimizationConfig.isUseLinkingValueSwitch()) {
boolean linkingValueEnabled = Boolean.valueOf(parserConfig.getParameterValue(PARSER_LINKING_VALUE_ENABLE));
if (linkingValueEnabled) {
linking.addAll(linkingValues);
}
llmReq.setLinking(linking);
@@ -96,7 +102,7 @@ public class LLMRequestService {
currentDate = DateUtils.getBeforeDate(0);
}
llmReq.setCurrentDate(currentDate);
llmReq.setSqlGenerationMode(optimizationConfig.getSqlGenType().getName());
llmReq.setSqlGenType(LLMReq.SqlGenType.valueOf(parserConfig.getParameterValue(PARSER_STRATEGY_TYPE)));
llmReq.setLlmConfig(queryCtx.getLlmConfig());
return llmReq;
}

View File

@@ -18,6 +18,11 @@ import java.util.Map;
import java.util.concurrent.CopyOnWriteArrayList;
import java.util.stream.Collectors;
import static com.tencent.supersonic.headless.core.config.ParserConfig.PARSER_EXEMPLAR_RECALL_NUMBER;
import static com.tencent.supersonic.headless.core.config.ParserConfig.PARSER_FEW_SHOT_NUMBER;
import static com.tencent.supersonic.headless.core.config.ParserConfig.PARSER_SELF_CONSISTENCY_NUMBER;
@Service
@Slf4j
public class OnePassSCSqlGenStrategy extends SqlGenStrategy {
@@ -27,11 +32,14 @@ public class OnePassSCSqlGenStrategy extends SqlGenStrategy {
//1.retriever sqlExamples and generate exampleListPool
keyPipelineLog.info("OnePassSCSqlGenStrategy llmReq:{}", llmReq);
List<Map<String, String>> sqlExamples = exemplarManager.recallExemplars(llmReq.getQueryText(),
optimizationConfig.getText2sqlExampleNum());
int exemplarRecallNumber = Integer.valueOf(parserConfig.getParameterValue(PARSER_EXEMPLAR_RECALL_NUMBER));
int fewShotNumber = Integer.valueOf(parserConfig.getParameterValue(PARSER_FEW_SHOT_NUMBER));
int selfConsistencyNumber = Integer.valueOf(parserConfig.getParameterValue(PARSER_SELF_CONSISTENCY_NUMBER));
List<Map<String, String>> sqlExamples = exemplarManager.recallExemplars(llmReq.getQueryText(),
exemplarRecallNumber);
List<List<Map<String, String>>> exampleListPool = promptGenerator.getExampleCombos(sqlExamples,
optimizationConfig.getText2sqlFewShotsNum(), optimizationConfig.getText2sqlSelfConsistencyNum());
fewShotNumber, selfConsistencyNumber);
//2.generator linking and sql prompt by sqlExamples,and parallel generate response.
List<String> linkingSqlPromptPool = promptGenerator.generatePromptPool(llmReq, exampleListPool, true);

View File

@@ -3,7 +3,7 @@ package com.tencent.supersonic.headless.core.chat.parser.llm;
import com.tencent.supersonic.headless.api.pojo.LLMConfig;
import com.tencent.supersonic.headless.core.chat.query.llm.s2sql.LLMReq;
import com.tencent.supersonic.headless.core.chat.query.llm.s2sql.LLMResp;
import com.tencent.supersonic.headless.core.config.OptimizationConfig;
import com.tencent.supersonic.headless.core.config.ParserConfig;
import com.tencent.supersonic.headless.core.utils.S2ChatModelProvider;
import dev.langchain4j.model.chat.ChatLanguageModel;
import org.slf4j.Logger;
@@ -25,7 +25,7 @@ public abstract class SqlGenStrategy implements InitializingBean {
protected ExemplarManager exemplarManager;
@Autowired
protected OptimizationConfig optimizationConfig;
protected ParserConfig parserConfig;
@Autowired
protected PromptGenerator promptGenerator;

View File

@@ -17,6 +17,10 @@ import java.util.List;
import java.util.Map;
import java.util.concurrent.CopyOnWriteArrayList;
import static com.tencent.supersonic.headless.core.config.ParserConfig.PARSER_EXEMPLAR_RECALL_NUMBER;
import static com.tencent.supersonic.headless.core.config.ParserConfig.PARSER_FEW_SHOT_NUMBER;
import static com.tencent.supersonic.headless.core.config.ParserConfig.PARSER_SELF_CONSISTENCY_NUMBER;
@Service
public class TwoPassSCSqlGenStrategy extends SqlGenStrategy {
@@ -24,11 +28,15 @@ public class TwoPassSCSqlGenStrategy extends SqlGenStrategy {
public LLMResp generate(LLMReq llmReq) {
//1.retriever sqlExamples and generate exampleListPool
keyPipelineLog.info("TwoPassSCSqlGenStrategy llmReq:{}", llmReq);
List<Map<String, String>> sqlExamples = exemplarManager.recallExemplars(llmReq.getQueryText(),
optimizationConfig.getText2sqlExampleNum());
int exemplarRecallNumber = Integer.valueOf(parserConfig.getParameterValue(PARSER_EXEMPLAR_RECALL_NUMBER));
int fewShotNumber = Integer.valueOf(parserConfig.getParameterValue(PARSER_FEW_SHOT_NUMBER));
int selfConsistencyNumber = Integer.valueOf(parserConfig.getParameterValue(PARSER_SELF_CONSISTENCY_NUMBER));
List<Map<String, String>> sqlExamples = exemplarManager.recallExemplars(llmReq.getQueryText(),
exemplarRecallNumber);
List<List<Map<String, String>>> exampleListPool = promptGenerator.getExampleCombos(sqlExamples,
optimizationConfig.getText2sqlFewShotsNum(), optimizationConfig.getText2sqlSelfConsistencyNum());
fewShotNumber, selfConsistencyNumber);
//2.generator linking prompt,and parallel generate response.
List<String> linkingPromptPool = promptGenerator.generatePromptPool(llmReq, exampleListPool, false);

View File

@@ -10,7 +10,7 @@ import com.tencent.supersonic.headless.api.pojo.SemanticParseInfo;
import com.tencent.supersonic.headless.api.pojo.SemanticSchema;
import com.tencent.supersonic.headless.api.pojo.request.QuerySqlReq;
import com.tencent.supersonic.headless.api.pojo.request.QueryStructReq;
import com.tencent.supersonic.headless.core.config.OptimizationConfig;
import com.tencent.supersonic.headless.core.config.ParserConfig;
import com.tencent.supersonic.headless.core.utils.QueryReqBuilder;
import lombok.ToString;
import lombok.extern.slf4j.Slf4j;
@@ -21,6 +21,8 @@ import java.util.List;
import java.util.Map;
import java.util.stream.Collectors;
import static com.tencent.supersonic.headless.core.config.ParserConfig.PARSER_S2SQL_ENABLE;
@Slf4j
@ToString
public abstract class BaseSemanticQuery implements SemanticQuery, Serializable {
@@ -73,8 +75,9 @@ public abstract class BaseSemanticQuery implements SemanticQuery, Serializable {
}
protected void initS2SqlByStruct(SemanticSchema semanticSchema) {
OptimizationConfig optimizationConfig = ContextUtils.getBean(OptimizationConfig.class);
if (!optimizationConfig.isUseS2SqlSwitch()) {
ParserConfig parserConfig = ContextUtils.getBean(ParserConfig.class);
boolean s2sqlEnable = Boolean.valueOf(parserConfig.getParameterValue(PARSER_S2SQL_ENABLE));
if (!s2sqlEnable) {
return;
}
QueryStructReq queryStructReq = convertQueryStruct();

View File

@@ -22,7 +22,7 @@ public class LLMReq {
private String priorExts;
private String sqlGenerationMode;
private SqlGenType sqlGenType;
private LLMConfig llmConfig;
@@ -82,14 +82,5 @@ public class LLMReq {
return name;
}
public static SqlGenType getMode(String name) {
for (SqlGenType sqlGenType : SqlGenType.values()) {
if (sqlGenType.name.equals(name)) {
return sqlGenType;
}
}
return null;
}
}
}

View File

@@ -0,0 +1,110 @@
package com.tencent.supersonic.headless.core.config;
import com.google.common.collect.Lists;
import com.tencent.supersonic.common.pojo.ParameterConfig;
import com.tencent.supersonic.common.pojo.Parameter;
import org.springframework.stereotype.Service;
import java.util.List;
@Service("HeadlessMapperConfig")
public class MapperConfig extends ParameterConfig {
public static final Parameter MAPPER_DETECTION_SIZE =
new Parameter("s2.mapper.detection.size", "8",
"一次探测返回结果个数",
"在每次探测后, 将前后缀匹配的结果合并, 并根据相似度阈值过滤后的结果个数",
"number", "Mapper相关配置");
public static final Parameter MAPPER_DETECTION_MAX_SIZE =
new Parameter("s2.mapper.detection.max.size", "20",
"一次探测前后缀匹配结果返回个数",
"单次前后缀匹配返回的结果个数",
"number", "Mapper相关配置");
public static final Parameter MAPPER_NAME_THRESHOLD =
new Parameter("s2.mapper.name.threshold", "0.3",
"指标名、维度名文本相似度阈值",
"文本片段和匹配到的指标、维度名计算出来的编辑距离阈值, 若超出该阈值, 则舍弃",
"number", "Mapper相关配置");
public static final Parameter MAPPER_NAME_THRESHOLD_MIN =
new Parameter("s2.mapper.name.min.threshold", "0.25",
"指标名、维度名最小文本相似度阈值",
"指标名、维度名相似度阈值在动态调整中的最低值",
"number", "Mapper相关配置");
public static final Parameter MAPPER_DIMENSION_VALUE_SIZE =
new Parameter("s2.mapper.value.size", "1",
"指标名、维度名最小文本相似度阈值",
"指标名、维度名相似度阈值在动态调整中的最低值",
"number", "Mapper相关配置");
public static final Parameter MAPPER_VALUE_THRESHOLD =
new Parameter("s2.mapper.value.threshold", "0.5",
"维度值文本相似度阈值",
"文本片段和匹配到的维度值计算出来的编辑距离阈值, 若超出该阈值, 则舍弃",
"number", "Mapper相关配置");
public static final Parameter MAPPER_VALUE_THRESHOLD_MIN =
new Parameter("s2.mapper.value.min.threshold", "0.3",
"维度值最小文本相似度阈值",
"维度值相似度阈值在动态调整中的最低值",
"number", "Mapper相关配置");
public static final Parameter EMBEDDING_MAPPER_MIN =
new Parameter("s2.mapper.embedding.word.min", "4",
"用于向量召回最小的文本长度",
"为提高向量召回效率, 小于该长度的文本不进行向量语义召回",
"number", "Mapper相关配置");
public static final Parameter EMBEDDING_MAPPER_MAX =
new Parameter("s2.mapper.embedding.word.max", "5",
"用于向量召回最大的文本长度",
"为提高向量召回效率, 大于该长度的文本不进行向量语义召回",
"number", "Mapper相关配置");
public static final Parameter EMBEDDING_MAPPER_BATCH =
new Parameter("s2.mapper.embedding.batch", "50",
"批量向量召回文本请求个数",
"每次进行向量语义召回的原始文本片段个数",
"number", "Mapper相关配置");
public static final Parameter EMBEDDING_MAPPER_NUMBER =
new Parameter("s2.mapper.embedding.number", "5",
"批量向量召回文本返回结果个数",
"每个文本进行向量语义召回的文本结果个数",
"number", "Mapper相关配置");
public static final Parameter EMBEDDING_MAPPER_THRESHOLD =
new Parameter("s2.mapper.embedding.threshold", "0.99",
"向量召回相似度阈值",
"相似度小于该阈值的则舍弃",
"number", "Mapper相关配置");
public static final Parameter EMBEDDING_MAPPER_THRESHOLD_MIN =
new Parameter("s2.mapper.embedding.min.threshold", "0.9",
"向量召回最小相似度阈值",
"向量召回相似度阈值在动态调整中的最低值",
"number", "Mapper相关配置");
public static final Parameter EMBEDDING_MAPPER_ROUND_NUMBER =
new Parameter("s2.mapper.embedding.round.number", "10",
"向量召回最小相似度阈值",
"向量召回相似度阈值在动态调整中的最低值",
"number", "Mapper相关配置");
@Override
public List<Parameter> getSysParameters() {
return Lists.newArrayList(
MAPPER_DETECTION_SIZE,
MAPPER_DETECTION_MAX_SIZE,
MAPPER_NAME_THRESHOLD,
MAPPER_NAME_THRESHOLD_MIN,
MAPPER_DIMENSION_VALUE_SIZE,
MAPPER_VALUE_THRESHOLD,
MAPPER_VALUE_THRESHOLD_MIN
);
}
}

View File

@@ -1,193 +0,0 @@
package com.tencent.supersonic.headless.core.config;
import com.tencent.supersonic.common.service.SysParameterService;
import com.tencent.supersonic.headless.core.chat.query.llm.s2sql.LLMReq;
import lombok.Data;
import lombok.extern.slf4j.Slf4j;
import org.apache.commons.lang3.StringUtils;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.beans.factory.annotation.Value;
import org.springframework.context.annotation.Configuration;
@Configuration
@Data
@Slf4j
public class OptimizationConfig {
@Value("${s2.one.detection.size:8}")
private Integer oneDetectionSize;
@Value("${s2.one.detection.max.size:20}")
private Integer oneDetectionMaxSize;
@Value("${s2.one.detection.dimensionValue.size:1}")
private Integer oneDetectionDimensionValueSize;
@Value("${s2.metric.dimension.min.threshold:0.3}")
private Double metricDimensionMinThresholdConfig;
@Value("${s2.metric.dimension.threshold:0.3}")
private Double metricDimensionThresholdConfig;
@Value("${s2.dimension.value.min.threshold:0.2}")
private Double dimensionValueMinThresholdConfig;
@Value("${s2.dimension.value.threshold:0.5}")
private Double dimensionValueThresholdConfig;
@Value("${s2.long.text.threshold:0.8}")
private Double longTextThreshold;
@Value("${s2.short.text.threshold:0.5}")
private Double shortTextThreshold;
@Value("${s2.query.text.length.threshold:10}")
private Integer queryTextLengthThreshold;
@Value("${s2.embedding.mapper.word.min:4}")
private int embeddingMapperWordMin;
@Value("${s2.embedding.mapper.word.max:4}")
private int embeddingMapperWordMax;
@Value("${s2.embedding.mapper.batch:50}")
private int embeddingMapperBatch;
@Value("${s2.embedding.mapper.number:5}")
private int embeddingMapperNumber;
@Value("${s2.embedding.mapper.round.number:10}")
private int embeddingMapperRoundNumber;
@Value("${s2.embedding.mapper.min.threshold:0.6}")
private Double embeddingMapperMinThreshold;
@Value("${s2.embedding.mapper.threshold:0.99}")
private Double embeddingMapperThreshold;
@Value("${s2.parser.linking.value.switch:true}")
private boolean useLinkingValueSwitch;
@Value("${s2.parser.strategy:TWO_PASS_AUTO_COT_SELF_CONSISTENCY}")
private LLMReq.SqlGenType sqlGenType;
@Value("${s2.parser.use.switch:true}")
private boolean useS2SqlSwitch;
@Value("${s2.parser.exemplar-recall.number:15}")
private int text2sqlExampleNum;
@Value("${s2.parser.few-shot.number:5}")
private int text2sqlFewShotsNum;
@Value("${s2.parser.self-consistency.number:5}")
private int text2sqlSelfConsistencyNum;
@Value("${s2.parser.show-count:3}")
private Integer parseShowCount;
@Autowired
private SysParameterService sysParameterService;
public Integer getOneDetectionSize() {
return convertValue("s2.one.detection.size", Integer.class, oneDetectionSize);
}
public Integer getOneDetectionMaxSize() {
return convertValue("s2.one.detection.max.size", Integer.class, oneDetectionMaxSize);
}
public Double getMetricDimensionMinThresholdConfig() {
return convertValue("s2.metric.dimension.min.threshold", Double.class, metricDimensionMinThresholdConfig);
}
public Double getMetricDimensionThresholdConfig() {
return convertValue("s2.metric.dimension.threshold", Double.class, metricDimensionThresholdConfig);
}
public Double getDimensionValueMinThresholdConfig() {
return convertValue("s2.dimension.value.min.threshold", Double.class, dimensionValueMinThresholdConfig);
}
public Double getDimensionValueThresholdConfig() {
return convertValue("s2.dimension.value.threshold", Double.class, dimensionValueThresholdConfig);
}
public Double getLongTextThreshold() {
return convertValue("s2.long.text.threshold", Double.class, longTextThreshold);
}
public Double getShortTextThreshold() {
return convertValue("s2.short.text.threshold", Double.class, shortTextThreshold);
}
public Integer getQueryTextLengthThreshold() {
return convertValue("s2.query.text.length.threshold", Integer.class, queryTextLengthThreshold);
}
public Integer getEmbeddingMapperWordMin() {
return convertValue("s2.embedding.mapper.word.min", Integer.class, embeddingMapperWordMin);
}
public Integer getEmbeddingMapperWordMax() {
return convertValue("s2.embedding.mapper.word.max", Integer.class, embeddingMapperWordMax);
}
public Integer getEmbeddingMapperBatch() {
return convertValue("s2.embedding.mapper.batch", Integer.class, embeddingMapperBatch);
}
public Integer getEmbeddingMapperNumber() {
return convertValue("s2.embedding.mapper.number", Integer.class, embeddingMapperNumber);
}
public Integer getEmbeddingMapperRoundNumber() {
return convertValue("s2.embedding.mapper.round.number", Integer.class, embeddingMapperRoundNumber);
}
public Double getEmbeddingMapperMinThreshold() {
return convertValue("s2.embedding.mapper.min.threshold", Double.class, embeddingMapperMinThreshold);
}
public Double getEmbeddingMapperThreshold() {
return convertValue("s2.embedding.mapper.threshold", Double.class, embeddingMapperThreshold);
}
public boolean isUseS2SqlSwitch() {
return convertValue("s2.parser.use.switch", Boolean.class, useS2SqlSwitch);
}
public boolean isUseLinkingValueSwitch() {
return convertValue("s2.parser.linking.value.switch", Boolean.class, useLinkingValueSwitch);
}
public LLMReq.SqlGenType getSqlGenType() {
return convertValue("s2.parser.strategy", LLMReq.SqlGenType.class, sqlGenType);
}
public Integer getParseShowCount() {
return convertValue("s2.parse.show-count", Integer.class, parseShowCount);
}
public <T> T convertValue(String paramName, Class<T> targetType, T defaultValue) {
try {
String value = sysParameterService.getSysParameter().getParameterByName(paramName);
if (StringUtils.isBlank(value)) {
return defaultValue;
}
if (targetType == Double.class) {
return targetType.cast(Double.parseDouble(value));
} else if (targetType == Integer.class) {
return targetType.cast(Integer.parseInt(value));
} else if (targetType == Boolean.class) {
return targetType.cast(Boolean.parseBoolean(value));
} else if (targetType == LLMReq.SqlGenType.class) {
return targetType.cast(LLMReq.SqlGenType.valueOf(value));
}
} catch (Exception e) {
log.error("convertValue", e);
}
return defaultValue;
}
}

View File

@@ -0,0 +1,84 @@
package com.tencent.supersonic.headless.core.config;
import com.google.common.collect.Lists;
import com.tencent.supersonic.common.pojo.ParameterConfig;
import com.tencent.supersonic.common.pojo.Parameter;
import lombok.extern.slf4j.Slf4j;
import org.springframework.stereotype.Service;
import java.util.List;
@Service("HeadlessParserConfig")
@Slf4j
public class ParserConfig extends ParameterConfig {
public static final Parameter PARSER_STRATEGY_TYPE =
new Parameter("s2.parser.strategy", "ONE_PASS_AUTO_COT_SELF_CONSISTENCY",
"LLM解析生成S2SQL策略",
"ONE_PASS_AUTO_COT_SELF_CONSISTENCY: 通过思维链且投票方式一步生成sql"
+ "\nTWO_PASS_AUTO_COT_SELF_CONSISTENCY: 通过思维链且投票方式两步生成sql",
"list", "Parser相关配置", Lists.newArrayList(
"ONE_PASS_AUTO_COT_SELF_CONSISTENCY", "TWO_PASS_AUTO_COT_SELF_CONSISTENCY"));
public static final Parameter PARSER_LINKING_VALUE_ENABLE =
new Parameter("s2.parser.linking.value.enable", "true",
"是否将Mapper探测识别到的维度值提供给大模型", "为了数据安全考虑, 这里可进行开关选择",
"bool", "Parser相关配置");
public static final Parameter PARSER_TEXT_LENGTH_THRESHOLD =
new Parameter("s2.parser.text.length.threshold", "10",
"用户输入文本长短阈值", "文本超过该阈值为长文本",
"number", "Parser相关配置");
public static final Parameter PARSER_TEXT_LENGTH_THRESHOLD_SHORT =
new Parameter("s2.parser.text.threshold", "0.5",
"短文本匹配阈值",
"由于请求大模型耗时较长, 因此如果有规则类型的Query得分达到阈值,则跳过大模型的调用,"
+ "\n如果是短文本, 若query得分/文本长度>该阈值, 则跳过当前parser",
"number", "Parser相关配置");
public static final Parameter PARSER_TEXT_LENGTH_THRESHOLD_LONG =
new Parameter("s2.parser.text.threshold", "0.8",
"长文本匹配阈值", "如果是长文本, 若query得分/文本长度>该阈值, 则跳过当前parser",
"number", "Parser相关配置");
public static final Parameter PARSER_EXEMPLAR_RECALL_NUMBER =
new Parameter("s2.parser.exemplar-recall.number", "10",
"exemplar召回个数", "",
"number", "Parser相关配置");
public static final Parameter PARSER_FEW_SHOT_NUMBER =
new Parameter("s2.parser.few-shot.number", "5",
"few-shot样例个数", "样例越多效果可能越好但token消耗越大",
"number", "Parser相关配置");
public static final Parameter PARSER_SELF_CONSISTENCY_NUMBER =
new Parameter("s2.parser.self-consistency.number", "1",
"self-consistency执行个数", "执行越多效果可能越好但token消耗越大",
"number", "Parser相关配置");
public static final Parameter PARSER_SHOW_COUNT =
new Parameter("s2.parser.show.count", "3",
"解析结果展示个数", "前端展示的解析个数",
"number", "Parser相关配置");
public static final Parameter PARSER_S2SQL_ENABLE =
new Parameter("s2.parser.s2sql.switch", "true",
"", "",
"bool", "Parser相关配置");
@Override
public List<Parameter> getSysParameters() {
return Lists.newArrayList(
PARSER_STRATEGY_TYPE,
PARSER_LINKING_VALUE_ENABLE,
PARSER_TEXT_LENGTH_THRESHOLD,
PARSER_TEXT_LENGTH_THRESHOLD_SHORT,
PARSER_TEXT_LENGTH_THRESHOLD_LONG,
PARSER_FEW_SHOT_NUMBER,
PARSER_SELF_CONSISTENCY_NUMBER,
PARSER_SHOW_COUNT
);
}
}

View File

@@ -12,7 +12,7 @@ import com.tencent.supersonic.headless.api.pojo.enums.MapModeEnum;
import com.tencent.supersonic.headless.api.pojo.enums.WorkflowState;
import com.tencent.supersonic.headless.api.pojo.request.QueryFilters;
import com.tencent.supersonic.headless.core.chat.query.SemanticQuery;
import com.tencent.supersonic.headless.core.config.OptimizationConfig;
import com.tencent.supersonic.headless.core.config.ParserConfig;
import lombok.AllArgsConstructor;
import lombok.Builder;
import lombok.Data;
@@ -26,6 +26,8 @@ import java.util.Map;
import java.util.Set;
import java.util.stream.Collectors;
import static com.tencent.supersonic.headless.core.config.ParserConfig.PARSER_SHOW_COUNT;
@Data
@Builder
@NoArgsConstructor
@@ -51,8 +53,8 @@ public class QueryContext {
private LLMConfig llmConfig;
public List<SemanticQuery> getCandidateQueries() {
OptimizationConfig optimizationConfig = ContextUtils.getBean(OptimizationConfig.class);
Integer parseShowCount = optimizationConfig.getParseShowCount();
ParserConfig parserConfig = ContextUtils.getBean(ParserConfig.class);
int parseShowCount = Integer.valueOf(parserConfig.getParameterValue(PARSER_SHOW_COUNT));
candidateQueries = candidateQueries.stream()
.sorted(Comparator.comparing(semanticQuery -> semanticQuery.getParseInfo().getScore(),
Comparator.reverseOrder()))

View File

@@ -9,7 +9,7 @@ import com.tencent.supersonic.chat.server.service.AgentService;
import com.tencent.supersonic.chat.server.service.ChatManageService;
import com.tencent.supersonic.chat.server.service.ChatService;
import com.tencent.supersonic.chat.server.service.PluginService;
import com.tencent.supersonic.common.service.SysParameterService;
import com.tencent.supersonic.common.service.SystemConfigService;
import com.tencent.supersonic.headless.api.pojo.DataSetModelConfig;
import com.tencent.supersonic.headless.api.pojo.DrillDownDimension;
import com.tencent.supersonic.headless.api.pojo.RelateDimension;
@@ -84,7 +84,7 @@ public abstract class S2BaseDemo implements CommandLineRunner {
@Autowired
protected AgentService agentService;
@Autowired
protected SysParameterService sysParameterService;
protected SystemConfigService sysParameterService;
@Autowired
protected CanvasService canvasService;
@Autowired

View File

@@ -16,7 +16,7 @@ import com.tencent.supersonic.chat.server.plugin.PluginParseConfig;
import com.tencent.supersonic.chat.server.plugin.build.WebBase;
import com.tencent.supersonic.common.pojo.JoinCondition;
import com.tencent.supersonic.common.pojo.ModelRela;
import com.tencent.supersonic.common.pojo.SysParameter;
import com.tencent.supersonic.common.pojo.SystemConfig;
import com.tencent.supersonic.common.pojo.enums.AggOperatorEnum;
import com.tencent.supersonic.common.pojo.enums.AggregateTypeEnum;
import com.tencent.supersonic.common.pojo.enums.FilterOperatorEnum;
@@ -150,7 +150,7 @@ public class S2VisitsDemo extends S2BaseDemo {
}
public void addSysParameter() {
SysParameter sysParameter = new SysParameter();
SystemConfig sysParameter = new SystemConfig();
sysParameter.setId(1);
sysParameter.init();
sysParameterService.save(sysParameter);

View File

@@ -37,7 +37,7 @@ logging:
s2:
parser:
strategy: TWO_PASS_AUTO_COT_SELF_CONSISTENCY
strategy: ONE_PASS_AUTO_COT_SELF_CONSISTENCY
exemplar-recall:
number: 5
few-shot: