From 7e6639df8323fa048d9ac1f048fcc9c1a5f2381f Mon Sep 17 00:00:00 2001 From: guilinlewis <185641548@qq.com> Date: Mon, 23 Jun 2025 09:47:48 +0800 Subject: [PATCH] =?UTF-8?q?(improvement)(common|headless|chat|auth)=20?= =?UTF-8?q?=E9=89=B4=E6=9D=83=E4=BC=98=E5=8C=96=E4=B8=8E=E5=8F=AC=E5=9B=9E?= =?UTF-8?q?=E4=BC=98=E5=8C=96?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit 1 修复生成的用户token 一生成就失效的问题 2 如果用户设置的token ,需校验是否数据库存在,因为用户可设置一年的token 有泄露风险 3 结果解析优化, 去除不可以解析的情况,解析问题需要改写后的问, 4 召回样例,用相似度,保住至少有一个样例是高相似度的 5 数据集召回,填加完全匹配格式筛选逻辑 --- .../adaptor/DefaultUserAdaptor.java | 3 +- .../authentication/utils/TokenService.java | 21 ++++++++++ .../execute/DataInterpretProcessor.java | 13 ++++++- .../com/hankcs/hanlp/LoadRemoveService.java | 2 +- .../common/pojo/Text2SQLExemplar.java | 2 + .../service/impl/ExemplarServiceImpl.java | 5 ++- .../headless/api/pojo/SemanticParseInfo.java | 18 +++++++-- .../chat/parser/llm/PromptHelper.java | 38 ++++++++++++++----- 8 files changed, 84 insertions(+), 18 deletions(-) diff --git a/auth/authentication/src/main/java/com/tencent/supersonic/auth/authentication/adaptor/DefaultUserAdaptor.java b/auth/authentication/src/main/java/com/tencent/supersonic/auth/authentication/adaptor/DefaultUserAdaptor.java index 8e24db424..bd950736a 100644 --- a/auth/authentication/src/main/java/com/tencent/supersonic/auth/authentication/adaptor/DefaultUserAdaptor.java +++ b/auth/authentication/src/main/java/com/tencent/supersonic/auth/authentication/adaptor/DefaultUserAdaptor.java @@ -222,8 +222,9 @@ public class DefaultUserAdaptor implements UserAdaptor { new UserWithPassword(userDO.getId(), userDO.getName(), userDO.getDisplayName(), userDO.getEmail(), userDO.getPassword(), userDO.getIsAdmin()); + // 使用令牌名称作为生成key ,这样可以区分正常请求和api 请求,api 的令牌失效时间很长,需考虑令牌泄露的情况 String token = - tokenService.generateToken(UserWithPassword.convert(userWithPassword), expireTime); + tokenService.generateToken(UserWithPassword.convert(userWithPassword),"SysDbToken:"+name, (new Date().getTime() + expireTime)); UserTokenDO userTokenDO = saveUserToken(name, userName, token, expireTime); return convertUserToken(userTokenDO); } diff --git a/auth/authentication/src/main/java/com/tencent/supersonic/auth/authentication/utils/TokenService.java b/auth/authentication/src/main/java/com/tencent/supersonic/auth/authentication/utils/TokenService.java index a8b249602..dae100f11 100644 --- a/auth/authentication/src/main/java/com/tencent/supersonic/auth/authentication/utils/TokenService.java +++ b/auth/authentication/src/main/java/com/tencent/supersonic/auth/authentication/utils/TokenService.java @@ -6,7 +6,10 @@ import javax.crypto.spec.SecretKeySpec; import com.tencent.supersonic.auth.api.authentication.config.AuthenticationConfig; import com.tencent.supersonic.auth.api.authentication.pojo.UserWithPassword; +import com.tencent.supersonic.auth.authentication.persistence.dataobject.UserTokenDO; +import com.tencent.supersonic.auth.authentication.persistence.repository.UserRepository; import com.tencent.supersonic.common.pojo.exception.AccessException; +import com.tencent.supersonic.common.util.ContextUtils; import io.jsonwebtoken.Claims; import io.jsonwebtoken.Jwts; import io.jsonwebtoken.SignatureAlgorithm; @@ -71,6 +74,7 @@ public class TokenService { return generateToken(UserWithPassword.convert(appUser), request); } + public Optional getClaims(HttpServletRequest request) { String token = request.getHeader(authenticationConfig.getTokenHttpHeaderKey()); String appKey = getAppKey(request); @@ -90,6 +94,13 @@ public class TokenService { public Optional getClaims(String token, String appKey) { try { + if(StringUtils.isNotBlank(appKey)&&appKey.startsWith("SysDbToken:")) {// 如果是配置的长期令牌,需校验数据库是否存在该配置 + UserRepository userRepository = ContextUtils.getBean(UserRepository.class); + UserTokenDO dbToken= userRepository.getUserTokenByName(appKey.substring("SysDbToken:".length())); + if(dbToken==null||!dbToken.getToken().equals(token.replace("Bearer ",""))) { + throw new AccessException("Token does not exist :" + appKey); + } + } String tokenSecret = getTokenSecret(appKey); Claims claims = Jwts.parser().setSigningKey(tokenSecret.getBytes(StandardCharsets.UTF_8)) @@ -122,6 +133,16 @@ public class TokenService { Map appKeyToSecretMap = authenticationConfig.getAppKeyToSecretMap(); String secret = appKeyToSecretMap.get(appKey); if (StringUtils.isBlank(secret)) { + if(StringUtils.isNotBlank(appKey)&&appKey.startsWith("SysDbToken:")) { // 是配置的长期令牌 + String realAppKey=appKey.substring("SysDbToken:".length()); + String tmp = "WIaO9YRRVt+7QtpPvyWsARFngnEcbaKBk783uGFwMrbJBaochsqCH62L4Kijcb0sZCYoSsiKGV/zPml5MnZ3uQ=="; + if(tmp.length()<=realAppKey.length()) { + return realAppKey; + } + else{ + return realAppKey+tmp.substring(realAppKey.length()); + } + } throw new AccessException("get secret from appKey failed :" + appKey); } return secret; diff --git a/chat/server/src/main/java/com/tencent/supersonic/chat/server/processor/execute/DataInterpretProcessor.java b/chat/server/src/main/java/com/tencent/supersonic/chat/server/processor/execute/DataInterpretProcessor.java index fc0f12e99..28b3e5f72 100644 --- a/chat/server/src/main/java/com/tencent/supersonic/chat/server/processor/execute/DataInterpretProcessor.java +++ b/chat/server/src/main/java/com/tencent/supersonic/chat/server/processor/execute/DataInterpretProcessor.java @@ -47,7 +47,8 @@ public class DataInterpretProcessor implements ExecuteResultProcessor { Agent agent = executeContext.getAgent(); ChatApp chatApp = agent.getChatAppConfig().get(APP_KEY); return Objects.nonNull(chatApp) && chatApp.isEnable() - && StringUtils.isNotBlank(executeContext.getResponse().getTextResult()); // 如果都没结果,则无法处理,直接跳过 + && StringUtils.isNotBlank(executeContext.getResponse().getTextResult()) // 如果都没结果,则无法处理 + && StringUtils.isBlank(executeContext.getResponse().getTextSummary()); // 如果已经有汇总的结果了,无法再次处理 } @Override @@ -57,7 +58,15 @@ public class DataInterpretProcessor implements ExecuteResultProcessor { ChatApp chatApp = agent.getChatAppConfig().get(APP_KEY); Map variable = new HashMap<>(); - variable.put("question", executeContext.getRequest().getQueryText()); + String question = executeContext.getResponse().getTextResult();// 结果解析应该用改写的问题,因为改写的内容信息量更大 + if(executeContext.getParseInfo().getProperties()!=null&& + executeContext.getParseInfo().getProperties().containsKey("CONTEXT")){ + Map context = (Map) executeContext.getParseInfo().getProperties().get("CONTEXT"); + if(context.get("queryText")!=null&&"".equals(context.get("queryText"))){ + question = context.get("queryText").toString(); + } + } + variable.put("question", question); variable.put("data", queryResult.getTextResult()); Prompt prompt = PromptTemplate.from(chatApp.getPrompt()).apply(variable); diff --git a/common/src/main/java/com/hankcs/hanlp/LoadRemoveService.java b/common/src/main/java/com/hankcs/hanlp/LoadRemoveService.java index 3472a6277..aea4d2421 100644 --- a/common/src/main/java/com/hankcs/hanlp/LoadRemoveService.java +++ b/common/src/main/java/com/hankcs/hanlp/LoadRemoveService.java @@ -21,7 +21,7 @@ public class LoadRemoveService { List resultList = new ArrayList<>(value); if (!CollectionUtils.isEmpty(modelIdOrDataSetIds)) { resultList.removeIf(nature -> { - if (Objects.isNull(nature)) { + if (Objects.isNull(nature)||!nature.startsWith("_")) { // 系统的字典是以 _ 开头的, 过滤因引用外部字典导致的异常 return false; } Long id = getId(nature); diff --git a/common/src/main/java/com/tencent/supersonic/common/pojo/Text2SQLExemplar.java b/common/src/main/java/com/tencent/supersonic/common/pojo/Text2SQLExemplar.java index d878c13c2..c4785b3a1 100644 --- a/common/src/main/java/com/tencent/supersonic/common/pojo/Text2SQLExemplar.java +++ b/common/src/main/java/com/tencent/supersonic/common/pojo/Text2SQLExemplar.java @@ -22,4 +22,6 @@ public class Text2SQLExemplar implements Serializable { private String dbSchema; private String sql; + + protected double similarity; // 传递相似度,可以作为样本筛选的依据 } diff --git a/common/src/main/java/com/tencent/supersonic/common/service/impl/ExemplarServiceImpl.java b/common/src/main/java/com/tencent/supersonic/common/service/impl/ExemplarServiceImpl.java index 8852375e0..a25b6d08a 100644 --- a/common/src/main/java/com/tencent/supersonic/common/service/impl/ExemplarServiceImpl.java +++ b/common/src/main/java/com/tencent/supersonic/common/service/impl/ExemplarServiceImpl.java @@ -72,7 +72,10 @@ public class ExemplarServiceImpl implements ExemplarService, CommandLineRunner { embeddingService.retrieveQuery(collection, retrieveQuery, num); results.forEach(ret -> { ret.getRetrieval().forEach(r -> { - exemplars.add(JsonUtil.mapToObject(r.getMetadata(), Text2SQLExemplar.class)); + Text2SQLExemplar tmp = //传递相似度,可以作为样本筛选的依据 + JsonUtil.mapToObject(r.getMetadata(), Text2SQLExemplar.class); + tmp.setSimilarity(r.getSimilarity()); + exemplars.add(tmp); }); }); diff --git a/headless/api/src/main/java/com/tencent/supersonic/headless/api/pojo/SemanticParseInfo.java b/headless/api/src/main/java/com/tencent/supersonic/headless/api/pojo/SemanticParseInfo.java index d7c9df4d8..6c10e824b 100644 --- a/headless/api/src/main/java/com/tencent/supersonic/headless/api/pojo/SemanticParseInfo.java +++ b/headless/api/src/main/java/com/tencent/supersonic/headless/api/pojo/SemanticParseInfo.java @@ -18,6 +18,7 @@ import java.util.List; import java.util.Map; import java.util.Objects; import java.util.Set; +import java.util.stream.Collectors; import static com.tencent.supersonic.common.pojo.Constants.DEFAULT_DETAIL_LIMIT; import static com.tencent.supersonic.common.pojo.Constants.DEFAULT_METRIC_LIMIT; @@ -65,12 +66,23 @@ public class SemanticParseInfo implements Serializable { DataSetMatchResult mr2 = getDataSetMatchResult(o2.getElementMatches()); double difference = mr1.getMaxDatesetSimilarity() - mr2.getMaxDatesetSimilarity(); - if (difference == 0) { + if (Math.abs(difference) < 0.0005) { // 看完全匹配的个数,实践证明,可以用户输入规范后,该逻辑具有优势 + if (!o1.getDataSetId().equals(o2.getDataSetId())) { + List elementMatches1 = o1.getElementMatches().stream() + .filter(e -> e.getSimilarity() == 1).collect(Collectors.toList()); + List elementMatches2 = o2.getElementMatches().stream() + .filter(e -> e.getSimilarity() == 1).collect(Collectors.toList()); + if (elementMatches1.size() > elementMatches2.size()) { + return -1; + } else if (elementMatches1.size() < elementMatches2.size()) { + return 1; + } + } difference = mr1.getMaxMetricSimilarity() - mr2.getMaxMetricSimilarity(); - if (difference == 0) { + if (Math.abs(difference) < 0.0005) { difference = mr1.getTotalSimilarity() - mr2.getTotalSimilarity(); } - if (difference == 0) { + if (Math.abs(difference) < 0.0005) { difference = mr1.getMaxMetricUseCnt() - mr2.getMaxMetricUseCnt(); } } diff --git a/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/parser/llm/PromptHelper.java b/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/parser/llm/PromptHelper.java index a319b8491..f438e8a43 100644 --- a/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/parser/llm/PromptHelper.java +++ b/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/parser/llm/PromptHelper.java @@ -14,10 +14,8 @@ import org.springframework.beans.factory.annotation.Autowired; import org.springframework.stereotype.Component; import org.springframework.util.CollectionUtils; -import java.util.ArrayList; -import java.util.Collections; -import java.util.List; -import java.util.Objects; +import java.util.*; +import java.util.stream.Collectors; import static com.tencent.supersonic.headless.chat.parser.ParserConfig.*; @@ -51,13 +49,33 @@ public class PromptHelper { // use random collection of exemplars for each self-consistency inference for (int i = 0; i < selfConsistencyNumber; i++) { List shuffledList = new ArrayList<>(exemplars); - // only shuffle the exemplars from config - List subList = - shuffledList.subList(llmReq.getDynamicExemplars().size(), shuffledList.size()); - Collections.shuffle(subList); - results.add(shuffledList.subList(0, Math.min(shuffledList.size(), fewShotNumber))); + List same = shuffledList.stream() // 相似度极高的话,先找出来 + .filter(e -> e.getSimilarity() > 0.989).collect(Collectors.toList()); + List noSame = shuffledList.stream() + .filter(e -> e.getSimilarity() <= 0.989).collect(Collectors.toList()); + if ((noSame.size() - same.size()) > fewShotNumber) {// 去除部分最低分 + noSame.sort(Comparator.comparingDouble(Text2SQLExemplar::getSimilarity)); + noSame = noSame.subList((noSame.size() - fewShotNumber) / 2, noSame.size()); + } + Text2SQLExemplar mostSimilar = noSame.get(noSame.size() - 1); + Collections.shuffle(noSame); + List ts; + if (same.size() > 0) {// 一样的话,必须作为提示语 + ts = new ArrayList<>(); + int needSize = Math.min(noSame.size() + same.size(), fewShotNumber); + if (needSize > same.size()) { + ts.addAll(noSame.subList(0, needSize - same.size())); + } + ts.addAll(same); + } else { // 至少要一个最像的 + ts = noSame.subList(0, Math.min(noSame.size(), fewShotNumber)); + if (!ts.contains(mostSimilar)) { + ts.remove(ts.size() - 1); + ts.add(mostSimilar); + } + } + results.add(ts); } - return results; }