mirror of
https://github.com/tencentmusic/supersonic.git
synced 2025-12-11 12:07:42 +00:00
(improvement)(common|headless|chat|auth) 鉴权优化与召回优化
1 修复生成的用户token 一生成就失效的问题 2 如果用户设置的token ,需校验是否数据库存在,因为用户可设置一年的token 有泄露风险 3 结果解析优化, 去除不可以解析的情况,解析问题需要改写后的问, 4 召回样例,用相似度,保住至少有一个样例是高相似度的 5 数据集召回,填加完全匹配格式筛选逻辑
This commit is contained in:
@@ -19,6 +19,7 @@ import lombok.extern.slf4j.Slf4j;
|
|||||||
import org.springframework.beans.BeanUtils;
|
import org.springframework.beans.BeanUtils;
|
||||||
|
|
||||||
import java.sql.Timestamp;
|
import java.sql.Timestamp;
|
||||||
|
import java.util.Date;
|
||||||
import java.util.List;
|
import java.util.List;
|
||||||
import java.util.Optional;
|
import java.util.Optional;
|
||||||
import java.util.Set;
|
import java.util.Set;
|
||||||
@@ -223,8 +224,8 @@ public class DefaultUserAdaptor implements UserAdaptor {
|
|||||||
userDO.getEmail(), userDO.getPassword(), userDO.getIsAdmin());
|
userDO.getEmail(), userDO.getPassword(), userDO.getIsAdmin());
|
||||||
|
|
||||||
// 使用令牌名称作为生成key ,这样可以区分正常请求和api 请求,api 的令牌失效时间很长,需考虑令牌泄露的情况
|
// 使用令牌名称作为生成key ,这样可以区分正常请求和api 请求,api 的令牌失效时间很长,需考虑令牌泄露的情况
|
||||||
String token =
|
String token = tokenService.generateToken(UserWithPassword.convert(userWithPassword),
|
||||||
tokenService.generateToken(UserWithPassword.convert(userWithPassword),"SysDbToken:"+name, (new Date().getTime() + expireTime));
|
"SysDbToken:" + name, (new Date().getTime() + expireTime));
|
||||||
UserTokenDO userTokenDO = saveUserToken(name, userName, token, expireTime);
|
UserTokenDO userTokenDO = saveUserToken(name, userName, token, expireTime);
|
||||||
return convertUserToken(userTokenDO);
|
return convertUserToken(userTokenDO);
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -21,6 +21,8 @@ public interface UserRepository {
|
|||||||
|
|
||||||
UserTokenDO getUserToken(Long tokenId);
|
UserTokenDO getUserToken(Long tokenId);
|
||||||
|
|
||||||
|
UserTokenDO getUserTokenByName(String tokenName);
|
||||||
|
|
||||||
void deleteUserTokenByName(String userName);
|
void deleteUserTokenByName(String userName);
|
||||||
|
|
||||||
void deleteUserToken(Long tokenId);
|
void deleteUserToken(Long tokenId);
|
||||||
|
|||||||
@@ -65,6 +65,13 @@ public class UserRepositoryImpl implements UserRepository {
|
|||||||
return userTokenDOMapper.selectById(tokenId);
|
return userTokenDOMapper.selectById(tokenId);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public UserTokenDO getUserTokenByName(String tokenName) {
|
||||||
|
QueryWrapper<UserTokenDO> queryWrapper = new QueryWrapper<>();
|
||||||
|
queryWrapper.lambda().eq(UserTokenDO::getName, tokenName);
|
||||||
|
return userTokenDOMapper.selectOne(queryWrapper);
|
||||||
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public void deleteUserTokenByName(String userName) {
|
public void deleteUserTokenByName(String userName) {
|
||||||
QueryWrapper<UserTokenDO> queryWrapper = new QueryWrapper<>();
|
QueryWrapper<UserTokenDO> queryWrapper = new QueryWrapper<>();
|
||||||
|
|||||||
@@ -94,10 +94,11 @@ public class TokenService {
|
|||||||
|
|
||||||
public Optional<Claims> getClaims(String token, String appKey) {
|
public Optional<Claims> getClaims(String token, String appKey) {
|
||||||
try {
|
try {
|
||||||
if(StringUtils.isNotBlank(appKey)&&appKey.startsWith("SysDbToken:")) {// 如果是配置的长期令牌,需校验数据库是否存在该配置
|
if (StringUtils.isNotBlank(appKey) && appKey.startsWith("SysDbToken:")) {// 如果是配置的长期令牌,需校验数据库是否存在该配置
|
||||||
UserRepository userRepository = ContextUtils.getBean(UserRepository.class);
|
UserRepository userRepository = ContextUtils.getBean(UserRepository.class);
|
||||||
UserTokenDO dbToken= userRepository.getUserTokenByName(appKey.substring("SysDbToken:".length()));
|
UserTokenDO dbToken =
|
||||||
if(dbToken==null||!dbToken.getToken().equals(token.replace("Bearer ",""))) {
|
userRepository.getUserTokenByName(appKey.substring("SysDbToken:".length()));
|
||||||
|
if (dbToken == null || !dbToken.getToken().equals(token.replace("Bearer ", ""))) {
|
||||||
throw new AccessException("Token does not exist :" + appKey);
|
throw new AccessException("Token does not exist :" + appKey);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -133,14 +134,14 @@ public class TokenService {
|
|||||||
Map<String, String> appKeyToSecretMap = authenticationConfig.getAppKeyToSecretMap();
|
Map<String, String> appKeyToSecretMap = authenticationConfig.getAppKeyToSecretMap();
|
||||||
String secret = appKeyToSecretMap.get(appKey);
|
String secret = appKeyToSecretMap.get(appKey);
|
||||||
if (StringUtils.isBlank(secret)) {
|
if (StringUtils.isBlank(secret)) {
|
||||||
if(StringUtils.isNotBlank(appKey)&&appKey.startsWith("SysDbToken:")) { // 是配置的长期令牌
|
if (StringUtils.isNotBlank(appKey) && appKey.startsWith("SysDbToken:")) { // 是配置的长期令牌
|
||||||
String realAppKey=appKey.substring("SysDbToken:".length());
|
String realAppKey = appKey.substring("SysDbToken:".length());
|
||||||
String tmp = "WIaO9YRRVt+7QtpPvyWsARFngnEcbaKBk783uGFwMrbJBaochsqCH62L4Kijcb0sZCYoSsiKGV/zPml5MnZ3uQ==";
|
String tmp =
|
||||||
if(tmp.length()<=realAppKey.length()) {
|
"WIaO9YRRVt+7QtpPvyWsARFngnEcbaKBk783uGFwMrbJBaochsqCH62L4Kijcb0sZCYoSsiKGV/zPml5MnZ3uQ==";
|
||||||
|
if (tmp.length() <= realAppKey.length()) {
|
||||||
return realAppKey;
|
return realAppKey;
|
||||||
}
|
} else {
|
||||||
else{
|
return realAppKey + tmp.substring(realAppKey.length());
|
||||||
return realAppKey+tmp.substring(realAppKey.length());
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
throw new AccessException("get secret from appKey failed :" + appKey);
|
throw new AccessException("get secret from appKey failed :" + appKey);
|
||||||
|
|||||||
@@ -47,8 +47,8 @@ public class DataInterpretProcessor implements ExecuteResultProcessor {
|
|||||||
Agent agent = executeContext.getAgent();
|
Agent agent = executeContext.getAgent();
|
||||||
ChatApp chatApp = agent.getChatAppConfig().get(APP_KEY);
|
ChatApp chatApp = agent.getChatAppConfig().get(APP_KEY);
|
||||||
return Objects.nonNull(chatApp) && chatApp.isEnable()
|
return Objects.nonNull(chatApp) && chatApp.isEnable()
|
||||||
&& StringUtils.isNotBlank(executeContext.getResponse().getTextResult()) // 如果都没结果,则无法处理
|
&& StringUtils.isNotBlank(executeContext.getResponse().getTextResult()) // 如果都没结果,则无法处理
|
||||||
&& StringUtils.isBlank(executeContext.getResponse().getTextSummary()); // 如果已经有汇总的结果了,无法再次处理
|
&& StringUtils.isBlank(executeContext.getResponse().getTextSummary()); // 如果已经有汇总的结果了,无法再次处理
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
@@ -59,10 +59,11 @@ public class DataInterpretProcessor implements ExecuteResultProcessor {
|
|||||||
|
|
||||||
Map<String, Object> variable = new HashMap<>();
|
Map<String, Object> variable = new HashMap<>();
|
||||||
String question = executeContext.getResponse().getTextResult();// 结果解析应该用改写的问题,因为改写的内容信息量更大
|
String question = executeContext.getResponse().getTextResult();// 结果解析应该用改写的问题,因为改写的内容信息量更大
|
||||||
if(executeContext.getParseInfo().getProperties()!=null&&
|
if (executeContext.getParseInfo().getProperties() != null
|
||||||
executeContext.getParseInfo().getProperties().containsKey("CONTEXT")){
|
&& executeContext.getParseInfo().getProperties().containsKey("CONTEXT")) {
|
||||||
Map<String,Object> context = (Map<String, Object>) executeContext.getParseInfo().getProperties().get("CONTEXT");
|
Map<String, Object> context = (Map<String, Object>) executeContext.getParseInfo()
|
||||||
if(context.get("queryText")!=null&&"".equals(context.get("queryText"))){
|
.getProperties().get("CONTEXT");
|
||||||
|
if (context.get("queryText") != null && "".equals(context.get("queryText"))) {
|
||||||
question = context.get("queryText").toString();
|
question = context.get("queryText").toString();
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -21,7 +21,8 @@ public class LoadRemoveService {
|
|||||||
List<String> resultList = new ArrayList<>(value);
|
List<String> resultList = new ArrayList<>(value);
|
||||||
if (!CollectionUtils.isEmpty(modelIdOrDataSetIds)) {
|
if (!CollectionUtils.isEmpty(modelIdOrDataSetIds)) {
|
||||||
resultList.removeIf(nature -> {
|
resultList.removeIf(nature -> {
|
||||||
if (Objects.isNull(nature)||!nature.startsWith("_")) { // 系统的字典是以 _ 开头的, 过滤因引用外部字典导致的异常
|
if (Objects.isNull(nature) || !nature.startsWith("_")) { // 系统的字典是以 _ 开头的,
|
||||||
|
// 过滤因引用外部字典导致的异常
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
Long id = getId(nature);
|
Long id = getId(nature);
|
||||||
|
|||||||
@@ -23,5 +23,5 @@ public class Text2SQLExemplar implements Serializable {
|
|||||||
|
|
||||||
private String sql;
|
private String sql;
|
||||||
|
|
||||||
protected double similarity; // 传递相似度,可以作为样本筛选的依据
|
protected double similarity; // 传递相似度,可以作为样本筛选的依据
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -72,7 +72,7 @@ public class ExemplarServiceImpl implements ExemplarService, CommandLineRunner {
|
|||||||
embeddingService.retrieveQuery(collection, retrieveQuery, num);
|
embeddingService.retrieveQuery(collection, retrieveQuery, num);
|
||||||
results.forEach(ret -> {
|
results.forEach(ret -> {
|
||||||
ret.getRetrieval().forEach(r -> {
|
ret.getRetrieval().forEach(r -> {
|
||||||
Text2SQLExemplar tmp = //传递相似度,可以作为样本筛选的依据
|
Text2SQLExemplar tmp = // 传递相似度,可以作为样本筛选的依据
|
||||||
JsonUtil.mapToObject(r.getMetadata(), Text2SQLExemplar.class);
|
JsonUtil.mapToObject(r.getMetadata(), Text2SQLExemplar.class);
|
||||||
tmp.setSimilarity(r.getSimilarity());
|
tmp.setSimilarity(r.getSimilarity());
|
||||||
exemplars.add(tmp);
|
exemplars.add(tmp);
|
||||||
|
|||||||
@@ -66,7 +66,7 @@ public class SemanticParseInfo implements Serializable {
|
|||||||
DataSetMatchResult mr2 = getDataSetMatchResult(o2.getElementMatches());
|
DataSetMatchResult mr2 = getDataSetMatchResult(o2.getElementMatches());
|
||||||
|
|
||||||
double difference = mr1.getMaxDatesetSimilarity() - mr2.getMaxDatesetSimilarity();
|
double difference = mr1.getMaxDatesetSimilarity() - mr2.getMaxDatesetSimilarity();
|
||||||
if (Math.abs(difference) < 0.0005) { // 看完全匹配的个数,实践证明,可以用户输入规范后,该逻辑具有优势
|
if (Math.abs(difference) < 0.0005) { // 看完全匹配的个数,实践证明,可以用户输入规范后,该逻辑具有优势
|
||||||
if (!o1.getDataSetId().equals(o2.getDataSetId())) {
|
if (!o1.getDataSetId().equals(o2.getDataSetId())) {
|
||||||
List<SchemaElementMatch> elementMatches1 = o1.getElementMatches().stream()
|
List<SchemaElementMatch> elementMatches1 = o1.getElementMatches().stream()
|
||||||
.filter(e -> e.getSimilarity() == 1).collect(Collectors.toList());
|
.filter(e -> e.getSimilarity() == 1).collect(Collectors.toList());
|
||||||
|
|||||||
@@ -49,7 +49,7 @@ public class PromptHelper {
|
|||||||
// use random collection of exemplars for each self-consistency inference
|
// use random collection of exemplars for each self-consistency inference
|
||||||
for (int i = 0; i < selfConsistencyNumber; i++) {
|
for (int i = 0; i < selfConsistencyNumber; i++) {
|
||||||
List<Text2SQLExemplar> shuffledList = new ArrayList<>(exemplars);
|
List<Text2SQLExemplar> shuffledList = new ArrayList<>(exemplars);
|
||||||
List<Text2SQLExemplar> same = shuffledList.stream() // 相似度极高的话,先找出来
|
List<Text2SQLExemplar> same = shuffledList.stream() // 相似度极高的话,先找出来
|
||||||
.filter(e -> e.getSimilarity() > 0.989).collect(Collectors.toList());
|
.filter(e -> e.getSimilarity() > 0.989).collect(Collectors.toList());
|
||||||
List<Text2SQLExemplar> noSame = shuffledList.stream()
|
List<Text2SQLExemplar> noSame = shuffledList.stream()
|
||||||
.filter(e -> e.getSimilarity() <= 0.989).collect(Collectors.toList());
|
.filter(e -> e.getSimilarity() <= 0.989).collect(Collectors.toList());
|
||||||
|
|||||||
Reference in New Issue
Block a user