mirror of
https://github.com/tencentmusic/supersonic.git
synced 2025-12-10 19:51:00 +00:00
(improvement)(headless&chat)Refactor system parameter impl
This commit is contained in:
@@ -2,8 +2,8 @@ package com.tencent.supersonic.auth.api.authentication.utils;
|
||||
|
||||
import com.tencent.supersonic.auth.api.authentication.pojo.User;
|
||||
import com.tencent.supersonic.auth.api.authentication.service.UserStrategy;
|
||||
import com.tencent.supersonic.common.pojo.SysParameter;
|
||||
import com.tencent.supersonic.common.service.SysParameterService;
|
||||
import com.tencent.supersonic.common.pojo.SystemConfig;
|
||||
import com.tencent.supersonic.common.service.SystemConfigService;
|
||||
import com.tencent.supersonic.common.util.ContextUtils;
|
||||
import org.springframework.util.CollectionUtils;
|
||||
|
||||
@@ -20,8 +20,8 @@ public final class UserHolder {
|
||||
|
||||
public static User findUser(HttpServletRequest request, HttpServletResponse response) {
|
||||
User user = REPO.findUser(request, response);
|
||||
SysParameterService sysParameterService = ContextUtils.getBean(SysParameterService.class);
|
||||
SysParameter sysParameter = sysParameterService.getSysParameter();
|
||||
SystemConfigService sysParameterService = ContextUtils.getBean(SystemConfigService.class);
|
||||
SystemConfig sysParameter = sysParameterService.getSysParameter();
|
||||
if (!CollectionUtils.isEmpty(sysParameter.getAdmins())
|
||||
&& sysParameter.getAdmins().contains(user.getName())) {
|
||||
user.setIsAdmin(1);
|
||||
|
||||
@@ -6,8 +6,8 @@ import com.tencent.supersonic.auth.api.authentication.request.UserReq;
|
||||
import com.tencent.supersonic.auth.api.authentication.service.UserService;
|
||||
import com.tencent.supersonic.auth.api.authentication.utils.UserHolder;
|
||||
import com.tencent.supersonic.auth.authentication.utils.ComponentFactory;
|
||||
import com.tencent.supersonic.common.pojo.SysParameter;
|
||||
import com.tencent.supersonic.common.service.SysParameterService;
|
||||
import com.tencent.supersonic.common.pojo.SystemConfig;
|
||||
import com.tencent.supersonic.common.service.SystemConfigService;
|
||||
import org.springframework.stereotype.Service;
|
||||
import org.springframework.util.CollectionUtils;
|
||||
import javax.servlet.http.HttpServletRequest;
|
||||
@@ -18,9 +18,9 @@ import java.util.Set;
|
||||
@Service
|
||||
public class UserServiceImpl implements UserService {
|
||||
|
||||
private SysParameterService sysParameterService;
|
||||
private SystemConfigService sysParameterService;
|
||||
|
||||
public UserServiceImpl(SysParameterService sysParameterService) {
|
||||
public UserServiceImpl(SystemConfigService sysParameterService) {
|
||||
this.sysParameterService = sysParameterService;
|
||||
}
|
||||
|
||||
@@ -28,7 +28,7 @@ public class UserServiceImpl implements UserService {
|
||||
public User getCurrentUser(HttpServletRequest httpServletRequest, HttpServletResponse httpServletResponse) {
|
||||
User user = UserHolder.findUser(httpServletRequest, httpServletResponse);
|
||||
if (user != null) {
|
||||
SysParameter sysParameter = sysParameterService.getSysParameter();
|
||||
SystemConfig sysParameter = sysParameterService.getSysParameter();
|
||||
if (!CollectionUtils.isEmpty(sysParameter.getAdmins())
|
||||
&& sysParameter.getAdmins().contains(user.getName())) {
|
||||
user.setIsAdmin(1);
|
||||
|
||||
@@ -23,7 +23,6 @@ import lombok.Data;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.slf4j.Logger;
|
||||
import org.slf4j.LoggerFactory;
|
||||
import org.springframework.core.env.Environment;
|
||||
|
||||
import java.util.Map;
|
||||
import java.util.HashMap;
|
||||
@@ -32,6 +31,8 @@ import java.util.List;
|
||||
import java.util.stream.Collectors;
|
||||
import java.util.Collections;
|
||||
|
||||
import static com.tencent.supersonic.chat.server.parser.ParserConfig.PARSER_MULTI_TURN_ENABLE;
|
||||
|
||||
@Slf4j
|
||||
public class MultiTurnParser implements ChatParser {
|
||||
|
||||
@@ -51,9 +52,9 @@ public class MultiTurnParser implements ChatParser {
|
||||
|
||||
@Override
|
||||
public void parse(ChatParseContext chatParseContext, ParseResp parseResp) {
|
||||
Environment environment = ContextUtils.getBean(Environment.class);
|
||||
ParserConfig parserConfig = ContextUtils.getBean(ParserConfig.class);
|
||||
MultiTurnConfig agentMultiTurnConfig = chatParseContext.getAgent().getMultiTurnConfig();
|
||||
Boolean globalMultiTurnConfig = environment.getProperty("s2.parser.multi-turn.enable", Boolean.class);
|
||||
Boolean globalMultiTurnConfig = Boolean.valueOf(parserConfig.getParameterValue(PARSER_MULTI_TURN_ENABLE));
|
||||
|
||||
Boolean multiTurnConfig = agentMultiTurnConfig != null
|
||||
? agentMultiTurnConfig.isEnableMultiTurn() : globalMultiTurnConfig;
|
||||
|
||||
@@ -17,8 +17,8 @@ public class NL2PluginParser implements ChatParser {
|
||||
public void parse(ChatParseContext chatParseContext, ParseResp parseResp) {
|
||||
pluginRecognizers.forEach(pluginRecognizer -> {
|
||||
pluginRecognizer.recognize(chatParseContext, parseResp);
|
||||
log.info("{} context:{} result:{}", pluginRecognizer.getClass().getSimpleName(),
|
||||
JsonUtil.toString(chatParseContext), JsonUtil.toString(parseResp));
|
||||
log.info("{} recallResult:{}", pluginRecognizer.getClass().getSimpleName(),
|
||||
JsonUtil.toString(parseResp));
|
||||
});
|
||||
}
|
||||
|
||||
|
||||
@@ -0,0 +1,27 @@
|
||||
package com.tencent.supersonic.chat.server.parser;
|
||||
|
||||
import com.google.common.collect.Lists;
|
||||
import com.tencent.supersonic.common.pojo.ParameterConfig;
|
||||
import com.tencent.supersonic.common.pojo.Parameter;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.springframework.stereotype.Service;
|
||||
|
||||
import java.util.List;
|
||||
|
||||
@Service("ChatParserConfig")
|
||||
@Slf4j
|
||||
public class ParserConfig extends ParameterConfig {
|
||||
|
||||
public static final Parameter PARSER_MULTI_TURN_ENABLE =
|
||||
new Parameter("s2.parser.multi-turn.enable", "false",
|
||||
"是否开启多轮对话", "开启多轮对话将消耗更多token",
|
||||
"bool", "Parser相关配置");
|
||||
|
||||
@Override
|
||||
public List<Parameter> getSysParameters() {
|
||||
return Lists.newArrayList(
|
||||
PARSER_MULTI_TURN_ENABLE
|
||||
);
|
||||
}
|
||||
|
||||
}
|
||||
@@ -10,15 +10,17 @@ import java.util.List;
|
||||
@NoArgsConstructor
|
||||
public class Parameter {
|
||||
private String name;
|
||||
private String value;
|
||||
private String defaultValue;
|
||||
private String comment;
|
||||
private String description;
|
||||
private String dataType;
|
||||
private String module;
|
||||
private List<Object> candidateValues;
|
||||
public Parameter(String name, String value, String comment, String description, String dataType, String module) {
|
||||
|
||||
public Parameter(String name, String defaultValue, String comment,
|
||||
String description, String dataType, String module) {
|
||||
this.name = name;
|
||||
this.value = value;
|
||||
this.defaultValue = defaultValue;
|
||||
this.comment = comment;
|
||||
this.description = description;
|
||||
this.dataType = dataType;
|
||||
|
||||
@@ -0,0 +1,49 @@
|
||||
package com.tencent.supersonic.common.pojo;
|
||||
|
||||
import com.tencent.supersonic.common.service.SystemConfigService;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.apache.commons.lang3.StringUtils;
|
||||
import org.springframework.beans.factory.annotation.Autowired;
|
||||
import org.springframework.core.env.Environment;
|
||||
import org.springframework.stereotype.Service;
|
||||
|
||||
import java.util.List;
|
||||
|
||||
@Service
|
||||
@Slf4j
|
||||
public abstract class ParameterConfig {
|
||||
|
||||
@Autowired
|
||||
private SystemConfigService sysConfigService;
|
||||
|
||||
@Autowired
|
||||
private Environment environment;
|
||||
|
||||
protected abstract List<Parameter> getSysParameters();
|
||||
|
||||
/**
|
||||
* Parameter value will be derived by following orders:
|
||||
* 1. `system config` set with user interface
|
||||
* 2. `system property` set with application.yaml
|
||||
* 3. `default value` set with parameter declaration
|
||||
* @param parameter
|
||||
* @return parameter value
|
||||
*/
|
||||
public String getParameterValue(Parameter parameter) {
|
||||
String paramName = parameter.getName();
|
||||
String value = sysConfigService.getSysParameter().getParameterByName(paramName);
|
||||
try {
|
||||
if (StringUtils.isBlank(value)) {
|
||||
if (environment.containsProperty(paramName)) {
|
||||
value = environment.getProperty(paramName);
|
||||
} else {
|
||||
value = parameter.getDefaultValue();
|
||||
}
|
||||
}
|
||||
} catch (Exception e) {
|
||||
log.error("Failed to get parameter value for {} with exception: {}", paramName, e);
|
||||
}
|
||||
|
||||
return value;
|
||||
}
|
||||
}
|
||||
@@ -1,111 +0,0 @@
|
||||
package com.tencent.supersonic.common.pojo;
|
||||
|
||||
import com.google.common.collect.Lists;
|
||||
import lombok.Data;
|
||||
import org.apache.commons.lang3.StringUtils;
|
||||
import org.springframework.util.CollectionUtils;
|
||||
|
||||
import java.util.Arrays;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
import java.util.stream.Collectors;
|
||||
|
||||
@Data
|
||||
public class SysParameter {
|
||||
|
||||
private Integer id;
|
||||
|
||||
private List<String> admins;
|
||||
|
||||
private List<Parameter> parameters;
|
||||
|
||||
public String getAdmin() {
|
||||
if (CollectionUtils.isEmpty(admins)) {
|
||||
return "";
|
||||
}
|
||||
return StringUtils.join(admins, ",");
|
||||
}
|
||||
|
||||
public String getParameterByName(String name) {
|
||||
if (StringUtils.isBlank(name)) {
|
||||
return "";
|
||||
}
|
||||
Map<String, String> nameToValue = parameters.stream()
|
||||
.collect(Collectors.toMap(Parameter::getName, Parameter::getValue, (k1, k2) -> k1));
|
||||
return nameToValue.get(name);
|
||||
}
|
||||
|
||||
public void setAdminList(String admin) {
|
||||
if (StringUtils.isNotBlank(admin)) {
|
||||
admins = Arrays.asList(admin.split(","));
|
||||
} else {
|
||||
admins = Lists.newArrayList();
|
||||
}
|
||||
}
|
||||
|
||||
public void init() {
|
||||
parameters = Lists.newArrayList();
|
||||
admins = Lists.newArrayList("admin");
|
||||
|
||||
//detect config
|
||||
parameters.add(new Parameter("s2.one.detection.size", "8",
|
||||
"一次探测返回结果个数", "在每次探测后, 将前后缀匹配的结果合并, 并根据相似度阈值过滤后的结果个数",
|
||||
"number", "Mapper相关配置"));
|
||||
parameters.add(new Parameter("s2.one.detection.max.size", "20",
|
||||
"一次探测前后缀匹配结果返回个数", "单次前后缀匹配返回的结果个数", "number", "Mapper相关配置"));
|
||||
|
||||
//mapper config
|
||||
parameters.add(new Parameter("s2.metric.dimension.threshold", "0.3",
|
||||
"指标名、维度名文本相似度阈值", "文本片段和匹配到的指标、维度名计算出来的编辑距离阈值, 若超出该阈值, 则舍弃",
|
||||
"number", "Mapper相关配置"));
|
||||
parameters.add(new Parameter("s2.metric.dimension.min.threshold", "0.25",
|
||||
"指标名、维度名最小文本相似度阈值", "指标名、维度名相似度阈值在动态调整中的最低值",
|
||||
"number", "Mapper相关配置"));
|
||||
parameters.add(new Parameter("s2.dimension.value.threshold", "0.5",
|
||||
"维度值文本相似度阈值", "文本片段和匹配到的维度值计算出来的编辑距离阈值, 若超出该阈值, 则舍弃",
|
||||
"number", "Mapper相关配置"));
|
||||
parameters.add(new Parameter("s2.dimension.value.min.threshold", "0.3",
|
||||
"维度值最小文本相似度阈值", "维度值相似度阈值在动态调整中的最低值",
|
||||
"number", "Mapper相关配置"));
|
||||
|
||||
//embedding mapper config
|
||||
parameters.add(new Parameter("s2.embedding.mapper.word.min",
|
||||
"4", "用于向量召回最小的文本长度", "为提高向量召回效率, 小于该长度的文本不进行向量语义召回", "number", "Mapper相关配置"));
|
||||
parameters.add(new Parameter("s2.embedding.mapper.word.max", "5",
|
||||
"用于向量召回最大的文本长度", "为提高向量召回效率, 大于该长度的文本不进行向量语义召回", "number", "Mapper相关配置"));
|
||||
parameters.add(new Parameter("s2.embedding.mapper.batch", "50",
|
||||
"批量向量召回文本请求个数", "每次进行向量语义召回的原始文本片段个数", "number", "Mapper相关配置"));
|
||||
parameters.add(new Parameter("s2.embedding.mapper.number", "5",
|
||||
"批量向量召回文本返回结果个数", "每个文本进行向量语义召回的文本结果个数", "number", "Mapper相关配置"));
|
||||
parameters.add(new Parameter("s2.embedding.mapper.threshold",
|
||||
"0.99", "向量召回相似度阈值", "相似度小于该阈值的则舍弃", "number", "Mapper相关配置"));
|
||||
parameters.add(new Parameter("s2.embedding.mapper.min.threshold",
|
||||
"0.9", "向量召回最小相似度阈值", "向量召回相似度阈值在动态调整中的最低值", "number", "Mapper相关配置"));
|
||||
|
||||
//parser config
|
||||
Parameter s2SQLParameter = new Parameter("s2.parser.strategy",
|
||||
"TWO_PASS_AUTO_COT_SELF_CONSISTENCY",
|
||||
"LLM解析生成S2SQL策略",
|
||||
"ONE_PASS_AUTO_COT_SELF_CONSISTENCY: 通过思维链且投票方式一步生成sql"
|
||||
+ "\nTWO_PASS_AUTO_COT_SELF_CONSISTENCY: 通过思维链且投票方式两步生成sql", "list", "Parser相关配置");
|
||||
|
||||
s2SQLParameter.setCandidateValues(Lists.newArrayList(
|
||||
"ONE_PASS_AUTO_COT_SELF_CONSISTENCY", "TWO_PASS_AUTO_COT_SELF_CONSISTENCY"));
|
||||
parameters.add(s2SQLParameter);
|
||||
parameters.add(new Parameter("s2.s2SQL.linking.value.switch", "true",
|
||||
"是否将Mapper探测识别到的维度值提供给大模型", "为了数据安全考虑, 这里可进行开关选择",
|
||||
"bool", "Parser相关配置"));
|
||||
parameters.add(new Parameter("s2.query.text.length.threshold", "10",
|
||||
"用户输入文本长短阈值", "文本超过该阈值为长文本", "number", "Parser相关配置"));
|
||||
parameters.add(new Parameter("s2.short.text.threshold", "0.5",
|
||||
"短文本匹配阈值", "由于请求大模型耗时较长, 因此如果有规则类型的Query得分达到阈值,则跳过大模型的调用,\n如果是短文本, 若query得分/文本长度>该阈值, 则跳过当前parser",
|
||||
"number", "Parser相关配置"));
|
||||
parameters.add(new Parameter("s2.long.text.threshold", "0.8",
|
||||
"长文本匹配阈值", "如果是长文本, 若query得分/文本长度>该阈值, 则跳过当前parser",
|
||||
"number", "Parser相关配置"));
|
||||
parameters.add(new Parameter("s2.parse.show-count", "3",
|
||||
"解析结果个数", "前端展示的解析个数",
|
||||
"number", "Parser相关配置"));
|
||||
}
|
||||
|
||||
}
|
||||
@@ -0,0 +1,59 @@
|
||||
package com.tencent.supersonic.common.pojo;
|
||||
|
||||
import com.google.common.collect.Lists;
|
||||
import com.tencent.supersonic.common.util.ContextUtils;
|
||||
import lombok.Data;
|
||||
import org.apache.commons.lang3.StringUtils;
|
||||
import org.springframework.util.CollectionUtils;
|
||||
|
||||
import java.util.Arrays;
|
||||
import java.util.Collection;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
import java.util.stream.Collectors;
|
||||
|
||||
@Data
|
||||
public class SystemConfig {
|
||||
|
||||
private Integer id;
|
||||
|
||||
private List<String> admins;
|
||||
|
||||
private List<Parameter> parameters;
|
||||
|
||||
public String getAdmin() {
|
||||
if (CollectionUtils.isEmpty(admins)) {
|
||||
return "";
|
||||
}
|
||||
return StringUtils.join(admins, ",");
|
||||
}
|
||||
|
||||
public String getParameterByName(String name) {
|
||||
if (StringUtils.isBlank(name)) {
|
||||
return "";
|
||||
}
|
||||
Map<String, String> nameToValue = parameters.stream()
|
||||
.collect(Collectors.toMap(Parameter::getName, Parameter::getDefaultValue, (k1, k2) -> k1));
|
||||
return nameToValue.get(name);
|
||||
}
|
||||
|
||||
public void setAdminList(String admin) {
|
||||
if (StringUtils.isNotBlank(admin)) {
|
||||
admins = Arrays.asList(admin.split(","));
|
||||
} else {
|
||||
admins = Lists.newArrayList();
|
||||
}
|
||||
}
|
||||
|
||||
public void init() {
|
||||
parameters = Lists.newArrayList();
|
||||
admins = Lists.newArrayList("admin");
|
||||
|
||||
Collection<ParameterConfig> configurableParameters =
|
||||
ContextUtils.getBeansOfType(ParameterConfig.class).values();
|
||||
for (ParameterConfig configParameters : configurableParameters) {
|
||||
parameters.addAll(configParameters.getSysParameters());
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
@@ -1,7 +1,7 @@
|
||||
package com.tencent.supersonic.common.rest;
|
||||
|
||||
import com.tencent.supersonic.common.pojo.SysParameter;
|
||||
import com.tencent.supersonic.common.service.SysParameterService;
|
||||
import com.tencent.supersonic.common.pojo.SystemConfig;
|
||||
import com.tencent.supersonic.common.service.SystemConfigService;
|
||||
import org.springframework.beans.factory.annotation.Autowired;
|
||||
import org.springframework.web.bind.annotation.GetMapping;
|
||||
import org.springframework.web.bind.annotation.PostMapping;
|
||||
@@ -11,20 +11,20 @@ import org.springframework.web.bind.annotation.RestController;
|
||||
|
||||
@RestController
|
||||
@RequestMapping({"/api/semantic/parameter"})
|
||||
public class SysParameterController {
|
||||
public class SystemConfigController {
|
||||
|
||||
@Autowired
|
||||
private SysParameterService sysParameterService;
|
||||
private SystemConfigService sysConfigService;
|
||||
|
||||
@PostMapping
|
||||
public Boolean save(@RequestBody SysParameter sysParameter) {
|
||||
sysParameterService.save(sysParameter);
|
||||
public Boolean save(@RequestBody SystemConfig sysParameter) {
|
||||
sysConfigService.save(sysParameter);
|
||||
return true;
|
||||
}
|
||||
|
||||
@GetMapping
|
||||
public SysParameter get() {
|
||||
return sysParameterService.getSysParameter();
|
||||
public SystemConfig get() {
|
||||
return sysConfigService.getSysParameter();
|
||||
}
|
||||
|
||||
}
|
||||
@@ -2,12 +2,12 @@ package com.tencent.supersonic.common.service;
|
||||
|
||||
import com.baomidou.mybatisplus.extension.service.IService;
|
||||
import com.tencent.supersonic.common.persistence.dataobject.SysParameterDO;
|
||||
import com.tencent.supersonic.common.pojo.SysParameter;
|
||||
import com.tencent.supersonic.common.pojo.SystemConfig;
|
||||
|
||||
public interface SysParameterService extends IService<SysParameterDO> {
|
||||
public interface SystemConfigService extends IService<SysParameterDO> {
|
||||
|
||||
SysParameter getSysParameter();
|
||||
SystemConfig getSysParameter();
|
||||
|
||||
void save(SysParameter sysParameter);
|
||||
void save(SystemConfig sysConfig);
|
||||
|
||||
}
|
||||
@@ -6,22 +6,22 @@ import com.fasterxml.jackson.core.type.TypeReference;
|
||||
import com.tencent.supersonic.common.persistence.dataobject.SysParameterDO;
|
||||
import com.tencent.supersonic.common.persistence.mapper.SysParameterMapper;
|
||||
import com.tencent.supersonic.common.pojo.Parameter;
|
||||
import com.tencent.supersonic.common.pojo.SysParameter;
|
||||
import com.tencent.supersonic.common.service.SysParameterService;
|
||||
import com.tencent.supersonic.common.pojo.SystemConfig;
|
||||
import com.tencent.supersonic.common.service.SystemConfigService;
|
||||
import com.tencent.supersonic.common.util.JsonUtil;
|
||||
import java.util.List;
|
||||
import org.springframework.stereotype.Service;
|
||||
import org.springframework.util.CollectionUtils;
|
||||
|
||||
@Service
|
||||
public class SysParameterServiceImpl
|
||||
extends ServiceImpl<SysParameterMapper, SysParameterDO> implements SysParameterService {
|
||||
public class SystemConfigServiceImpl
|
||||
extends ServiceImpl<SysParameterMapper, SysParameterDO> implements SystemConfigService {
|
||||
|
||||
@Override
|
||||
public SysParameter getSysParameter() {
|
||||
public SystemConfig getSysParameter() {
|
||||
List<SysParameterDO> list = list();
|
||||
if (CollectionUtils.isEmpty(list)) {
|
||||
SysParameter sysParameter = new SysParameter();
|
||||
SystemConfig sysParameter = new SystemConfig();
|
||||
sysParameter.setId(1);
|
||||
sysParameter.init();
|
||||
save(sysParameter);
|
||||
@@ -31,13 +31,13 @@ public class SysParameterServiceImpl
|
||||
}
|
||||
|
||||
@Override
|
||||
public void save(SysParameter sysParameter) {
|
||||
SysParameterDO sysParameterDO = convert(sysParameter);
|
||||
public void save(SystemConfig sysConfig) {
|
||||
SysParameterDO sysParameterDO = convert(sysConfig);
|
||||
saveOrUpdate(sysParameterDO);
|
||||
}
|
||||
|
||||
private SysParameter convert(SysParameterDO sysParameterDO) {
|
||||
SysParameter sysParameter = new SysParameter();
|
||||
private SystemConfig convert(SysParameterDO sysParameterDO) {
|
||||
SystemConfig sysParameter = new SystemConfig();
|
||||
sysParameter.setId(sysParameterDO.getId());
|
||||
List<Parameter> parameters = JsonUtil.toObject(sysParameterDO.getParameters(),
|
||||
new TypeReference<List<Parameter>>() {
|
||||
@@ -47,7 +47,7 @@ public class SysParameterServiceImpl
|
||||
return sysParameter;
|
||||
}
|
||||
|
||||
private SysParameterDO convert(SysParameter sysParameter) {
|
||||
private SysParameterDO convert(SystemConfig sysParameter) {
|
||||
SysParameterDO sysParameterDO = new SysParameterDO();
|
||||
sysParameterDO.setId(sysParameter.getId());
|
||||
sysParameterDO.setParameters(JSONObject.toJSONString(sysParameter.getParameters()));
|
||||
@@ -3,6 +3,7 @@ package com.tencent.supersonic.headless.core.chat.mapper;
|
||||
|
||||
import com.tencent.supersonic.headless.api.pojo.enums.MapModeEnum;
|
||||
import com.tencent.supersonic.headless.api.pojo.response.S2Term;
|
||||
import com.tencent.supersonic.headless.core.config.MapperConfig;
|
||||
import com.tencent.supersonic.headless.core.pojo.QueryContext;
|
||||
import com.tencent.supersonic.headless.core.chat.knowledge.helper.NatureHelper;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
@@ -27,7 +28,10 @@ import java.util.stream.Collectors;
|
||||
public abstract class BaseMatchStrategy<T> implements MatchStrategy<T> {
|
||||
|
||||
@Autowired
|
||||
private MapperHelper mapperHelper;
|
||||
protected MapperHelper mapperHelper;
|
||||
|
||||
@Autowired
|
||||
protected MapperConfig mapperConfig;
|
||||
|
||||
@Override
|
||||
public Map<MatchText, List<T>> match(QueryContext queryContext, List<S2Term> terms,
|
||||
|
||||
@@ -5,12 +5,10 @@ import com.tencent.supersonic.common.pojo.Constants;
|
||||
import com.tencent.supersonic.headless.api.pojo.SchemaElement;
|
||||
import com.tencent.supersonic.headless.api.pojo.SchemaElementMatch;
|
||||
import com.tencent.supersonic.headless.api.pojo.response.S2Term;
|
||||
import com.tencent.supersonic.headless.core.config.OptimizationConfig;
|
||||
import com.tencent.supersonic.headless.core.pojo.QueryContext;
|
||||
import com.tencent.supersonic.headless.core.chat.knowledge.DatabaseMapResult;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.apache.commons.lang3.StringUtils;
|
||||
import org.springframework.beans.factory.annotation.Autowired;
|
||||
import org.springframework.stereotype.Service;
|
||||
import org.springframework.util.CollectionUtils;
|
||||
|
||||
@@ -22,6 +20,9 @@ import java.util.Map.Entry;
|
||||
import java.util.Set;
|
||||
import java.util.stream.Collectors;
|
||||
|
||||
import static com.tencent.supersonic.headless.core.config.MapperConfig.MAPPER_NAME_THRESHOLD;
|
||||
import static com.tencent.supersonic.headless.core.config.MapperConfig.MAPPER_NAME_THRESHOLD_MIN;
|
||||
|
||||
/**
|
||||
* DatabaseMatchStrategy uses SQL LIKE operator to match schema elements.
|
||||
* It currently supports fuzzy matching against names and aliases.
|
||||
@@ -30,10 +31,6 @@ import java.util.stream.Collectors;
|
||||
@Slf4j
|
||||
public class DatabaseMatchStrategy extends BaseMatchStrategy<DatabaseMapResult> {
|
||||
|
||||
@Autowired
|
||||
private OptimizationConfig optimizationConfig;
|
||||
@Autowired
|
||||
private MapperHelper mapperHelper;
|
||||
private List<SchemaElement> allElements;
|
||||
|
||||
@Override
|
||||
@@ -94,9 +91,8 @@ public class DatabaseMatchStrategy extends BaseMatchStrategy<DatabaseMapResult>
|
||||
}
|
||||
|
||||
private Double getThreshold(QueryContext queryContext) {
|
||||
|
||||
Double threshold = optimizationConfig.getMetricDimensionThresholdConfig();
|
||||
Double minThreshold = optimizationConfig.getMetricDimensionMinThresholdConfig();
|
||||
Double threshold = Double.valueOf(mapperConfig.getParameterValue(MAPPER_NAME_THRESHOLD));
|
||||
Double minThreshold = Double.valueOf(mapperConfig.getParameterValue(MAPPER_NAME_THRESHOLD_MIN));
|
||||
|
||||
Map<Long, List<SchemaElementMatch>> modelElementMatches = queryContext.getMapInfo().getDataSetElementMatches();
|
||||
|
||||
|
||||
@@ -5,7 +5,6 @@ import com.tencent.supersonic.common.pojo.Constants;
|
||||
import com.tencent.supersonic.common.util.embedding.Retrieval;
|
||||
import com.tencent.supersonic.common.util.embedding.RetrieveQuery;
|
||||
import com.tencent.supersonic.common.util.embedding.RetrieveQueryResult;
|
||||
import com.tencent.supersonic.headless.core.config.OptimizationConfig;
|
||||
import com.tencent.supersonic.headless.core.chat.knowledge.EmbeddingResult;
|
||||
import com.tencent.supersonic.headless.core.chat.knowledge.MetaEmbeddingService;
|
||||
import com.tencent.supersonic.headless.core.pojo.QueryContext;
|
||||
@@ -22,6 +21,14 @@ import java.util.Map;
|
||||
import java.util.Set;
|
||||
import java.util.stream.Collectors;
|
||||
|
||||
import static com.tencent.supersonic.headless.core.config.MapperConfig.EMBEDDING_MAPPER_BATCH;
|
||||
import static com.tencent.supersonic.headless.core.config.MapperConfig.EMBEDDING_MAPPER_MAX;
|
||||
import static com.tencent.supersonic.headless.core.config.MapperConfig.EMBEDDING_MAPPER_MIN;
|
||||
import static com.tencent.supersonic.headless.core.config.MapperConfig.EMBEDDING_MAPPER_NUMBER;
|
||||
import static com.tencent.supersonic.headless.core.config.MapperConfig.EMBEDDING_MAPPER_ROUND_NUMBER;
|
||||
import static com.tencent.supersonic.headless.core.config.MapperConfig.EMBEDDING_MAPPER_THRESHOLD;
|
||||
import static com.tencent.supersonic.headless.core.config.MapperConfig.EMBEDDING_MAPPER_THRESHOLD_MIN;
|
||||
|
||||
/**
|
||||
* EmbeddingMatchStrategy uses vector database to perform
|
||||
* similarity search against the embeddings of schema elements.
|
||||
@@ -30,9 +37,6 @@ import java.util.stream.Collectors;
|
||||
@Slf4j
|
||||
public class EmbeddingMatchStrategy extends BaseMatchStrategy<EmbeddingResult> {
|
||||
|
||||
@Autowired
|
||||
private OptimizationConfig optimizationConfig;
|
||||
|
||||
@Autowired
|
||||
private MetaEmbeddingService metaEmbeddingService;
|
||||
|
||||
@@ -48,24 +52,27 @@ public class EmbeddingMatchStrategy extends BaseMatchStrategy<EmbeddingResult> {
|
||||
}
|
||||
|
||||
@Override
|
||||
public void detectByStep(QueryContext queryContext, Set<EmbeddingResult> existResults, Set<Long> detectDataSetIds,
|
||||
String detectSegment, int offset) {
|
||||
public void detectByStep(QueryContext queryContext, Set<EmbeddingResult> existResults,
|
||||
Set<Long> detectDataSetIds, String detectSegment, int offset) {
|
||||
|
||||
}
|
||||
|
||||
@Override
|
||||
protected void detectByBatch(QueryContext queryContext, Set<EmbeddingResult> results, Set<Long> detectDataSetIds,
|
||||
Set<String> detectSegments) {
|
||||
protected void detectByBatch(QueryContext queryContext, Set<EmbeddingResult> results,
|
||||
Set<Long> detectDataSetIds, Set<String> detectSegments) {
|
||||
int embedddingMapperMin = Integer.valueOf(mapperConfig.getParameterValue(EMBEDDING_MAPPER_MIN));
|
||||
int embedddingMapperMax = Integer.valueOf(mapperConfig.getParameterValue(EMBEDDING_MAPPER_MAX));
|
||||
int embeddingMapperBatch = Integer.valueOf(mapperConfig.getParameterValue(EMBEDDING_MAPPER_BATCH));
|
||||
|
||||
List<String> queryTextsList = detectSegments.stream()
|
||||
.map(detectSegment -> detectSegment.trim())
|
||||
.filter(detectSegment -> StringUtils.isNotBlank(detectSegment)
|
||||
&& detectSegment.length() >= optimizationConfig.getEmbeddingMapperWordMin()
|
||||
&& detectSegment.length() <= optimizationConfig.getEmbeddingMapperWordMax())
|
||||
&& detectSegment.length() >= embedddingMapperMin
|
||||
&& detectSegment.length() <= embedddingMapperMax)
|
||||
.collect(Collectors.toList());
|
||||
|
||||
List<List<String>> queryTextsSubList = Lists.partition(queryTextsList,
|
||||
optimizationConfig.getEmbeddingMapperBatch());
|
||||
embeddingMapperBatch);
|
||||
|
||||
for (List<String> queryTextsSub : queryTextsSubList) {
|
||||
detectByQueryTextsSub(results, detectDataSetIds, queryTextsSub, queryContext);
|
||||
@@ -74,15 +81,16 @@ public class EmbeddingMatchStrategy extends BaseMatchStrategy<EmbeddingResult> {
|
||||
|
||||
private void detectByQueryTextsSub(Set<EmbeddingResult> results, Set<Long> detectDataSetIds,
|
||||
List<String> queryTextsSub, QueryContext queryContext) {
|
||||
|
||||
Map<Long, List<Long>> modelIdToDataSetIds = queryContext.getModelIdToDataSetIds();
|
||||
int embeddingNumber = optimizationConfig.getEmbeddingMapperNumber();
|
||||
double threshold = getThreshold(optimizationConfig.getEmbeddingMapperThreshold(),
|
||||
optimizationConfig.getEmbeddingMapperMinThreshold(), queryContext.getMapModeEnum());
|
||||
double embeddingThreshold = Double.valueOf(mapperConfig.getParameterValue(EMBEDDING_MAPPER_THRESHOLD));
|
||||
double embeddingThresholdMin = Double.valueOf(mapperConfig.getParameterValue(EMBEDDING_MAPPER_THRESHOLD_MIN));
|
||||
double threshold = getThreshold(embeddingThreshold, embeddingThresholdMin, queryContext.getMapModeEnum());
|
||||
|
||||
// step1. build query params
|
||||
RetrieveQuery retrieveQuery = RetrieveQuery.builder().queryTextsList(queryTextsSub).build();
|
||||
// step2. retrieveQuery by detectSegment
|
||||
|
||||
// step2. retrieveQuery by detectSegment
|
||||
int embeddingNumber = Integer.valueOf(mapperConfig.getParameterValue(EMBEDDING_MAPPER_NUMBER));
|
||||
List<RetrieveQueryResult> retrieveQueryResults = metaEmbeddingService.retrieveQuery(
|
||||
retrieveQuery, embeddingNumber, modelIdToDataSetIds, detectDataSetIds);
|
||||
|
||||
@@ -118,7 +126,8 @@ public class EmbeddingMatchStrategy extends BaseMatchStrategy<EmbeddingResult> {
|
||||
.collect(Collectors.toList());
|
||||
|
||||
// step4. select mapResul in one round
|
||||
int roundNumber = optimizationConfig.getEmbeddingMapperRoundNumber() * queryTextsSub.size();
|
||||
int embeddingRoundNumber = Integer.valueOf(mapperConfig.getParameterValue(EMBEDDING_MAPPER_ROUND_NUMBER));
|
||||
int roundNumber = embeddingRoundNumber * queryTextsSub.size();
|
||||
List<EmbeddingResult> oneRoundResults = collect.stream()
|
||||
.sorted(Comparator.comparingDouble(EmbeddingResult::getDistance))
|
||||
.limit(roundNumber)
|
||||
|
||||
@@ -4,7 +4,6 @@ import com.tencent.supersonic.common.pojo.Constants;
|
||||
import com.tencent.supersonic.headless.api.pojo.response.S2Term;
|
||||
import com.tencent.supersonic.headless.core.chat.knowledge.HanlpMapResult;
|
||||
import com.tencent.supersonic.headless.core.chat.knowledge.KnowledgeBaseService;
|
||||
import com.tencent.supersonic.headless.core.config.OptimizationConfig;
|
||||
import com.tencent.supersonic.headless.core.pojo.QueryContext;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.apache.commons.collections.CollectionUtils;
|
||||
@@ -21,6 +20,14 @@ import java.util.Objects;
|
||||
import java.util.Set;
|
||||
import java.util.stream.Collectors;
|
||||
|
||||
import static com.tencent.supersonic.headless.core.config.MapperConfig.MAPPER_DETECTION_MAX_SIZE;
|
||||
import static com.tencent.supersonic.headless.core.config.MapperConfig.MAPPER_DETECTION_SIZE;
|
||||
import static com.tencent.supersonic.headless.core.config.MapperConfig.MAPPER_DIMENSION_VALUE_SIZE;
|
||||
import static com.tencent.supersonic.headless.core.config.MapperConfig.MAPPER_NAME_THRESHOLD;
|
||||
import static com.tencent.supersonic.headless.core.config.MapperConfig.MAPPER_NAME_THRESHOLD_MIN;
|
||||
import static com.tencent.supersonic.headless.core.config.MapperConfig.MAPPER_VALUE_THRESHOLD;
|
||||
import static com.tencent.supersonic.headless.core.config.MapperConfig.MAPPER_VALUE_THRESHOLD_MIN;
|
||||
|
||||
/**
|
||||
* HanlpDictMatchStrategy uses <a href="https://www.hanlp.com/">HanLP</a> to
|
||||
* match schema elements. It currently supports prefix and suffix matching
|
||||
@@ -30,12 +37,6 @@ import java.util.stream.Collectors;
|
||||
@Slf4j
|
||||
public class HanlpDictMatchStrategy extends BaseMatchStrategy<HanlpMapResult> {
|
||||
|
||||
@Autowired
|
||||
private MapperHelper mapperHelper;
|
||||
|
||||
@Autowired
|
||||
private OptimizationConfig optimizationConfig;
|
||||
|
||||
@Autowired
|
||||
private KnowledgeBaseService knowledgeBaseService;
|
||||
|
||||
@@ -65,7 +66,7 @@ public class HanlpDictMatchStrategy extends BaseMatchStrategy<HanlpMapResult> {
|
||||
public void detectByStep(QueryContext queryContext, Set<HanlpMapResult> existResults, Set<Long> detectDataSetIds,
|
||||
String detectSegment, int offset) {
|
||||
// step1. pre search
|
||||
Integer oneDetectionMaxSize = optimizationConfig.getOneDetectionMaxSize();
|
||||
Integer oneDetectionMaxSize = Integer.valueOf(mapperConfig.getParameterValue(MAPPER_DETECTION_MAX_SIZE));
|
||||
LinkedHashSet<HanlpMapResult> hanlpMapResults = knowledgeBaseService.prefixSearch(detectSegment,
|
||||
oneDetectionMaxSize, queryContext.getModelIdToDataSetIds(), detectDataSetIds)
|
||||
.stream().collect(Collectors.toCollection(LinkedHashSet::new));
|
||||
@@ -99,12 +100,13 @@ public class HanlpDictMatchStrategy extends BaseMatchStrategy<HanlpMapResult> {
|
||||
}).collect(Collectors.toCollection(LinkedHashSet::new));
|
||||
|
||||
// step5. take only M dimensionValue or N-M metric/dimension value per rond.
|
||||
int oneDetectionValueSize = Integer.valueOf(mapperConfig.getParameterValue(MAPPER_DIMENSION_VALUE_SIZE));
|
||||
List<HanlpMapResult> dimensionValues = hanlpMapResults.stream()
|
||||
.filter(entry -> mapperHelper.existDimensionValues(entry.getNatures()))
|
||||
.limit(optimizationConfig.getOneDetectionDimensionValueSize())
|
||||
.limit(oneDetectionValueSize)
|
||||
.collect(Collectors.toList());
|
||||
|
||||
Integer oneDetectionSize = optimizationConfig.getOneDetectionSize();
|
||||
Integer oneDetectionSize = Integer.valueOf(mapperConfig.getParameterValue(MAPPER_DETECTION_SIZE));
|
||||
List<HanlpMapResult> oneRoundResults = new ArrayList<>();
|
||||
|
||||
// add the dimensionValue if it exists
|
||||
@@ -129,13 +131,14 @@ public class HanlpDictMatchStrategy extends BaseMatchStrategy<HanlpMapResult> {
|
||||
}
|
||||
|
||||
public double getThresholdMatch(List<String> natures, QueryContext queryContext) {
|
||||
Double threshold = optimizationConfig.getMetricDimensionThresholdConfig();
|
||||
Double minThreshold = optimizationConfig.getMetricDimensionMinThresholdConfig();
|
||||
Double threshold = Double.valueOf(mapperConfig.getParameterValue(MAPPER_NAME_THRESHOLD));
|
||||
Double minThreshold = Double.valueOf(mapperConfig.getParameterValue(MAPPER_NAME_THRESHOLD_MIN));
|
||||
if (mapperHelper.existDimensionValues(natures)) {
|
||||
threshold = optimizationConfig.getDimensionValueThresholdConfig();
|
||||
minThreshold = optimizationConfig.getDimensionValueMinThresholdConfig();
|
||||
threshold = Double.valueOf(mapperConfig.getParameterValue(MAPPER_VALUE_THRESHOLD));
|
||||
minThreshold = Double.valueOf(mapperConfig.getParameterValue(MAPPER_VALUE_THRESHOLD_MIN));
|
||||
}
|
||||
return getThreshold(threshold, minThreshold, queryContext.getMapModeEnum());
|
||||
|
||||
return getThreshold(threshold, minThreshold, queryContext.getMapModeEnum());
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
@@ -2,11 +2,9 @@ package com.tencent.supersonic.headless.core.chat.mapper;
|
||||
|
||||
import com.hankcs.hanlp.algorithm.EditDistance;
|
||||
import com.tencent.supersonic.headless.api.pojo.response.S2Term;
|
||||
import com.tencent.supersonic.headless.core.config.OptimizationConfig;
|
||||
import com.tencent.supersonic.headless.core.chat.knowledge.helper.NatureHelper;
|
||||
import lombok.Data;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.springframework.beans.factory.annotation.Autowired;
|
||||
import org.springframework.stereotype.Service;
|
||||
|
||||
import java.util.Comparator;
|
||||
@@ -20,9 +18,6 @@ import java.util.stream.Collectors;
|
||||
@Slf4j
|
||||
public class MapperHelper {
|
||||
|
||||
@Autowired
|
||||
private OptimizationConfig optimizationConfig;
|
||||
|
||||
public Integer getStepIndex(Map<Integer, Integer> regOffsetToLength, Integer index) {
|
||||
Integer subRegLength = regOffsetToLength.get(index);
|
||||
if (Objects.nonNull(subRegLength)) {
|
||||
|
||||
@@ -2,12 +2,16 @@ package com.tencent.supersonic.headless.core.chat.parser;
|
||||
|
||||
import com.tencent.supersonic.common.util.ContextUtils;
|
||||
import com.tencent.supersonic.headless.api.pojo.SemanticParseInfo;
|
||||
import com.tencent.supersonic.headless.core.config.OptimizationConfig;
|
||||
import com.tencent.supersonic.headless.core.config.ParserConfig;
|
||||
import com.tencent.supersonic.headless.core.pojo.QueryContext;
|
||||
import com.tencent.supersonic.headless.core.chat.query.SemanticQuery;
|
||||
import com.tencent.supersonic.headless.core.chat.query.llm.s2sql.LLMSqlQuery;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
|
||||
import static com.tencent.supersonic.headless.core.config.ParserConfig.PARSER_TEXT_LENGTH_THRESHOLD;
|
||||
import static com.tencent.supersonic.headless.core.config.ParserConfig.PARSER_TEXT_LENGTH_THRESHOLD_LONG;
|
||||
import static com.tencent.supersonic.headless.core.config.ParserConfig.PARSER_TEXT_LENGTH_THRESHOLD_SHORT;
|
||||
|
||||
/**
|
||||
* This checker can be used by semantic parsers to check if query intent
|
||||
* has already been satisfied by current candidate queries. If so, current
|
||||
@@ -32,12 +36,19 @@ public class SatisfactionChecker {
|
||||
private static boolean checkThreshold(String queryText, SemanticParseInfo semanticParseInfo) {
|
||||
int queryTextLength = queryText.replaceAll(" ", "").length();
|
||||
double degree = semanticParseInfo.getScore() / queryTextLength;
|
||||
OptimizationConfig optimizationConfig = ContextUtils.getBean(OptimizationConfig.class);
|
||||
if (queryTextLength > optimizationConfig.getQueryTextLengthThreshold()) {
|
||||
if (degree < optimizationConfig.getLongTextThreshold()) {
|
||||
ParserConfig parserConfig = ContextUtils.getBean(ParserConfig.class);
|
||||
int textLengthThreshold =
|
||||
Integer.valueOf(parserConfig.getParameterValue(PARSER_TEXT_LENGTH_THRESHOLD));
|
||||
double longTextLengthThreshold =
|
||||
Double.valueOf(parserConfig.getParameterValue(PARSER_TEXT_LENGTH_THRESHOLD_LONG));
|
||||
double shortTextLengthThreshold =
|
||||
Double.valueOf(parserConfig.getParameterValue(PARSER_TEXT_LENGTH_THRESHOLD_SHORT));
|
||||
|
||||
if (queryTextLength > textLengthThreshold) {
|
||||
if (degree < longTextLengthThreshold) {
|
||||
return false;
|
||||
}
|
||||
} else if (degree < optimizationConfig.getShortTextThreshold()) {
|
||||
} else if (degree < shortTextLengthThreshold) {
|
||||
return false;
|
||||
}
|
||||
log.info("queryMode:{}, degree:{}, parse info:{}",
|
||||
|
||||
@@ -2,7 +2,6 @@ package com.tencent.supersonic.headless.core.chat.parser.llm;
|
||||
|
||||
|
||||
import com.tencent.supersonic.headless.core.chat.query.llm.s2sql.LLMReq;
|
||||
import com.tencent.supersonic.headless.core.chat.query.llm.s2sql.LLMReq.SqlGenType;
|
||||
import com.tencent.supersonic.headless.core.chat.query.llm.s2sql.LLMResp;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.springframework.stereotype.Component;
|
||||
@@ -16,8 +15,7 @@ public class JavaLLMProxy implements LLMProxy {
|
||||
|
||||
public LLMResp text2sql(LLMReq llmReq) {
|
||||
|
||||
SqlGenStrategy sqlGenStrategy = SqlGenStrategyFactory.get(
|
||||
SqlGenType.getMode(llmReq.getSqlGenerationMode()));
|
||||
SqlGenStrategy sqlGenStrategy = SqlGenStrategyFactory.get(llmReq.getSqlGenType());
|
||||
String modelName = llmReq.getSchema().getDataSetName();
|
||||
LLMResp result = sqlGenStrategy.generate(llmReq);
|
||||
result.setQuery(llmReq.getQueryText());
|
||||
|
||||
@@ -11,8 +11,8 @@ import com.tencent.supersonic.headless.core.chat.parser.SatisfactionChecker;
|
||||
import com.tencent.supersonic.headless.core.chat.query.llm.s2sql.LLMReq;
|
||||
import com.tencent.supersonic.headless.core.chat.query.llm.s2sql.LLMReq.ElementValue;
|
||||
import com.tencent.supersonic.headless.core.chat.query.llm.s2sql.LLMResp;
|
||||
import com.tencent.supersonic.headless.core.config.ParserConfig;
|
||||
import com.tencent.supersonic.headless.core.config.LLMParserConfig;
|
||||
import com.tencent.supersonic.headless.core.config.OptimizationConfig;
|
||||
import com.tencent.supersonic.headless.core.pojo.QueryContext;
|
||||
import com.tencent.supersonic.headless.core.utils.ComponentFactory;
|
||||
import com.tencent.supersonic.headless.core.utils.S2SqlDateHelper;
|
||||
@@ -31,14 +31,18 @@ import java.util.Objects;
|
||||
import java.util.Set;
|
||||
import java.util.stream.Collectors;
|
||||
|
||||
import static com.tencent.supersonic.headless.core.config.ParserConfig.PARSER_LINKING_VALUE_ENABLE;
|
||||
import static com.tencent.supersonic.headless.core.config.ParserConfig.PARSER_STRATEGY_TYPE;
|
||||
|
||||
@Slf4j
|
||||
@Service
|
||||
public class LLMRequestService {
|
||||
|
||||
@Autowired
|
||||
private LLMParserConfig llmParserConfig;
|
||||
|
||||
@Autowired
|
||||
private OptimizationConfig optimizationConfig;
|
||||
private ParserConfig parserConfig;
|
||||
|
||||
public boolean isSkip(QueryContext queryCtx) {
|
||||
if (!queryCtx.getText2SQLType().enableLLM()) {
|
||||
@@ -86,7 +90,9 @@ public class LLMRequestService {
|
||||
llmReq.setSchema(llmSchema);
|
||||
|
||||
List<ElementValue> linking = new ArrayList<>();
|
||||
if (optimizationConfig.isUseLinkingValueSwitch()) {
|
||||
boolean linkingValueEnabled = Boolean.valueOf(parserConfig.getParameterValue(PARSER_LINKING_VALUE_ENABLE));
|
||||
|
||||
if (linkingValueEnabled) {
|
||||
linking.addAll(linkingValues);
|
||||
}
|
||||
llmReq.setLinking(linking);
|
||||
@@ -96,7 +102,7 @@ public class LLMRequestService {
|
||||
currentDate = DateUtils.getBeforeDate(0);
|
||||
}
|
||||
llmReq.setCurrentDate(currentDate);
|
||||
llmReq.setSqlGenerationMode(optimizationConfig.getSqlGenType().getName());
|
||||
llmReq.setSqlGenType(LLMReq.SqlGenType.valueOf(parserConfig.getParameterValue(PARSER_STRATEGY_TYPE)));
|
||||
llmReq.setLlmConfig(queryCtx.getLlmConfig());
|
||||
return llmReq;
|
||||
}
|
||||
|
||||
@@ -18,6 +18,11 @@ import java.util.Map;
|
||||
import java.util.concurrent.CopyOnWriteArrayList;
|
||||
import java.util.stream.Collectors;
|
||||
|
||||
import static com.tencent.supersonic.headless.core.config.ParserConfig.PARSER_EXEMPLAR_RECALL_NUMBER;
|
||||
import static com.tencent.supersonic.headless.core.config.ParserConfig.PARSER_FEW_SHOT_NUMBER;
|
||||
import static com.tencent.supersonic.headless.core.config.ParserConfig.PARSER_SELF_CONSISTENCY_NUMBER;
|
||||
|
||||
|
||||
@Service
|
||||
@Slf4j
|
||||
public class OnePassSCSqlGenStrategy extends SqlGenStrategy {
|
||||
@@ -27,11 +32,14 @@ public class OnePassSCSqlGenStrategy extends SqlGenStrategy {
|
||||
//1.retriever sqlExamples and generate exampleListPool
|
||||
keyPipelineLog.info("OnePassSCSqlGenStrategy llmReq:{}", llmReq);
|
||||
|
||||
List<Map<String, String>> sqlExamples = exemplarManager.recallExemplars(llmReq.getQueryText(),
|
||||
optimizationConfig.getText2sqlExampleNum());
|
||||
int exemplarRecallNumber = Integer.valueOf(parserConfig.getParameterValue(PARSER_EXEMPLAR_RECALL_NUMBER));
|
||||
int fewShotNumber = Integer.valueOf(parserConfig.getParameterValue(PARSER_FEW_SHOT_NUMBER));
|
||||
int selfConsistencyNumber = Integer.valueOf(parserConfig.getParameterValue(PARSER_SELF_CONSISTENCY_NUMBER));
|
||||
|
||||
List<Map<String, String>> sqlExamples = exemplarManager.recallExemplars(llmReq.getQueryText(),
|
||||
exemplarRecallNumber);
|
||||
List<List<Map<String, String>>> exampleListPool = promptGenerator.getExampleCombos(sqlExamples,
|
||||
optimizationConfig.getText2sqlFewShotsNum(), optimizationConfig.getText2sqlSelfConsistencyNum());
|
||||
fewShotNumber, selfConsistencyNumber);
|
||||
|
||||
//2.generator linking and sql prompt by sqlExamples,and parallel generate response.
|
||||
List<String> linkingSqlPromptPool = promptGenerator.generatePromptPool(llmReq, exampleListPool, true);
|
||||
|
||||
@@ -3,7 +3,7 @@ package com.tencent.supersonic.headless.core.chat.parser.llm;
|
||||
import com.tencent.supersonic.headless.api.pojo.LLMConfig;
|
||||
import com.tencent.supersonic.headless.core.chat.query.llm.s2sql.LLMReq;
|
||||
import com.tencent.supersonic.headless.core.chat.query.llm.s2sql.LLMResp;
|
||||
import com.tencent.supersonic.headless.core.config.OptimizationConfig;
|
||||
import com.tencent.supersonic.headless.core.config.ParserConfig;
|
||||
import com.tencent.supersonic.headless.core.utils.S2ChatModelProvider;
|
||||
import dev.langchain4j.model.chat.ChatLanguageModel;
|
||||
import org.slf4j.Logger;
|
||||
@@ -25,7 +25,7 @@ public abstract class SqlGenStrategy implements InitializingBean {
|
||||
protected ExemplarManager exemplarManager;
|
||||
|
||||
@Autowired
|
||||
protected OptimizationConfig optimizationConfig;
|
||||
protected ParserConfig parserConfig;
|
||||
|
||||
@Autowired
|
||||
protected PromptGenerator promptGenerator;
|
||||
|
||||
@@ -17,6 +17,10 @@ import java.util.List;
|
||||
import java.util.Map;
|
||||
import java.util.concurrent.CopyOnWriteArrayList;
|
||||
|
||||
import static com.tencent.supersonic.headless.core.config.ParserConfig.PARSER_EXEMPLAR_RECALL_NUMBER;
|
||||
import static com.tencent.supersonic.headless.core.config.ParserConfig.PARSER_FEW_SHOT_NUMBER;
|
||||
import static com.tencent.supersonic.headless.core.config.ParserConfig.PARSER_SELF_CONSISTENCY_NUMBER;
|
||||
|
||||
@Service
|
||||
public class TwoPassSCSqlGenStrategy extends SqlGenStrategy {
|
||||
|
||||
@@ -24,11 +28,15 @@ public class TwoPassSCSqlGenStrategy extends SqlGenStrategy {
|
||||
public LLMResp generate(LLMReq llmReq) {
|
||||
//1.retriever sqlExamples and generate exampleListPool
|
||||
keyPipelineLog.info("TwoPassSCSqlGenStrategy llmReq:{}", llmReq);
|
||||
List<Map<String, String>> sqlExamples = exemplarManager.recallExemplars(llmReq.getQueryText(),
|
||||
optimizationConfig.getText2sqlExampleNum());
|
||||
|
||||
int exemplarRecallNumber = Integer.valueOf(parserConfig.getParameterValue(PARSER_EXEMPLAR_RECALL_NUMBER));
|
||||
int fewShotNumber = Integer.valueOf(parserConfig.getParameterValue(PARSER_FEW_SHOT_NUMBER));
|
||||
int selfConsistencyNumber = Integer.valueOf(parserConfig.getParameterValue(PARSER_SELF_CONSISTENCY_NUMBER));
|
||||
|
||||
List<Map<String, String>> sqlExamples = exemplarManager.recallExemplars(llmReq.getQueryText(),
|
||||
exemplarRecallNumber);
|
||||
List<List<Map<String, String>>> exampleListPool = promptGenerator.getExampleCombos(sqlExamples,
|
||||
optimizationConfig.getText2sqlFewShotsNum(), optimizationConfig.getText2sqlSelfConsistencyNum());
|
||||
fewShotNumber, selfConsistencyNumber);
|
||||
|
||||
//2.generator linking prompt,and parallel generate response.
|
||||
List<String> linkingPromptPool = promptGenerator.generatePromptPool(llmReq, exampleListPool, false);
|
||||
|
||||
@@ -10,7 +10,7 @@ import com.tencent.supersonic.headless.api.pojo.SemanticParseInfo;
|
||||
import com.tencent.supersonic.headless.api.pojo.SemanticSchema;
|
||||
import com.tencent.supersonic.headless.api.pojo.request.QuerySqlReq;
|
||||
import com.tencent.supersonic.headless.api.pojo.request.QueryStructReq;
|
||||
import com.tencent.supersonic.headless.core.config.OptimizationConfig;
|
||||
import com.tencent.supersonic.headless.core.config.ParserConfig;
|
||||
import com.tencent.supersonic.headless.core.utils.QueryReqBuilder;
|
||||
import lombok.ToString;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
@@ -21,6 +21,8 @@ import java.util.List;
|
||||
import java.util.Map;
|
||||
import java.util.stream.Collectors;
|
||||
|
||||
import static com.tencent.supersonic.headless.core.config.ParserConfig.PARSER_S2SQL_ENABLE;
|
||||
|
||||
@Slf4j
|
||||
@ToString
|
||||
public abstract class BaseSemanticQuery implements SemanticQuery, Serializable {
|
||||
@@ -73,8 +75,9 @@ public abstract class BaseSemanticQuery implements SemanticQuery, Serializable {
|
||||
}
|
||||
|
||||
protected void initS2SqlByStruct(SemanticSchema semanticSchema) {
|
||||
OptimizationConfig optimizationConfig = ContextUtils.getBean(OptimizationConfig.class);
|
||||
if (!optimizationConfig.isUseS2SqlSwitch()) {
|
||||
ParserConfig parserConfig = ContextUtils.getBean(ParserConfig.class);
|
||||
boolean s2sqlEnable = Boolean.valueOf(parserConfig.getParameterValue(PARSER_S2SQL_ENABLE));
|
||||
if (!s2sqlEnable) {
|
||||
return;
|
||||
}
|
||||
QueryStructReq queryStructReq = convertQueryStruct();
|
||||
|
||||
@@ -22,7 +22,7 @@ public class LLMReq {
|
||||
|
||||
private String priorExts;
|
||||
|
||||
private String sqlGenerationMode;
|
||||
private SqlGenType sqlGenType;
|
||||
|
||||
private LLMConfig llmConfig;
|
||||
|
||||
@@ -82,14 +82,5 @@ public class LLMReq {
|
||||
return name;
|
||||
}
|
||||
|
||||
public static SqlGenType getMode(String name) {
|
||||
for (SqlGenType sqlGenType : SqlGenType.values()) {
|
||||
if (sqlGenType.name.equals(name)) {
|
||||
return sqlGenType;
|
||||
}
|
||||
}
|
||||
return null;
|
||||
}
|
||||
|
||||
}
|
||||
}
|
||||
|
||||
@@ -0,0 +1,110 @@
|
||||
package com.tencent.supersonic.headless.core.config;
|
||||
|
||||
import com.google.common.collect.Lists;
|
||||
import com.tencent.supersonic.common.pojo.ParameterConfig;
|
||||
import com.tencent.supersonic.common.pojo.Parameter;
|
||||
import org.springframework.stereotype.Service;
|
||||
|
||||
import java.util.List;
|
||||
|
||||
@Service("HeadlessMapperConfig")
|
||||
public class MapperConfig extends ParameterConfig {
|
||||
|
||||
public static final Parameter MAPPER_DETECTION_SIZE =
|
||||
new Parameter("s2.mapper.detection.size", "8",
|
||||
"一次探测返回结果个数",
|
||||
"在每次探测后, 将前后缀匹配的结果合并, 并根据相似度阈值过滤后的结果个数",
|
||||
"number", "Mapper相关配置");
|
||||
|
||||
public static final Parameter MAPPER_DETECTION_MAX_SIZE =
|
||||
new Parameter("s2.mapper.detection.max.size", "20",
|
||||
"一次探测前后缀匹配结果返回个数",
|
||||
"单次前后缀匹配返回的结果个数",
|
||||
"number", "Mapper相关配置");
|
||||
|
||||
public static final Parameter MAPPER_NAME_THRESHOLD =
|
||||
new Parameter("s2.mapper.name.threshold", "0.3",
|
||||
"指标名、维度名文本相似度阈值",
|
||||
"文本片段和匹配到的指标、维度名计算出来的编辑距离阈值, 若超出该阈值, 则舍弃",
|
||||
"number", "Mapper相关配置");
|
||||
|
||||
public static final Parameter MAPPER_NAME_THRESHOLD_MIN =
|
||||
new Parameter("s2.mapper.name.min.threshold", "0.25",
|
||||
"指标名、维度名最小文本相似度阈值",
|
||||
"指标名、维度名相似度阈值在动态调整中的最低值",
|
||||
"number", "Mapper相关配置");
|
||||
|
||||
public static final Parameter MAPPER_DIMENSION_VALUE_SIZE =
|
||||
new Parameter("s2.mapper.value.size", "1",
|
||||
"指标名、维度名最小文本相似度阈值",
|
||||
"指标名、维度名相似度阈值在动态调整中的最低值",
|
||||
"number", "Mapper相关配置");
|
||||
|
||||
public static final Parameter MAPPER_VALUE_THRESHOLD =
|
||||
new Parameter("s2.mapper.value.threshold", "0.5",
|
||||
"维度值文本相似度阈值",
|
||||
"文本片段和匹配到的维度值计算出来的编辑距离阈值, 若超出该阈值, 则舍弃",
|
||||
"number", "Mapper相关配置");
|
||||
|
||||
public static final Parameter MAPPER_VALUE_THRESHOLD_MIN =
|
||||
new Parameter("s2.mapper.value.min.threshold", "0.3",
|
||||
"维度值最小文本相似度阈值",
|
||||
"维度值相似度阈值在动态调整中的最低值",
|
||||
"number", "Mapper相关配置");
|
||||
|
||||
public static final Parameter EMBEDDING_MAPPER_MIN =
|
||||
new Parameter("s2.mapper.embedding.word.min", "4",
|
||||
"用于向量召回最小的文本长度",
|
||||
"为提高向量召回效率, 小于该长度的文本不进行向量语义召回",
|
||||
"number", "Mapper相关配置");
|
||||
|
||||
public static final Parameter EMBEDDING_MAPPER_MAX =
|
||||
new Parameter("s2.mapper.embedding.word.max", "5",
|
||||
"用于向量召回最大的文本长度",
|
||||
"为提高向量召回效率, 大于该长度的文本不进行向量语义召回",
|
||||
"number", "Mapper相关配置");
|
||||
|
||||
public static final Parameter EMBEDDING_MAPPER_BATCH =
|
||||
new Parameter("s2.mapper.embedding.batch", "50",
|
||||
"批量向量召回文本请求个数",
|
||||
"每次进行向量语义召回的原始文本片段个数",
|
||||
"number", "Mapper相关配置");
|
||||
|
||||
public static final Parameter EMBEDDING_MAPPER_NUMBER =
|
||||
new Parameter("s2.mapper.embedding.number", "5",
|
||||
"批量向量召回文本返回结果个数",
|
||||
"每个文本进行向量语义召回的文本结果个数",
|
||||
"number", "Mapper相关配置");
|
||||
|
||||
public static final Parameter EMBEDDING_MAPPER_THRESHOLD =
|
||||
new Parameter("s2.mapper.embedding.threshold", "0.99",
|
||||
"向量召回相似度阈值",
|
||||
"相似度小于该阈值的则舍弃",
|
||||
"number", "Mapper相关配置");
|
||||
|
||||
public static final Parameter EMBEDDING_MAPPER_THRESHOLD_MIN =
|
||||
new Parameter("s2.mapper.embedding.min.threshold", "0.9",
|
||||
"向量召回最小相似度阈值",
|
||||
"向量召回相似度阈值在动态调整中的最低值",
|
||||
"number", "Mapper相关配置");
|
||||
|
||||
public static final Parameter EMBEDDING_MAPPER_ROUND_NUMBER =
|
||||
new Parameter("s2.mapper.embedding.round.number", "10",
|
||||
"向量召回最小相似度阈值",
|
||||
"向量召回相似度阈值在动态调整中的最低值",
|
||||
"number", "Mapper相关配置");
|
||||
|
||||
@Override
|
||||
public List<Parameter> getSysParameters() {
|
||||
return Lists.newArrayList(
|
||||
MAPPER_DETECTION_SIZE,
|
||||
MAPPER_DETECTION_MAX_SIZE,
|
||||
MAPPER_NAME_THRESHOLD,
|
||||
MAPPER_NAME_THRESHOLD_MIN,
|
||||
MAPPER_DIMENSION_VALUE_SIZE,
|
||||
MAPPER_VALUE_THRESHOLD,
|
||||
MAPPER_VALUE_THRESHOLD_MIN
|
||||
);
|
||||
}
|
||||
|
||||
}
|
||||
@@ -1,193 +0,0 @@
|
||||
package com.tencent.supersonic.headless.core.config;
|
||||
|
||||
import com.tencent.supersonic.common.service.SysParameterService;
|
||||
import com.tencent.supersonic.headless.core.chat.query.llm.s2sql.LLMReq;
|
||||
import lombok.Data;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.apache.commons.lang3.StringUtils;
|
||||
import org.springframework.beans.factory.annotation.Autowired;
|
||||
import org.springframework.beans.factory.annotation.Value;
|
||||
import org.springframework.context.annotation.Configuration;
|
||||
|
||||
@Configuration
|
||||
@Data
|
||||
@Slf4j
|
||||
public class OptimizationConfig {
|
||||
|
||||
@Value("${s2.one.detection.size:8}")
|
||||
private Integer oneDetectionSize;
|
||||
|
||||
@Value("${s2.one.detection.max.size:20}")
|
||||
private Integer oneDetectionMaxSize;
|
||||
|
||||
@Value("${s2.one.detection.dimensionValue.size:1}")
|
||||
private Integer oneDetectionDimensionValueSize;
|
||||
|
||||
@Value("${s2.metric.dimension.min.threshold:0.3}")
|
||||
private Double metricDimensionMinThresholdConfig;
|
||||
|
||||
@Value("${s2.metric.dimension.threshold:0.3}")
|
||||
private Double metricDimensionThresholdConfig;
|
||||
|
||||
@Value("${s2.dimension.value.min.threshold:0.2}")
|
||||
private Double dimensionValueMinThresholdConfig;
|
||||
|
||||
@Value("${s2.dimension.value.threshold:0.5}")
|
||||
private Double dimensionValueThresholdConfig;
|
||||
|
||||
@Value("${s2.long.text.threshold:0.8}")
|
||||
private Double longTextThreshold;
|
||||
|
||||
@Value("${s2.short.text.threshold:0.5}")
|
||||
private Double shortTextThreshold;
|
||||
|
||||
@Value("${s2.query.text.length.threshold:10}")
|
||||
private Integer queryTextLengthThreshold;
|
||||
|
||||
@Value("${s2.embedding.mapper.word.min:4}")
|
||||
private int embeddingMapperWordMin;
|
||||
|
||||
@Value("${s2.embedding.mapper.word.max:4}")
|
||||
private int embeddingMapperWordMax;
|
||||
|
||||
@Value("${s2.embedding.mapper.batch:50}")
|
||||
private int embeddingMapperBatch;
|
||||
|
||||
@Value("${s2.embedding.mapper.number:5}")
|
||||
private int embeddingMapperNumber;
|
||||
|
||||
@Value("${s2.embedding.mapper.round.number:10}")
|
||||
private int embeddingMapperRoundNumber;
|
||||
|
||||
@Value("${s2.embedding.mapper.min.threshold:0.6}")
|
||||
private Double embeddingMapperMinThreshold;
|
||||
|
||||
@Value("${s2.embedding.mapper.threshold:0.99}")
|
||||
private Double embeddingMapperThreshold;
|
||||
|
||||
@Value("${s2.parser.linking.value.switch:true}")
|
||||
private boolean useLinkingValueSwitch;
|
||||
|
||||
@Value("${s2.parser.strategy:TWO_PASS_AUTO_COT_SELF_CONSISTENCY}")
|
||||
private LLMReq.SqlGenType sqlGenType;
|
||||
|
||||
@Value("${s2.parser.use.switch:true}")
|
||||
private boolean useS2SqlSwitch;
|
||||
|
||||
@Value("${s2.parser.exemplar-recall.number:15}")
|
||||
private int text2sqlExampleNum;
|
||||
|
||||
@Value("${s2.parser.few-shot.number:5}")
|
||||
private int text2sqlFewShotsNum;
|
||||
|
||||
@Value("${s2.parser.self-consistency.number:5}")
|
||||
private int text2sqlSelfConsistencyNum;
|
||||
|
||||
@Value("${s2.parser.show-count:3}")
|
||||
private Integer parseShowCount;
|
||||
|
||||
@Autowired
|
||||
private SysParameterService sysParameterService;
|
||||
|
||||
public Integer getOneDetectionSize() {
|
||||
return convertValue("s2.one.detection.size", Integer.class, oneDetectionSize);
|
||||
}
|
||||
|
||||
public Integer getOneDetectionMaxSize() {
|
||||
return convertValue("s2.one.detection.max.size", Integer.class, oneDetectionMaxSize);
|
||||
}
|
||||
|
||||
public Double getMetricDimensionMinThresholdConfig() {
|
||||
return convertValue("s2.metric.dimension.min.threshold", Double.class, metricDimensionMinThresholdConfig);
|
||||
}
|
||||
|
||||
public Double getMetricDimensionThresholdConfig() {
|
||||
return convertValue("s2.metric.dimension.threshold", Double.class, metricDimensionThresholdConfig);
|
||||
}
|
||||
|
||||
public Double getDimensionValueMinThresholdConfig() {
|
||||
return convertValue("s2.dimension.value.min.threshold", Double.class, dimensionValueMinThresholdConfig);
|
||||
}
|
||||
|
||||
public Double getDimensionValueThresholdConfig() {
|
||||
return convertValue("s2.dimension.value.threshold", Double.class, dimensionValueThresholdConfig);
|
||||
}
|
||||
|
||||
public Double getLongTextThreshold() {
|
||||
return convertValue("s2.long.text.threshold", Double.class, longTextThreshold);
|
||||
}
|
||||
|
||||
public Double getShortTextThreshold() {
|
||||
return convertValue("s2.short.text.threshold", Double.class, shortTextThreshold);
|
||||
}
|
||||
|
||||
public Integer getQueryTextLengthThreshold() {
|
||||
return convertValue("s2.query.text.length.threshold", Integer.class, queryTextLengthThreshold);
|
||||
}
|
||||
|
||||
public Integer getEmbeddingMapperWordMin() {
|
||||
return convertValue("s2.embedding.mapper.word.min", Integer.class, embeddingMapperWordMin);
|
||||
}
|
||||
|
||||
public Integer getEmbeddingMapperWordMax() {
|
||||
return convertValue("s2.embedding.mapper.word.max", Integer.class, embeddingMapperWordMax);
|
||||
}
|
||||
|
||||
public Integer getEmbeddingMapperBatch() {
|
||||
return convertValue("s2.embedding.mapper.batch", Integer.class, embeddingMapperBatch);
|
||||
}
|
||||
|
||||
public Integer getEmbeddingMapperNumber() {
|
||||
return convertValue("s2.embedding.mapper.number", Integer.class, embeddingMapperNumber);
|
||||
}
|
||||
|
||||
public Integer getEmbeddingMapperRoundNumber() {
|
||||
return convertValue("s2.embedding.mapper.round.number", Integer.class, embeddingMapperRoundNumber);
|
||||
}
|
||||
|
||||
public Double getEmbeddingMapperMinThreshold() {
|
||||
return convertValue("s2.embedding.mapper.min.threshold", Double.class, embeddingMapperMinThreshold);
|
||||
}
|
||||
|
||||
public Double getEmbeddingMapperThreshold() {
|
||||
return convertValue("s2.embedding.mapper.threshold", Double.class, embeddingMapperThreshold);
|
||||
}
|
||||
|
||||
public boolean isUseS2SqlSwitch() {
|
||||
return convertValue("s2.parser.use.switch", Boolean.class, useS2SqlSwitch);
|
||||
}
|
||||
|
||||
public boolean isUseLinkingValueSwitch() {
|
||||
return convertValue("s2.parser.linking.value.switch", Boolean.class, useLinkingValueSwitch);
|
||||
}
|
||||
|
||||
public LLMReq.SqlGenType getSqlGenType() {
|
||||
return convertValue("s2.parser.strategy", LLMReq.SqlGenType.class, sqlGenType);
|
||||
}
|
||||
|
||||
public Integer getParseShowCount() {
|
||||
return convertValue("s2.parse.show-count", Integer.class, parseShowCount);
|
||||
}
|
||||
|
||||
public <T> T convertValue(String paramName, Class<T> targetType, T defaultValue) {
|
||||
try {
|
||||
String value = sysParameterService.getSysParameter().getParameterByName(paramName);
|
||||
if (StringUtils.isBlank(value)) {
|
||||
return defaultValue;
|
||||
}
|
||||
if (targetType == Double.class) {
|
||||
return targetType.cast(Double.parseDouble(value));
|
||||
} else if (targetType == Integer.class) {
|
||||
return targetType.cast(Integer.parseInt(value));
|
||||
} else if (targetType == Boolean.class) {
|
||||
return targetType.cast(Boolean.parseBoolean(value));
|
||||
} else if (targetType == LLMReq.SqlGenType.class) {
|
||||
return targetType.cast(LLMReq.SqlGenType.valueOf(value));
|
||||
}
|
||||
} catch (Exception e) {
|
||||
log.error("convertValue", e);
|
||||
}
|
||||
return defaultValue;
|
||||
}
|
||||
|
||||
}
|
||||
@@ -0,0 +1,84 @@
|
||||
package com.tencent.supersonic.headless.core.config;
|
||||
|
||||
import com.google.common.collect.Lists;
|
||||
import com.tencent.supersonic.common.pojo.ParameterConfig;
|
||||
import com.tencent.supersonic.common.pojo.Parameter;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.springframework.stereotype.Service;
|
||||
|
||||
import java.util.List;
|
||||
|
||||
@Service("HeadlessParserConfig")
|
||||
@Slf4j
|
||||
public class ParserConfig extends ParameterConfig {
|
||||
|
||||
public static final Parameter PARSER_STRATEGY_TYPE =
|
||||
new Parameter("s2.parser.strategy", "ONE_PASS_AUTO_COT_SELF_CONSISTENCY",
|
||||
"LLM解析生成S2SQL策略",
|
||||
"ONE_PASS_AUTO_COT_SELF_CONSISTENCY: 通过思维链且投票方式一步生成sql"
|
||||
+ "\nTWO_PASS_AUTO_COT_SELF_CONSISTENCY: 通过思维链且投票方式两步生成sql",
|
||||
"list", "Parser相关配置", Lists.newArrayList(
|
||||
"ONE_PASS_AUTO_COT_SELF_CONSISTENCY", "TWO_PASS_AUTO_COT_SELF_CONSISTENCY"));
|
||||
|
||||
public static final Parameter PARSER_LINKING_VALUE_ENABLE =
|
||||
new Parameter("s2.parser.linking.value.enable", "true",
|
||||
"是否将Mapper探测识别到的维度值提供给大模型", "为了数据安全考虑, 这里可进行开关选择",
|
||||
"bool", "Parser相关配置");
|
||||
|
||||
public static final Parameter PARSER_TEXT_LENGTH_THRESHOLD =
|
||||
new Parameter("s2.parser.text.length.threshold", "10",
|
||||
"用户输入文本长短阈值", "文本超过该阈值为长文本",
|
||||
"number", "Parser相关配置");
|
||||
|
||||
public static final Parameter PARSER_TEXT_LENGTH_THRESHOLD_SHORT =
|
||||
new Parameter("s2.parser.text.threshold", "0.5",
|
||||
"短文本匹配阈值",
|
||||
"由于请求大模型耗时较长, 因此如果有规则类型的Query得分达到阈值,则跳过大模型的调用,"
|
||||
+ "\n如果是短文本, 若query得分/文本长度>该阈值, 则跳过当前parser",
|
||||
"number", "Parser相关配置");
|
||||
|
||||
public static final Parameter PARSER_TEXT_LENGTH_THRESHOLD_LONG =
|
||||
new Parameter("s2.parser.text.threshold", "0.8",
|
||||
"长文本匹配阈值", "如果是长文本, 若query得分/文本长度>该阈值, 则跳过当前parser",
|
||||
"number", "Parser相关配置");
|
||||
|
||||
public static final Parameter PARSER_EXEMPLAR_RECALL_NUMBER =
|
||||
new Parameter("s2.parser.exemplar-recall.number", "10",
|
||||
"exemplar召回个数", "",
|
||||
"number", "Parser相关配置");
|
||||
|
||||
public static final Parameter PARSER_FEW_SHOT_NUMBER =
|
||||
new Parameter("s2.parser.few-shot.number", "5",
|
||||
"few-shot样例个数", "样例越多效果可能越好,但token消耗越大",
|
||||
"number", "Parser相关配置");
|
||||
|
||||
public static final Parameter PARSER_SELF_CONSISTENCY_NUMBER =
|
||||
new Parameter("s2.parser.self-consistency.number", "1",
|
||||
"self-consistency执行个数", "执行越多效果可能越好,但token消耗越大",
|
||||
"number", "Parser相关配置");
|
||||
|
||||
public static final Parameter PARSER_SHOW_COUNT =
|
||||
new Parameter("s2.parser.show.count", "3",
|
||||
"解析结果展示个数", "前端展示的解析个数",
|
||||
"number", "Parser相关配置");
|
||||
|
||||
public static final Parameter PARSER_S2SQL_ENABLE =
|
||||
new Parameter("s2.parser.s2sql.switch", "true",
|
||||
"", "",
|
||||
"bool", "Parser相关配置");
|
||||
|
||||
@Override
|
||||
public List<Parameter> getSysParameters() {
|
||||
return Lists.newArrayList(
|
||||
PARSER_STRATEGY_TYPE,
|
||||
PARSER_LINKING_VALUE_ENABLE,
|
||||
PARSER_TEXT_LENGTH_THRESHOLD,
|
||||
PARSER_TEXT_LENGTH_THRESHOLD_SHORT,
|
||||
PARSER_TEXT_LENGTH_THRESHOLD_LONG,
|
||||
PARSER_FEW_SHOT_NUMBER,
|
||||
PARSER_SELF_CONSISTENCY_NUMBER,
|
||||
PARSER_SHOW_COUNT
|
||||
);
|
||||
}
|
||||
|
||||
}
|
||||
@@ -12,7 +12,7 @@ import com.tencent.supersonic.headless.api.pojo.enums.MapModeEnum;
|
||||
import com.tencent.supersonic.headless.api.pojo.enums.WorkflowState;
|
||||
import com.tencent.supersonic.headless.api.pojo.request.QueryFilters;
|
||||
import com.tencent.supersonic.headless.core.chat.query.SemanticQuery;
|
||||
import com.tencent.supersonic.headless.core.config.OptimizationConfig;
|
||||
import com.tencent.supersonic.headless.core.config.ParserConfig;
|
||||
import lombok.AllArgsConstructor;
|
||||
import lombok.Builder;
|
||||
import lombok.Data;
|
||||
@@ -26,6 +26,8 @@ import java.util.Map;
|
||||
import java.util.Set;
|
||||
import java.util.stream.Collectors;
|
||||
|
||||
import static com.tencent.supersonic.headless.core.config.ParserConfig.PARSER_SHOW_COUNT;
|
||||
|
||||
@Data
|
||||
@Builder
|
||||
@NoArgsConstructor
|
||||
@@ -51,8 +53,8 @@ public class QueryContext {
|
||||
private LLMConfig llmConfig;
|
||||
|
||||
public List<SemanticQuery> getCandidateQueries() {
|
||||
OptimizationConfig optimizationConfig = ContextUtils.getBean(OptimizationConfig.class);
|
||||
Integer parseShowCount = optimizationConfig.getParseShowCount();
|
||||
ParserConfig parserConfig = ContextUtils.getBean(ParserConfig.class);
|
||||
int parseShowCount = Integer.valueOf(parserConfig.getParameterValue(PARSER_SHOW_COUNT));
|
||||
candidateQueries = candidateQueries.stream()
|
||||
.sorted(Comparator.comparing(semanticQuery -> semanticQuery.getParseInfo().getScore(),
|
||||
Comparator.reverseOrder()))
|
||||
|
||||
@@ -9,7 +9,7 @@ import com.tencent.supersonic.chat.server.service.AgentService;
|
||||
import com.tencent.supersonic.chat.server.service.ChatManageService;
|
||||
import com.tencent.supersonic.chat.server.service.ChatService;
|
||||
import com.tencent.supersonic.chat.server.service.PluginService;
|
||||
import com.tencent.supersonic.common.service.SysParameterService;
|
||||
import com.tencent.supersonic.common.service.SystemConfigService;
|
||||
import com.tencent.supersonic.headless.api.pojo.DataSetModelConfig;
|
||||
import com.tencent.supersonic.headless.api.pojo.DrillDownDimension;
|
||||
import com.tencent.supersonic.headless.api.pojo.RelateDimension;
|
||||
@@ -84,7 +84,7 @@ public abstract class S2BaseDemo implements CommandLineRunner {
|
||||
@Autowired
|
||||
protected AgentService agentService;
|
||||
@Autowired
|
||||
protected SysParameterService sysParameterService;
|
||||
protected SystemConfigService sysParameterService;
|
||||
@Autowired
|
||||
protected CanvasService canvasService;
|
||||
@Autowired
|
||||
|
||||
@@ -16,7 +16,7 @@ import com.tencent.supersonic.chat.server.plugin.PluginParseConfig;
|
||||
import com.tencent.supersonic.chat.server.plugin.build.WebBase;
|
||||
import com.tencent.supersonic.common.pojo.JoinCondition;
|
||||
import com.tencent.supersonic.common.pojo.ModelRela;
|
||||
import com.tencent.supersonic.common.pojo.SysParameter;
|
||||
import com.tencent.supersonic.common.pojo.SystemConfig;
|
||||
import com.tencent.supersonic.common.pojo.enums.AggOperatorEnum;
|
||||
import com.tencent.supersonic.common.pojo.enums.AggregateTypeEnum;
|
||||
import com.tencent.supersonic.common.pojo.enums.FilterOperatorEnum;
|
||||
@@ -150,7 +150,7 @@ public class S2VisitsDemo extends S2BaseDemo {
|
||||
}
|
||||
|
||||
public void addSysParameter() {
|
||||
SysParameter sysParameter = new SysParameter();
|
||||
SystemConfig sysParameter = new SystemConfig();
|
||||
sysParameter.setId(1);
|
||||
sysParameter.init();
|
||||
sysParameterService.save(sysParameter);
|
||||
|
||||
@@ -37,7 +37,7 @@ logging:
|
||||
|
||||
s2:
|
||||
parser:
|
||||
strategy: TWO_PASS_AUTO_COT_SELF_CONSISTENCY
|
||||
strategy: ONE_PASS_AUTO_COT_SELF_CONSISTENCY
|
||||
exemplar-recall:
|
||||
number: 5
|
||||
few-shot:
|
||||
|
||||
Reference in New Issue
Block a user