From d6a386ad0317d27dd189cee5bbc1e8faab3879b1 Mon Sep 17 00:00:00 2001 From: lexluo09 <39718951+lexluo09@users.noreply.github.com> Date: Fri, 17 Nov 2023 18:11:07 +0800 Subject: [PATCH] [improvement](project) Parameters are uniformly obtained from system settings, removing optimization.properties, and modifying SysParameter parameters (#399) --- .../chat/config/OptimizationConfig.java | 119 +++++++++++++++--- .../supersonic/common/pojo/SysParameter.java | 83 +++++++----- .../service/impl/SysParameterServiceImpl.java | 10 +- .../main/resources/optimization.properties | 10 -- .../main/resources/optimization.properties | 11 -- 5 files changed, 165 insertions(+), 68 deletions(-) delete mode 100644 launchers/chat/src/main/resources/optimization.properties delete mode 100644 launchers/standalone/src/main/resources/optimization.properties diff --git a/chat/core/src/main/java/com/tencent/supersonic/chat/config/OptimizationConfig.java b/chat/core/src/main/java/com/tencent/supersonic/chat/config/OptimizationConfig.java index fcdd2af65..630e9ec49 100644 --- a/chat/core/src/main/java/com/tencent/supersonic/chat/config/OptimizationConfig.java +++ b/chat/core/src/main/java/com/tencent/supersonic/chat/config/OptimizationConfig.java @@ -1,41 +1,41 @@ package com.tencent.supersonic.chat.config; +import com.tencent.supersonic.common.service.SysParameterService; import lombok.Data; +import lombok.extern.slf4j.Slf4j; +import org.apache.commons.lang3.StringUtils; +import org.springframework.beans.factory.annotation.Autowired; import org.springframework.beans.factory.annotation.Value; import org.springframework.context.annotation.Configuration; -import org.springframework.context.annotation.PropertySource; @Configuration @Data -@PropertySource("classpath:optimization.properties") +@Slf4j public class OptimizationConfig { - @Value("${one.detection.size}") + @Value("${one.detection.size:8}") private Integer oneDetectionSize; - @Value("${one.detection.max.size}") + + @Value("${one.detection.max.size:20}") private Integer oneDetectionMaxSize; - @Value("${metric.dimension.min.threshold}") + @Value("${metric.dimension.min.threshold:0.3}") private Double metricDimensionMinThresholdConfig; - @Value("${metric.dimension.threshold}") + @Value("${metric.dimension.threshold:0.3}") private Double metricDimensionThresholdConfig; - @Value("${dimension.value.threshold}") + @Value("${dimension.value.threshold:0.5}") private Double dimensionValueThresholdConfig; - @Value("${long.text.threshold}") + @Value("${long.text.threshold:0.8}") private Double longTextThreshold; - @Value("${short.text.threshold}") + @Value("${short.text.threshold:0.5}") private Double shortTextThreshold; - @Value("${query.text.length.threshold}") + @Value("${query.text.length.threshold:10}") private Integer queryTextLengthThreshold; - - @Value("${use.s2SQL.switch:false}") - private boolean useS2SqlSwitch; - @Value("${embedding.mapper.word.min:4}") private int embeddingMapperWordMin; @@ -54,6 +54,95 @@ public class OptimizationConfig { @Value("${embedding.mapper.distance.threshold:0.58}") private Double embeddingMapperDistanceThreshold; - @Value("${use.linking.value.switch:true}") + @Value("${s2SQL.linking.value.switch:true}") private boolean useLinkingValueSwitch; + + @Value("${s2SQL.use.switch:true}") + private boolean useS2SqlSwitch; + @Autowired + private SysParameterService sysParameterService; + + public Integer getOneDetectionSize() { + return convertValue("one.detection.size", Integer.class, oneDetectionSize); + } + + public Integer getOneDetectionMaxSize() { + return convertValue("one.detection.max.size", Integer.class, oneDetectionMaxSize); + } + + public Double getMetricDimensionMinThresholdConfig() { + return convertValue("metric.dimension.min.threshold", Double.class, metricDimensionMinThresholdConfig); + } + + public Double getMetricDimensionThresholdConfig() { + return convertValue("metric.dimension.threshold", Double.class, metricDimensionThresholdConfig); + } + + public Double getDimensionValueThresholdConfig() { + return convertValue("dimension.value.threshold", Double.class, dimensionValueThresholdConfig); + } + + public Double getLongTextThreshold() { + return convertValue("long.text.threshold", Double.class, longTextThreshold); + } + + public Double getShortTextThreshold() { + return convertValue("short.text.threshold", Double.class, shortTextThreshold); + } + + public Integer getQueryTextLengthThreshold() { + return convertValue("query.text.length.threshold", Integer.class, queryTextLengthThreshold); + } + + public boolean isUseS2SqlSwitch() { + return convertValue("use.s2SQL.switch", Boolean.class, useS2SqlSwitch); + } + + public Integer getEmbeddingMapperWordMin() { + return convertValue("embedding.mapper.word.min", Integer.class, embeddingMapperWordMin); + } + + public Integer getEmbeddingMapperWordMax() { + return convertValue("embedding.mapper.word.max", Integer.class, embeddingMapperWordMax); + } + + public Integer getEmbeddingMapperBatch() { + return convertValue("embedding.mapper.batch", Integer.class, embeddingMapperBatch); + } + + public Integer getEmbeddingMapperNumber() { + return convertValue("embedding.mapper.number", Integer.class, embeddingMapperNumber); + } + + public Integer getEmbeddingMapperRoundNumber() { + return convertValue("embedding.mapper.round.number", Integer.class, embeddingMapperRoundNumber); + } + + public Double getEmbeddingMapperDistanceThreshold() { + return convertValue("embedding.mapper.distance.threshold", Double.class, embeddingMapperDistanceThreshold); + } + + public boolean isUseLinkingValueSwitch() { + return convertValue("s2SQL.linking.value.switch", Boolean.class, useLinkingValueSwitch); + } + + public T convertValue(String paramName, Class targetType, T defaultValue) { + try { + String value = sysParameterService.getSysParameter().getParameterByName(paramName); + if (StringUtils.isBlank(value)) { + return defaultValue; + } + if (targetType == Double.class) { + return targetType.cast(Double.parseDouble(value)); + } else if (targetType == Integer.class) { + return targetType.cast(Integer.parseInt(value)); + } else if (targetType == Boolean.class) { + return targetType.cast(Boolean.parseBoolean(value)); + } + } catch (Exception e) { + log.error("convertValue", e); + } + return defaultValue; + } + } diff --git a/common/src/main/java/com/tencent/supersonic/common/pojo/SysParameter.java b/common/src/main/java/com/tencent/supersonic/common/pojo/SysParameter.java index 02f40143e..6fcdc9544 100644 --- a/common/src/main/java/com/tencent/supersonic/common/pojo/SysParameter.java +++ b/common/src/main/java/com/tencent/supersonic/common/pojo/SysParameter.java @@ -1,11 +1,14 @@ package com.tencent.supersonic.common.pojo; import com.google.common.collect.Lists; +import java.util.Arrays; +import java.util.List; +import java.util.Map; +import java.util.stream.Collectors; import lombok.Data; import org.apache.commons.lang3.StringUtils; import org.springframework.util.CollectionUtils; -import java.util.Arrays; -import java.util.List; +import retrofit2.http.HEAD; @Data public class SysParameter { @@ -23,6 +26,15 @@ public class SysParameter { return StringUtils.join(admins, ","); } + public String getParameterByName(String name) { + if (StringUtils.isBlank(name)) { + return ""; + } + Map nameToValue = parameters.stream() + .collect(Collectors.toMap(a -> a.getName(), a -> a.getValue(), (k1, k2) -> k1)); + return nameToValue.get(name); + } + public void setAdminList(String admin) { if (StringUtils.isNotBlank(admin)) { admins = Arrays.asList(admin.split(",")); @@ -34,40 +46,53 @@ public class SysParameter { public void init() { parameters = Lists.newArrayList(); admins = Lists.newArrayList("admin"); - Parameter parameter = new Parameter("llm.model.name", "gpt4", - "模型名称", "list", "大语言模型相关配置"); - parameter.setCandidateValues(Lists.newArrayList("gpt3.5", "gpt3.5-16k")); - parameters.add(parameter); - parameters.add(new Parameter("llm.api.key", "sk-secret", - "模型密钥", "string", "大语言模型相关配置")); + //llm config + parameters.add(new Parameter("llm.model.name", "gpt4", + "模型名称(大语言模型相关配置)", "string", "大语言模型相关配置")); + parameters.add(new Parameter("llm.api.key", "sk-afdasdasd", + "模型密钥(大语言模型相关配置)", "string", "大语言模型相关配置")); + parameters.add(new Parameter("llm.temperature", "0.0", + "温度值", "number", "大语言模型相关配置")); + + //detect config parameters.add(new Parameter("one.detection.size", "8", - "一次探测个数", "number", "[mapper]hanlp相关配置")); + "一次探测个数(hanlp相关配置)", "number", "hanlp相关配置")); parameters.add(new Parameter("one.detection.max.size", "20", - "一次探测最大个数", "number", "[mapper]hanlp相关配置")); + "一次探测最大个数(hanlp相关配置)", "number", "hanlp相关配置")); + + //mapper config parameters.add(new Parameter("metric.dimension.min.threshold", "0.3", - "指标名、维度名最小文本相似度", "number", "[mapper]模糊匹配相关配置")); + "指标名、维度名最小文本相似度(mapper模糊匹配相关配置)", "number", "mapper模糊匹配相关配置")); parameters.add(new Parameter("metric.dimension.threshold", "0.3", - "指标名、维度名文本相似度", "number", "[mapper]模糊匹配相关配置")); + "指标名、维度名文本相似度(mapper模糊匹配相关配置)", "number", "mapper模糊匹配相关配置")); parameters.add(new Parameter("dimension.value.threshold", "0.5", - "维度值最小文本相似度", "number", "[mapper]模糊匹配相关配置")); + "维度值最小文本相似度(mapper模糊匹配相关配置)", "number", "mapper模糊匹配相关配置")); + + //skip config + parameters.add(new Parameter("query.text.length.threshold", "10", + "文本长短阈值(是否跳过当前parser相关配置)", "number", "是否跳过当前parser相关配置")); + parameters.add(new Parameter("short.text.threshold", "5", + "短文本匹配阈值(是否跳过当前parser相关配置)", "number", "是否跳过当前parser相关配置")); + parameters.add(new Parameter("long.text.threshold", "0.8", + "长文本匹配阈值(是否跳过当前parser相关配置)", "number", "是否跳过当前parser相关配置")); + + //embedding mapper config parameters.add(new Parameter("embedding.mapper.word.min", - "0.3", "用于向量召回最小的文本长度", "number", "[mapper]向量召回相关配置")); - parameters.add(new Parameter("embedding.mapper.word.max", "0.3", - "用于向量召回最大的文本长度", "number", "[mapper]向量召回相关配置")); - parameters.add(new Parameter("embedding.mapper.batch", "0.3", - "批量向量召回文本请求个数", "number", "[mapper]向量召回相关配置")); - parameters.add(new Parameter("embedding.mapper.number", "0.3", - "批量向量召回文本返回结果个数", "number", "[mapper]向量召回相关配置")); + "4", "用于向量召回最小的文本长度(向量召回mapper相关配置)", "number", "向量召回mapper相关配置")); + parameters.add(new Parameter("embedding.mapper.word.max", "5", + "用于向量召回最大的文本长度(向量召回mapper相关配置)", "number", "向量召回mapper相关配置")); + parameters.add(new Parameter("embedding.mapper.batch", "50", + "批量向量召回文本请求个数(向量召回mapper相关配置)", "number", "向量召回mapper相关配置")); + parameters.add(new Parameter("embedding.mapper.number", "5", + "批量向量召回文本返回结果个数(向量召回mapper相关配置)", "number", "向量召回mapper相关配置")); parameters.add(new Parameter("embedding.mapper.distance.threshold", - "0.3", "Mapper阶段向量召回相似度阈值", "number", "[mapper]向量召回相关配置")); - parameters.add(new Parameter("query.text.length.threshold", "0.5", - "文本长短阈值", "number", "[parser]是否跳过当前parser相关配置")); - parameters.add(new Parameter("short.text.threshold", "0.5", - "短文本匹配阈值", "number", "[parser]是否跳过当前parser相关配置")); - parameters.add(new Parameter("long.text.threshold", "0.5", - "长文本匹配阈值", "number", "[parser]是否跳过当前parser相关配置")); - parameters.add(new Parameter("use.s2SQL.switch", "true", - "是否打开S2SQL转换开关", "bool", "S2SQL相关配置")); + "0.58", "Mapper阶段向量召回相似度阈值(向量召回mapper相关配置)", "number", "向量召回mapper相关配置")); + + //s2SQL config + parameters.add(new Parameter("s2SQL.generation", "2-steps", + "S2SQL生成方式", "string", "S2SQL相关配置")); + parameters.add(new Parameter("s2SQL.linking.value.switch", "true", + "是否将linkingValues提供给大模型", "bool", "S2SQL相关配置")); } } diff --git a/common/src/main/java/com/tencent/supersonic/common/service/impl/SysParameterServiceImpl.java b/common/src/main/java/com/tencent/supersonic/common/service/impl/SysParameterServiceImpl.java index d675d6248..dd87a0276 100644 --- a/common/src/main/java/com/tencent/supersonic/common/service/impl/SysParameterServiceImpl.java +++ b/common/src/main/java/com/tencent/supersonic/common/service/impl/SysParameterServiceImpl.java @@ -2,13 +2,16 @@ package com.tencent.supersonic.common.service.impl; import com.alibaba.fastjson.JSONObject; import com.baomidou.mybatisplus.extension.service.impl.ServiceImpl; -import com.tencent.supersonic.common.pojo.SysParameter; +import com.fasterxml.jackson.core.type.TypeReference; import com.tencent.supersonic.common.persistence.dataobject.SysParameterDO; import com.tencent.supersonic.common.persistence.mapper.SysParameterMapper; +import com.tencent.supersonic.common.pojo.Parameter; +import com.tencent.supersonic.common.pojo.SysParameter; import com.tencent.supersonic.common.service.SysParameterService; +import com.tencent.supersonic.common.util.JsonUtil; +import java.util.List; import org.springframework.stereotype.Service; import org.springframework.util.CollectionUtils; -import java.util.List; @Service public class SysParameterServiceImpl @@ -36,7 +39,8 @@ public class SysParameterServiceImpl private SysParameter convert(SysParameterDO sysParameterDO) { SysParameter sysParameter = new SysParameter(); sysParameter.setId(sysParameterDO.getId()); - sysParameter.setParameters(JSONObject.parseObject(sysParameterDO.getParameters(), List.class)); + List parameters = JsonUtil.toObject(sysParameterDO.getParameters(), new TypeReference>() {}); + sysParameter.setParameters(parameters); sysParameter.setAdminList(sysParameterDO.getAdmin()); return sysParameter; } diff --git a/launchers/chat/src/main/resources/optimization.properties b/launchers/chat/src/main/resources/optimization.properties deleted file mode 100644 index bd4e14d4e..000000000 --- a/launchers/chat/src/main/resources/optimization.properties +++ /dev/null @@ -1,10 +0,0 @@ -one.detection.size=8 -one.detection.max.size=20 -metric.dimension.min.threshold=0.3 -metric.dimension.threshold=0.3 -dimension.value.threshold=0.5 -function.bonus.threshold=201 -long.text.threshold=0.8 -short.text.threshold=0.5 -query.text.length.threshold=10 -candidate.threshold=0.2 diff --git a/launchers/standalone/src/main/resources/optimization.properties b/launchers/standalone/src/main/resources/optimization.properties deleted file mode 100644 index 628025850..000000000 --- a/launchers/standalone/src/main/resources/optimization.properties +++ /dev/null @@ -1,11 +0,0 @@ -one.detection.size=8 -one.detection.max.size=20 -metric.dimension.min.threshold=0.3 -metric.dimension.threshold=0.3 -dimension.value.threshold=0.5 -function.bonus.threshold=201 -long.text.threshold=0.8 -short.text.threshold=0.5 -query.text.length.threshold=10 -candidate.threshold=0.2 -use.s2SQL.switch=true \ No newline at end of file