diff --git a/auth/api/src/main/java/com/tencent/supersonic/auth/api/authentication/utils/UserHolder.java b/auth/api/src/main/java/com/tencent/supersonic/auth/api/authentication/utils/UserHolder.java index edb51159e..0397e667e 100644 --- a/auth/api/src/main/java/com/tencent/supersonic/auth/api/authentication/utils/UserHolder.java +++ b/auth/api/src/main/java/com/tencent/supersonic/auth/api/authentication/utils/UserHolder.java @@ -2,8 +2,8 @@ package com.tencent.supersonic.auth.api.authentication.utils; import com.tencent.supersonic.auth.api.authentication.pojo.User; import com.tencent.supersonic.auth.api.authentication.service.UserStrategy; -import com.tencent.supersonic.common.pojo.SysParameter; -import com.tencent.supersonic.common.service.SysParameterService; +import com.tencent.supersonic.common.pojo.SystemConfig; +import com.tencent.supersonic.common.service.SystemConfigService; import com.tencent.supersonic.common.util.ContextUtils; import org.springframework.util.CollectionUtils; @@ -20,8 +20,8 @@ public final class UserHolder { public static User findUser(HttpServletRequest request, HttpServletResponse response) { User user = REPO.findUser(request, response); - SysParameterService sysParameterService = ContextUtils.getBean(SysParameterService.class); - SysParameter sysParameter = sysParameterService.getSysParameter(); + SystemConfigService sysParameterService = ContextUtils.getBean(SystemConfigService.class); + SystemConfig sysParameter = sysParameterService.getSysParameter(); if (!CollectionUtils.isEmpty(sysParameter.getAdmins()) && sysParameter.getAdmins().contains(user.getName())) { user.setIsAdmin(1); diff --git a/auth/authentication/src/main/java/com/tencent/supersonic/auth/authentication/service/UserServiceImpl.java b/auth/authentication/src/main/java/com/tencent/supersonic/auth/authentication/service/UserServiceImpl.java index 6cd5c05c5..4b3bf7eb4 100644 --- a/auth/authentication/src/main/java/com/tencent/supersonic/auth/authentication/service/UserServiceImpl.java +++ b/auth/authentication/src/main/java/com/tencent/supersonic/auth/authentication/service/UserServiceImpl.java @@ -6,8 +6,8 @@ import com.tencent.supersonic.auth.api.authentication.request.UserReq; import com.tencent.supersonic.auth.api.authentication.service.UserService; import com.tencent.supersonic.auth.api.authentication.utils.UserHolder; import com.tencent.supersonic.auth.authentication.utils.ComponentFactory; -import com.tencent.supersonic.common.pojo.SysParameter; -import com.tencent.supersonic.common.service.SysParameterService; +import com.tencent.supersonic.common.pojo.SystemConfig; +import com.tencent.supersonic.common.service.SystemConfigService; import org.springframework.stereotype.Service; import org.springframework.util.CollectionUtils; import javax.servlet.http.HttpServletRequest; @@ -18,9 +18,9 @@ import java.util.Set; @Service public class UserServiceImpl implements UserService { - private SysParameterService sysParameterService; + private SystemConfigService sysParameterService; - public UserServiceImpl(SysParameterService sysParameterService) { + public UserServiceImpl(SystemConfigService sysParameterService) { this.sysParameterService = sysParameterService; } @@ -28,7 +28,7 @@ public class UserServiceImpl implements UserService { public User getCurrentUser(HttpServletRequest httpServletRequest, HttpServletResponse httpServletResponse) { User user = UserHolder.findUser(httpServletRequest, httpServletResponse); if (user != null) { - SysParameter sysParameter = sysParameterService.getSysParameter(); + SystemConfig sysParameter = sysParameterService.getSysParameter(); if (!CollectionUtils.isEmpty(sysParameter.getAdmins()) && sysParameter.getAdmins().contains(user.getName())) { user.setIsAdmin(1); diff --git a/chat/server/src/main/java/com/tencent/supersonic/chat/server/parser/MultiTurnParser.java b/chat/server/src/main/java/com/tencent/supersonic/chat/server/parser/MultiTurnParser.java index a6aac1fc2..94930d361 100644 --- a/chat/server/src/main/java/com/tencent/supersonic/chat/server/parser/MultiTurnParser.java +++ b/chat/server/src/main/java/com/tencent/supersonic/chat/server/parser/MultiTurnParser.java @@ -23,7 +23,6 @@ import lombok.Data; import lombok.extern.slf4j.Slf4j; import org.slf4j.Logger; import org.slf4j.LoggerFactory; -import org.springframework.core.env.Environment; import java.util.Map; import java.util.HashMap; @@ -32,6 +31,8 @@ import java.util.List; import java.util.stream.Collectors; import java.util.Collections; +import static com.tencent.supersonic.chat.server.parser.ParserConfig.PARSER_MULTI_TURN_ENABLE; + @Slf4j public class MultiTurnParser implements ChatParser { @@ -51,9 +52,9 @@ public class MultiTurnParser implements ChatParser { @Override public void parse(ChatParseContext chatParseContext, ParseResp parseResp) { - Environment environment = ContextUtils.getBean(Environment.class); + ParserConfig parserConfig = ContextUtils.getBean(ParserConfig.class); MultiTurnConfig agentMultiTurnConfig = chatParseContext.getAgent().getMultiTurnConfig(); - Boolean globalMultiTurnConfig = environment.getProperty("s2.parser.multi-turn.enable", Boolean.class); + Boolean globalMultiTurnConfig = Boolean.valueOf(parserConfig.getParameterValue(PARSER_MULTI_TURN_ENABLE)); Boolean multiTurnConfig = agentMultiTurnConfig != null ? agentMultiTurnConfig.isEnableMultiTurn() : globalMultiTurnConfig; diff --git a/chat/server/src/main/java/com/tencent/supersonic/chat/server/parser/NL2PluginParser.java b/chat/server/src/main/java/com/tencent/supersonic/chat/server/parser/NL2PluginParser.java index 8182797c1..19608d2d8 100644 --- a/chat/server/src/main/java/com/tencent/supersonic/chat/server/parser/NL2PluginParser.java +++ b/chat/server/src/main/java/com/tencent/supersonic/chat/server/parser/NL2PluginParser.java @@ -17,8 +17,8 @@ public class NL2PluginParser implements ChatParser { public void parse(ChatParseContext chatParseContext, ParseResp parseResp) { pluginRecognizers.forEach(pluginRecognizer -> { pluginRecognizer.recognize(chatParseContext, parseResp); - log.info("{} context:{} result:{}", pluginRecognizer.getClass().getSimpleName(), - JsonUtil.toString(chatParseContext), JsonUtil.toString(parseResp)); + log.info("{} recallResult:{}", pluginRecognizer.getClass().getSimpleName(), + JsonUtil.toString(parseResp)); }); } diff --git a/chat/server/src/main/java/com/tencent/supersonic/chat/server/parser/ParserConfig.java b/chat/server/src/main/java/com/tencent/supersonic/chat/server/parser/ParserConfig.java new file mode 100644 index 000000000..a2f0d8c4a --- /dev/null +++ b/chat/server/src/main/java/com/tencent/supersonic/chat/server/parser/ParserConfig.java @@ -0,0 +1,27 @@ +package com.tencent.supersonic.chat.server.parser; + +import com.google.common.collect.Lists; +import com.tencent.supersonic.common.pojo.ParameterConfig; +import com.tencent.supersonic.common.pojo.Parameter; +import lombok.extern.slf4j.Slf4j; +import org.springframework.stereotype.Service; + +import java.util.List; + +@Service("ChatParserConfig") +@Slf4j +public class ParserConfig extends ParameterConfig { + + public static final Parameter PARSER_MULTI_TURN_ENABLE = + new Parameter("s2.parser.multi-turn.enable", "false", + "是否开启多轮对话", "开启多轮对话将消耗更多token", + "bool", "Parser相关配置"); + + @Override + public List getSysParameters() { + return Lists.newArrayList( + PARSER_MULTI_TURN_ENABLE + ); + } + +} diff --git a/common/src/main/java/com/tencent/supersonic/common/pojo/Parameter.java b/common/src/main/java/com/tencent/supersonic/common/pojo/Parameter.java index 248ab7a92..efb29bab9 100644 --- a/common/src/main/java/com/tencent/supersonic/common/pojo/Parameter.java +++ b/common/src/main/java/com/tencent/supersonic/common/pojo/Parameter.java @@ -10,15 +10,17 @@ import java.util.List; @NoArgsConstructor public class Parameter { private String name; - private String value; + private String defaultValue; private String comment; private String description; private String dataType; private String module; private List candidateValues; - public Parameter(String name, String value, String comment, String description, String dataType, String module) { + + public Parameter(String name, String defaultValue, String comment, + String description, String dataType, String module) { this.name = name; - this.value = value; + this.defaultValue = defaultValue; this.comment = comment; this.description = description; this.dataType = dataType; diff --git a/common/src/main/java/com/tencent/supersonic/common/pojo/ParameterConfig.java b/common/src/main/java/com/tencent/supersonic/common/pojo/ParameterConfig.java new file mode 100644 index 000000000..b6c7310ed --- /dev/null +++ b/common/src/main/java/com/tencent/supersonic/common/pojo/ParameterConfig.java @@ -0,0 +1,49 @@ +package com.tencent.supersonic.common.pojo; + +import com.tencent.supersonic.common.service.SystemConfigService; +import lombok.extern.slf4j.Slf4j; +import org.apache.commons.lang3.StringUtils; +import org.springframework.beans.factory.annotation.Autowired; +import org.springframework.core.env.Environment; +import org.springframework.stereotype.Service; + +import java.util.List; + +@Service +@Slf4j +public abstract class ParameterConfig { + + @Autowired + private SystemConfigService sysConfigService; + + @Autowired + private Environment environment; + + protected abstract List getSysParameters(); + + /** + * Parameter value will be derived by following orders: + * 1. `system config` set with user interface + * 2. `system property` set with application.yaml + * 3. `default value` set with parameter declaration + * @param parameter + * @return parameter value + */ + public String getParameterValue(Parameter parameter) { + String paramName = parameter.getName(); + String value = sysConfigService.getSysParameter().getParameterByName(paramName); + try { + if (StringUtils.isBlank(value)) { + if (environment.containsProperty(paramName)) { + value = environment.getProperty(paramName); + } else { + value = parameter.getDefaultValue(); + } + } + } catch (Exception e) { + log.error("Failed to get parameter value for {} with exception: {}", paramName, e); + } + + return value; + } +} diff --git a/common/src/main/java/com/tencent/supersonic/common/pojo/SysParameter.java b/common/src/main/java/com/tencent/supersonic/common/pojo/SysParameter.java deleted file mode 100644 index 5e92efaa1..000000000 --- a/common/src/main/java/com/tencent/supersonic/common/pojo/SysParameter.java +++ /dev/null @@ -1,111 +0,0 @@ -package com.tencent.supersonic.common.pojo; - -import com.google.common.collect.Lists; -import lombok.Data; -import org.apache.commons.lang3.StringUtils; -import org.springframework.util.CollectionUtils; - -import java.util.Arrays; -import java.util.List; -import java.util.Map; -import java.util.stream.Collectors; - -@Data -public class SysParameter { - - private Integer id; - - private List admins; - - private List parameters; - - public String getAdmin() { - if (CollectionUtils.isEmpty(admins)) { - return ""; - } - return StringUtils.join(admins, ","); - } - - public String getParameterByName(String name) { - if (StringUtils.isBlank(name)) { - return ""; - } - Map nameToValue = parameters.stream() - .collect(Collectors.toMap(Parameter::getName, Parameter::getValue, (k1, k2) -> k1)); - return nameToValue.get(name); - } - - public void setAdminList(String admin) { - if (StringUtils.isNotBlank(admin)) { - admins = Arrays.asList(admin.split(",")); - } else { - admins = Lists.newArrayList(); - } - } - - public void init() { - parameters = Lists.newArrayList(); - admins = Lists.newArrayList("admin"); - - //detect config - parameters.add(new Parameter("s2.one.detection.size", "8", - "一次探测返回结果个数", "在每次探测后, 将前后缀匹配的结果合并, 并根据相似度阈值过滤后的结果个数", - "number", "Mapper相关配置")); - parameters.add(new Parameter("s2.one.detection.max.size", "20", - "一次探测前后缀匹配结果返回个数", "单次前后缀匹配返回的结果个数", "number", "Mapper相关配置")); - - //mapper config - parameters.add(new Parameter("s2.metric.dimension.threshold", "0.3", - "指标名、维度名文本相似度阈值", "文本片段和匹配到的指标、维度名计算出来的编辑距离阈值, 若超出该阈值, 则舍弃", - "number", "Mapper相关配置")); - parameters.add(new Parameter("s2.metric.dimension.min.threshold", "0.25", - "指标名、维度名最小文本相似度阈值", "指标名、维度名相似度阈值在动态调整中的最低值", - "number", "Mapper相关配置")); - parameters.add(new Parameter("s2.dimension.value.threshold", "0.5", - "维度值文本相似度阈值", "文本片段和匹配到的维度值计算出来的编辑距离阈值, 若超出该阈值, 则舍弃", - "number", "Mapper相关配置")); - parameters.add(new Parameter("s2.dimension.value.min.threshold", "0.3", - "维度值最小文本相似度阈值", "维度值相似度阈值在动态调整中的最低值", - "number", "Mapper相关配置")); - - //embedding mapper config - parameters.add(new Parameter("s2.embedding.mapper.word.min", - "4", "用于向量召回最小的文本长度", "为提高向量召回效率, 小于该长度的文本不进行向量语义召回", "number", "Mapper相关配置")); - parameters.add(new Parameter("s2.embedding.mapper.word.max", "5", - "用于向量召回最大的文本长度", "为提高向量召回效率, 大于该长度的文本不进行向量语义召回", "number", "Mapper相关配置")); - parameters.add(new Parameter("s2.embedding.mapper.batch", "50", - "批量向量召回文本请求个数", "每次进行向量语义召回的原始文本片段个数", "number", "Mapper相关配置")); - parameters.add(new Parameter("s2.embedding.mapper.number", "5", - "批量向量召回文本返回结果个数", "每个文本进行向量语义召回的文本结果个数", "number", "Mapper相关配置")); - parameters.add(new Parameter("s2.embedding.mapper.threshold", - "0.99", "向量召回相似度阈值", "相似度小于该阈值的则舍弃", "number", "Mapper相关配置")); - parameters.add(new Parameter("s2.embedding.mapper.min.threshold", - "0.9", "向量召回最小相似度阈值", "向量召回相似度阈值在动态调整中的最低值", "number", "Mapper相关配置")); - - //parser config - Parameter s2SQLParameter = new Parameter("s2.parser.strategy", - "TWO_PASS_AUTO_COT_SELF_CONSISTENCY", - "LLM解析生成S2SQL策略", - "ONE_PASS_AUTO_COT_SELF_CONSISTENCY: 通过思维链且投票方式一步生成sql" - + "\nTWO_PASS_AUTO_COT_SELF_CONSISTENCY: 通过思维链且投票方式两步生成sql", "list", "Parser相关配置"); - - s2SQLParameter.setCandidateValues(Lists.newArrayList( - "ONE_PASS_AUTO_COT_SELF_CONSISTENCY", "TWO_PASS_AUTO_COT_SELF_CONSISTENCY")); - parameters.add(s2SQLParameter); - parameters.add(new Parameter("s2.s2SQL.linking.value.switch", "true", - "是否将Mapper探测识别到的维度值提供给大模型", "为了数据安全考虑, 这里可进行开关选择", - "bool", "Parser相关配置")); - parameters.add(new Parameter("s2.query.text.length.threshold", "10", - "用户输入文本长短阈值", "文本超过该阈值为长文本", "number", "Parser相关配置")); - parameters.add(new Parameter("s2.short.text.threshold", "0.5", - "短文本匹配阈值", "由于请求大模型耗时较长, 因此如果有规则类型的Query得分达到阈值,则跳过大模型的调用,\n如果是短文本, 若query得分/文本长度>该阈值, 则跳过当前parser", - "number", "Parser相关配置")); - parameters.add(new Parameter("s2.long.text.threshold", "0.8", - "长文本匹配阈值", "如果是长文本, 若query得分/文本长度>该阈值, 则跳过当前parser", - "number", "Parser相关配置")); - parameters.add(new Parameter("s2.parse.show-count", "3", - "解析结果个数", "前端展示的解析个数", - "number", "Parser相关配置")); - } - -} diff --git a/common/src/main/java/com/tencent/supersonic/common/pojo/SystemConfig.java b/common/src/main/java/com/tencent/supersonic/common/pojo/SystemConfig.java new file mode 100644 index 000000000..4d5e62fc6 --- /dev/null +++ b/common/src/main/java/com/tencent/supersonic/common/pojo/SystemConfig.java @@ -0,0 +1,59 @@ +package com.tencent.supersonic.common.pojo; + +import com.google.common.collect.Lists; +import com.tencent.supersonic.common.util.ContextUtils; +import lombok.Data; +import org.apache.commons.lang3.StringUtils; +import org.springframework.util.CollectionUtils; + +import java.util.Arrays; +import java.util.Collection; +import java.util.List; +import java.util.Map; +import java.util.stream.Collectors; + +@Data +public class SystemConfig { + + private Integer id; + + private List admins; + + private List parameters; + + public String getAdmin() { + if (CollectionUtils.isEmpty(admins)) { + return ""; + } + return StringUtils.join(admins, ","); + } + + public String getParameterByName(String name) { + if (StringUtils.isBlank(name)) { + return ""; + } + Map nameToValue = parameters.stream() + .collect(Collectors.toMap(Parameter::getName, Parameter::getDefaultValue, (k1, k2) -> k1)); + return nameToValue.get(name); + } + + public void setAdminList(String admin) { + if (StringUtils.isNotBlank(admin)) { + admins = Arrays.asList(admin.split(",")); + } else { + admins = Lists.newArrayList(); + } + } + + public void init() { + parameters = Lists.newArrayList(); + admins = Lists.newArrayList("admin"); + + Collection configurableParameters = + ContextUtils.getBeansOfType(ParameterConfig.class).values(); + for (ParameterConfig configParameters : configurableParameters) { + parameters.addAll(configParameters.getSysParameters()); + } + } + +} diff --git a/common/src/main/java/com/tencent/supersonic/common/rest/SysParameterController.java b/common/src/main/java/com/tencent/supersonic/common/rest/SystemConfigController.java similarity index 58% rename from common/src/main/java/com/tencent/supersonic/common/rest/SysParameterController.java rename to common/src/main/java/com/tencent/supersonic/common/rest/SystemConfigController.java index 7cc3fc06e..00308d186 100644 --- a/common/src/main/java/com/tencent/supersonic/common/rest/SysParameterController.java +++ b/common/src/main/java/com/tencent/supersonic/common/rest/SystemConfigController.java @@ -1,7 +1,7 @@ package com.tencent.supersonic.common.rest; -import com.tencent.supersonic.common.pojo.SysParameter; -import com.tencent.supersonic.common.service.SysParameterService; +import com.tencent.supersonic.common.pojo.SystemConfig; +import com.tencent.supersonic.common.service.SystemConfigService; import org.springframework.beans.factory.annotation.Autowired; import org.springframework.web.bind.annotation.GetMapping; import org.springframework.web.bind.annotation.PostMapping; @@ -11,20 +11,20 @@ import org.springframework.web.bind.annotation.RestController; @RestController @RequestMapping({"/api/semantic/parameter"}) -public class SysParameterController { +public class SystemConfigController { @Autowired - private SysParameterService sysParameterService; + private SystemConfigService sysConfigService; @PostMapping - public Boolean save(@RequestBody SysParameter sysParameter) { - sysParameterService.save(sysParameter); + public Boolean save(@RequestBody SystemConfig sysParameter) { + sysConfigService.save(sysParameter); return true; } @GetMapping - public SysParameter get() { - return sysParameterService.getSysParameter(); + public SystemConfig get() { + return sysConfigService.getSysParameter(); } } diff --git a/common/src/main/java/com/tencent/supersonic/common/service/SysParameterService.java b/common/src/main/java/com/tencent/supersonic/common/service/SystemConfigService.java similarity index 50% rename from common/src/main/java/com/tencent/supersonic/common/service/SysParameterService.java rename to common/src/main/java/com/tencent/supersonic/common/service/SystemConfigService.java index 1c074f3c2..ff830737f 100644 --- a/common/src/main/java/com/tencent/supersonic/common/service/SysParameterService.java +++ b/common/src/main/java/com/tencent/supersonic/common/service/SystemConfigService.java @@ -2,12 +2,12 @@ package com.tencent.supersonic.common.service; import com.baomidou.mybatisplus.extension.service.IService; import com.tencent.supersonic.common.persistence.dataobject.SysParameterDO; -import com.tencent.supersonic.common.pojo.SysParameter; +import com.tencent.supersonic.common.pojo.SystemConfig; -public interface SysParameterService extends IService { +public interface SystemConfigService extends IService { - SysParameter getSysParameter(); + SystemConfig getSysParameter(); - void save(SysParameter sysParameter); + void save(SystemConfig sysConfig); } diff --git a/common/src/main/java/com/tencent/supersonic/common/service/impl/SysParameterServiceImpl.java b/common/src/main/java/com/tencent/supersonic/common/service/impl/SystemConfigServiceImpl.java similarity index 74% rename from common/src/main/java/com/tencent/supersonic/common/service/impl/SysParameterServiceImpl.java rename to common/src/main/java/com/tencent/supersonic/common/service/impl/SystemConfigServiceImpl.java index 370ee01ae..e85593214 100644 --- a/common/src/main/java/com/tencent/supersonic/common/service/impl/SysParameterServiceImpl.java +++ b/common/src/main/java/com/tencent/supersonic/common/service/impl/SystemConfigServiceImpl.java @@ -6,22 +6,22 @@ import com.fasterxml.jackson.core.type.TypeReference; import com.tencent.supersonic.common.persistence.dataobject.SysParameterDO; import com.tencent.supersonic.common.persistence.mapper.SysParameterMapper; import com.tencent.supersonic.common.pojo.Parameter; -import com.tencent.supersonic.common.pojo.SysParameter; -import com.tencent.supersonic.common.service.SysParameterService; +import com.tencent.supersonic.common.pojo.SystemConfig; +import com.tencent.supersonic.common.service.SystemConfigService; import com.tencent.supersonic.common.util.JsonUtil; import java.util.List; import org.springframework.stereotype.Service; import org.springframework.util.CollectionUtils; @Service -public class SysParameterServiceImpl - extends ServiceImpl implements SysParameterService { +public class SystemConfigServiceImpl + extends ServiceImpl implements SystemConfigService { @Override - public SysParameter getSysParameter() { + public SystemConfig getSysParameter() { List list = list(); if (CollectionUtils.isEmpty(list)) { - SysParameter sysParameter = new SysParameter(); + SystemConfig sysParameter = new SystemConfig(); sysParameter.setId(1); sysParameter.init(); save(sysParameter); @@ -31,13 +31,13 @@ public class SysParameterServiceImpl } @Override - public void save(SysParameter sysParameter) { - SysParameterDO sysParameterDO = convert(sysParameter); + public void save(SystemConfig sysConfig) { + SysParameterDO sysParameterDO = convert(sysConfig); saveOrUpdate(sysParameterDO); } - private SysParameter convert(SysParameterDO sysParameterDO) { - SysParameter sysParameter = new SysParameter(); + private SystemConfig convert(SysParameterDO sysParameterDO) { + SystemConfig sysParameter = new SystemConfig(); sysParameter.setId(sysParameterDO.getId()); List parameters = JsonUtil.toObject(sysParameterDO.getParameters(), new TypeReference>() { @@ -47,7 +47,7 @@ public class SysParameterServiceImpl return sysParameter; } - private SysParameterDO convert(SysParameter sysParameter) { + private SysParameterDO convert(SystemConfig sysParameter) { SysParameterDO sysParameterDO = new SysParameterDO(); sysParameterDO.setId(sysParameter.getId()); sysParameterDO.setParameters(JSONObject.toJSONString(sysParameter.getParameters())); diff --git a/headless/core/src/main/java/com/tencent/supersonic/headless/core/chat/mapper/BaseMatchStrategy.java b/headless/core/src/main/java/com/tencent/supersonic/headless/core/chat/mapper/BaseMatchStrategy.java index 3bf8151c1..75ae4d726 100644 --- a/headless/core/src/main/java/com/tencent/supersonic/headless/core/chat/mapper/BaseMatchStrategy.java +++ b/headless/core/src/main/java/com/tencent/supersonic/headless/core/chat/mapper/BaseMatchStrategy.java @@ -3,6 +3,7 @@ package com.tencent.supersonic.headless.core.chat.mapper; import com.tencent.supersonic.headless.api.pojo.enums.MapModeEnum; import com.tencent.supersonic.headless.api.pojo.response.S2Term; +import com.tencent.supersonic.headless.core.config.MapperConfig; import com.tencent.supersonic.headless.core.pojo.QueryContext; import com.tencent.supersonic.headless.core.chat.knowledge.helper.NatureHelper; import lombok.extern.slf4j.Slf4j; @@ -27,7 +28,10 @@ import java.util.stream.Collectors; public abstract class BaseMatchStrategy implements MatchStrategy { @Autowired - private MapperHelper mapperHelper; + protected MapperHelper mapperHelper; + + @Autowired + protected MapperConfig mapperConfig; @Override public Map> match(QueryContext queryContext, List terms, diff --git a/headless/core/src/main/java/com/tencent/supersonic/headless/core/chat/mapper/DatabaseMatchStrategy.java b/headless/core/src/main/java/com/tencent/supersonic/headless/core/chat/mapper/DatabaseMatchStrategy.java index 965a29d28..358b2c532 100644 --- a/headless/core/src/main/java/com/tencent/supersonic/headless/core/chat/mapper/DatabaseMatchStrategy.java +++ b/headless/core/src/main/java/com/tencent/supersonic/headless/core/chat/mapper/DatabaseMatchStrategy.java @@ -5,12 +5,10 @@ import com.tencent.supersonic.common.pojo.Constants; import com.tencent.supersonic.headless.api.pojo.SchemaElement; import com.tencent.supersonic.headless.api.pojo.SchemaElementMatch; import com.tencent.supersonic.headless.api.pojo.response.S2Term; -import com.tencent.supersonic.headless.core.config.OptimizationConfig; import com.tencent.supersonic.headless.core.pojo.QueryContext; import com.tencent.supersonic.headless.core.chat.knowledge.DatabaseMapResult; import lombok.extern.slf4j.Slf4j; import org.apache.commons.lang3.StringUtils; -import org.springframework.beans.factory.annotation.Autowired; import org.springframework.stereotype.Service; import org.springframework.util.CollectionUtils; @@ -22,6 +20,9 @@ import java.util.Map.Entry; import java.util.Set; import java.util.stream.Collectors; +import static com.tencent.supersonic.headless.core.config.MapperConfig.MAPPER_NAME_THRESHOLD; +import static com.tencent.supersonic.headless.core.config.MapperConfig.MAPPER_NAME_THRESHOLD_MIN; + /** * DatabaseMatchStrategy uses SQL LIKE operator to match schema elements. * It currently supports fuzzy matching against names and aliases. @@ -30,10 +31,6 @@ import java.util.stream.Collectors; @Slf4j public class DatabaseMatchStrategy extends BaseMatchStrategy { - @Autowired - private OptimizationConfig optimizationConfig; - @Autowired - private MapperHelper mapperHelper; private List allElements; @Override @@ -94,9 +91,8 @@ public class DatabaseMatchStrategy extends BaseMatchStrategy } private Double getThreshold(QueryContext queryContext) { - - Double threshold = optimizationConfig.getMetricDimensionThresholdConfig(); - Double minThreshold = optimizationConfig.getMetricDimensionMinThresholdConfig(); + Double threshold = Double.valueOf(mapperConfig.getParameterValue(MAPPER_NAME_THRESHOLD)); + Double minThreshold = Double.valueOf(mapperConfig.getParameterValue(MAPPER_NAME_THRESHOLD_MIN)); Map> modelElementMatches = queryContext.getMapInfo().getDataSetElementMatches(); diff --git a/headless/core/src/main/java/com/tencent/supersonic/headless/core/chat/mapper/EmbeddingMatchStrategy.java b/headless/core/src/main/java/com/tencent/supersonic/headless/core/chat/mapper/EmbeddingMatchStrategy.java index 40a8ddcf4..8bc033eb7 100644 --- a/headless/core/src/main/java/com/tencent/supersonic/headless/core/chat/mapper/EmbeddingMatchStrategy.java +++ b/headless/core/src/main/java/com/tencent/supersonic/headless/core/chat/mapper/EmbeddingMatchStrategy.java @@ -5,7 +5,6 @@ import com.tencent.supersonic.common.pojo.Constants; import com.tencent.supersonic.common.util.embedding.Retrieval; import com.tencent.supersonic.common.util.embedding.RetrieveQuery; import com.tencent.supersonic.common.util.embedding.RetrieveQueryResult; -import com.tencent.supersonic.headless.core.config.OptimizationConfig; import com.tencent.supersonic.headless.core.chat.knowledge.EmbeddingResult; import com.tencent.supersonic.headless.core.chat.knowledge.MetaEmbeddingService; import com.tencent.supersonic.headless.core.pojo.QueryContext; @@ -22,6 +21,14 @@ import java.util.Map; import java.util.Set; import java.util.stream.Collectors; +import static com.tencent.supersonic.headless.core.config.MapperConfig.EMBEDDING_MAPPER_BATCH; +import static com.tencent.supersonic.headless.core.config.MapperConfig.EMBEDDING_MAPPER_MAX; +import static com.tencent.supersonic.headless.core.config.MapperConfig.EMBEDDING_MAPPER_MIN; +import static com.tencent.supersonic.headless.core.config.MapperConfig.EMBEDDING_MAPPER_NUMBER; +import static com.tencent.supersonic.headless.core.config.MapperConfig.EMBEDDING_MAPPER_ROUND_NUMBER; +import static com.tencent.supersonic.headless.core.config.MapperConfig.EMBEDDING_MAPPER_THRESHOLD; +import static com.tencent.supersonic.headless.core.config.MapperConfig.EMBEDDING_MAPPER_THRESHOLD_MIN; + /** * EmbeddingMatchStrategy uses vector database to perform * similarity search against the embeddings of schema elements. @@ -30,9 +37,6 @@ import java.util.stream.Collectors; @Slf4j public class EmbeddingMatchStrategy extends BaseMatchStrategy { - @Autowired - private OptimizationConfig optimizationConfig; - @Autowired private MetaEmbeddingService metaEmbeddingService; @@ -48,24 +52,27 @@ public class EmbeddingMatchStrategy extends BaseMatchStrategy { } @Override - public void detectByStep(QueryContext queryContext, Set existResults, Set detectDataSetIds, - String detectSegment, int offset) { + public void detectByStep(QueryContext queryContext, Set existResults, + Set detectDataSetIds, String detectSegment, int offset) { } @Override - protected void detectByBatch(QueryContext queryContext, Set results, Set detectDataSetIds, - Set detectSegments) { + protected void detectByBatch(QueryContext queryContext, Set results, + Set detectDataSetIds, Set detectSegments) { + int embedddingMapperMin = Integer.valueOf(mapperConfig.getParameterValue(EMBEDDING_MAPPER_MIN)); + int embedddingMapperMax = Integer.valueOf(mapperConfig.getParameterValue(EMBEDDING_MAPPER_MAX)); + int embeddingMapperBatch = Integer.valueOf(mapperConfig.getParameterValue(EMBEDDING_MAPPER_BATCH)); List queryTextsList = detectSegments.stream() .map(detectSegment -> detectSegment.trim()) .filter(detectSegment -> StringUtils.isNotBlank(detectSegment) - && detectSegment.length() >= optimizationConfig.getEmbeddingMapperWordMin() - && detectSegment.length() <= optimizationConfig.getEmbeddingMapperWordMax()) + && detectSegment.length() >= embedddingMapperMin + && detectSegment.length() <= embedddingMapperMax) .collect(Collectors.toList()); List> queryTextsSubList = Lists.partition(queryTextsList, - optimizationConfig.getEmbeddingMapperBatch()); + embeddingMapperBatch); for (List queryTextsSub : queryTextsSubList) { detectByQueryTextsSub(results, detectDataSetIds, queryTextsSub, queryContext); @@ -74,15 +81,16 @@ public class EmbeddingMatchStrategy extends BaseMatchStrategy { private void detectByQueryTextsSub(Set results, Set detectDataSetIds, List queryTextsSub, QueryContext queryContext) { - Map> modelIdToDataSetIds = queryContext.getModelIdToDataSetIds(); - int embeddingNumber = optimizationConfig.getEmbeddingMapperNumber(); - double threshold = getThreshold(optimizationConfig.getEmbeddingMapperThreshold(), - optimizationConfig.getEmbeddingMapperMinThreshold(), queryContext.getMapModeEnum()); + double embeddingThreshold = Double.valueOf(mapperConfig.getParameterValue(EMBEDDING_MAPPER_THRESHOLD)); + double embeddingThresholdMin = Double.valueOf(mapperConfig.getParameterValue(EMBEDDING_MAPPER_THRESHOLD_MIN)); + double threshold = getThreshold(embeddingThreshold, embeddingThresholdMin, queryContext.getMapModeEnum()); + // step1. build query params RetrieveQuery retrieveQuery = RetrieveQuery.builder().queryTextsList(queryTextsSub).build(); - // step2. retrieveQuery by detectSegment + // step2. retrieveQuery by detectSegment + int embeddingNumber = Integer.valueOf(mapperConfig.getParameterValue(EMBEDDING_MAPPER_NUMBER)); List retrieveQueryResults = metaEmbeddingService.retrieveQuery( retrieveQuery, embeddingNumber, modelIdToDataSetIds, detectDataSetIds); @@ -118,7 +126,8 @@ public class EmbeddingMatchStrategy extends BaseMatchStrategy { .collect(Collectors.toList()); // step4. select mapResul in one round - int roundNumber = optimizationConfig.getEmbeddingMapperRoundNumber() * queryTextsSub.size(); + int embeddingRoundNumber = Integer.valueOf(mapperConfig.getParameterValue(EMBEDDING_MAPPER_ROUND_NUMBER)); + int roundNumber = embeddingRoundNumber * queryTextsSub.size(); List oneRoundResults = collect.stream() .sorted(Comparator.comparingDouble(EmbeddingResult::getDistance)) .limit(roundNumber) diff --git a/headless/core/src/main/java/com/tencent/supersonic/headless/core/chat/mapper/HanlpDictMatchStrategy.java b/headless/core/src/main/java/com/tencent/supersonic/headless/core/chat/mapper/HanlpDictMatchStrategy.java index e086b167c..c7e71c967 100644 --- a/headless/core/src/main/java/com/tencent/supersonic/headless/core/chat/mapper/HanlpDictMatchStrategy.java +++ b/headless/core/src/main/java/com/tencent/supersonic/headless/core/chat/mapper/HanlpDictMatchStrategy.java @@ -4,7 +4,6 @@ import com.tencent.supersonic.common.pojo.Constants; import com.tencent.supersonic.headless.api.pojo.response.S2Term; import com.tencent.supersonic.headless.core.chat.knowledge.HanlpMapResult; import com.tencent.supersonic.headless.core.chat.knowledge.KnowledgeBaseService; -import com.tencent.supersonic.headless.core.config.OptimizationConfig; import com.tencent.supersonic.headless.core.pojo.QueryContext; import lombok.extern.slf4j.Slf4j; import org.apache.commons.collections.CollectionUtils; @@ -21,6 +20,14 @@ import java.util.Objects; import java.util.Set; import java.util.stream.Collectors; +import static com.tencent.supersonic.headless.core.config.MapperConfig.MAPPER_DETECTION_MAX_SIZE; +import static com.tencent.supersonic.headless.core.config.MapperConfig.MAPPER_DETECTION_SIZE; +import static com.tencent.supersonic.headless.core.config.MapperConfig.MAPPER_DIMENSION_VALUE_SIZE; +import static com.tencent.supersonic.headless.core.config.MapperConfig.MAPPER_NAME_THRESHOLD; +import static com.tencent.supersonic.headless.core.config.MapperConfig.MAPPER_NAME_THRESHOLD_MIN; +import static com.tencent.supersonic.headless.core.config.MapperConfig.MAPPER_VALUE_THRESHOLD; +import static com.tencent.supersonic.headless.core.config.MapperConfig.MAPPER_VALUE_THRESHOLD_MIN; + /** * HanlpDictMatchStrategy uses HanLP to * match schema elements. It currently supports prefix and suffix matching @@ -30,12 +37,6 @@ import java.util.stream.Collectors; @Slf4j public class HanlpDictMatchStrategy extends BaseMatchStrategy { - @Autowired - private MapperHelper mapperHelper; - - @Autowired - private OptimizationConfig optimizationConfig; - @Autowired private KnowledgeBaseService knowledgeBaseService; @@ -65,7 +66,7 @@ public class HanlpDictMatchStrategy extends BaseMatchStrategy { public void detectByStep(QueryContext queryContext, Set existResults, Set detectDataSetIds, String detectSegment, int offset) { // step1. pre search - Integer oneDetectionMaxSize = optimizationConfig.getOneDetectionMaxSize(); + Integer oneDetectionMaxSize = Integer.valueOf(mapperConfig.getParameterValue(MAPPER_DETECTION_MAX_SIZE)); LinkedHashSet hanlpMapResults = knowledgeBaseService.prefixSearch(detectSegment, oneDetectionMaxSize, queryContext.getModelIdToDataSetIds(), detectDataSetIds) .stream().collect(Collectors.toCollection(LinkedHashSet::new)); @@ -99,12 +100,13 @@ public class HanlpDictMatchStrategy extends BaseMatchStrategy { }).collect(Collectors.toCollection(LinkedHashSet::new)); // step5. take only M dimensionValue or N-M metric/dimension value per rond. + int oneDetectionValueSize = Integer.valueOf(mapperConfig.getParameterValue(MAPPER_DIMENSION_VALUE_SIZE)); List dimensionValues = hanlpMapResults.stream() .filter(entry -> mapperHelper.existDimensionValues(entry.getNatures())) - .limit(optimizationConfig.getOneDetectionDimensionValueSize()) + .limit(oneDetectionValueSize) .collect(Collectors.toList()); - Integer oneDetectionSize = optimizationConfig.getOneDetectionSize(); + Integer oneDetectionSize = Integer.valueOf(mapperConfig.getParameterValue(MAPPER_DETECTION_SIZE)); List oneRoundResults = new ArrayList<>(); // add the dimensionValue if it exists @@ -129,13 +131,14 @@ public class HanlpDictMatchStrategy extends BaseMatchStrategy { } public double getThresholdMatch(List natures, QueryContext queryContext) { - Double threshold = optimizationConfig.getMetricDimensionThresholdConfig(); - Double minThreshold = optimizationConfig.getMetricDimensionMinThresholdConfig(); + Double threshold = Double.valueOf(mapperConfig.getParameterValue(MAPPER_NAME_THRESHOLD)); + Double minThreshold = Double.valueOf(mapperConfig.getParameterValue(MAPPER_NAME_THRESHOLD_MIN)); if (mapperHelper.existDimensionValues(natures)) { - threshold = optimizationConfig.getDimensionValueThresholdConfig(); - minThreshold = optimizationConfig.getDimensionValueMinThresholdConfig(); + threshold = Double.valueOf(mapperConfig.getParameterValue(MAPPER_VALUE_THRESHOLD)); + minThreshold = Double.valueOf(mapperConfig.getParameterValue(MAPPER_VALUE_THRESHOLD_MIN)); } - return getThreshold(threshold, minThreshold, queryContext.getMapModeEnum()); + return getThreshold(threshold, minThreshold, queryContext.getMapModeEnum()); } + } diff --git a/headless/core/src/main/java/com/tencent/supersonic/headless/core/chat/mapper/MapperHelper.java b/headless/core/src/main/java/com/tencent/supersonic/headless/core/chat/mapper/MapperHelper.java index a245a4287..a1d022c07 100644 --- a/headless/core/src/main/java/com/tencent/supersonic/headless/core/chat/mapper/MapperHelper.java +++ b/headless/core/src/main/java/com/tencent/supersonic/headless/core/chat/mapper/MapperHelper.java @@ -2,11 +2,9 @@ package com.tencent.supersonic.headless.core.chat.mapper; import com.hankcs.hanlp.algorithm.EditDistance; import com.tencent.supersonic.headless.api.pojo.response.S2Term; -import com.tencent.supersonic.headless.core.config.OptimizationConfig; import com.tencent.supersonic.headless.core.chat.knowledge.helper.NatureHelper; import lombok.Data; import lombok.extern.slf4j.Slf4j; -import org.springframework.beans.factory.annotation.Autowired; import org.springframework.stereotype.Service; import java.util.Comparator; @@ -20,9 +18,6 @@ import java.util.stream.Collectors; @Slf4j public class MapperHelper { - @Autowired - private OptimizationConfig optimizationConfig; - public Integer getStepIndex(Map regOffsetToLength, Integer index) { Integer subRegLength = regOffsetToLength.get(index); if (Objects.nonNull(subRegLength)) { diff --git a/headless/core/src/main/java/com/tencent/supersonic/headless/core/chat/parser/SatisfactionChecker.java b/headless/core/src/main/java/com/tencent/supersonic/headless/core/chat/parser/SatisfactionChecker.java index 34b491762..53543f510 100644 --- a/headless/core/src/main/java/com/tencent/supersonic/headless/core/chat/parser/SatisfactionChecker.java +++ b/headless/core/src/main/java/com/tencent/supersonic/headless/core/chat/parser/SatisfactionChecker.java @@ -2,12 +2,16 @@ package com.tencent.supersonic.headless.core.chat.parser; import com.tencent.supersonic.common.util.ContextUtils; import com.tencent.supersonic.headless.api.pojo.SemanticParseInfo; -import com.tencent.supersonic.headless.core.config.OptimizationConfig; +import com.tencent.supersonic.headless.core.config.ParserConfig; import com.tencent.supersonic.headless.core.pojo.QueryContext; import com.tencent.supersonic.headless.core.chat.query.SemanticQuery; import com.tencent.supersonic.headless.core.chat.query.llm.s2sql.LLMSqlQuery; import lombok.extern.slf4j.Slf4j; +import static com.tencent.supersonic.headless.core.config.ParserConfig.PARSER_TEXT_LENGTH_THRESHOLD; +import static com.tencent.supersonic.headless.core.config.ParserConfig.PARSER_TEXT_LENGTH_THRESHOLD_LONG; +import static com.tencent.supersonic.headless.core.config.ParserConfig.PARSER_TEXT_LENGTH_THRESHOLD_SHORT; + /** * This checker can be used by semantic parsers to check if query intent * has already been satisfied by current candidate queries. If so, current @@ -32,12 +36,19 @@ public class SatisfactionChecker { private static boolean checkThreshold(String queryText, SemanticParseInfo semanticParseInfo) { int queryTextLength = queryText.replaceAll(" ", "").length(); double degree = semanticParseInfo.getScore() / queryTextLength; - OptimizationConfig optimizationConfig = ContextUtils.getBean(OptimizationConfig.class); - if (queryTextLength > optimizationConfig.getQueryTextLengthThreshold()) { - if (degree < optimizationConfig.getLongTextThreshold()) { + ParserConfig parserConfig = ContextUtils.getBean(ParserConfig.class); + int textLengthThreshold = + Integer.valueOf(parserConfig.getParameterValue(PARSER_TEXT_LENGTH_THRESHOLD)); + double longTextLengthThreshold = + Double.valueOf(parserConfig.getParameterValue(PARSER_TEXT_LENGTH_THRESHOLD_LONG)); + double shortTextLengthThreshold = + Double.valueOf(parserConfig.getParameterValue(PARSER_TEXT_LENGTH_THRESHOLD_SHORT)); + + if (queryTextLength > textLengthThreshold) { + if (degree < longTextLengthThreshold) { return false; } - } else if (degree < optimizationConfig.getShortTextThreshold()) { + } else if (degree < shortTextLengthThreshold) { return false; } log.info("queryMode:{}, degree:{}, parse info:{}", diff --git a/headless/core/src/main/java/com/tencent/supersonic/headless/core/chat/parser/llm/JavaLLMProxy.java b/headless/core/src/main/java/com/tencent/supersonic/headless/core/chat/parser/llm/JavaLLMProxy.java index 69395432d..88550ed15 100644 --- a/headless/core/src/main/java/com/tencent/supersonic/headless/core/chat/parser/llm/JavaLLMProxy.java +++ b/headless/core/src/main/java/com/tencent/supersonic/headless/core/chat/parser/llm/JavaLLMProxy.java @@ -2,7 +2,6 @@ package com.tencent.supersonic.headless.core.chat.parser.llm; import com.tencent.supersonic.headless.core.chat.query.llm.s2sql.LLMReq; -import com.tencent.supersonic.headless.core.chat.query.llm.s2sql.LLMReq.SqlGenType; import com.tencent.supersonic.headless.core.chat.query.llm.s2sql.LLMResp; import lombok.extern.slf4j.Slf4j; import org.springframework.stereotype.Component; @@ -16,8 +15,7 @@ public class JavaLLMProxy implements LLMProxy { public LLMResp text2sql(LLMReq llmReq) { - SqlGenStrategy sqlGenStrategy = SqlGenStrategyFactory.get( - SqlGenType.getMode(llmReq.getSqlGenerationMode())); + SqlGenStrategy sqlGenStrategy = SqlGenStrategyFactory.get(llmReq.getSqlGenType()); String modelName = llmReq.getSchema().getDataSetName(); LLMResp result = sqlGenStrategy.generate(llmReq); result.setQuery(llmReq.getQueryText()); diff --git a/headless/core/src/main/java/com/tencent/supersonic/headless/core/chat/parser/llm/LLMRequestService.java b/headless/core/src/main/java/com/tencent/supersonic/headless/core/chat/parser/llm/LLMRequestService.java index 7945bb2a6..f8d37313b 100644 --- a/headless/core/src/main/java/com/tencent/supersonic/headless/core/chat/parser/llm/LLMRequestService.java +++ b/headless/core/src/main/java/com/tencent/supersonic/headless/core/chat/parser/llm/LLMRequestService.java @@ -11,8 +11,8 @@ import com.tencent.supersonic.headless.core.chat.parser.SatisfactionChecker; import com.tencent.supersonic.headless.core.chat.query.llm.s2sql.LLMReq; import com.tencent.supersonic.headless.core.chat.query.llm.s2sql.LLMReq.ElementValue; import com.tencent.supersonic.headless.core.chat.query.llm.s2sql.LLMResp; +import com.tencent.supersonic.headless.core.config.ParserConfig; import com.tencent.supersonic.headless.core.config.LLMParserConfig; -import com.tencent.supersonic.headless.core.config.OptimizationConfig; import com.tencent.supersonic.headless.core.pojo.QueryContext; import com.tencent.supersonic.headless.core.utils.ComponentFactory; import com.tencent.supersonic.headless.core.utils.S2SqlDateHelper; @@ -31,14 +31,18 @@ import java.util.Objects; import java.util.Set; import java.util.stream.Collectors; +import static com.tencent.supersonic.headless.core.config.ParserConfig.PARSER_LINKING_VALUE_ENABLE; +import static com.tencent.supersonic.headless.core.config.ParserConfig.PARSER_STRATEGY_TYPE; + @Slf4j @Service public class LLMRequestService { @Autowired private LLMParserConfig llmParserConfig; + @Autowired - private OptimizationConfig optimizationConfig; + private ParserConfig parserConfig; public boolean isSkip(QueryContext queryCtx) { if (!queryCtx.getText2SQLType().enableLLM()) { @@ -86,7 +90,9 @@ public class LLMRequestService { llmReq.setSchema(llmSchema); List linking = new ArrayList<>(); - if (optimizationConfig.isUseLinkingValueSwitch()) { + boolean linkingValueEnabled = Boolean.valueOf(parserConfig.getParameterValue(PARSER_LINKING_VALUE_ENABLE)); + + if (linkingValueEnabled) { linking.addAll(linkingValues); } llmReq.setLinking(linking); @@ -96,7 +102,7 @@ public class LLMRequestService { currentDate = DateUtils.getBeforeDate(0); } llmReq.setCurrentDate(currentDate); - llmReq.setSqlGenerationMode(optimizationConfig.getSqlGenType().getName()); + llmReq.setSqlGenType(LLMReq.SqlGenType.valueOf(parserConfig.getParameterValue(PARSER_STRATEGY_TYPE))); llmReq.setLlmConfig(queryCtx.getLlmConfig()); return llmReq; } diff --git a/headless/core/src/main/java/com/tencent/supersonic/headless/core/chat/parser/llm/OnePassSCSqlGenStrategy.java b/headless/core/src/main/java/com/tencent/supersonic/headless/core/chat/parser/llm/OnePassSCSqlGenStrategy.java index 4e1c79aee..315febb79 100644 --- a/headless/core/src/main/java/com/tencent/supersonic/headless/core/chat/parser/llm/OnePassSCSqlGenStrategy.java +++ b/headless/core/src/main/java/com/tencent/supersonic/headless/core/chat/parser/llm/OnePassSCSqlGenStrategy.java @@ -18,6 +18,11 @@ import java.util.Map; import java.util.concurrent.CopyOnWriteArrayList; import java.util.stream.Collectors; +import static com.tencent.supersonic.headless.core.config.ParserConfig.PARSER_EXEMPLAR_RECALL_NUMBER; +import static com.tencent.supersonic.headless.core.config.ParserConfig.PARSER_FEW_SHOT_NUMBER; +import static com.tencent.supersonic.headless.core.config.ParserConfig.PARSER_SELF_CONSISTENCY_NUMBER; + + @Service @Slf4j public class OnePassSCSqlGenStrategy extends SqlGenStrategy { @@ -27,11 +32,14 @@ public class OnePassSCSqlGenStrategy extends SqlGenStrategy { //1.retriever sqlExamples and generate exampleListPool keyPipelineLog.info("OnePassSCSqlGenStrategy llmReq:{}", llmReq); - List> sqlExamples = exemplarManager.recallExemplars(llmReq.getQueryText(), - optimizationConfig.getText2sqlExampleNum()); + int exemplarRecallNumber = Integer.valueOf(parserConfig.getParameterValue(PARSER_EXEMPLAR_RECALL_NUMBER)); + int fewShotNumber = Integer.valueOf(parserConfig.getParameterValue(PARSER_FEW_SHOT_NUMBER)); + int selfConsistencyNumber = Integer.valueOf(parserConfig.getParameterValue(PARSER_SELF_CONSISTENCY_NUMBER)); + List> sqlExamples = exemplarManager.recallExemplars(llmReq.getQueryText(), + exemplarRecallNumber); List>> exampleListPool = promptGenerator.getExampleCombos(sqlExamples, - optimizationConfig.getText2sqlFewShotsNum(), optimizationConfig.getText2sqlSelfConsistencyNum()); + fewShotNumber, selfConsistencyNumber); //2.generator linking and sql prompt by sqlExamples,and parallel generate response. List linkingSqlPromptPool = promptGenerator.generatePromptPool(llmReq, exampleListPool, true); diff --git a/headless/core/src/main/java/com/tencent/supersonic/headless/core/chat/parser/llm/SqlGenStrategy.java b/headless/core/src/main/java/com/tencent/supersonic/headless/core/chat/parser/llm/SqlGenStrategy.java index c645e4cab..02dd8dbfa 100644 --- a/headless/core/src/main/java/com/tencent/supersonic/headless/core/chat/parser/llm/SqlGenStrategy.java +++ b/headless/core/src/main/java/com/tencent/supersonic/headless/core/chat/parser/llm/SqlGenStrategy.java @@ -3,7 +3,7 @@ package com.tencent.supersonic.headless.core.chat.parser.llm; import com.tencent.supersonic.headless.api.pojo.LLMConfig; import com.tencent.supersonic.headless.core.chat.query.llm.s2sql.LLMReq; import com.tencent.supersonic.headless.core.chat.query.llm.s2sql.LLMResp; -import com.tencent.supersonic.headless.core.config.OptimizationConfig; +import com.tencent.supersonic.headless.core.config.ParserConfig; import com.tencent.supersonic.headless.core.utils.S2ChatModelProvider; import dev.langchain4j.model.chat.ChatLanguageModel; import org.slf4j.Logger; @@ -25,7 +25,7 @@ public abstract class SqlGenStrategy implements InitializingBean { protected ExemplarManager exemplarManager; @Autowired - protected OptimizationConfig optimizationConfig; + protected ParserConfig parserConfig; @Autowired protected PromptGenerator promptGenerator; diff --git a/headless/core/src/main/java/com/tencent/supersonic/headless/core/chat/parser/llm/TwoPassSCSqlGenStrategy.java b/headless/core/src/main/java/com/tencent/supersonic/headless/core/chat/parser/llm/TwoPassSCSqlGenStrategy.java index 9011a8160..91d2616bf 100644 --- a/headless/core/src/main/java/com/tencent/supersonic/headless/core/chat/parser/llm/TwoPassSCSqlGenStrategy.java +++ b/headless/core/src/main/java/com/tencent/supersonic/headless/core/chat/parser/llm/TwoPassSCSqlGenStrategy.java @@ -17,6 +17,10 @@ import java.util.List; import java.util.Map; import java.util.concurrent.CopyOnWriteArrayList; +import static com.tencent.supersonic.headless.core.config.ParserConfig.PARSER_EXEMPLAR_RECALL_NUMBER; +import static com.tencent.supersonic.headless.core.config.ParserConfig.PARSER_FEW_SHOT_NUMBER; +import static com.tencent.supersonic.headless.core.config.ParserConfig.PARSER_SELF_CONSISTENCY_NUMBER; + @Service public class TwoPassSCSqlGenStrategy extends SqlGenStrategy { @@ -24,11 +28,15 @@ public class TwoPassSCSqlGenStrategy extends SqlGenStrategy { public LLMResp generate(LLMReq llmReq) { //1.retriever sqlExamples and generate exampleListPool keyPipelineLog.info("TwoPassSCSqlGenStrategy llmReq:{}", llmReq); - List> sqlExamples = exemplarManager.recallExemplars(llmReq.getQueryText(), - optimizationConfig.getText2sqlExampleNum()); + int exemplarRecallNumber = Integer.valueOf(parserConfig.getParameterValue(PARSER_EXEMPLAR_RECALL_NUMBER)); + int fewShotNumber = Integer.valueOf(parserConfig.getParameterValue(PARSER_FEW_SHOT_NUMBER)); + int selfConsistencyNumber = Integer.valueOf(parserConfig.getParameterValue(PARSER_SELF_CONSISTENCY_NUMBER)); + + List> sqlExamples = exemplarManager.recallExemplars(llmReq.getQueryText(), + exemplarRecallNumber); List>> exampleListPool = promptGenerator.getExampleCombos(sqlExamples, - optimizationConfig.getText2sqlFewShotsNum(), optimizationConfig.getText2sqlSelfConsistencyNum()); + fewShotNumber, selfConsistencyNumber); //2.generator linking prompt,and parallel generate response. List linkingPromptPool = promptGenerator.generatePromptPool(llmReq, exampleListPool, false); diff --git a/headless/core/src/main/java/com/tencent/supersonic/headless/core/chat/query/BaseSemanticQuery.java b/headless/core/src/main/java/com/tencent/supersonic/headless/core/chat/query/BaseSemanticQuery.java index 2d883d490..6b86ad4b0 100644 --- a/headless/core/src/main/java/com/tencent/supersonic/headless/core/chat/query/BaseSemanticQuery.java +++ b/headless/core/src/main/java/com/tencent/supersonic/headless/core/chat/query/BaseSemanticQuery.java @@ -10,7 +10,7 @@ import com.tencent.supersonic.headless.api.pojo.SemanticParseInfo; import com.tencent.supersonic.headless.api.pojo.SemanticSchema; import com.tencent.supersonic.headless.api.pojo.request.QuerySqlReq; import com.tencent.supersonic.headless.api.pojo.request.QueryStructReq; -import com.tencent.supersonic.headless.core.config.OptimizationConfig; +import com.tencent.supersonic.headless.core.config.ParserConfig; import com.tencent.supersonic.headless.core.utils.QueryReqBuilder; import lombok.ToString; import lombok.extern.slf4j.Slf4j; @@ -21,6 +21,8 @@ import java.util.List; import java.util.Map; import java.util.stream.Collectors; +import static com.tencent.supersonic.headless.core.config.ParserConfig.PARSER_S2SQL_ENABLE; + @Slf4j @ToString public abstract class BaseSemanticQuery implements SemanticQuery, Serializable { @@ -73,8 +75,9 @@ public abstract class BaseSemanticQuery implements SemanticQuery, Serializable { } protected void initS2SqlByStruct(SemanticSchema semanticSchema) { - OptimizationConfig optimizationConfig = ContextUtils.getBean(OptimizationConfig.class); - if (!optimizationConfig.isUseS2SqlSwitch()) { + ParserConfig parserConfig = ContextUtils.getBean(ParserConfig.class); + boolean s2sqlEnable = Boolean.valueOf(parserConfig.getParameterValue(PARSER_S2SQL_ENABLE)); + if (!s2sqlEnable) { return; } QueryStructReq queryStructReq = convertQueryStruct(); diff --git a/headless/core/src/main/java/com/tencent/supersonic/headless/core/chat/query/llm/s2sql/LLMReq.java b/headless/core/src/main/java/com/tencent/supersonic/headless/core/chat/query/llm/s2sql/LLMReq.java index 5dfd89c6c..58456ba87 100644 --- a/headless/core/src/main/java/com/tencent/supersonic/headless/core/chat/query/llm/s2sql/LLMReq.java +++ b/headless/core/src/main/java/com/tencent/supersonic/headless/core/chat/query/llm/s2sql/LLMReq.java @@ -22,7 +22,7 @@ public class LLMReq { private String priorExts; - private String sqlGenerationMode; + private SqlGenType sqlGenType; private LLMConfig llmConfig; @@ -82,14 +82,5 @@ public class LLMReq { return name; } - public static SqlGenType getMode(String name) { - for (SqlGenType sqlGenType : SqlGenType.values()) { - if (sqlGenType.name.equals(name)) { - return sqlGenType; - } - } - return null; - } - } } diff --git a/headless/core/src/main/java/com/tencent/supersonic/headless/core/config/MapperConfig.java b/headless/core/src/main/java/com/tencent/supersonic/headless/core/config/MapperConfig.java new file mode 100644 index 000000000..d57ab0aa9 --- /dev/null +++ b/headless/core/src/main/java/com/tencent/supersonic/headless/core/config/MapperConfig.java @@ -0,0 +1,110 @@ +package com.tencent.supersonic.headless.core.config; + +import com.google.common.collect.Lists; +import com.tencent.supersonic.common.pojo.ParameterConfig; +import com.tencent.supersonic.common.pojo.Parameter; +import org.springframework.stereotype.Service; + +import java.util.List; + +@Service("HeadlessMapperConfig") +public class MapperConfig extends ParameterConfig { + + public static final Parameter MAPPER_DETECTION_SIZE = + new Parameter("s2.mapper.detection.size", "8", + "一次探测返回结果个数", + "在每次探测后, 将前后缀匹配的结果合并, 并根据相似度阈值过滤后的结果个数", + "number", "Mapper相关配置"); + + public static final Parameter MAPPER_DETECTION_MAX_SIZE = + new Parameter("s2.mapper.detection.max.size", "20", + "一次探测前后缀匹配结果返回个数", + "单次前后缀匹配返回的结果个数", + "number", "Mapper相关配置"); + + public static final Parameter MAPPER_NAME_THRESHOLD = + new Parameter("s2.mapper.name.threshold", "0.3", + "指标名、维度名文本相似度阈值", + "文本片段和匹配到的指标、维度名计算出来的编辑距离阈值, 若超出该阈值, 则舍弃", + "number", "Mapper相关配置"); + + public static final Parameter MAPPER_NAME_THRESHOLD_MIN = + new Parameter("s2.mapper.name.min.threshold", "0.25", + "指标名、维度名最小文本相似度阈值", + "指标名、维度名相似度阈值在动态调整中的最低值", + "number", "Mapper相关配置"); + + public static final Parameter MAPPER_DIMENSION_VALUE_SIZE = + new Parameter("s2.mapper.value.size", "1", + "指标名、维度名最小文本相似度阈值", + "指标名、维度名相似度阈值在动态调整中的最低值", + "number", "Mapper相关配置"); + + public static final Parameter MAPPER_VALUE_THRESHOLD = + new Parameter("s2.mapper.value.threshold", "0.5", + "维度值文本相似度阈值", + "文本片段和匹配到的维度值计算出来的编辑距离阈值, 若超出该阈值, 则舍弃", + "number", "Mapper相关配置"); + + public static final Parameter MAPPER_VALUE_THRESHOLD_MIN = + new Parameter("s2.mapper.value.min.threshold", "0.3", + "维度值最小文本相似度阈值", + "维度值相似度阈值在动态调整中的最低值", + "number", "Mapper相关配置"); + + public static final Parameter EMBEDDING_MAPPER_MIN = + new Parameter("s2.mapper.embedding.word.min", "4", + "用于向量召回最小的文本长度", + "为提高向量召回效率, 小于该长度的文本不进行向量语义召回", + "number", "Mapper相关配置"); + + public static final Parameter EMBEDDING_MAPPER_MAX = + new Parameter("s2.mapper.embedding.word.max", "5", + "用于向量召回最大的文本长度", + "为提高向量召回效率, 大于该长度的文本不进行向量语义召回", + "number", "Mapper相关配置"); + + public static final Parameter EMBEDDING_MAPPER_BATCH = + new Parameter("s2.mapper.embedding.batch", "50", + "批量向量召回文本请求个数", + "每次进行向量语义召回的原始文本片段个数", + "number", "Mapper相关配置"); + + public static final Parameter EMBEDDING_MAPPER_NUMBER = + new Parameter("s2.mapper.embedding.number", "5", + "批量向量召回文本返回结果个数", + "每个文本进行向量语义召回的文本结果个数", + "number", "Mapper相关配置"); + + public static final Parameter EMBEDDING_MAPPER_THRESHOLD = + new Parameter("s2.mapper.embedding.threshold", "0.99", + "向量召回相似度阈值", + "相似度小于该阈值的则舍弃", + "number", "Mapper相关配置"); + + public static final Parameter EMBEDDING_MAPPER_THRESHOLD_MIN = + new Parameter("s2.mapper.embedding.min.threshold", "0.9", + "向量召回最小相似度阈值", + "向量召回相似度阈值在动态调整中的最低值", + "number", "Mapper相关配置"); + + public static final Parameter EMBEDDING_MAPPER_ROUND_NUMBER = + new Parameter("s2.mapper.embedding.round.number", "10", + "向量召回最小相似度阈值", + "向量召回相似度阈值在动态调整中的最低值", + "number", "Mapper相关配置"); + + @Override + public List getSysParameters() { + return Lists.newArrayList( + MAPPER_DETECTION_SIZE, + MAPPER_DETECTION_MAX_SIZE, + MAPPER_NAME_THRESHOLD, + MAPPER_NAME_THRESHOLD_MIN, + MAPPER_DIMENSION_VALUE_SIZE, + MAPPER_VALUE_THRESHOLD, + MAPPER_VALUE_THRESHOLD_MIN + ); + } + +} diff --git a/headless/core/src/main/java/com/tencent/supersonic/headless/core/config/OptimizationConfig.java b/headless/core/src/main/java/com/tencent/supersonic/headless/core/config/OptimizationConfig.java deleted file mode 100644 index a940ec3d2..000000000 --- a/headless/core/src/main/java/com/tencent/supersonic/headless/core/config/OptimizationConfig.java +++ /dev/null @@ -1,193 +0,0 @@ -package com.tencent.supersonic.headless.core.config; - -import com.tencent.supersonic.common.service.SysParameterService; -import com.tencent.supersonic.headless.core.chat.query.llm.s2sql.LLMReq; -import lombok.Data; -import lombok.extern.slf4j.Slf4j; -import org.apache.commons.lang3.StringUtils; -import org.springframework.beans.factory.annotation.Autowired; -import org.springframework.beans.factory.annotation.Value; -import org.springframework.context.annotation.Configuration; - -@Configuration -@Data -@Slf4j -public class OptimizationConfig { - - @Value("${s2.one.detection.size:8}") - private Integer oneDetectionSize; - - @Value("${s2.one.detection.max.size:20}") - private Integer oneDetectionMaxSize; - - @Value("${s2.one.detection.dimensionValue.size:1}") - private Integer oneDetectionDimensionValueSize; - - @Value("${s2.metric.dimension.min.threshold:0.3}") - private Double metricDimensionMinThresholdConfig; - - @Value("${s2.metric.dimension.threshold:0.3}") - private Double metricDimensionThresholdConfig; - - @Value("${s2.dimension.value.min.threshold:0.2}") - private Double dimensionValueMinThresholdConfig; - - @Value("${s2.dimension.value.threshold:0.5}") - private Double dimensionValueThresholdConfig; - - @Value("${s2.long.text.threshold:0.8}") - private Double longTextThreshold; - - @Value("${s2.short.text.threshold:0.5}") - private Double shortTextThreshold; - - @Value("${s2.query.text.length.threshold:10}") - private Integer queryTextLengthThreshold; - - @Value("${s2.embedding.mapper.word.min:4}") - private int embeddingMapperWordMin; - - @Value("${s2.embedding.mapper.word.max:4}") - private int embeddingMapperWordMax; - - @Value("${s2.embedding.mapper.batch:50}") - private int embeddingMapperBatch; - - @Value("${s2.embedding.mapper.number:5}") - private int embeddingMapperNumber; - - @Value("${s2.embedding.mapper.round.number:10}") - private int embeddingMapperRoundNumber; - - @Value("${s2.embedding.mapper.min.threshold:0.6}") - private Double embeddingMapperMinThreshold; - - @Value("${s2.embedding.mapper.threshold:0.99}") - private Double embeddingMapperThreshold; - - @Value("${s2.parser.linking.value.switch:true}") - private boolean useLinkingValueSwitch; - - @Value("${s2.parser.strategy:TWO_PASS_AUTO_COT_SELF_CONSISTENCY}") - private LLMReq.SqlGenType sqlGenType; - - @Value("${s2.parser.use.switch:true}") - private boolean useS2SqlSwitch; - - @Value("${s2.parser.exemplar-recall.number:15}") - private int text2sqlExampleNum; - - @Value("${s2.parser.few-shot.number:5}") - private int text2sqlFewShotsNum; - - @Value("${s2.parser.self-consistency.number:5}") - private int text2sqlSelfConsistencyNum; - - @Value("${s2.parser.show-count:3}") - private Integer parseShowCount; - - @Autowired - private SysParameterService sysParameterService; - - public Integer getOneDetectionSize() { - return convertValue("s2.one.detection.size", Integer.class, oneDetectionSize); - } - - public Integer getOneDetectionMaxSize() { - return convertValue("s2.one.detection.max.size", Integer.class, oneDetectionMaxSize); - } - - public Double getMetricDimensionMinThresholdConfig() { - return convertValue("s2.metric.dimension.min.threshold", Double.class, metricDimensionMinThresholdConfig); - } - - public Double getMetricDimensionThresholdConfig() { - return convertValue("s2.metric.dimension.threshold", Double.class, metricDimensionThresholdConfig); - } - - public Double getDimensionValueMinThresholdConfig() { - return convertValue("s2.dimension.value.min.threshold", Double.class, dimensionValueMinThresholdConfig); - } - - public Double getDimensionValueThresholdConfig() { - return convertValue("s2.dimension.value.threshold", Double.class, dimensionValueThresholdConfig); - } - - public Double getLongTextThreshold() { - return convertValue("s2.long.text.threshold", Double.class, longTextThreshold); - } - - public Double getShortTextThreshold() { - return convertValue("s2.short.text.threshold", Double.class, shortTextThreshold); - } - - public Integer getQueryTextLengthThreshold() { - return convertValue("s2.query.text.length.threshold", Integer.class, queryTextLengthThreshold); - } - - public Integer getEmbeddingMapperWordMin() { - return convertValue("s2.embedding.mapper.word.min", Integer.class, embeddingMapperWordMin); - } - - public Integer getEmbeddingMapperWordMax() { - return convertValue("s2.embedding.mapper.word.max", Integer.class, embeddingMapperWordMax); - } - - public Integer getEmbeddingMapperBatch() { - return convertValue("s2.embedding.mapper.batch", Integer.class, embeddingMapperBatch); - } - - public Integer getEmbeddingMapperNumber() { - return convertValue("s2.embedding.mapper.number", Integer.class, embeddingMapperNumber); - } - - public Integer getEmbeddingMapperRoundNumber() { - return convertValue("s2.embedding.mapper.round.number", Integer.class, embeddingMapperRoundNumber); - } - - public Double getEmbeddingMapperMinThreshold() { - return convertValue("s2.embedding.mapper.min.threshold", Double.class, embeddingMapperMinThreshold); - } - - public Double getEmbeddingMapperThreshold() { - return convertValue("s2.embedding.mapper.threshold", Double.class, embeddingMapperThreshold); - } - - public boolean isUseS2SqlSwitch() { - return convertValue("s2.parser.use.switch", Boolean.class, useS2SqlSwitch); - } - - public boolean isUseLinkingValueSwitch() { - return convertValue("s2.parser.linking.value.switch", Boolean.class, useLinkingValueSwitch); - } - - public LLMReq.SqlGenType getSqlGenType() { - return convertValue("s2.parser.strategy", LLMReq.SqlGenType.class, sqlGenType); - } - - public Integer getParseShowCount() { - return convertValue("s2.parse.show-count", Integer.class, parseShowCount); - } - - public T convertValue(String paramName, Class targetType, T defaultValue) { - try { - String value = sysParameterService.getSysParameter().getParameterByName(paramName); - if (StringUtils.isBlank(value)) { - return defaultValue; - } - if (targetType == Double.class) { - return targetType.cast(Double.parseDouble(value)); - } else if (targetType == Integer.class) { - return targetType.cast(Integer.parseInt(value)); - } else if (targetType == Boolean.class) { - return targetType.cast(Boolean.parseBoolean(value)); - } else if (targetType == LLMReq.SqlGenType.class) { - return targetType.cast(LLMReq.SqlGenType.valueOf(value)); - } - } catch (Exception e) { - log.error("convertValue", e); - } - return defaultValue; - } - -} diff --git a/headless/core/src/main/java/com/tencent/supersonic/headless/core/config/ParserConfig.java b/headless/core/src/main/java/com/tencent/supersonic/headless/core/config/ParserConfig.java new file mode 100644 index 000000000..5d1d9609d --- /dev/null +++ b/headless/core/src/main/java/com/tencent/supersonic/headless/core/config/ParserConfig.java @@ -0,0 +1,84 @@ +package com.tencent.supersonic.headless.core.config; + +import com.google.common.collect.Lists; +import com.tencent.supersonic.common.pojo.ParameterConfig; +import com.tencent.supersonic.common.pojo.Parameter; +import lombok.extern.slf4j.Slf4j; +import org.springframework.stereotype.Service; + +import java.util.List; + +@Service("HeadlessParserConfig") +@Slf4j +public class ParserConfig extends ParameterConfig { + + public static final Parameter PARSER_STRATEGY_TYPE = + new Parameter("s2.parser.strategy", "ONE_PASS_AUTO_COT_SELF_CONSISTENCY", + "LLM解析生成S2SQL策略", + "ONE_PASS_AUTO_COT_SELF_CONSISTENCY: 通过思维链且投票方式一步生成sql" + + "\nTWO_PASS_AUTO_COT_SELF_CONSISTENCY: 通过思维链且投票方式两步生成sql", + "list", "Parser相关配置", Lists.newArrayList( + "ONE_PASS_AUTO_COT_SELF_CONSISTENCY", "TWO_PASS_AUTO_COT_SELF_CONSISTENCY")); + + public static final Parameter PARSER_LINKING_VALUE_ENABLE = + new Parameter("s2.parser.linking.value.enable", "true", + "是否将Mapper探测识别到的维度值提供给大模型", "为了数据安全考虑, 这里可进行开关选择", + "bool", "Parser相关配置"); + + public static final Parameter PARSER_TEXT_LENGTH_THRESHOLD = + new Parameter("s2.parser.text.length.threshold", "10", + "用户输入文本长短阈值", "文本超过该阈值为长文本", + "number", "Parser相关配置"); + + public static final Parameter PARSER_TEXT_LENGTH_THRESHOLD_SHORT = + new Parameter("s2.parser.text.threshold", "0.5", + "短文本匹配阈值", + "由于请求大模型耗时较长, 因此如果有规则类型的Query得分达到阈值,则跳过大模型的调用," + + "\n如果是短文本, 若query得分/文本长度>该阈值, 则跳过当前parser", + "number", "Parser相关配置"); + + public static final Parameter PARSER_TEXT_LENGTH_THRESHOLD_LONG = + new Parameter("s2.parser.text.threshold", "0.8", + "长文本匹配阈值", "如果是长文本, 若query得分/文本长度>该阈值, 则跳过当前parser", + "number", "Parser相关配置"); + + public static final Parameter PARSER_EXEMPLAR_RECALL_NUMBER = + new Parameter("s2.parser.exemplar-recall.number", "10", + "exemplar召回个数", "", + "number", "Parser相关配置"); + + public static final Parameter PARSER_FEW_SHOT_NUMBER = + new Parameter("s2.parser.few-shot.number", "5", + "few-shot样例个数", "样例越多效果可能越好,但token消耗越大", + "number", "Parser相关配置"); + + public static final Parameter PARSER_SELF_CONSISTENCY_NUMBER = + new Parameter("s2.parser.self-consistency.number", "1", + "self-consistency执行个数", "执行越多效果可能越好,但token消耗越大", + "number", "Parser相关配置"); + + public static final Parameter PARSER_SHOW_COUNT = + new Parameter("s2.parser.show.count", "3", + "解析结果展示个数", "前端展示的解析个数", + "number", "Parser相关配置"); + + public static final Parameter PARSER_S2SQL_ENABLE = + new Parameter("s2.parser.s2sql.switch", "true", + "", "", + "bool", "Parser相关配置"); + + @Override + public List getSysParameters() { + return Lists.newArrayList( + PARSER_STRATEGY_TYPE, + PARSER_LINKING_VALUE_ENABLE, + PARSER_TEXT_LENGTH_THRESHOLD, + PARSER_TEXT_LENGTH_THRESHOLD_SHORT, + PARSER_TEXT_LENGTH_THRESHOLD_LONG, + PARSER_FEW_SHOT_NUMBER, + PARSER_SELF_CONSISTENCY_NUMBER, + PARSER_SHOW_COUNT + ); + } + +} diff --git a/headless/core/src/main/java/com/tencent/supersonic/headless/core/pojo/QueryContext.java b/headless/core/src/main/java/com/tencent/supersonic/headless/core/pojo/QueryContext.java index 5b7c34b9d..6a825843d 100644 --- a/headless/core/src/main/java/com/tencent/supersonic/headless/core/pojo/QueryContext.java +++ b/headless/core/src/main/java/com/tencent/supersonic/headless/core/pojo/QueryContext.java @@ -12,7 +12,7 @@ import com.tencent.supersonic.headless.api.pojo.enums.MapModeEnum; import com.tencent.supersonic.headless.api.pojo.enums.WorkflowState; import com.tencent.supersonic.headless.api.pojo.request.QueryFilters; import com.tencent.supersonic.headless.core.chat.query.SemanticQuery; -import com.tencent.supersonic.headless.core.config.OptimizationConfig; +import com.tencent.supersonic.headless.core.config.ParserConfig; import lombok.AllArgsConstructor; import lombok.Builder; import lombok.Data; @@ -26,6 +26,8 @@ import java.util.Map; import java.util.Set; import java.util.stream.Collectors; +import static com.tencent.supersonic.headless.core.config.ParserConfig.PARSER_SHOW_COUNT; + @Data @Builder @NoArgsConstructor @@ -51,8 +53,8 @@ public class QueryContext { private LLMConfig llmConfig; public List getCandidateQueries() { - OptimizationConfig optimizationConfig = ContextUtils.getBean(OptimizationConfig.class); - Integer parseShowCount = optimizationConfig.getParseShowCount(); + ParserConfig parserConfig = ContextUtils.getBean(ParserConfig.class); + int parseShowCount = Integer.valueOf(parserConfig.getParameterValue(PARSER_SHOW_COUNT)); candidateQueries = candidateQueries.stream() .sorted(Comparator.comparing(semanticQuery -> semanticQuery.getParseInfo().getScore(), Comparator.reverseOrder())) diff --git a/launchers/standalone/src/main/java/com/tencent/supersonic/demo/S2BaseDemo.java b/launchers/standalone/src/main/java/com/tencent/supersonic/demo/S2BaseDemo.java index bd6ae2143..af0d7a773 100644 --- a/launchers/standalone/src/main/java/com/tencent/supersonic/demo/S2BaseDemo.java +++ b/launchers/standalone/src/main/java/com/tencent/supersonic/demo/S2BaseDemo.java @@ -9,7 +9,7 @@ import com.tencent.supersonic.chat.server.service.AgentService; import com.tencent.supersonic.chat.server.service.ChatManageService; import com.tencent.supersonic.chat.server.service.ChatService; import com.tencent.supersonic.chat.server.service.PluginService; -import com.tencent.supersonic.common.service.SysParameterService; +import com.tencent.supersonic.common.service.SystemConfigService; import com.tencent.supersonic.headless.api.pojo.DataSetModelConfig; import com.tencent.supersonic.headless.api.pojo.DrillDownDimension; import com.tencent.supersonic.headless.api.pojo.RelateDimension; @@ -84,7 +84,7 @@ public abstract class S2BaseDemo implements CommandLineRunner { @Autowired protected AgentService agentService; @Autowired - protected SysParameterService sysParameterService; + protected SystemConfigService sysParameterService; @Autowired protected CanvasService canvasService; @Autowired diff --git a/launchers/standalone/src/main/java/com/tencent/supersonic/demo/S2VisitsDemo.java b/launchers/standalone/src/main/java/com/tencent/supersonic/demo/S2VisitsDemo.java index f3cf62340..5ea160c85 100644 --- a/launchers/standalone/src/main/java/com/tencent/supersonic/demo/S2VisitsDemo.java +++ b/launchers/standalone/src/main/java/com/tencent/supersonic/demo/S2VisitsDemo.java @@ -16,7 +16,7 @@ import com.tencent.supersonic.chat.server.plugin.PluginParseConfig; import com.tencent.supersonic.chat.server.plugin.build.WebBase; import com.tencent.supersonic.common.pojo.JoinCondition; import com.tencent.supersonic.common.pojo.ModelRela; -import com.tencent.supersonic.common.pojo.SysParameter; +import com.tencent.supersonic.common.pojo.SystemConfig; import com.tencent.supersonic.common.pojo.enums.AggOperatorEnum; import com.tencent.supersonic.common.pojo.enums.AggregateTypeEnum; import com.tencent.supersonic.common.pojo.enums.FilterOperatorEnum; @@ -150,7 +150,7 @@ public class S2VisitsDemo extends S2BaseDemo { } public void addSysParameter() { - SysParameter sysParameter = new SysParameter(); + SystemConfig sysParameter = new SystemConfig(); sysParameter.setId(1); sysParameter.init(); sysParameterService.save(sysParameter); diff --git a/launchers/standalone/src/test/resources/application-local.yaml b/launchers/standalone/src/test/resources/application-local.yaml index 5d4deb3a6..8aeb28bdc 100644 --- a/launchers/standalone/src/test/resources/application-local.yaml +++ b/launchers/standalone/src/test/resources/application-local.yaml @@ -37,7 +37,7 @@ logging: s2: parser: - strategy: TWO_PASS_AUTO_COT_SELF_CONSISTENCY + strategy: ONE_PASS_AUTO_COT_SELF_CONSISTENCY exemplar-recall: number: 5 few-shot: