(improvement)(common|headless|chat|auth) 鉴权优化与召回优化

1 修复生成的用户token 一生成就失效的问题
2 如果用户设置的token ,需校验是否数据库存在,因为用户可设置一年的token 有泄露风险
3 结果解析优化, 去除不可以解析的情况,解析问题需要改写后的问,
4 召回样例,用相似度,保住至少有一个样例是高相似度的
5 数据集召回,填加完全匹配格式筛选逻辑
This commit is contained in:
guilinlewis
2025-06-23 09:47:48 +08:00
parent 0721df2e66
commit 7e6639df83
8 changed files with 84 additions and 18 deletions

View File

@@ -222,8 +222,9 @@ public class DefaultUserAdaptor implements UserAdaptor {
new UserWithPassword(userDO.getId(), userDO.getName(), userDO.getDisplayName(), new UserWithPassword(userDO.getId(), userDO.getName(), userDO.getDisplayName(),
userDO.getEmail(), userDO.getPassword(), userDO.getIsAdmin()); userDO.getEmail(), userDO.getPassword(), userDO.getIsAdmin());
// 使用令牌名称作为生成key 这样可以区分正常请求和api 请求api 的令牌失效时间很长,需考虑令牌泄露的情况
String token = String token =
tokenService.generateToken(UserWithPassword.convert(userWithPassword), expireTime); tokenService.generateToken(UserWithPassword.convert(userWithPassword),"SysDbToken:"+name, (new Date().getTime() + expireTime));
UserTokenDO userTokenDO = saveUserToken(name, userName, token, expireTime); UserTokenDO userTokenDO = saveUserToken(name, userName, token, expireTime);
return convertUserToken(userTokenDO); return convertUserToken(userTokenDO);
} }

View File

@@ -6,7 +6,10 @@ import javax.crypto.spec.SecretKeySpec;
import com.tencent.supersonic.auth.api.authentication.config.AuthenticationConfig; import com.tencent.supersonic.auth.api.authentication.config.AuthenticationConfig;
import com.tencent.supersonic.auth.api.authentication.pojo.UserWithPassword; import com.tencent.supersonic.auth.api.authentication.pojo.UserWithPassword;
import com.tencent.supersonic.auth.authentication.persistence.dataobject.UserTokenDO;
import com.tencent.supersonic.auth.authentication.persistence.repository.UserRepository;
import com.tencent.supersonic.common.pojo.exception.AccessException; import com.tencent.supersonic.common.pojo.exception.AccessException;
import com.tencent.supersonic.common.util.ContextUtils;
import io.jsonwebtoken.Claims; import io.jsonwebtoken.Claims;
import io.jsonwebtoken.Jwts; import io.jsonwebtoken.Jwts;
import io.jsonwebtoken.SignatureAlgorithm; import io.jsonwebtoken.SignatureAlgorithm;
@@ -71,6 +74,7 @@ public class TokenService {
return generateToken(UserWithPassword.convert(appUser), request); return generateToken(UserWithPassword.convert(appUser), request);
} }
public Optional<Claims> getClaims(HttpServletRequest request) { public Optional<Claims> getClaims(HttpServletRequest request) {
String token = request.getHeader(authenticationConfig.getTokenHttpHeaderKey()); String token = request.getHeader(authenticationConfig.getTokenHttpHeaderKey());
String appKey = getAppKey(request); String appKey = getAppKey(request);
@@ -90,6 +94,13 @@ public class TokenService {
public Optional<Claims> getClaims(String token, String appKey) { public Optional<Claims> getClaims(String token, String appKey) {
try { try {
if(StringUtils.isNotBlank(appKey)&&appKey.startsWith("SysDbToken:")) {// 如果是配置的长期令牌,需校验数据库是否存在该配置
UserRepository userRepository = ContextUtils.getBean(UserRepository.class);
UserTokenDO dbToken= userRepository.getUserTokenByName(appKey.substring("SysDbToken:".length()));
if(dbToken==null||!dbToken.getToken().equals(token.replace("Bearer ",""))) {
throw new AccessException("Token does not exist :" + appKey);
}
}
String tokenSecret = getTokenSecret(appKey); String tokenSecret = getTokenSecret(appKey);
Claims claims = Claims claims =
Jwts.parser().setSigningKey(tokenSecret.getBytes(StandardCharsets.UTF_8)) Jwts.parser().setSigningKey(tokenSecret.getBytes(StandardCharsets.UTF_8))
@@ -122,6 +133,16 @@ public class TokenService {
Map<String, String> appKeyToSecretMap = authenticationConfig.getAppKeyToSecretMap(); Map<String, String> appKeyToSecretMap = authenticationConfig.getAppKeyToSecretMap();
String secret = appKeyToSecretMap.get(appKey); String secret = appKeyToSecretMap.get(appKey);
if (StringUtils.isBlank(secret)) { if (StringUtils.isBlank(secret)) {
if(StringUtils.isNotBlank(appKey)&&appKey.startsWith("SysDbToken:")) { // 是配置的长期令牌
String realAppKey=appKey.substring("SysDbToken:".length());
String tmp = "WIaO9YRRVt+7QtpPvyWsARFngnEcbaKBk783uGFwMrbJBaochsqCH62L4Kijcb0sZCYoSsiKGV/zPml5MnZ3uQ==";
if(tmp.length()<=realAppKey.length()) {
return realAppKey;
}
else{
return realAppKey+tmp.substring(realAppKey.length());
}
}
throw new AccessException("get secret from appKey failed :" + appKey); throw new AccessException("get secret from appKey failed :" + appKey);
} }
return secret; return secret;

View File

@@ -47,7 +47,8 @@ public class DataInterpretProcessor implements ExecuteResultProcessor {
Agent agent = executeContext.getAgent(); Agent agent = executeContext.getAgent();
ChatApp chatApp = agent.getChatAppConfig().get(APP_KEY); ChatApp chatApp = agent.getChatAppConfig().get(APP_KEY);
return Objects.nonNull(chatApp) && chatApp.isEnable() return Objects.nonNull(chatApp) && chatApp.isEnable()
&& StringUtils.isNotBlank(executeContext.getResponse().getTextResult()); // 如果都没结果,则无法处理,直接跳过 && StringUtils.isNotBlank(executeContext.getResponse().getTextResult()) // 如果都没结果,则无法处理
&& StringUtils.isBlank(executeContext.getResponse().getTextSummary()); // 如果已经有汇总的结果了,无法再次处理
} }
@Override @Override
@@ -57,7 +58,15 @@ public class DataInterpretProcessor implements ExecuteResultProcessor {
ChatApp chatApp = agent.getChatAppConfig().get(APP_KEY); ChatApp chatApp = agent.getChatAppConfig().get(APP_KEY);
Map<String, Object> variable = new HashMap<>(); Map<String, Object> variable = new HashMap<>();
variable.put("question", executeContext.getRequest().getQueryText()); String question = executeContext.getResponse().getTextResult();// 结果解析应该用改写的问题,因为改写的内容信息量更大
if(executeContext.getParseInfo().getProperties()!=null&&
executeContext.getParseInfo().getProperties().containsKey("CONTEXT")){
Map<String,Object> context = (Map<String, Object>) executeContext.getParseInfo().getProperties().get("CONTEXT");
if(context.get("queryText")!=null&&"".equals(context.get("queryText"))){
question = context.get("queryText").toString();
}
}
variable.put("question", question);
variable.put("data", queryResult.getTextResult()); variable.put("data", queryResult.getTextResult());
Prompt prompt = PromptTemplate.from(chatApp.getPrompt()).apply(variable); Prompt prompt = PromptTemplate.from(chatApp.getPrompt()).apply(variable);

View File

@@ -21,7 +21,7 @@ public class LoadRemoveService {
List<String> resultList = new ArrayList<>(value); List<String> resultList = new ArrayList<>(value);
if (!CollectionUtils.isEmpty(modelIdOrDataSetIds)) { if (!CollectionUtils.isEmpty(modelIdOrDataSetIds)) {
resultList.removeIf(nature -> { resultList.removeIf(nature -> {
if (Objects.isNull(nature)) { if (Objects.isNull(nature)||!nature.startsWith("_")) { // 系统的字典是以 _ 开头的, 过滤因引用外部字典导致的异常
return false; return false;
} }
Long id = getId(nature); Long id = getId(nature);

View File

@@ -22,4 +22,6 @@ public class Text2SQLExemplar implements Serializable {
private String dbSchema; private String dbSchema;
private String sql; private String sql;
protected double similarity; // 传递相似度,可以作为样本筛选的依据
} }

View File

@@ -72,7 +72,10 @@ public class ExemplarServiceImpl implements ExemplarService, CommandLineRunner {
embeddingService.retrieveQuery(collection, retrieveQuery, num); embeddingService.retrieveQuery(collection, retrieveQuery, num);
results.forEach(ret -> { results.forEach(ret -> {
ret.getRetrieval().forEach(r -> { ret.getRetrieval().forEach(r -> {
exemplars.add(JsonUtil.mapToObject(r.getMetadata(), Text2SQLExemplar.class)); Text2SQLExemplar tmp = //传递相似度,可以作为样本筛选的依据
JsonUtil.mapToObject(r.getMetadata(), Text2SQLExemplar.class);
tmp.setSimilarity(r.getSimilarity());
exemplars.add(tmp);
}); });
}); });

View File

@@ -18,6 +18,7 @@ import java.util.List;
import java.util.Map; import java.util.Map;
import java.util.Objects; import java.util.Objects;
import java.util.Set; import java.util.Set;
import java.util.stream.Collectors;
import static com.tencent.supersonic.common.pojo.Constants.DEFAULT_DETAIL_LIMIT; import static com.tencent.supersonic.common.pojo.Constants.DEFAULT_DETAIL_LIMIT;
import static com.tencent.supersonic.common.pojo.Constants.DEFAULT_METRIC_LIMIT; import static com.tencent.supersonic.common.pojo.Constants.DEFAULT_METRIC_LIMIT;
@@ -65,12 +66,23 @@ public class SemanticParseInfo implements Serializable {
DataSetMatchResult mr2 = getDataSetMatchResult(o2.getElementMatches()); DataSetMatchResult mr2 = getDataSetMatchResult(o2.getElementMatches());
double difference = mr1.getMaxDatesetSimilarity() - mr2.getMaxDatesetSimilarity(); double difference = mr1.getMaxDatesetSimilarity() - mr2.getMaxDatesetSimilarity();
if (difference == 0) { if (Math.abs(difference) < 0.0005) { // 看完全匹配的个数,实践证明,可以用户输入规范后,该逻辑具有优势
if (!o1.getDataSetId().equals(o2.getDataSetId())) {
List<SchemaElementMatch> elementMatches1 = o1.getElementMatches().stream()
.filter(e -> e.getSimilarity() == 1).collect(Collectors.toList());
List<SchemaElementMatch> elementMatches2 = o2.getElementMatches().stream()
.filter(e -> e.getSimilarity() == 1).collect(Collectors.toList());
if (elementMatches1.size() > elementMatches2.size()) {
return -1;
} else if (elementMatches1.size() < elementMatches2.size()) {
return 1;
}
}
difference = mr1.getMaxMetricSimilarity() - mr2.getMaxMetricSimilarity(); difference = mr1.getMaxMetricSimilarity() - mr2.getMaxMetricSimilarity();
if (difference == 0) { if (Math.abs(difference) < 0.0005) {
difference = mr1.getTotalSimilarity() - mr2.getTotalSimilarity(); difference = mr1.getTotalSimilarity() - mr2.getTotalSimilarity();
} }
if (difference == 0) { if (Math.abs(difference) < 0.0005) {
difference = mr1.getMaxMetricUseCnt() - mr2.getMaxMetricUseCnt(); difference = mr1.getMaxMetricUseCnt() - mr2.getMaxMetricUseCnt();
} }
} }

View File

@@ -14,10 +14,8 @@ import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.stereotype.Component; import org.springframework.stereotype.Component;
import org.springframework.util.CollectionUtils; import org.springframework.util.CollectionUtils;
import java.util.ArrayList; import java.util.*;
import java.util.Collections; import java.util.stream.Collectors;
import java.util.List;
import java.util.Objects;
import static com.tencent.supersonic.headless.chat.parser.ParserConfig.*; import static com.tencent.supersonic.headless.chat.parser.ParserConfig.*;
@@ -51,13 +49,33 @@ public class PromptHelper {
// use random collection of exemplars for each self-consistency inference // use random collection of exemplars for each self-consistency inference
for (int i = 0; i < selfConsistencyNumber; i++) { for (int i = 0; i < selfConsistencyNumber; i++) {
List<Text2SQLExemplar> shuffledList = new ArrayList<>(exemplars); List<Text2SQLExemplar> shuffledList = new ArrayList<>(exemplars);
// only shuffle the exemplars from config List<Text2SQLExemplar> same = shuffledList.stream() // 相似度极高的话,先找出来
List<Text2SQLExemplar> subList = .filter(e -> e.getSimilarity() > 0.989).collect(Collectors.toList());
shuffledList.subList(llmReq.getDynamicExemplars().size(), shuffledList.size()); List<Text2SQLExemplar> noSame = shuffledList.stream()
Collections.shuffle(subList); .filter(e -> e.getSimilarity() <= 0.989).collect(Collectors.toList());
results.add(shuffledList.subList(0, Math.min(shuffledList.size(), fewShotNumber))); if ((noSame.size() - same.size()) > fewShotNumber) {// 去除部分最低分
noSame.sort(Comparator.comparingDouble(Text2SQLExemplar::getSimilarity));
noSame = noSame.subList((noSame.size() - fewShotNumber) / 2, noSame.size());
}
Text2SQLExemplar mostSimilar = noSame.get(noSame.size() - 1);
Collections.shuffle(noSame);
List<Text2SQLExemplar> ts;
if (same.size() > 0) {// 一样的话,必须作为提示语
ts = new ArrayList<>();
int needSize = Math.min(noSame.size() + same.size(), fewShotNumber);
if (needSize > same.size()) {
ts.addAll(noSame.subList(0, needSize - same.size()));
}
ts.addAll(same);
} else { // 至少要一个最像的
ts = noSame.subList(0, Math.min(noSame.size(), fewShotNumber));
if (!ts.contains(mostSimilar)) {
ts.remove(ts.size() - 1);
ts.add(mostSimilar);
}
}
results.add(ts);
} }
return results; return results;
} }